Author: Ibrahim El Kaddouri
Repository:
The repository is still private, under construction
Abstract
This project presents Graph Surgeon, a custom graph-level compiler optimisation pipeline for transformer-based language models, built on top of Apache TVM’s Relax intermediate representation. An Intermediate Representation (IR) is like a translation layer. When you compile C++ to machine code, the compiler first converts your code into an intermediate form that is easier to analyse and simplify, but not yet CPU instructions. TVM has two IRs: Relax (graph-level) and TensorIR (loop-level). More below.
The goal of this project is to see whether targeted graph-level transformations, specifically QKV attention-head fusion applied to ONNXOpen Neural Network Exchange is a standard file format for neural networks. A universal format that different frameworks (PyTorch, TensorFlow, etc.) can all export to and import from. computation graphs, can produce measurable and explainable performance improvements.
This project demonstrates that QKV fusion delivers an 11.8% speedup on CPU for DistilBERT. On GPU, the same pass has no effect or a slight regression. DLightDLight is TVM’s rule-based TIR scheduling system, a set of hand-crafted rules for how to lay out loops, tiles and threads on the GPU. “Rule-based” means it applies fixed strategies rather than searching for the best one. rule-based TIR scheduling proved either neutral or harmful across all tested models.
1. Introduction
1.1 The Performance Problem in Modern LLMs
Transformer-based language models dominate NLP. Models like GPT-2, BERT and their distilled variants are deployed across text generation, question answering, code completion and more. Their computational cost is substantial, even DistilGPT-2, with 82 million parameters, requires hundreds of millions of floating-point operations per inference pass. At scale, this translates into energy, hardware cost and latency.
The conventional response is hardware, more and bigger GPUs, specialised accelerators, aggressive batching. But there is a complementary and often underexploited lever, the compiler. A sufficiently intelligent compiler, given the computation graph A computation graph is a diagram of every operation a neural network performs, where each node is one operation (a matrix multiply, an addition) and each edge is the tensor flowing between them. of a model, can eliminate redundant operations, fuse adjacent kernels, rewrite memory access patterns and restructure the graph in ways invisible to the model author, but which can dramatically reduce execution time.
1.2 Why Compiler Optimisations, Not Model Compression?
There is a rich literature on making models smaller and faster through model-level techniques: pruning Removing weights (or entire neurons) from a neural network that contribute little to the output, like removing irrelevant words from a sentence. A pruned model has fewer parameters and runs faster, but may lose some accuracy., quantisationReplacing 32-bit floating-point weights (float32) with lower-precision representations (float16, int8). This reduces memory and increases throughput, at the cost of some numerical precision., knowledge distillationTraining a smaller “student” model to mimic the behaviour of a larger “teacher” model. DistilBERT and DistilGPT-2 (used in this project) were created this way., and low-rank factorisation. These are valid, but they change the model.
Compiler optimisation is semantics-preserving. The output of the optimised model is bit-for-bit identical to the baseline (for structural passes) or numerically equivalent within a tolerance (for approximation passes). A compiler pass can be applied to any model without retraining, and it is orthogonal to compression, you can compress and compiler-optimise the same model.
2. Background
2.1 The Transformer Architecture
The transformer was introduced by Vaswani et al. (2017)[1] and the key innovation was the self-attention A mechanism that lets each token look at every other token in the input and decide how much to weight each one when building its own representation. See Jurafsky & Martin (2026), §8.1. mechanism. We’ll go step by step, following the ONNX graph of a DistilBERT transformer model, in order to explain how it works in detail.
Inputs
The model takes two integer tensors of shape [1, seq_len]. The first tensor,
input_ids, is a sequence of vocabulary indices. Each integer maps to one token
in the model’s vocabulary. The second tensor, attention_mask, is a binary tensor
where 1 marks a real token and 0 marks padding.
Token embedding lookup
The model holds a learned weight matrix wte of shape
[30522, 768], one 768-dimensional float vector per vocabulary token.
A Gather op uses input_ids as row indices and retrieves the
corresponding vectors, producing a float tensor of shape [1, seq_len, 768].
These vectors encode semantic meaning. They were trained to place
conceptually similar tokens near each other in that 768-dimensional space.[12]
Position embedding lookup
Because transformers process all tokens simultaneously and have no built-in
notion of order, position must be injected explicitly.[12]
A second learned matrix, wpe, of shape
[512, 768]
DistilBERT’s context window is 512 tokens. This is a hard architectural
limit: inputs longer than 512 tokens must be truncated or chunked before
they reach the model, and that’s why seq_len ≤ 512.
,
stores one vector per possible position (0 through 511).
A Gather op with an indices attribute specifying the index sequence
[0, 1, 2, ..., seq_len−1] retrieves the corresponding
rows, giving another [1, seq_len, 768] tensor.[13]
Embedding addition
The token and position tensors are added element-wise to
form the initial hidden state h₀ of shape [1, seq_len, 768]. Each position in
this tensor blends what the token means with where it sits in the sequence.[12]
This is the input to the first transformer block.
Six transformer blocks
The following structure runs six times with separate
learned weights at each layer.
Self Attention Mechanism
We will now describe the workings of one of those 6 blocks.
As we will see, DistilBERT uses post-norm ordering
The Jurafsky & Martin textbook (2026, §8.2, Fig. 8.7)
describes the more common modern pre-norm architecture (LayerNorm before
attention/FFN). DistilBERT instead follows the original Vaswani et al. (2017)
post-norm design (LayerNorm after the residual addition), which was also
used in BERT.
.
Q, K, V projections multiply the normalised hidden state against three separate
weight matrices, each [768, 768], to produce
Query, Key and Value tensors
These correspond to the three roles described in Jurafsky & Martin
(2026, §8.1.1): the query is the current element being compared,
the key is a token being compared against, and the
value is the content that gets weighted and summed.
Each is computed
as q = xWQ, k = xWK, v = xWV (Eq. 8.9).
.
Each result is reshaped from [1, seq_len, 768] to [1, seq_len, 12, 64], splitting
the 768 dimensions across 12 attention heads of 64 dimensions each, then transposed
to [1, 12, seq_len, 64] so each head’s data is contiguous in memory.
Scaled dot-product attention runs independently per head. Q · Kᵀ produces a
[seq_len, seq_len] score matrix
This implements Eq. 8.11 from Jurafsky & Martin (2026, §8.1.1):
score(xi, xj) = (qi · kj) / √dk.
The full parallel form is QKT/√dk (Eq. 8.33).
measuring how much each token attends to every
other. Dividing by √64 = 8 prevents the scores from growing too large.
Afterwards, the attention_mask is applied. Positions where the mask is 0 (i.e. padding
tokens) receive a large negative value so that softmax drives those weights to
zero. Because DistilBERT is a
bidirectional encoder
Unlike the causal (left-to-right) decoder described in Jurafsky & Martin
(2026, §8.1), DistilBERT is an encoder model in the BERT family. Encoders
let every token attend to every other token in the sequence, not just to
preceding tokens. Causal masking of future positions (§8.3, Fig. 8.10) is
therefore not used here. Jurafsky & Martin introduce bidirectional
encoder models in Chapter 9.
,
every non-padding token can attend to every other non-padding token,
there is no causal mask blocking future positions.
Softmax is applied row-wise to normalise the scores into weights. Finally,
multiplying by V yields a weighted sum of value vectors per position with output
shape [1, 12, seq_len, 64].
Output projection, residual, and Layer Norm 1. The 12 head outputs are
concatenated back to [1, seq_len, 768] via reshape and transpose, then
passed through a linear layer of shape [768, 768] that mixes information
across heads.[12] A
residual connection
The residual connection provides a skip path for gradients and means the
attention sub-layer only needs to learn a correction to the input, not a full
new representation. In the residual stream view (Elhage et al., 2021;
Jurafsky & Martin, 2026, §8.2), information flows up through the stream
and each component adds to it.
adds the hidden state from before the Q/K/V projections back to the
attention output. A normalisation step is then applied to the result. The formula is
\(\text{LN}(x) = \gamma \cdot \frac{x - \mu}{\sigma + \varepsilon} + \beta\)
where γ and β are learned parameters of size 768.[12]
FFN, residual, and Layer Norm 2. A two-layer feed-forward network This is the position-wise FFN described in Jurafsky & Martin (2026, §8.2, Eq. 8.21). It applies the same weights independently to each token position. The textbook notes that it is common for the FFN hidden dimension to be larger than the model dimension (e.g. d=512, dff=2048 in the original transformer; here d=768, dff=3072, the same 4× factor). takes the post-LN1 state as input. The first linear layer expands from 768 to 3072 dimensions (a 4× factor), a GELU Gaussian Error Linear Unit: a smooth, non-monotonic activation function defined as GELU(x) = x · Φ(x), where Φ(x) is the standard normal CDF. Unlike ReLU, which hard-gates at zero, GELU weights each input by the probability that it is positive under a Gaussian, producing a soft gating effect. It is the default activation in BERT, GPT-2, and most modern transformers (Hendrycks & Gimpel, 2016). non-linearity is applied element-wise, and the second layer contracts back to 768. A second residual connection adds the post-LN1 state to the FFN output, and a second layer normalisation normalises the sum. This produces the final output of the block.
Output
The ONNX graph ends at the last hidden state of shape [1, seq_len, 768]. At
inference time, the hidden state goes through a masked LM head
Unlike the autoregressive (left-to-right) language modelling described in
Jurafsky & Martin (2026, §8.5), DistilBERT uses masked language
modelling (MLM): some input tokens are replaced with a special [MASK]
token, and the model predicts what word belongs in each masked position.
The textbook introduces this approach in Chapter 9 on BERT and
bidirectional encoders.
and softmax to obtain a probability
distribution over all 30 522 tokens. For masked language modelling, the
predictions at [MASK] positions are read off to recover the hidden words. For
downstream tasks such as classification, only the representation at a special
[CLS] token position is used.
2.2 What is Apache TVM?
Apache TVM is a machine learning compiler framework. It takes a neural network (from any training framework) and produces fast, deployable code for any hardware target.
What is an IRModule?
The official docs say an IRModule “encompasses the entirety of the ML models, incorporating the computational graph, tensor programs, and potential calls to external libraries.” That’s… not very helpful.
When a compiler (any compiler, C++, Java, ...) transforms code, it can't jump directly from "human-readable source" to "machine code" in one step. There's too much to do. Instead, it converts through one or more Intermediate Representations (IRs): simplified, structured forms of the program that are easier to analyse and transform.
A classic example: GCC compiles C code to GIMPLE (a simplified form of C), then to RTL (Register Transfer Language, closer to assembly), then to machine code. Each IR is progressively lower-level.
TVM has two IRs:
- Relax, high-level, describes the computation graph (which operations, in what order, on what-shaped tensors)
- TensorIR (TIR), low-level, describes exactly how to implement each operation (loops, memory buffers, thread assignment)
An IRModule is simply a container, like a Python module
(a .py file), that holds all the functions needed to run a model,
written in TVM’s IRs. That’s it. The IRModule is a data structure that
describes the model. Passes transform this data structure. When you call
relax.build(), TVM reads the IRModule and generates actual executable
code (LLVM bitcode or CUDA PTX).
ONNX import: Gemm node → R.matmul
A Relax op, high-level, no implementation yet
After LegalizeOps: R.matmul → T.prim_func
A TIR loop nest, now you have an actual implementation
After codegen: TIR → LLVM IR → native x86/CUDA PTX
Graph Surgeon's passes happen at the first stage: they restructure the Relax graph before any TIR exists. This is why they're called "graph-level" passes.
2.3 Relax: The Graph Level
Relax is TVM’s high-level IR for representing the computation graph of a model. It is a functional, dataflow-oriented language. Here’s what a simple two-layer MLP looks like in Relax:
Relax
@R.function
def main(
x: R.Tensor((1, 784), dtype="float32"),
weight: R.Tensor((784, 256), dtype="float32"),
bias: R.Tensor((256,), dtype="float32"),
) -> R.Tensor((1, 256), dtype="float32"):
with R.dataflow(): # "safe zone" for optimisation
lv0 = R.matmul(x, weight)
lv1 = R.add(lv0, bias)
gv = R.nn.relu(lv1)
R.output(gv)
return gv
Notice that Relax does not say how to implement R.matmul.
No loops, no thread assignments, no tiling. The dataflow block
(with R.dataflow()) is a hint to the compiler that this region is
purely functional, no side effects, so it can safely reorder, fuse or eliminate
operations.
In Relax, each line in a dataflow block is a binding:
a name assigned to the result of one operation.
lv0 = R.matmul(x, weight) means: "compute the matmul, name the
result lv0."
Compiler passes traverse these bindings. The QKV fusion pass, for example,
uses ExprMutator, a TVM API that walks through every binding in the
dataflow block. When it finds three consecutive matmul bindings that all take the
same input variable, it knows it has found a Q, K, V triple and can merge them.
This is fundamentally similar to how static analysis tools work in software engineering, they traverse an AST (abstract syntax tree) looking for patterns and transform them. Relax's dataflow graph is TVM's equivalent of an AST for ML models.
The @R.function decorator does the same trick as @T.prim_func below.
It never executes the function as Python. Instead, TVM captures it into
a tree of IR objects. The line lv0 = R.matmul(x, weight) becomes a
relax.VarBinding whose value is a relax.Call node:
# What Python "sees" at runtime after the decorator runs:
VarBinding(
var = DataflowVar("lv0"),
value = Call(op="relax.matmul", args=[Var("x"), Var("weight")])
)
Pattern-matching in a pass is literally just checking this tree. In
qkv_fusion.py, detecting a matmul binding is:
def _match_matmul(self, expr):
# check if node is Call to "relax.matmul" op
if not (isinstance(expr, relax.Call) and
expr.op == tvm.ir.Op.get("relax.matmul")):
return None
x, w = expr.args[0], expr.args[1]
...
And rewriting the graph is rebuilding the binding list with new Call nodes in
place of the old ones.
2.4 TensorIR: The Kernel Level
Below Relax sits TensorIR (TIR), TVM’s low-level IR for individual tensor programs. Where Relax says what to compute, TIR specifies how: the loop nests, memory layout, thread bindings and tiling decisions.
TensorIR
@T.prim_func
def matmul(A: T.Buffer((128, 768), "float32"),
B: T.Buffer((768, 768), "float32"),
C: T.Buffer((128, 768), "float32")):
for i, j, k in T.grid(128, 768, 768):
with T.block("matmul"):
# SSR = Spatial, Spatial, Reduction
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
with T.init():
C[vi, vj] = 0.0
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
Like the @R.function decorator, @T.prim_func tells TVM not to run this as
Python, but to parse it into IR objects. Each buffer is declared with an explicit
shape and dtype, not as a hint, but as a hard constraint that lets later passes
compute memory strides and choose vectorisation widths without guessing. C is
the output buffer, written in place, there is no return statement, which is the
convention for all prim_func kernels.[16]
T.grid(128, 768, 768) is syntactic sugar for three nested loops, defining
128 × 768 × 768 = 75,497,472 iterations. The T.block wraps the actual
computation into a self-contained unit, and the "SSR" annotation on the axes
is the crucial part. It tells the compiler that vi and vj are spatial
(independent, safe to parallelise and tile freely) while vk is a reduction
(all k iterations write to the same accumulator, so it cannot be freely
reordered without an explicit parallel reduction pattern). The T.init() block
cleanly separates the accumulator reset from the accumulation step, letting the
compiler hoist or fuse the initialisation correctly when the schedule is
transformed.[16]
@T.prim_func declaring a primitive tensor function.
A primitive tensor function is the smallest schedulable grain of work in TVM:
one matmul, one convolution, one fused conv+relu. Marking a function with this decorator
tells TVM to parse it into IR, type-check it, and make it available for transformation
and code generation.[14]
T.axis.remap("SSR", [i, j, k])
This shorthand declares three block axes in one line, equivalent to writing:
[16]
vi = T.axis.spatial(128, i) # S independent, safe to parallelise
vj = T.axis.spatial(768, j) # S independent, safe to parallelise
vk = T.axis.reduce (768, k) # R accumulates into C[vi,vj]
Each declaration conveys three things: the binding of the block variable to
its outer loop variable, the range of valid indices and the axis type. The
spatial axes can be freely tiled, vectorised and mapped to GPU threads without
data races. The reduce axis vk contributes to the same accumulator
for all k values and can only be parallelised with an explicit
parallel reduction pattern such as a tree-reduce.[16]
T.init()
The with T.init() sub-block formally declares the initialisation
action for the reduction.[16] It makes reduction
semantics unambiguous and allows the compiler to hoist this initialisation
correctly when the loop is later split or vectorised, it only needs to fire
once per (vi, vj) pair, not on every k iteration.
2.5 The Compilation Pipeline
TVM’s compilation is a sequence of transformation passes applied to the IRModule. Click a stage below to learn what happens at each step:
2.6 The ONNX Format
ONNX (Open Neural Network Exchange)[6] is a standard, framework-agnostic file format for neural networks.
TVM can import from PyTorch directly via torch.export,
but ONNX provides a more stable, explicit graph. ONNX explicitly lays
out every operation as a typed node in a directed acyclic graph,
making it easier to pattern-match and restructure.
The project uses ONNX opset 18 with static shapes (batch=1, seq=128). Static shapes are important for Graph Surgeon: the QKV fusion pass needs to know the exact dimensions of weight matrices at compile time to perform the concatenation. With dynamic shapes, you’d need to defer this to runtime.
3. Technical Approach
3.1 Target Models
| Model | Type | Params | Layers | Why it's interesting |
|---|---|---|---|---|
| DistilGPT-2 | LLM (decoder) | 82M | 6 | Causal language model; QKV is not implemented as separate MatMuls in ONNX |
| DistilBERT | LLM (encoder) | 66M | 6 | Bidirectional encoder; QKV is implemented as separate MatMuls in ONNX |
Encoder (BERT-style): Every token can attend to every other token, "bidirectional" attention. Great for tasks where you process a whole sentence and need to understand it (classification, NER, question answering). Processes the whole input at once.
Decoder (GPT-style): Each token can only attend to previous tokens, "causal" (left-to-right) attention, enforced by masking future positions in the attention score matrix. Designed for generation: you predict one token at a time, left to right. GPT-2 and ChatGPT use this architecture.
3.2 The QKV Fusion Pass
The most important pass in this project. Here’s what happens step by step:
ExprMutator API to walk through each DataflowBlock in the main Relax function, visiting bindings one at a time.
numpy.concatenate into a single [768×2304] matrix. Biases too. This happens once, at compile time, not at every inference.
relax.op.strided_slice calls to extract Q, K, V from the result.
FoldConstant() pass evaluates the concatenated weights, materialising them into the graph. No runtime overhead from the concatenation.
lv_q2 = R.add(lv_q, b_Q)
lv_k2 = R.add(lv_k, b_K)
lv_v2 = R.add(lv_v, b_V)
lv_qkv2 = R.add(lv_qkv, b_QKV)
3.3 FuseTransposeMatmul Pass
FuseTransposeMatmul is a built-in TVM pass. In ONNX-exported
transformers, some operations produce explicit Transpose nodes before
a Matmul. This pass folds those transpose nodes into the matmul’s
transA/transB flags, eliminating a memcpy on CPU, where
explicit transpositions have real cost. On GPU, cuBLAS handles
CUBLAS_OP_T natively as a flag, so this pass eliminates
a Relax node that was never a real cost to begin with.
4. Results
- GPU: NVIDIA Turing (sm_75), CUDA.
- CPU: x86-64.
- Benchmarks: 200 warmup + 200 timed iterations.
- Statistical significance: Welch t-test at p < 0.01.
4.1 GPU Results
| Model | Pipeline | Mean (ms) | Std (ms) | Speedup | p<0.01? |
|---|---|---|---|---|---|
| DistilGPT-2 | o2_baseline | 27.03 | 1.550 | - | - |
| llm_fuse_transpose | 27.54 | 0.005 | 0.981× | ✗ | |
| llm_qkv | 27.54 | 0.001 | 0.982× | ✗ | |
| llm_dlight | 27.71 | 0.053 | 0.976× | ✓ | |
| DistilBERT | o2_baseline | 14.78 | 0.004 | - | - |
| llm_fuse_transpose | 14.79 | 0.005 | 0.999× | ✓ | |
| llm_qkv | 14.92 | 0.004 | 0.990× | ✓ | |
| llm_dlight | 14.77 | 0.041 | 1.001× | ✗ |
No graph-level pass improved transformer inference on GPU. QKV fusion, transpose elimination, and DLight TIR scheduling all resulted in no change or slight regressions.
4.2 CPU Results
| Model | Pipeline | Mean (ms) | Speedup | p<0.01? |
|---|---|---|---|---|
| DistilGPT-2 | o2_baseline | 20,282 | - | - |
| llm_qkv | 19,657 | 1.032× | ✓ | |
| llm_all | 19,732 | 1.028× | ✓ | |
| DistilBERT | o2_baseline | 9,293 | - | - |
| llm_fuse_transpose | 8,998 | 1.033× | ✓ | |
| llm_qkv | 8,309 | 1.118× | ✓ |
QKV fusion delivers +11.8% on CPU for DistilBERT. The single largest improvement in the entire project. The pass was designed for CPU cache behaviour, and it works exactly where it was designed to work. FuseTransposeMatmul also helps (+3.3%), confirming that explicit transpose nodes carry real cost on x86, unlike GPU, where cuBLAS handles transposed inputs natively.
The CPU runs fell back to -mcpu=generic due to an LLVM 22.1.2 compatibility issue with -mcpu=native. This means no AVX2, no FMA, the CPU results are without SIMD vectorisation. The 11.8% QKV speedup would likely be even larger with AVX2 enabled, as the fused matmul's wider inner loop can better exploit SIMD vector units (which process 8 floats in parallel).
5. Conclusion and Future Work
By concatenating \(W_{Q}, W_{K}, W_{V}\) into \(W_{QKV}\) and doing one wide matmul, we save the cost of loading the input tensor X from memory three times instead of once. On CPU, where memory bandwidth is tight and cache is small, that’s a real win, hence the 11.8%. On GPU, cuBLAS is already so good at scheduling individual matmuls (and the L2 cache is large enough to keep X hot across launches) that the saving evaporates. The slight regression we see on GPU is probably because the 768×2304 matmul is a less “canonical” shape than three 768×768 ones, and cuBLAS’s internal heuristics are tuned for the latter.
The real bottleneck is the attention score matrix. The QKV projections are
compute-bound matmuls, hardware is doing useful arithmetic almost every cycle.
The expensive part that graph-level fusion can’t touch is the attention computation
itself, Q · Kᵀ produces a [seq_len, seq_len] score matrix.
The correct solution to this is FlashAttention (Dao et al., 2022),[17]
which fuses Q · Kᵀ, scaling, masking, softmax, softmax · V, into a single tiled CUDA kernel.
References
- Vaswani, A. et al. (2017). Attention Is All You Need. Advances in Neural Information Processing Systems, 30. arxiv.org/abs/1706.03762
- Chen, T. et al. (2018). TVM: An Automated End-to-End Optimizing Compiler for Deep Learning. arxiv.org/abs/1802.04799
- Apache TVM Documentation. Overview. tvm.apache.org/docs/get_started/overview.html
- Apache TVM Documentation. IRModule. tvm.apache.org/docs/get_started/tutorials/ir_module.html
- Apache TVM Documentation. Relax: Graph Abstraction for ML Models. tvm.apache.org/docs/deep_dive/relax/abstraction.html
- ONNX Project. ONNX: Open Neural Network Exchange. onnx.ai/onnx/intro/concepts.html
- Shao, J. et al. (2022). Tensor Program Optimization with Probabilistic Programs. NeurIPS 2022. proceedings.neurips.cc
- Apache TVM Documentation. MetaSchedule: Search-Based Auto-Tuning. tvm.apache.org/docs/.../meta_schedule.html
- Apache TVM Documentation. Optimize Large Language Model. tvm.apache.org/docs/.../optimize_llm.html
- Li, Y. et al. (2021). A Short Study on Compressing Decoder-Based Language Models. arxiv.org/abs/2110.08460
- Apache TVM Documentation. Customize Optimization. tvm.apache.org/docs/.../customize_opt.html
- Jurafsky, D. & Martin, J. H. Speech and Language Processing. Draft of January 6, 2026. web.stanford.edu/~jurafsky/slp3/ It's a very good book!
- ONNX Operator Reference. Gather. onnx.ai/onnx/operators/onnx__Gather.html
- Apache TVM Documentation. Tensor Program Abstraction. tvm.apache.org/docs/deep_dive/tensor_ir/abstraction.html
- Apache TVM Documentation. TensorIR Overview. tvm.apache.org/docs/deep_dive/tensor_ir/index.html
- Apache TVM Documentation. Understand TensorIR Abstraction. tvm.apache.org/docs/deep_dive/tensor_ir/learning.html
- Dao, T., Fu, D. Y., Ermon, S., Rudra, A., & Ré, C. (2022). FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness. Advances in Neural Information Processing Systems, 35. arxiv.org/abs/2205.14135
- Dao, T. (2023). FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning. International Conference on Learning Representations (ICLR 2024). arxiv.org/abs/2307.08691