For past two days I ended up staring at several GPU memory profile plots. There is not a lot of info available on what to make out of GPU memory profiles so I wrote a script to do it in one line and wrote a blog post. So what all those strange shapes mean? Quick tutorial🧵
First, to profile any block of code, just drop this Python file in your project and use code like below to get profiling data: github.com/sytelus/nanugp…
You will get visualization like below. You can immediately see 3 repeating pattern for 3 training steps. That's good start. Those stripes indicate mem ops happening layer by layer. The constant band below is model params + optimizer states. But what are all those things?
Let's get first training step and break it down. It starts with half triangle which is forward pass. That ends with logits creation which is giant red parallelogram. The sharp spike at the start is temp mem for softmax. The reason logits are huge is because vocab is 151k.
Next we call zero_grad() which clears old .grad allocations. At the end of zero_grad() there is a little spike and then a huge cliff. The spike is when .backward() begins at softmax and cliff is because logits tensor is no longer needed. This produces LM head grads (green band).
In above picture notice the wave like pattern on the right. This is because I've enabled activation checkpointing. This causes each layer to recompute activations causing little spikes and also its input tensor getting freed as .backward() progresses layer by layer.
Finally, my favorite part: optimizer.step(). This is just awesomely beautiful like a temple with a flagpole. It sits on two foundations. The orange stripe is actually embedding layer grad produced by the end of .backward().
In above picture, the green stripe is storage needed to hold sqrt of grad^2 for Adam. Fused Adam computes updates to apply layer by layer causing stripes. At the top it applies those updates in one shot. The spike is temp buffer needed for division before applying update.
There is lot more to say on these beautiful visualizations. Please let me know of any errors or other cool things we should know. More at my blog: How to Get and Interpret GPU Memory Profiling shital.com/blog/gpu-memor…





