ROCm in Practice: of Convolutions and Feedforwards
My ramblings from training several types of models on AMD hardware w/PyTorch ROCm
Introduction
So, somehow, I managed to subsist my machine learning work entirely on AMD hardware for… about two months, something I previously thought impossible. While I’m far from done with my thing, I thought I’d drop some alpha here for those still intimidated by it.
As the old saying goes, nobody ever got fired for choosing IBM. NVIDIA is still the de-facto provider, and they damn well deserve their status: their focus on GPGPU started well before AI, with PhysX (which was conceptually cool but flopped because it was too niche), being the first to offer GPU compute in Blender Cycles, et al, with AMD lagging behind. Then machine learning came along and they bet on that too, and that bet paid off dividends.
But I think that the latest technological revolution — which is estimated to hit $4.8 trillion in value by 20331 — shouldn’t be monopolized by one company. If only one person can sell shovels, who really controls all the gold?
Journey
I started off with a friend’s Radeon 7900 XTX, the most direct competitor to NVIDIA’s RTX 4090. The former has an estimated BF16 TFLOP count of 122, while the latter is 165.2 TFLOPS2, which means that theoretically, the AMD GPU should perform at about 73% of the NVIDIA’s peak performance.
But will the software stack be up to the task? For some historical context, consult this old Blender benchmark (emphasis mine)
As noted, while the NVIDIA OptiX Cycles back-end is the fastest for NVIDIA RTX GPUs, even the NVIDIA CUDA back-end with these current-generation GPUs still outperforms the AMD Radeon RX 6000 series with the current HIP back-end. Even using a GeForce RTX 3060 Ti was faster than the RX 6800 XT with Blender's well known "BMW" scene.
Why were NVIDIA’s lower tier GPUs beating AMD’s more expensive chips? Because the OptiX backend was using the green team’s dedicated raytracing cores, while the red one was still based on HIP alone, which only utilized general compute cores; for that you needed HIP-RT.
I’m bringing this up to show you that AMD’s software quirks — specifically in GPGPU — did not start with AI, they were a consistent pattern for quite a while now.
Do This, not That
So, with that out of the way, let me tell you what NOT to do and why. For context: Ubuntu 24.04, ROCm 6.3, PyTorch 2.x
Do NOT try to set up a full ROCm + Python + ML libs training environment by yourself, use the docker images. You can get away with that on CUDA because it’s such a mature tech it tolerates slightly messy setups, but ROCm is sensitive to version mismatches.
I tried that at first because I thought I was slick: what I got was segfaults with torch 2.6.0 and extremely slow training until I downgraded to 2.4, and even then it was kind of weird.
So, I went and tried one of the official Docker images, as recommended.
The good thing? After that, training with PyTorch was mostly — no, entirely — plug and play! Torch on ROCm does not have a separate .rocm()
device, it just replaces .cuda()
and associated functionality, so you can run code without changing anything. Having heard mostly doom stories of ROCm, I was pleasantly surprised; rock solid stable too!
But while running my first model which is a modified FastSpeech2 containing both transformers and lots of convs, I noticed extremely slow training speed. After yapping on Xitter about it, I brought out the Torch profiler and a sample training script to see what was happening under the hood.
See the red? I was hitting naive ops for some reason, which were the source of the astronomical slowdown. For some more context, my model uses variable sequence lengths with zero padding, so the shapes differ from batch to batch.
Then I remembered that there is a tuning option with PyTorch that tries to find the best operations under the hood, which you’re supposed to turn on if your sequence lengths are static3.
But ROCm has some sort of quirk that causes — well… what I just showed you, so I set torch.backends.cudnn.benchmark = False
right after importing torch, and it went away!
Even after that, the training was like, 50% of the speed I remember it being with an RTX 4090, which I dismissed at the time as general ROCm inefficiency… did I forget to tell you my model also has a GRU (type of RNN) in it? That’ll be important later on.
Transformer time!
For reasons beyond the scope of this post, I decided to roll my own encoder-decoder model for TTS, based on non-autoregressive encoder and autoregressive decoder, both pure Transformer models.
When I started training that one, I noticed the speed was much closer — dare I say, maybe on par — with an equivalent NVIDIA GPU. Since transformers are pretty much what everyone these days is doing, I imagine the ROCm team optimized mostly for what I call effective transformer-flops.
I tried Flash Attention, the Triton kernel worked without issue but I only had some VRAM savings, so I didn’t find it worth it (note that the scale I’m working with here is below 100M).
RNN… time?
I tried replacing my decoder with Tacotron2-style Location Sensitive Attention, which turns it from a Transformer to an LSTM, because I wanted to have good model without having to pretrain it on a massive dataset at scale.
The training slowed down to a crawl. It took 3 hours and 48 minutes for the LSTM decoder model (1x 1024 dim LSTM) to do 10k steps. The pure transformer model (6x 256 dim with 4x FFN expansion) could do that in 35 minutes.
Now, RNNs are sequential4 so normally they’re slower to train than Transformers, but not 6.5x! I remembered training Tacotron 2 on NVIDIA hardware and I knew that wasn’t normal.
Once again, why? Here’s the full Xitter convo about it; but in a nutshell, CUDA PyTorch ships with a dedicated CuDNN kernel for RNNs, while ROCm does everything naively. This means that my 3080 Ti at home, a last-last gen gaming GPU from NVIDIA, can beat AMD’s current-gen datacenter offering on training LSTMs, until the MI300X’s superior memory bandwidth kicks in at a certain scale5 — almost a repeat of the previous Blender benchmark situation.
Journey Arc 2 - MI300X
I found this company named Hot Aisle and saw that they offered compute, so I decided to email since I already had ROCm experience and wanted an upgrade.
I was hooked up with a single MI300X on a VM. Setting up the environment there was easy. My environment of choice is Docker + JupyterLab + PyTorch
After acquiring the docker image of my choice, I ran this command in a screen:
sudo docker run -it \
--network=host \
--shm-size=8G \
--device=/dev/kfd \
--device=/dev/dri \
--group-add=video \
-v $HOME/notebooks:/workspace \
rocm/pytorch:rocm6.4_ubuntu22.04_py3.10_pytorch_release_2.6.0 \
bash
Notice the --network=host
; this makes the container share the host’s network stack. I noticed that the download speeds were very low otherwise.
Much like the 7900XTX, the ROCm 6.4.0 image with PyTorch 2.6.0 was plug-and-play. I tested Flex Attention with torch.compile
and dynamic=True
and it worked very nicely. Unlike most my attention uses ALiBi which I personally prefer because it’s conceptually simple and allows extrapolation without tricks and hacks, and because I observed a slight performance increase even over RoPE + CoCA (my previous attention go-to), so FlexAttention is very nice to have.
Now surprisingly there’s not much for me to say here, for a second, I thought that maybe the bad RNN performance was a consumer GPU/old ROCm-exclusive, but I tried it on this one and it still didn’t do well.
Oh yeah, I forgot to say I tried RWKV training with the included Triton kernels. Another thing that surprised me was that all the Triton kernels I tried, much like PyTorch, were very plug-and-play, no complaints on that front. Nice one, AMD.
GAN time!
Because I didn’t like DAC’s 8x1024 codebook, I made a mel spectrogram quantizer.
It can compress 88-channel spectrograms to a single codebook of 1000 indices per frame with finite scalar quantization. The generator is a convolutional ResNet with 3 layers for both encoder and decoder, and the discriminator is a 1D PatchGAN; both have 12M parameters.
No complaints on training convnets and GANs on ROCm, although they do feel a bit slower than should be?
So, with the MI300X I trained an 85M transformer to do the TTS task, predicting Mel indices from phonemes. That model itself didn’t turn out very good because LibriTTS-R has a ton of long silences and the model, being small and autoregressive, overfits to those, and after filtering it was only 461 hours of data. Currently scaling up.
Although I will note that the VRAM on the MI300X (192GB vs 80GB for H100) is also a very nice to have for fitting large models and batch sizes.
Conclusion
I went in with low expectations but found out that AMD GPUs are definitely serviceable not only for inference (as others have proved), but also training and original research, albeit with a few quirks to watch out for.
My final ratings:
Transformer training? ✅ Great
PyTorch support? ✅ Great
ConvNet & GAN training? ✔ Good
Triton support? ✅ Great
RNN training? ❌ Unusable outside of toys
SSM training (RWKV, Mamba et al.)? ✅ Great
Inference? I didn’t test.
Now there are still kinks to iron out, but AMD is definitely on the path to becoming a serious contender to NVIDIA in the field of AI — which is good, because gold and shovels.
Hopefully one day I’ll have something more than just a blog post to show for all my effort. I hate being unemployed.
I don’t consider RWKV, Mamba class of models RNNs as they don’t have sequential training and BPTT, they’re autoregressive stateful convolutions
Great post! Just some clarification - /dev/dri is actually the GPU. In Linux, everything is a file, even hardware. DRI = “direct render interface”. —network=host is what configured the container to use the host network namespace.