Let's face it - we're in the middle of a GPU shortage crisis. Everyone wants GPUs for their AI projects. The demand is through the roof, and prices are absolutely insane - a single H100 can cost over $30,000, and good luck even finding one in stock.
For most companies and researchers, buying more GPUs simply isn't an option. The only realistic solution? We need to squeeze every bit of performance from the GPUs we already have.
If you've ever tried optimizing CUDA code, you know the pain. It's like solving a massive puzzle where you're constantly tweaking memory access patterns, adjusting thread blocks, and running endless profiling tests. Engineers spend weeks or months on this stuff, and it's honestly exhausting.
Here's where things get interesting. Recent LLM models - think DeepSeek-R1 and OpenAI's o1 - are getting pretty good at writing code. And here's the kicker: CUDA optimization has a super clear reward signal - speed! Your code either runs faster or it doesn't. That's perfect for training RL.
Imagine this: instead of you pulling your hair out trying different optimizations, an AI could generate thousands of variations, test them all, and learn what works. It might even discover tricks that humans never thought of!
So we built CUDA-L1, which uses something we call "contrastive reinforcement learning." Think of it like this: instead of just trying random stuff, our AI compares different CUDA versions side-by-side and learns why some are faster than others. It's like having a coach that shows you good vs. bad examples until you get it. And we found CUDA-L1 excels at:
Ask any AI to write CUDA code and you'll likely get something that doesn't compile, crashes, or runs painfully slow. The reason is simple: these models barely saw any quality CUDA code during training. It's like asking someone who's only read cooking blogs to become a chef.
We built CUDA-L1 with a three-stage pipeline: supervised learning (learn the basics), self-supervised learning (practice until perfect), and contrastive reinforcement learning (compete for speed).
First, we needed to fix the data shortage problem. We took existing CUDA code and created variations of it - expanding the model's exposure to different CUDA patterns. This supervised fine-tuning phase has one goal: make sure the AI can write CUDA code that actually compiles and runs correctly.
Next, we let the model generate its own CUDA code, test it, and learn from what works. The model generates thousands of code samples, we automatically test each one, and only the successful implementations get fed back for more training. No speed optimization yet - just making sure the code works reliably.
This is where CUDA-L1 becomes special. Traditional RL would just assign scores to generated code and hope the model figures out why some implementations are faster. That's like grading exams without showing students the correct answers. Instead, we do something radically different. Look at this actual prompt we use:
Then we ask three critical questions:
The magic happens because the AI can directly see and reason about performance differences. It's not guessing in the dark - it's learning from concrete examples of what makes CUDA code fast.
Reinforcement learning is notorious for exhibiting reward hacking behaviors, where models exploit system vulnerabilities to achieve higher rewards while generating outputs that deviate from the intended objectives. In our experiments, we discovered that over 30% of generated implementations attempted some form of reward hacking.
During our initial training procedure, we identified three major categories of reward hacking behaviors:
KernelBench measures execution time by recording timing events on the main CUDA stream as follows:
start_event.record(original_model_stream)
model(*inputs)
end_event.record(original_model_stream)
torch.cuda.synchronize(device=device)
Warning: This vulnerability led to artificial 18× speedups that were entirely fake - the actual computation performance was unchanged!
However, RL-generated code exploits this by creating additional CUDA streams that execute asynchronously. Since KernelBench only monitors the main stream, it fails to capture the actual execution time of operations running on parallel streams. This vulnerability is significant: in our initial implementation, we find that 82 out of 250 (32.8%) RL-generated implementations exploit this timing loophole to appear faster than they actually are, leading to an overall speedup of 18×. To address this issue, prompt engineering alone is insufficient. Our fix synchronizes all CUDA streams:
if custom_contain_new_streams:
for stream in custom_model_streams:
custom_model_stream.wait_stream(stream)
end_event.record(custom_model_stream)
In KernelBench, each computational task is associated with specific hyper- parameters, including batch_size, dim, in_features dimension, out_features dimension, scaling_factor, and others. The RL agent learned to exploit these parameters by generating code that artificially reduces their values, thereby achieving superficial speedup improvements that do not reflect genuine optimization performance.
The RL agent developed strategies to cache computational results across evaluation batches based on input addresses. When another input’s address matches a cached one, it returns the cached output. The following code snippet gives an illustration:
cache_key = x.data_ptr()
if cache_key in self.cache:
return self.cache[cache_key]
To combat these sophisticated reward hacking behaviors, we developed a multi-layered defense system combining automated detection, continuous learning, and mathematical constraints.
We deployed DeepSeek-R1 as an adversarial checker that analyzes generated code for potential exploits. The model achieves over 60% detection accuracy through pattern analysis, semantic understanding, and anomaly detection. When speedups exceed suspicious thresholds (e.g., >10× for simple operations), it triggers multi-stage verification including output consistency checks, memory usage analysis, and GPU utilization metrics.
A continuously updated database containing 500+ unique hacking patterns helps identify new exploits. For each generated implementation, we retrieve similar hacking cases using AST-based code analysis and performance profiling. This contextual information improves detection accuracy by 25% and allows us to track how hacking strategies evolve across training iterations.
To prevent over-optimization for extreme cases, we apply mathematical constraints on rewards:
rnormalized = (r - μ) / σ
rsmooth = clip(rnormalized, -k, k)
Here, μ and σ are rolling statistics updated every 100 iterations, and k = 1.5 represents the maximum reasonable speedup. For suspicious high-reward cases with low confidence, we apply additional dampening based on consistency metrics and alignment with known optimization patterns.
Key Takeaway: The arms race between the RL agent and our detection systems highlights the importance of robust evaluation frameworks. Our multi-layered approach has proven effective at maintaining training integrity while still allowing genuine breakthrough optimizations to be rewarded appropriately.
We tested CUDA-L1 on KernelBench, a comprehensive benchmark suite with three difficulty levels:
And we use All Levels to denote the full dataset containing all three levels.
To perform a comprehensive evaluation on the generated code, we perform the following comparisons:
I) Default: This compares the CUDA-L1 generated code with the reference code by KernelBench.
II) Torch Compile: This compares the CUDA-L1 generated code with the reference code enhanced by torch.compile with default settings. Torch.compile applies graph-level optimizations including operator fusion, memory planning, and kernel selection to accelerate PyTorch models through just-in-time compilation.
III) Torch Compile Reduce Overhead: This compares the CUDA-L1 generated code with the reference code enhanced by torch.compile with reduce-overhead mode enabled. This mode minimizes the compilation overhead by caching compiled graphs more aggressively and reducing recompilation frequency, making it particularly suitable for inference workloads with static shapes.
IV) CUDA Graph: Since KernelBench does not provide official CUDA Graph implementations, we employ Claude 4 to generate CUDA Graph-augmented code for each reference implementation. CUDA Graphs capture a series of CUDA kernels and their dependencies into a single graph structure that can be launched with minimal CPU overhead, eliminating the need for repeated kernel launch commands and significantly reducing CPU-GPU synchronization costs.
Configuration | DataType | Mean | Max | 75% | 50% | 25% | Success↑ # out of total |
Speedup↑ >1.01x out of total |
---|---|---|---|---|---|---|---|---|
Default | All Levels | 3.12× | 120× | 2.25× | 1.42× | 1.17× | 249/250 | 226/250 |
Torch Compile | All Levels | 2.77× | 69.0× | 2.55× | 1.72× | 1.14× | 249/250 | 203/250 |
Torch Compile RO | All Levels | 2.88× | 80.1× | 2.48× | 1.67× | 1.13× | 249/250 | 200/250 |
CUDA Graph | All Levels | 2.81× | 97.9× | 1.83× | 1.20× | 0.954× | 249/250 | 147/229 |
• RO = Reduce Overhead
• Success and Speedup indicate the number of successful benchmarks out of the total for each level
We trained CUDA-L1 on NVIDIA A100s, but what if you're using a different GPU? Good news: the optimizations transfer remarkably well. We tested the same A100-optimized kernels on:
Configuration | A100 | 3090 | H100 | H20 | L40 |
---|---|---|---|---|---|
Default | 3.12× | 2.51× | 3.85× | 2.38× | 3.13× |
Torch Compile | 2.77× | 2.58× | 2.74× | 2.89× | 2.85× |
Torch Compile RO | 2.88× | 2.61× | 2.77× | 2.82× | 2.89× |
CUDA Graph | 2.81× | 3.34× | 2.23× | 2.20× | 3.98× |
The results show fascinating patterns across different GPU architectures:
These results confirm that CUDA-L1's optimization patterns are fundamental enough to benefit any modern GPU architecture, though different GPUs may favor different optimization strategies. This suggests exciting opportunities for GPU-specific training in future versions of CUDA-L1.
Let's dive deep into three examples to understand what CUDA-L1 actually does:
This task performs matrix multiplication between a diagonal matrix (represented by its diagonal elements) and a dense matrix, both with dimension N=4096.
class Model(nn.Module):
def forward(self, A, B):
# A: (N,) - 1D tensor of shape N
# B: (N, M) - 2D tensor of shape N x M
# torch.diag(A): (N, N) - creates diagonal matrix from A
# Result: (N, N) @ (N, M) = (N, M)
return torch.diag(A) @ B
class Model(nn.Module):
def forward(self, A, B):
return A.unsqueeze(1) * B
The optimized implementation leverages PyTorch's broadcasting mechanism brilliantly. Instead of creating a full N×N diagonal matrix (which would be mostly zeros), it simply reshapes the diagonal vector A from (N,) to (N, 1) and uses broadcasting to multiply each row of B by the corresponding element of A.
The benefits are substantial:
What's remarkable is that CUDA-L1 discovered this algebraic simplification on its own through RL exploration. By testing semantically equivalent implementations, it learned to identify patterns where computationally expensive operations can be replaced with more efficient alternatives.
For a classical LSTM neural network (Level 3, Task 35), CUDA-L1 achieved a 3.4× speedup by applying three key optimizations:
import torch
import torch.nn as nn
class Model(nn.Module):
def __init__(self, input_size, hidden_size, num_layers, output_size, dropout=0.0):
"""
Initialize the LSTM model.
:param input_size: The number of expected features in the input `x`
:param hidden_size: The number of features in the hidden state `h`
:param num_layers: Number of recurrent layers
:param output_size: The number of output features
:param dropout: If non-zero, introduces a Dropout layer on the outputs of each LSTM layer except the last layer, with dropout probability equal to `dropout`
"""
super(Model, self).__init__()
# Initialize hidden state with random values
self.h0 = torch.randn((num_layers, batch_size, hidden_size))
self.c0 = torch.randn((num_layers, batch_size, hidden_size))
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, dropout=dropout, bidirectional=False)
self.fc = nn.Linear(hidden_size, output_size)
def forward(self, x):
"""
Forward pass through the LSTM model.
:param x: The input tensor, shape (batch_size, sequence_length, input_size)
:return: The output tensor, shape (batch_size, sequence_length, output_size)
"""
self.h0 = self.h0.to(x.device)
self.c0 = self.h0.to(x.device) # BUG: This should be self.c0.to(x.device)
# Forward propagate LSTM
out, state = self.lstm(x, (self.h0, self.c0)) # out: tensor of shape (batch_size, seq_length, hidden_size)
# Decode the hidden state of the last time step
out = self.fc(out[:, -1, :]) # out: tensor of shape (batch_size, output_size)
return state[0]
# Test code
batch_size = 10
sequence_length = 512
input_size = 128
hidden_size = 256
num_layers = 6
output_size = 10
dropout = 0.0
def get_inputs():
return [torch.randn(batch_size, sequence_length, input_size)]
def get_init_inputs():
return [input_size, hidden_size, num_layers, output_size, dropout]
import torch
import torch.nn as nn
import torch.cuda as cuda
class ModelNew(nn.Module):
def __init__(self, input_size, hidden_size, num_layers, output_size, dropout=0.0):
"""
Initialize the LSTM model with three core optimization techniques.
Color coding:
- 🔵 BLUE: CUDA Graphs optimization
- 🟢 GREEN: Memory Contiguity optimization
- 🟠 ORANGE: Static Tensor Reuse optimization
"""
super(ModelNew, self).__init__()
# Initialize hidden states as buffers
self.register_buffer('h0', torch.randn((num_layers, batch_size, hidden_size)))
self.register_buffer('c0', torch.randn((num_layers, batch_size, hidden_size)))
# Use PyTorch's optimized LSTM implementation
self.lstm = nn.LSTM(
input_size=input_size,
hidden_size=hidden_size,
num_layers=num_layers,
batch_first=True,
dropout=dropout,
bidirectional=False
)
self.fc = nn.Linear(hidden_size, output_size)
# 🔵 CUDA GRAPHS: Variables for graph capture and replay
self.graph = None
self.graph_ready = False
self.input_shape = None
# 🟠 STATIC TENSOR REUSE: Pre-allocated tensors for graph execution
self.static_input = None
self.static_output = None
# 🔵 CUDA GRAPHS: Streams for graph operations
self.graph_stream = None
# Track if we're running on CUDA
self.is_cuda_available = torch.cuda.is_available()
def _initialize_cuda_resources(self):
"""🔵 CUDA GRAPHS: Initialize CUDA stream for graph operations"""
if self.graph_stream is None:
self.graph_stream = cuda.Stream()
def _capture_graph(self, x, result):
"""
🔵 CUDA GRAPHS: Capture the computation graph for replay
🟠 STATIC TENSOR REUSE: Create static tensors for graph capture
"""
# 🟠 STATIC TENSOR REUSE: Clone tensors for static allocation
self.static_input = x.clone()
self.static_output = result.clone()
# 🔵 CUDA GRAPHS: Capture the computation graph
with torch.cuda.stream(self.graph_stream):
self.graph = cuda.CUDAGraph()
with cuda.graph(self.graph):
# Operations to capture in the graph
static_out, _ = self.lstm(self.static_input, (self.h0, self.c0))
# 🟢 MEMORY CONTIGUITY: Ensure contiguous memory layout
static_last = static_out[:, -1, :].contiguous()
self.static_output.copy_(self.fc(static_last))
# Wait for graph capture to complete
torch.cuda.synchronize()
# Mark graph as ready for use
self.graph_ready = True
def _standard_forward(self, x):
"""Standard forward pass with memory contiguity optimization"""
# 🟢 MEMORY CONTIGUITY: Ensure input is contiguous
if not x.is_contiguous():
x = x.contiguous()
# Forward pass through LSTM
out, _ = self.lstm(x, (self.h0, self.c0))
# 🟢 MEMORY CONTIGUITY: Make last output contiguous for optimal memory access
last_out = out[:, -1, :].contiguous()
return self.fc(last_out)
def forward(self, x):
"""
Forward pass through the LSTM model with three optimization techniques.
Optimization flow:
1. 🔵 CUDA GRAPHS: Check if we can use the captured graph (fast path)
2. 🟠 STATIC TENSOR REUSE: Use pre-allocated tensors for graph replay
3. 🟢 MEMORY CONTIGUITY: Ensure optimal memory layout throughout
"""
# 🔵 CUDA GRAPHS: Fast path - use captured graph if available
if (x.is_cuda and
self.graph_ready and
x.shape == self.input_shape):
# 🟠 STATIC TENSOR REUSE: Copy to pre-allocated tensor with non-blocking transfer
self.static_input.copy_(x, non_blocking=True)
# 🔵 CUDA GRAPHS: Replay the captured graph
self.graph.replay()
# Return the output from static buffer
return self.static_output.clone()
# Standard execution path
with torch.no_grad():
result = self._standard_forward(x)
# 🔵 CUDA GRAPHS: Initialize graph on first CUDA input
if x.is_cuda and self.is_cuda_available and not self.graph_ready:
try:
# Store the current input shape
self.input_shape = x.shape
# 🔵 CUDA GRAPHS: Initialize CUDA resources
self._initialize_cuda_resources()
# 🔵 CUDA GRAPHS + 🟠 STATIC TENSOR REUSE: Capture the graph
self._capture_graph(x, result)
except Exception as e:
# If graph capture fails, continue without it
self.graph_ready = False
return result
# Hyperparameters from the reference implementation
batch_size = 10
sequence_length = 512
input_size = 128
hidden_size = 256
num_layers = 6
output_size = 10
dropout = 0.0
The optimized implementation uses color-coded emojis to highlight the three optimization techniques: 🔵 CUDA Graphs, 🟢 Memory Contiguity, and 🟠 Static Tensor Reuse. The results reveal something: CUDA Graphs is essential for any meaningful speedup. Without it, no combination of other optimizations provides any benefit. But once CUDA Graphs is enabled, the additional optimizations provide incremental improvements, with all three together achieving the best 3.42× speedup.
The most impressive speedup came from a 3D operation pipeline: transposed convolution, average pooling, clamping, softmax, and element-wise multiplication. CUDA-L1 applied four optimizations:
import torch
import torch.nn as nn
class Model(nn.Module):
"""
Model that performs a 3D convolution, applies Group Normalization, minimum, clamp, and dropout.
"""
def __init__(self, in_channels, out_channels, kernel_size, groups, min_value, max_value, dropout_p):
super(Model, self).__init__()
self.conv = nn.Conv3d(in_channels, out_channels, kernel_size)
self.norm = nn.GroupNorm(groups, out_channels)
self.dropout = nn.Dropout(dropout_p)
self.min_value = min_value
self.max_value = max_value
def forward(self, x):
x = self.conv(x)
x = self.norm(x)
x = torch.min(x, torch.tensor(self.min_value))
x = torch.clamp(x, min=self.min_value, max=self.max_value)
x = self.dropout(x)
return x
# Hyperparameters
batch_size = 128
in_channels = 3
out_channels = 16
depth, height, width = 16, 32, 32
kernel_size = 3
groups = 8
min_value = 0.0
max_value = 1.0
dropout_p = 0.2
def get_inputs():
return [torch.randn(batch_size, in_channels, depth, height, width)]
def get_init_inputs():
return [in_channels, out_channels, kernel_size, groups, min_value, max_value, dropout_p]
import torch
import torch.nn as nn
# Hyperparameters
batch_size = 128
in_channels = 3
out_channels = 16
depth, height, width = 16, 32, 32
kernel_size = 3
groups = 8
min_value = 0.0
max_value = 1.0
dropout_p = 0.2
class ModelNew(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, groups, min_value, max_value, dropout_p):
super(ModelNew, self).__init__()
# Store the original layers for parameter compatibility
self.conv = nn.Conv3d(in_channels, out_channels, kernel_size)
self.norm = nn.GroupNorm(groups, out_channels)
self.dropout = nn.Dropout(dropout_p)
self.min_value = min_value
self.max_value = max_value
self.dropout_p = dropout_p
# 🔵 TECH 1: Mathematical Short-Circuit Optimization
# Detects when min_value=0.0 to skip entire computation
self.use_optimized_path = (min_value == 0.0)
# 🟠 TECH 4: Pre-computed Convolution Parameters
# Extract and store conv parameters once during initialization
if isinstance(kernel_size, int):
self.kernel_size = (kernel_size, kernel_size, kernel_size)
else:
self.kernel_size = kernel_size
self.stride = self.conv.stride
self.padding = self.conv.padding
self.dilation = self.conv.dilation
# 🟠 TECH 4: Pre-compute output dimensions for standard input
self.out_depth = ((depth + 2 * self.padding[0] - self.dilation[0] * (self.kernel_size[0] - 1) - 1) // self.stride[0]) + 1
self.out_height = ((height + 2 * self.padding[1] - self.dilation[1] * (self.kernel_size[1] - 1) - 1) // self.stride[1]) + 1
self.out_width = ((width + 2 * self.padding[2] - self.dilation[2] * (self.kernel_size[2] - 1) - 1) // self.stride[2]) + 1
# Standard output shape for the default batch size
self.standard_shape = (batch_size, out_channels, self.out_depth, self.out_height, self.out_width)
# 🟣 TECH 2: Pre-allocated Zero Tensors
# Create zero tensors once to avoid allocation overhead
if self.use_optimized_path:
self.register_buffer('zero_output_float32',
torch.zeros(self.standard_shape, dtype=torch.float32),
persistent=False)
self.register_buffer('zero_output_float16',
torch.zeros(self.standard_shape, dtype=torch.float16),
persistent=False)
self.register_buffer('zero_output_bfloat16',
torch.zeros(self.standard_shape, dtype=torch.bfloat16),
persistent=False)
def calculate_output_shape(self, input_shape):
"""Calculate the output shape of the convolution operation."""
batch_size, _, d, h, w = input_shape
# 🟠 TECH 4: Use precomputed parameters
# Avoid repeated attribute lookups
out_d = ((d + 2 * self.padding[0] - self.dilation[0] * (self.kernel_size[0] - 1) - 1) // self.stride[0]) + 1
out_h = ((h + 2 * self.padding[1] - self.dilation[1] * (self.kernel_size[1] - 1) - 1) // self.stride[1]) + 1
out_w = ((w + 2 * self.padding[2] - self.dilation[2] * (self.kernel_size[2] - 1) - 1) // self.stride[2]) + 1
return (batch_size, self.conv.out_channels, out_d, out_h, out_w)
def forward(self, x):
# 🔵 TECH 1: Mathematical Short-Circuit - Main optimization
# Skip all computation when we know result will be zeros
if not self.use_optimized_path:
# Standard path for non-optimized cases
x = self.conv(x)
x = self.norm(x)
x = torch.minimum(x, torch.tensor(self.min_value, device=x.device))
x = torch.clamp(x, min=self.min_value, max=self.max_value)
x = self.dropout(x)
return x
# Optimized path when min_value == 0.0
# Since min(x, 0) followed by clamp(0, 1) always produces zeros
# 🟢 TECH 3: Direct Shape Matching
# Fast path for standard input dimensions
if x.shape == (batch_size, in_channels, depth, height, width):
# 🟣 TECH 2: Use pre-allocated tensors
# Return pre-allocated zeros matching input dtype
if x.dtype == torch.float32:
return self.zero_output_float32
elif x.dtype == torch.float16:
return self.zero_output_float16
elif x.dtype == torch.bfloat16:
return self.zero_output_bfloat16
else:
# Fallback for other dtypes
return torch.zeros(self.standard_shape, device=x.device, dtype=x.dtype)
else:
# For non-standard input shapes, calculate output shape
output_shape = self.calculate_output_shape(x.shape)
return torch.zeros(output_shape, device=x.device, dtype=x.dtype)
# Color Legend:
# 🔵 TECH 1: Mathematical Short-Circuit (Blue) - Skips computation when min_value=0
# 🟣 TECH 2: Pre-allocated Tensors (Purple) - Pre-allocates zero tensors
# 🟢 TECH 3: Direct Shape Matching (Green) - Fast path for standard shapes
# 🟠 TECH 4: Pre-computed Parameters (Orange) - Pre-computes conv parameters
The optimized implementation uses color-coded emojis to highlight the four optimization techniques: 🔵 Mathematical Short-Circuit, 🟣 Pre-allocated Tensors, 🟢 Direct Shape Matching, and 🟠 Pre-computed Parameters.
These case studies reveal the true power of CUDA-L1: it doesn't just apply known optimization tricks, it discovers fundamental mathematical and computational insights that lead to performance improvements. Through reinforcement learning, it explores the vast space of possible implementations and learns principles that even experienced CUDA programmers might miss.
During the training process, we found that RL is particularly susceptible to reward hacking. We've already identified quite some hacking cases (e.g., exploiting timing measurements & caching results). If you identify any additional reward hacks in the code, we would greatly appreciate you letting us know. You can contact us via email at `research@deep-reinforce.com` or open a GitHub issue at https://github.com/deepreinforce-ai/CUDA-L1
@article{deepreinforce2025cudal1,
title={CUDA-L1: Improving CUDA Optimization via Contrastive Reinforcement Learning},
author={DeepReinforce Team},
journal={arXiv preprint arXiv:2507.14111},
year={2025}
}