Graph Surgeon

Author: Ibrahim El Kaddouri

Repository: The repository is still private, under construction

Abstract

This project presents Graph Surgeon, a custom graph-level compiler optimisation pipeline for transformer-based language models, built on top of Apache TVM’s Relax intermediate representation. An Intermediate Representation (IR) is like a translation layer. When you compile C++ to machine code, the compiler first converts your code into an intermediate form that is easier to analyse and simplify, but not yet CPU instructions. TVM has two IRs: Relax (graph-level) and TensorIR (loop-level). More below.

The goal of this project is to see whether targeted graph-level transformations, specifically QKV attention-head fusion applied to ONNXOpen Neural Network Exchange is a standard file format for neural networks. A universal format that different frameworks (PyTorch, TensorFlow, etc.) can all export to and import from. computation graphs, can produce measurable and explainable performance improvements.

This project demonstrates that QKV fusion delivers an 11.8% speedup on CPU for DistilBERT. On GPU, the same pass has no effect or a slight regression. DLightDLight is TVM’s rule-based TIR scheduling system, a set of hand-crafted rules for how to lay out loops, tiles and threads on the GPU. “Rule-based” means it applies fixed strategies rather than searching for the best one. rule-based TIR scheduling proved either neutral or harmful across all tested models.

Takeaway A compiler pass is an optimization in the context of a specific hardware target, a specific numerical precision and a specific IR structure. The same pass can be a meaningful win on one target and a regression on another. The hardware matters as much as the transformation.

1. Introduction

1.1 The Performance Problem in Modern LLMs

Transformer-based language models dominate NLP. Models like GPT-2, BERT and their distilled variants are deployed across text generation, question answering, code completion and more. Their computational cost is substantial, even DistilGPT-2, with 82 million parameters, requires hundreds of millions of floating-point operations per inference pass. At scale, this translates into energy, hardware cost and latency.

The conventional response is hardware, more and bigger GPUs, specialised accelerators, aggressive batching. But there is a complementary and often underexploited lever, the compiler. A sufficiently intelligent compiler, given the computation graph A computation graph is a diagram of every operation a neural network performs, where each node is one operation (a matrix multiply, an addition) and each edge is the tensor flowing between them. of a model, can eliminate redundant operations, fuse adjacent kernels, rewrite memory access patterns and restructure the graph in ways invisible to the model author, but which can dramatically reduce execution time.

Analogy: Why Compilers Matter Think of following a recipe for a meal. You might see "chop 3 onions", "chop 3 garlic cloves" and "chop 3 peppers" as three separate steps somewhere in the recipe. A smart cookobviously you... might notice that all three happen at the cutting board. You then set up the cutting board once and chop everything in one go, instead of setting up a new cutting board for each step in the recipe. The results are identical, just with less wasted time. That is what a compiler pass does to a neural network's computation graph.

1.2 Why Compiler Optimisations, Not Model Compression?

There is a rich literature on making models smaller and faster through model-level techniques: pruning Removing weights (or entire neurons) from a neural network that contribute little to the output, like removing irrelevant words from a sentence. A pruned model has fewer parameters and runs faster, but may lose some accuracy., quantisationReplacing 32-bit floating-point weights (float32) with lower-precision representations (float16, int8). This reduces memory and increases throughput, at the cost of some numerical precision., knowledge distillationTraining a smaller “student” model to mimic the behaviour of a larger “teacher” model. DistilBERT and DistilGPT-2 (used in this project) were created this way., and low-rank factorisation. These are valid, but they change the model.

Compiler optimisation is semantics-preserving. The output of the optimised model is bit-for-bit identical to the baseline (for structural passes) or numerically equivalent within a tolerance (for approximation passes). A compiler pass can be applied to any model without retraining, and it is orthogonal to compression, you can compress and compiler-optimise the same model.

Key distinction Compression changes what the model is. Compilation changes how it runs.

2. Background

2.1 The Transformer Architecture

The transformer was introduced by Vaswani et al. (2017)[1] and the key innovation was the self-attention A mechanism that lets each token look at every other token in the input and decide how much to weight each one when building its own representation. See Jurafsky & Martin (2026), §8.1. mechanism. We’ll go step by step, following the ONNX graph of a DistilBERT transformer model, in order to explain how it works in detail.

Inputs
The model takes two integer tensors of shape [1, seq_len]. The first tensor, input_ids, is a sequence of vocabulary indices. Each integer maps to one token in the model’s vocabulary. The second tensor, attention_mask, is a binary tensor where 1 marks a real token and 0 marks padding.

Figure 1


Token embedding lookup
The model holds a learned weight matrix wte of shape [30522, 768], one 768-dimensional float vector per vocabulary token. A Gather op uses input_ids as row indices and retrieves the corresponding vectors, producing a float tensor of shape [1, seq_len, 768]. These vectors encode semantic meaning. They were trained to place conceptually similar tokens near each other in that 768-dimensional space.[12]

Figure 2


Position embedding lookup
Because transformers process all tokens simultaneously and have no built-in notion of order, position must be injected explicitly.[12] A second learned matrix, wpe, of shape [512, 768] DistilBERT’s context window is 512 tokens. This is a hard architectural limit: inputs longer than 512 tokens must be truncated or chunked before they reach the model, and that’s why seq_len ≤ 512. , stores one vector per possible position (0 through 511). A Gather op with an indices attribute specifying the index sequence [0, 1, 2, ..., seq_len−1] retrieves the corresponding rows, giving another [1, seq_len, 768] tensor.[13]

Figure 3


Embedding addition
The token and position tensors are added element-wise to form the initial hidden state h₀ of shape [1, seq_len, 768]. Each position in this tensor blends what the token means with where it sits in the sequence.[12] This is the input to the first transformer block.

Figure 4


Six transformer blocks
The following structure runs six times with separate learned weights at each layer.

Figure 5


Self Attention Mechanism
We will now describe the workings of one of those 6 blocks. As we will see, DistilBERT uses post-norm ordering The Jurafsky & Martin textbook (2026, §8.2, Fig. 8.7) describes the more common modern pre-norm architecture (LayerNorm before attention/FFN). DistilBERT instead follows the original Vaswani et al. (2017) post-norm design (LayerNorm after the residual addition), which was also used in BERT. .

Q, K, V projections multiply the normalised hidden state against three separate weight matrices, each [768, 768], to produce Query, Key and Value tensors These correspond to the three roles described in Jurafsky & Martin (2026, §8.1.1): the query is the current element being compared, the key is a token being compared against, and the value is the content that gets weighted and summed. Each is computed as q = xWQ, k = xWK, v = xWV (Eq. 8.9). .

Each result is reshaped from [1, seq_len, 768] to [1, seq_len, 12, 64], splitting the 768 dimensions across 12 attention heads of 64 dimensions each, then transposed to [1, 12, seq_len, 64] so each head’s data is contiguous in memory.

Figure 6


Scaled dot-product attention runs independently per head. Q · Kᵀ produces a [seq_len, seq_len] score matrix This implements Eq. 8.11 from Jurafsky & Martin (2026, §8.1.1): score(xi, xj) = (qi · kj) / √dk. The full parallel form is QKT/√dk (Eq. 8.33). measuring how much each token attends to every other. Dividing by √64 = 8 prevents the scores from growing too large.

Afterwards, the attention_mask is applied. Positions where the mask is 0 (i.e. padding tokens) receive a large negative value so that softmax drives those weights to zero. Because DistilBERT is a bidirectional encoder Unlike the causal (left-to-right) decoder described in Jurafsky & Martin (2026, §8.1), DistilBERT is an encoder model in the BERT family. Encoders let every token attend to every other token in the sequence, not just to preceding tokens. Causal masking of future positions (§8.3, Fig. 8.10) is therefore not used here. Jurafsky & Martin introduce bidirectional encoder models in Chapter 9. , every non-padding token can attend to every other non-padding token, there is no causal mask blocking future positions. Softmax is applied row-wise to normalise the scores into weights. Finally, multiplying by V yields a weighted sum of value vectors per position with output shape [1, 12, seq_len, 64].

Figure 7


Output projection, residual, and Layer Norm 1. The 12 head outputs are concatenated back to [1, seq_len, 768] via reshape and transpose, then passed through a linear layer of shape [768, 768] that mixes information across heads.[12] A residual connection The residual connection provides a skip path for gradients and means the attention sub-layer only needs to learn a correction to the input, not a full new representation. In the residual stream view (Elhage et al., 2021; Jurafsky & Martin, 2026, §8.2), information flows up through the stream and each component adds to it. adds the hidden state from before the Q/K/V projections back to the attention output. A normalisation step is then applied to the result. The formula is \(\text{LN}(x) = \gamma \cdot \frac{x - \mu}{\sigma + \varepsilon} + \beta\) where γ and β are learned parameters of size 768.[12]

Figure 8


FFN, residual, and Layer Norm 2. A two-layer feed-forward network This is the position-wise FFN described in Jurafsky & Martin (2026, §8.2, Eq. 8.21). It applies the same weights independently to each token position. The textbook notes that it is common for the FFN hidden dimension to be larger than the model dimension (e.g. d=512, dff=2048 in the original transformer; here d=768, dff=3072, the same 4× factor). takes the post-LN1 state as input. The first linear layer expands from 768 to 3072 dimensions (a 4× factor), a GELU Gaussian Error Linear Unit: a smooth, non-monotonic activation function defined as GELU(x) = x · Φ(x), where Φ(x) is the standard normal CDF. Unlike ReLU, which hard-gates at zero, GELU weights each input by the probability that it is positive under a Gaussian, producing a soft gating effect. It is the default activation in BERT, GPT-2, and most modern transformers (Hendrycks & Gimpel, 2016). non-linearity is applied element-wise, and the second layer contracts back to 768. A second residual connection adds the post-LN1 state to the FFN output, and a second layer normalisation normalises the sum. This produces the final output of the block.

Figure 9


Output
The ONNX graph ends at the last hidden state of shape [1, seq_len, 768]. At inference time, the hidden state goes through a masked LM head Unlike the autoregressive (left-to-right) language modelling described in Jurafsky & Martin (2026, §8.5), DistilBERT uses masked language modelling (MLM): some input tokens are replaced with a special [MASK] token, and the model predicts what word belongs in each masked position. The textbook introduces this approach in Chapter 9 on BERT and bidirectional encoders. and softmax to obtain a probability distribution over all 30 522 tokens. For masked language modelling, the predictions at [MASK] positions are read off to recover the hidden words. For downstream tasks such as classification, only the representation at a special [CLS] token position is used.

Figure 11


What the compiler sees The full graph for one forward pass contains hundreds of nodes. The dominant cost is the 6 × 3 = 18 QKV projection MatMuls and the 6 × 2 = 12 FFN MatMuls. Everything else, the Gathers, LayerNorms, Softmaxes, reshapes, is comparatively cheap. The compiler's job is to reduce the cost of those expensive nodes without changing the output.

2.2 What is Apache TVM?

Apache TVM is a machine learning compiler framework. It takes a neural network (from any training framework) and produces fast, deployable code for any hardware target.

What is an IRModule?

The official docs say an IRModule “encompasses the entirety of the ML models, incorporating the computational graph, tensor programs, and potential calls to external libraries.” That’s… not very helpful.

When a compiler (any compiler, C++, Java, ...) transforms code, it can't jump directly from "human-readable source" to "machine code" in one step. There's too much to do. Instead, it converts through one or more Intermediate Representations (IRs): simplified, structured forms of the program that are easier to analyse and transform.

A classic example: GCC compiles C code to GIMPLE (a simplified form of C), then to RTL (Register Transfer Language, closer to assembly), then to machine code. Each IR is progressively lower-level.

TVM has two IRs:

  • Relax, high-level, describes the computation graph (which operations, in what order, on what-shaped tensors)
  • TensorIR (TIR), low-level, describes exactly how to implement each operation (loops, memory buffers, thread assignment)

An IRModule is simply a container, like a Python module (a .py file), that holds all the functions needed to run a model, written in TVM’s IRs. That’s it. The IRModule is a data structure that describes the model. Passes transform this data structure. When you call relax.build(), TVM reads the IRModule and generates actual executable code (LLVM bitcode or CUDA PTX).

Grand Scheme Of Things

ONNX import: Gemm node → R.matmul
A Relax op, high-level, no implementation yet

After LegalizeOps: R.matmulT.prim_func
A TIR loop nest, now you have an actual implementation

After codegen: TIR → LLVM IR → native x86/CUDA PTX

Graph Surgeon's passes happen at the first stage: they restructure the Relax graph before any TIR exists. This is why they're called "graph-level" passes.

2.3 Relax: The Graph Level

Relax is TVM’s high-level IR for representing the computation graph of a model. It is a functional, dataflow-oriented language. Here’s what a simple two-layer MLP looks like in Relax:

Relax

@R.function
def main(
    x:      R.Tensor((1, 784),   dtype="float32"),
    weight: R.Tensor((784, 256), dtype="float32"),
    bias:   R.Tensor((256,),     dtype="float32"),
) -> R.Tensor((1, 256), dtype="float32"):
    with R.dataflow():              # "safe zone" for optimisation
        lv0 = R.matmul(x, weight)
        lv1 = R.add(lv0, bias)
        gv  = R.nn.relu(lv1)
        R.output(gv)
    return gv

Notice that Relax does not say how to implement R.matmul. No loops, no thread assignments, no tiling. The dataflow block (with R.dataflow()) is a hint to the compiler that this region is purely functional, no side effects, so it can safely reorder, fuse or eliminate operations.

In Relax, each line in a dataflow block is a binding: a name assigned to the result of one operation. lv0 = R.matmul(x, weight) means: "compute the matmul, name the result lv0."

Compiler passes traverse these bindings. The QKV fusion pass, for example, uses ExprMutator, a TVM API that walks through every binding in the dataflow block. When it finds three consecutive matmul bindings that all take the same input variable, it knows it has found a Q, K, V triple and can merge them.

This is fundamentally similar to how static analysis tools work in software engineering, they traverse an AST (abstract syntax tree) looking for patterns and transform them. Relax's dataflow graph is TVM's equivalent of an AST for ML models.

The @R.function decorator does the same trick as @T.prim_func below. It never executes the function as Python. Instead, TVM captures it into a tree of IR objects. The line lv0 = R.matmul(x, weight) becomes a relax.VarBinding whose value is a relax.Call node:

# What Python "sees" at runtime after the decorator runs:
VarBinding(
    var   = DataflowVar("lv0"),
    value = Call(op="relax.matmul", args=[Var("x"), Var("weight")])
)

Pattern-matching in a pass is literally just checking this tree. In qkv_fusion.py, detecting a matmul binding is:

def _match_matmul(self, expr):
    # check if node is Call to "relax.matmul" op
    if not (isinstance(expr, relax.Call) and
            expr.op == tvm.ir.Op.get("relax.matmul")):
        return None
    x, w = expr.args[0], expr.args[1]
    ...

And rewriting the graph is rebuilding the binding list with new Call nodes in place of the old ones.

2.4 TensorIR: The Kernel Level

Below Relax sits TensorIR (TIR), TVM’s low-level IR for individual tensor programs. Where Relax says what to compute, TIR specifies how: the loop nests, memory layout, thread bindings and tiling decisions.

TensorIR

@T.prim_func
def matmul(A: T.Buffer((128, 768), "float32"),
           B: T.Buffer((768, 768), "float32"),
           C: T.Buffer((128, 768), "float32")):
    for i, j, k in T.grid(128, 768, 768):
        with T.block("matmul"):
            # SSR = Spatial, Spatial, Reduction
            vi, vj, vk = T.axis.remap("SSR", [i, j, k])
            with T.init():
                C[vi, vj] = 0.0
            C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]

Like the @R.function decorator, @T.prim_func tells TVM not to run this as Python, but to parse it into IR objects. Each buffer is declared with an explicit shape and dtype, not as a hint, but as a hard constraint that lets later passes compute memory strides and choose vectorisation widths without guessing. C is the output buffer, written in place, there is no return statement, which is the convention for all prim_func kernels.[16]

T.grid(128, 768, 768) is syntactic sugar for three nested loops, defining 128 × 768 × 768 = 75,497,472 iterations. The T.block wraps the actual computation into a self-contained unit, and the "SSR" annotation on the axes is the crucial part. It tells the compiler that vi and vj are spatial (independent, safe to parallelise and tile freely) while vk is a reduction (all k iterations write to the same accumulator, so it cannot be freely reordered without an explicit parallel reduction pattern). The T.init() block cleanly separates the accumulator reset from the accumulation step, letting the compiler hoist or fuse the initialisation correctly when the schedule is transformed.[16]

@T.prim_func declaring a primitive tensor function. A primitive tensor function is the smallest schedulable grain of work in TVM: one matmul, one convolution, one fused conv+relu. Marking a function with this decorator tells TVM to parse it into IR, type-check it, and make it available for transformation and code generation.[14]

T.axis.remap("SSR", [i, j, k]) This shorthand declares three block axes in one line, equivalent to writing: [16]

vi = T.axis.spatial(128, i)   # S independent, safe to parallelise
vj = T.axis.spatial(768, j)   # S independent, safe to parallelise
vk = T.axis.reduce (768, k)   # R accumulates into C[vi,vj]

Each declaration conveys three things: the binding of the block variable to its outer loop variable, the range of valid indices and the axis type. The spatial axes can be freely tiled, vectorised and mapped to GPU threads without data races. The reduce axis vk contributes to the same accumulator for all k values and can only be parallelised with an explicit parallel reduction pattern such as a tree-reduce.[16]

T.init() The with T.init() sub-block formally declares the initialisation action for the reduction.[16] It makes reduction semantics unambiguous and allows the compiler to hoist this initialisation correctly when the loop is later split or vectorised, it only needs to fire once per (vi, vj) pair, not on every k iteration.

2.5 The Compilation Pipeline

TVM’s compilation is a sequence of transformation passes applied to the IRModule. Click a stage below to learn what happens at each step:

🏭 The TVM Compilation Pipeline
📥
1. ONNX Import
from_onnx() → Relax IRModule
⚙️
2. Graph Surgeon Passes
QKV fusion, FuseTransposeMatmul
🧹
3. Normalize + FoldConstant + DCE
Cleanup passes: fold constants, remove dead code
🔽
4. LegalizeOps
Relax ops → TIR implementations
🔗
5. FuseOps + FuseTIR
Adjacent TIR functions fused into single kernels
🚀
6. relax.build()
TIR → LLVM / CUDA PTX → executable
Click a stage above to see what happens there.

2.6 The ONNX Format

ONNX (Open Neural Network Exchange)[6] is a standard, framework-agnostic file format for neural networks.

💡 Why go through ONNX at all?

TVM can import from PyTorch directly via torch.export, but ONNX provides a more stable, explicit graph. ONNX explicitly lays out every operation as a typed node in a directed acyclic graph, making it easier to pattern-match and restructure.

The project uses ONNX opset 18 with static shapes (batch=1, seq=128). Static shapes are important for Graph Surgeon: the QKV fusion pass needs to know the exact dimensions of weight matrices at compile time to perform the concatenation. With dynamic shapes, you’d need to defer this to runtime.

3. Technical Approach

3.1 Target Models

ModelTypeParamsLayersWhy it's interesting
DistilGPT-2LLM (decoder)82M6 Causal language model; QKV is not implemented as separate MatMuls in ONNX
DistilBERTLLM (encoder)66M6 Bidirectional encoder; QKV is implemented as separate MatMuls in ONNX

Encoder (BERT-style): Every token can attend to every other token, "bidirectional" attention. Great for tasks where you process a whole sentence and need to understand it (classification, NER, question answering). Processes the whole input at once.

Decoder (GPT-style): Each token can only attend to previous tokens, "causal" (left-to-right) attention, enforced by masking future positions in the attention score matrix. Designed for generation: you predict one token at a time, left to right. GPT-2 and ChatGPT use this architecture.

3.2 The QKV Fusion Pass

The most important pass in this project. Here’s what happens step by step:

1
Graph traversal: The pass uses TVM's ExprMutator API to walk through each DataflowBlock in the main Relax function, visiting bindings one at a time.
2
Pattern matching: For each group of three matmul+bias-add subgraphs that share a common input variable, the pass identifies a Q, K, V triple.
3
Weight concatenation: The three weight matrices [768×768], [768×768], [768×768] are concatenated via numpy.concatenate into a single [768×2304] matrix. Biases too. This happens once, at compile time, not at every inference.
4
Graph rewrite: The three original bindings are replaced with: one fused matmul (768×2304), followed by three relax.op.strided_slice calls to extract Q, K, V from the result.
5
Constant folding: A follow-up FoldConstant() pass evaluates the concatenated weights, materialising them into the graph. No runtime overhead from the concatenation.
✂️ QKV Fusion
Showing: Before fusion
Relax Graph Nodes
lv_q = R.matmul(x, W_Q)
lv_q2 = R.add(lv_q, b_Q)
lv_k = R.matmul(x, W_K)
lv_k2 = R.add(lv_k, b_K)
lv_v = R.matmul(x, W_V)
lv_v2 = R.add(lv_v, b_V)
↑ 6 ops, 3 reads of X from memory
Kernel Launches (GPU)
kernel_matmul_Q(X, W_Q) reads X: ~393 KB kernel_matmul_K(X, W_K) reads X: ~393 KB again kernel_matmul_V(X, W_V) reads X: ~393 KB again Total X reads: ~1.2 MB
Relax Graph Nodes (after fusion)
lv_qkv = R.matmul(x, W_QKV)
lv_qkv2 = R.add(lv_qkv, b_QKV)
Q = R.strided_slice(lv_qkv2, [0,0,0], [b,s,d])
K = R.strided_slice(lv_qkv2, [0,0,d], [b,s,2d])
V = R.strided_slice(lv_qkv2, [0,0,2d], [b,s,3d])
↑ 5 ops, 1 read of X (slices are near-free)
Kernel Launches (GPU)
kernel_matmul_QKV(X, W_QKV) reads X once: ~393 KB W_QKV is 768×2304 (3× wider) slice_Q slice_K slice_V Total X reads: ~393 KB ✓ Saved ~820 KB per attention layer

3.3 FuseTransposeMatmul Pass

FuseTransposeMatmul is a built-in TVM pass. In ONNX-exported transformers, some operations produce explicit Transpose nodes before a Matmul. This pass folds those transpose nodes into the matmul’s transA/transB flags, eliminating a memcpy on CPU, where explicit transpositions have real cost. On GPU, cuBLAS handles CUBLAS_OP_T natively as a flag, so this pass eliminates a Relax node that was never a real cost to begin with.

4. Results

  • GPU: NVIDIA Turing (sm_75), CUDA.
  • CPU: x86-64.
  • Benchmarks: 200 warmup + 200 timed iterations.
  • Statistical significance: Welch t-test at p < 0.01.

4.1 GPU Results

ModelPipelineMean (ms)Std (ms)Speedupp<0.01?
DistilGPT-2 o2_baseline27.031.550--
llm_fuse_transpose27.540.0050.981×
llm_qkv27.540.0010.982×
llm_dlight27.710.0530.976×
DistilBERT o2_baseline14.780.004--
llm_fuse_transpose14.790.0050.999×
llm_qkv14.920.0040.990×
llm_dlight14.770.0411.001×
Findings: GPU LLMs

No graph-level pass improved transformer inference on GPU. QKV fusion, transpose elimination, and DLight TIR scheduling all resulted in no change or slight regressions.

4.2 CPU Results

ModelPipelineMean (ms)Speedupp<0.01?
DistilGPT-2 o2_baseline20,282--
llm_qkv19,6571.032×
llm_all19,7321.028×
DistilBERT o2_baseline9,293--
llm_fuse_transpose8,9981.033×
llm_qkv8,3091.118×
⭐ Findings: CPU LLMs

QKV fusion delivers +11.8% on CPU for DistilBERT. The single largest improvement in the entire project. The pass was designed for CPU cache behaviour, and it works exactly where it was designed to work. FuseTransposeMatmul also helps (+3.3%), confirming that explicit transpose nodes carry real cost on x86, unlike GPU, where cuBLAS handles transposed inputs natively.

Important Caveat

The CPU runs fell back to -mcpu=generic due to an LLVM 22.1.2 compatibility issue with -mcpu=native. This means no AVX2, no FMA, the CPU results are without SIMD vectorisation. The 11.8% QKV speedup would likely be even larger with AVX2 enabled, as the fused matmul's wider inner loop can better exploit SIMD vector units (which process 8 floats in parallel).

5. Conclusion and Future Work

By concatenating \(W_{Q}, W_{K}, W_{V}\) into \(W_{QKV}\) and doing one wide matmul, we save the cost of loading the input tensor X from memory three times instead of once. On CPU, where memory bandwidth is tight and cache is small, that’s a real win, hence the 11.8%. On GPU, cuBLAS is already so good at scheduling individual matmuls (and the L2 cache is large enough to keep X hot across launches) that the saving evaporates. The slight regression we see on GPU is probably because the 768×2304 matmul is a less “canonical” shape than three 768×768 ones, and cuBLAS’s internal heuristics are tuned for the latter.

The real bottleneck is the attention score matrix. The QKV projections are compute-bound matmuls, hardware is doing useful arithmetic almost every cycle. The expensive part that graph-level fusion can’t touch is the attention computation itself, Q · Kᵀ produces a [seq_len, seq_len] score matrix. The correct solution to this is FlashAttention (Dao et al., 2022),[17] which fuses Q · Kᵀ, scaling, masking, softmax, softmax · V, into a single tiled CUDA kernel.

References

  1. Vaswani, A. et al. (2017). Attention Is All You Need. Advances in Neural Information Processing Systems, 30. arxiv.org/abs/1706.03762
  2. Chen, T. et al. (2018). TVM: An Automated End-to-End Optimizing Compiler for Deep Learning. arxiv.org/abs/1802.04799
  3. Apache TVM Documentation. Overview. tvm.apache.org/docs/get_started/overview.html
  4. Apache TVM Documentation. IRModule. tvm.apache.org/docs/get_started/tutorials/ir_module.html
  5. Apache TVM Documentation. Relax: Graph Abstraction for ML Models. tvm.apache.org/docs/deep_dive/relax/abstraction.html
  6. ONNX Project. ONNX: Open Neural Network Exchange. onnx.ai/onnx/intro/concepts.html
  7. Shao, J. et al. (2022). Tensor Program Optimization with Probabilistic Programs. NeurIPS 2022. proceedings.neurips.cc
  8. Apache TVM Documentation. MetaSchedule: Search-Based Auto-Tuning. tvm.apache.org/docs/.../meta_schedule.html
  9. Apache TVM Documentation. Optimize Large Language Model. tvm.apache.org/docs/.../optimize_llm.html
  10. Li, Y. et al. (2021). A Short Study on Compressing Decoder-Based Language Models. arxiv.org/abs/2110.08460
  11. Apache TVM Documentation. Customize Optimization. tvm.apache.org/docs/.../customize_opt.html
  12. Jurafsky, D. & Martin, J. H. Speech and Language Processing. Draft of January 6, 2026. web.stanford.edu/~jurafsky/slp3/ It's a very good book!
  13. ONNX Operator Reference. Gather. onnx.ai/onnx/operators/onnx__Gather.html
  14. Apache TVM Documentation. Tensor Program Abstraction. tvm.apache.org/docs/deep_dive/tensor_ir/abstraction.html
  15. Apache TVM Documentation. TensorIR Overview. tvm.apache.org/docs/deep_dive/tensor_ir/index.html
  16. Apache TVM Documentation. Understand TensorIR Abstraction. tvm.apache.org/docs/deep_dive/tensor_ir/learning.html
  17. Dao, T., Fu, D. Y., Ermon, S., Rudra, A., & Ré, C. (2022). FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness. Advances in Neural Information Processing Systems, 35. arxiv.org/abs/2205.14135
  18. Dao, T. (2023). FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning. International Conference on Learning Representations (ICLR 2024). arxiv.org/abs/2307.08691