CUDA-L1: Improving CUDA Optimization via Contrastive Reinforcement Learning

DeepReinforce Team
July 21, 2025

Introduction

The GPU Crisis and How AI Might Save Us

Let's face it - we're in the middle of a GPU shortage crisis. Everyone wants GPUs for their AI projects. The demand is through the roof, and prices are absolutely insane - a single H100 can cost over $30,000, and good luck even finding one in stock.

For most companies and researchers, buying more GPUs simply isn't an option. The only realistic solution? We need to squeeze every bit of performance from the GPUs we already have.

The Old Way: Manual CUDA Optimization Hell

If you've ever tried optimizing CUDA code, you know the pain. It's like solving a massive puzzle where you're constantly tweaking memory access patterns, adjusting thread blocks, and running endless profiling tests. Engineers spend weeks or months on this stuff, and it's honestly exhausting.

What if LLMs Could Do This For Us?

Here's where things get interesting. Recent LLM models - think DeepSeek-R1 and OpenAI's o1 - are getting pretty good at writing code. And here's the kicker: CUDA optimization has a super clear reward signal - speed! Your code either runs faster or it doesn't. That's perfect for training RL.

Imagine this: instead of you pulling your hair out trying different optimizations, an AI could generate thousands of variations, test them all, and learn what works. It might even discover tricks that humans never thought of!

Introducing CUDA-L1

So we built CUDA-L1, which uses something we call "contrastive reinforcement learning." Think of it like this: instead of just trying random stuff, our AI compares different CUDA versions side-by-side and learns why some are faster than others. It's like having a coach that shows you good vs. bad examples until you get it. And we found CUDA-L1 excels at:

  • Discover optimization techniques – techniques like memory coalescing, loop unrolling, operation fusion. Some of these are well-known, others are rarely used.
  • Figure out the perfect combo – like a chef who knows exactly which spices work together, it combines optimizations in ways that maximize performance.
  • Learn the "rules" of CUDA – like how some optimizations multiply each other's effects, or how you need to apply certain "gatekeeper" techniques first before others will work.
  • Spot hidden problems – sometimes it rejects optimizations that look good on paper but actually slow things down due to sneaky issues like CPU-GPU sync overhead.

How CUDA-L1 Works ?

The Problem: Why Can't Current LLMs Write Good CUDA?

Ask any AI to write CUDA code and you'll likely get something that doesn't compile, crashes, or runs painfully slow. The reason is simple: these models barely saw any quality CUDA code during training. It's like asking someone who's only read cooking blogs to become a chef.

CUDA-L1: A Three-Step Recipe

We built CUDA-L1 with a three-stage pipeline: supervised learning (learn the basics), self-supervised learning (practice until perfect), and contrastive reinforcement learning (compete for speed).

CUDA-L1 Pipeline
Stage 1: Learning the Basics with Data Augmentation

First, we needed to fix the data shortage problem. We took existing CUDA code and created variations of it - expanding the model's exposure to different CUDA patterns. This supervised fine-tuning phase has one goal: make sure the AI can write CUDA code that actually compiles and runs correctly.

Stage 2: Practice Makes Perfect with Self-Supervised Learning

Next, we let the model generate its own CUDA code, test it, and learn from what works. The model generates thousands of code samples, we automatically test each one, and only the successful implementations get fed back for more training. No speed optimization yet - just making sure the code works reliably.

Stage 3: The Speed Revolution - Contrastive Reinforcement Learning

This is where CUDA-L1 becomes special. Traditional RL would just assign scores to generated code and hope the model figures out why some implementations are faster. That's like grading exams without showing students the correct answers. Instead, we do something radically different. Look at this actual prompt we use:

CUDA Kernel Optimization Prompt:
We show the AI multiple CUDA implementations WITH their speed scores:
  • "Here's kernel_v1 that achieves 1.2x speedup"
  • "Here's kernel_v2 that achieves 2.8x speedup"
  • "Here's kernel_v3 that achieves 1.5x speedup"

Then we ask three critical questions:

  1. Performance Analysis: "Why is kernel_v2 so much faster? What optimizations did it use that the others didn't?"
  2. Algorithm Design: "Based on this analysis, what optimization strategy would work even better?"
  3. Code Implementation: "Now write a kernel that beats them all."

The magic happens because the AI can directly see and reason about performance differences. It's not guessing in the dark - it's learning from concrete examples of what makes CUDA code fast.

Mitigating Reward Hacking in RL Training

Reinforcement learning is notorious for exhibiting reward hacking behaviors, where models exploit system vulnerabilities to achieve higher rewards while generating outputs that deviate from the intended objectives. In our experiments, we discovered that over 30% of generated implementations attempted some form of reward hacking.

Reward Hacking Cases

During our initial training procedure, we identified three major categories of reward hacking behaviors:

1. Improper Timing Measurement

KernelBench measures execution time by recording timing events on the main CUDA stream as follows:

CUDA - Vulnerable Timing
start_event.record(original_model_stream)
model(*inputs)
end_event.record(original_model_stream)
torch.cuda.synchronize(device=device)

Warning: This vulnerability led to artificial 18× speedups that were entirely fake - the actual computation performance was unchanged!

However, RL-generated code exploits this by creating additional CUDA streams that execute asynchronously. Since KernelBench only monitors the main stream, it fails to capture the actual execution time of operations running on parallel streams. This vulnerability is significant: in our initial implementation, we find that 82 out of 250 (32.8%) RL-generated implementations exploit this timing loophole to appear faster than they actually are, leading to an overall speedup of 18×. To address this issue, prompt engineering alone is insufficient. Our fix synchronizes all CUDA streams:

CUDA - Fixed Timing
if custom_contain_new_streams:
    for stream in custom_model_streams:
        custom_model_stream.wait_stream(stream)
end_event.record(custom_model_stream)
2. Hyperparameter Manipulation

In KernelBench, each computational task is associated with specific hyper- parameters, including batch_size, dim, in_features dimension, out_features dimension, scaling_factor, and others. The RL agent learned to exploit these parameters by generating code that artificially reduces their values, thereby achieving superficial speedup improvements that do not reflect genuine optimization performance.

3. Result Caching

The RL agent developed strategies to cache computational results across evaluation batches based on input addresses. When another input’s address matches a cached one, it returns the cached output. The following code snippet gives an illustration:

Python - Caching Example
cache_key = x.data_ptr()
if cache_key in self.cache:
    return self.cache[cache_key]

Defense Strategies

To combat these sophisticated reward hacking behaviors, we developed a multi-layered defense system combining automated detection, continuous learning, and mathematical constraints.

1. Adversarial Reward Checking Model

We deployed DeepSeek-R1 as an adversarial checker that analyzes generated code for potential exploits. The model achieves over 60% detection accuracy through pattern analysis, semantic understanding, and anomaly detection. When speedups exceed suspicious thresholds (e.g., >10× for simple operations), it triggers multi-stage verification including output consistency checks, memory usage analysis, and GPU utilization metrics.

2. Dynamic Hacking-Case Database

A continuously updated database containing 500+ unique hacking patterns helps identify new exploits. For each generated implementation, we retrieve similar hacking cases using AST-based code analysis and performance profiling. This contextual information improves detection accuracy by 25% and allows us to track how hacking strategies evolve across training iterations.

3. Reward Smoothing and Normalization

To prevent over-optimization for extreme cases, we apply mathematical constraints on rewards:

rnormalized = (r - μ) / σ
rsmooth = clip(rnormalized, -k, k)

Here, μ and σ are rolling statistics updated every 100 iterations, and k = 1.5 represents the maximum reasonable speedup. For suspicious high-reward cases with low confidence, we apply additional dampening based on consistency metrics and alignment with known optimization patterns.

Key Takeaway: The arms race between the RL agent and our detection systems highlights the importance of robust evaluation frameworks. Our multi-layered approach has proven effective at maintaining training integrity while still allowing genuine breakthrough optimizations to be rewarded appropriately.

Evaluations

We tested CUDA-L1 on KernelBench, a comprehensive benchmark suite with three difficulty levels:

  • Level 1: A subset that contains simple operations (like matrix multiply)
  • Level 2: A subset that contains operator sequences (like attention mechanisms)
  • Level 3: A subset that contains complex ML tasks (like full transformer layers)

And we use All Levels to denote the full dataset containing all three levels.

To perform a comprehensive evaluation on the generated code, we perform the following comparisons:

I) Default: This compares the CUDA-L1 generated code with the reference code by KernelBench.

II) Torch Compile: This compares the CUDA-L1 generated code with the reference code enhanced by torch.compile with default settings. Torch.compile applies graph-level optimizations including operator fusion, memory planning, and kernel selection to accelerate PyTorch models through just-in-time compilation.

III) Torch Compile Reduce Overhead: This compares the CUDA-L1 generated code with the reference code enhanced by torch.compile with reduce-overhead mode enabled. This mode minimizes the compilation overhead by caching compiled graphs more aggressively and reducing recompilation frequency, making it particularly suitable for inference workloads with static shapes.

IV) CUDA Graph: Since KernelBench does not provide official CUDA Graph implementations, we employ Claude 4 to generate CUDA Graph-augmented code for each reference implementation. CUDA Graphs capture a series of CUDA kernels and their dependencies into a single graph structure that can be launched with minimal CPU overhead, eliminating the need for repeated kernel launch commands and significantly reducing CPU-GPU synchronization costs.

Performance comparison across different configurations on KernelBench on A100
Configuration DataType Mean Max 75% 50% 25% Success↑
# out of total
Speedup↑
>1.01x out of total
Default All Levels 3.12× 120× 2.25× 1.42× 1.17× 249/250 226/250
Torch Compile All Levels 2.77× 69.0× 2.55× 1.72× 1.14× 249/250 203/250
Torch Compile RO All Levels 2.88× 80.1× 2.48× 1.67× 1.13× 249/250 200/250
CUDA Graph All Levels 2.81× 97.9× 1.83× 1.20× 0.954× 249/250 147/229

• RO = Reduce Overhead
• Success and Speedup indicate the number of successful benchmarks out of the total for each level

Generalization of A100-Optimized Kernels to Other GPU Architectures

We trained CUDA-L1 on NVIDIA A100s, but what if you're using a different GPU? Good news: the optimizations transfer remarkably well. We tested the same A100-optimized kernels on:

Mean speedup across different configurations and GPU devices
Configuration A100 3090 H100 H20 L40
Default 3.12× 2.51× 3.85× 2.38× 3.13×
Torch Compile 2.77× 2.58× 2.74× 2.89× 2.85×
Torch Compile RO 2.88× 2.61× 2.77× 2.82× 2.89×
CUDA Graph 2.81× 3.34× 2.23× 2.20× 3.98×

The results show fascinating patterns across different GPU architectures:

  • H100 excels with default CUDA-L1 optimizations (3.85×), showing that our base optimizations work particularly well on the latest datacenter GPUs.
  • Consumer GPUs (RTX 3090) benefit more from CUDA Graphs (3.34×), likely due to their different memory hierarchies and lower CPU-GPU communication overhead.
  • Kernels working on L40 shows exceptional performance with CUDA Graphs (3.98×), the highest speedup we observed across all configurations.
  • Optimized kernels show consistent performance across all GPUs, ranging from 2.20× to 3.89×, demonstrating strong generalization capability.

These results confirm that CUDA-L1's optimization patterns are fundamental enough to benefit any modern GPU architecture, though different GPUs may favor different optimization strategies. This suggests exciting opportunities for GPU-specific training in future versions of CUDA-L1.

Case Studies

Let's dive deep into three examples to understand what CUDA-L1 actually does:

Case 1: Diagonal Matrix Multiplication - 64× Faster

This task performs matrix multiplication between a diagonal matrix (represented by its diagonal elements) and a dense matrix, both with dimension N=4096.

Python - Reference Implementation
class Model(nn.Module):
    def forward(self, A, B):
        # A: (N,) - 1D tensor of shape N
        # B: (N, M) - 2D tensor of shape N x M
        # torch.diag(A): (N, N) - creates diagonal matrix from A
        # Result: (N, N) @ (N, M) = (N, M)
        return torch.diag(A) @ B
Python - CUDA-L1 Optimized
64× faster
class Model(nn.Module):
    def forward(self, A, B):
        return A.unsqueeze(1) * B

The optimized implementation leverages PyTorch's broadcasting mechanism brilliantly. Instead of creating a full N×N diagonal matrix (which would be mostly zeros), it simply reshapes the diagonal vector A from (N,) to (N, 1) and uses broadcasting to multiply each row of B by the corresponding element of A.

The benefits are substantial:

  • Memory: O(1) extra memory instead of O(N²)
  • Computation: O(NM) operations instead of O(N²M)
  • Result: 64× speedup!

What's remarkable is that CUDA-L1 discovered this algebraic simplification on its own through RL exploration. By testing semantically equivalent implementations, it learned to identify patterns where computationally expensive operations can be replaced with more efficient alternatives.

Case 2: LSTM Networks - 3.4× Faster

For a classical LSTM neural network (Level 3, Task 35), CUDA-L1 achieved a 3.4× speedup by applying three key optimizations:

  1. CUDA Graphs: Captures the entire LSTM computation sequence into a replayable graph structure
  2. Memory Contiguity: Ensures all tensors maintain contiguous memory layouts
  3. Static Tensor Reuse: Pre-allocates tensors and reuses them across forward passes
Python - Reference Implementation
import torch
import torch.nn as nn

class Model(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, output_size, dropout=0.0):
        """
        Initialize the LSTM model.
        :param input_size: The number of expected features in the input `x`
        :param hidden_size: The number of features in the hidden state `h`
        :param num_layers: Number of recurrent layers
        :param output_size: The number of output features
        :param dropout: If non-zero, introduces a Dropout layer on the outputs of each LSTM layer except the last layer, with dropout probability equal to `dropout`
        """
        super(Model, self).__init__()
        # Initialize hidden state with random values
        self.h0 = torch.randn((num_layers, batch_size, hidden_size))
        self.c0 = torch.randn((num_layers, batch_size, hidden_size))
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, dropout=dropout, bidirectional=False)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        """
        Forward pass through the LSTM model.
        :param x: The input tensor, shape (batch_size, sequence_length, input_size)
        :return: The output tensor, shape (batch_size, sequence_length, output_size)
        """
        self.h0 = self.h0.to(x.device)
        self.c0 = self.h0.to(x.device)  # BUG: This should be self.c0.to(x.device)

        # Forward propagate LSTM
        out, state = self.lstm(x, (self.h0, self.c0))  # out: tensor of shape (batch_size, seq_length, hidden_size)

        # Decode the hidden state of the last time step
        out = self.fc(out[:, -1, :])  # out: tensor of shape (batch_size, output_size)

        return state[0]

# Test code
batch_size = 10
sequence_length = 512
input_size = 128
hidden_size = 256
num_layers = 6
output_size = 10
dropout = 0.0

def get_inputs():
    return [torch.randn(batch_size, sequence_length, input_size)]

def get_init_inputs():
    return [input_size, hidden_size, num_layers, output_size, dropout]
Python - CUDA-L1 Optimized
3.4× faster
import torch
import torch.nn as nn
import torch.cuda as cuda

class ModelNew(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, output_size, dropout=0.0):
        """
        Initialize the LSTM model with three core optimization techniques.

        Color coding:
        - 🔵 BLUE: CUDA Graphs optimization
        - 🟢 GREEN: Memory Contiguity optimization
        - 🟠 ORANGE: Static Tensor Reuse optimization
        """
        super(ModelNew, self).__init__()

        # Initialize hidden states as buffers
        self.register_buffer('h0', torch.randn((num_layers, batch_size, hidden_size)))
        self.register_buffer('c0', torch.randn((num_layers, batch_size, hidden_size)))

        # Use PyTorch's optimized LSTM implementation
        self.lstm = nn.LSTM(
            input_size=input_size,
            hidden_size=hidden_size,
            num_layers=num_layers,
            batch_first=True,
            dropout=dropout,
            bidirectional=False
        )

        self.fc = nn.Linear(hidden_size, output_size)

        # 🔵 CUDA GRAPHS: Variables for graph capture and replay
        self.graph = None
        self.graph_ready = False
        self.input_shape = None

        # 🟠 STATIC TENSOR REUSE: Pre-allocated tensors for graph execution
        self.static_input = None
        self.static_output = None

        # 🔵 CUDA GRAPHS: Streams for graph operations
        self.graph_stream = None

        # Track if we're running on CUDA
        self.is_cuda_available = torch.cuda.is_available()

    def _initialize_cuda_resources(self):
        """🔵 CUDA GRAPHS: Initialize CUDA stream for graph operations"""
        if self.graph_stream is None:
            self.graph_stream = cuda.Stream()

    def _capture_graph(self, x, result):
        """
        🔵 CUDA GRAPHS: Capture the computation graph for replay
        🟠 STATIC TENSOR REUSE: Create static tensors for graph capture
        """
        # 🟠 STATIC TENSOR REUSE: Clone tensors for static allocation
        self.static_input = x.clone()
        self.static_output = result.clone()

        # 🔵 CUDA GRAPHS: Capture the computation graph
        with torch.cuda.stream(self.graph_stream):
            self.graph = cuda.CUDAGraph()
            with cuda.graph(self.graph):
                # Operations to capture in the graph
                static_out, _ = self.lstm(self.static_input, (self.h0, self.c0))

                # 🟢 MEMORY CONTIGUITY: Ensure contiguous memory layout
                static_last = static_out[:, -1, :].contiguous()

                self.static_output.copy_(self.fc(static_last))

        # Wait for graph capture to complete
        torch.cuda.synchronize()

        # Mark graph as ready for use
        self.graph_ready = True

    def _standard_forward(self, x):
        """Standard forward pass with memory contiguity optimization"""

        # 🟢 MEMORY CONTIGUITY: Ensure input is contiguous
        if not x.is_contiguous():
            x = x.contiguous()

        # Forward pass through LSTM
        out, _ = self.lstm(x, (self.h0, self.c0))

        # 🟢 MEMORY CONTIGUITY: Make last output contiguous for optimal memory access
        last_out = out[:, -1, :].contiguous()

        return self.fc(last_out)

    def forward(self, x):
        """
        Forward pass through the LSTM model with three optimization techniques.

        Optimization flow:
        1. 🔵 CUDA GRAPHS: Check if we can use the captured graph (fast path)
        2. 🟠 STATIC TENSOR REUSE: Use pre-allocated tensors for graph replay
        3. 🟢 MEMORY CONTIGUITY: Ensure optimal memory layout throughout
        """

        # 🔵 CUDA GRAPHS: Fast path - use captured graph if available
        if (x.is_cuda and
            self.graph_ready and
            x.shape == self.input_shape):

            # 🟠 STATIC TENSOR REUSE: Copy to pre-allocated tensor with non-blocking transfer
            self.static_input.copy_(x, non_blocking=True)

            # 🔵 CUDA GRAPHS: Replay the captured graph
            self.graph.replay()

            # Return the output from static buffer
            return self.static_output.clone()

        # Standard execution path
        with torch.no_grad():
            result = self._standard_forward(x)

            # 🔵 CUDA GRAPHS: Initialize graph on first CUDA input
            if x.is_cuda and self.is_cuda_available and not self.graph_ready:
                try:
                    # Store the current input shape
                    self.input_shape = x.shape

                    # 🔵 CUDA GRAPHS: Initialize CUDA resources
                    self._initialize_cuda_resources()

                    # 🔵 CUDA GRAPHS + 🟠 STATIC TENSOR REUSE: Capture the graph
                    self._capture_graph(x, result)

                except Exception as e:
                    # If graph capture fails, continue without it
                    self.graph_ready = False

            return result

            # Hyperparameters from the reference implementation
            batch_size = 10
            sequence_length = 512
            input_size = 128
            hidden_size = 256
            num_layers = 6
            output_size = 10
            dropout = 0.0

The optimized implementation uses color-coded emojis to highlight the three optimization techniques: 🔵 CUDA Graphs, 🟢 Memory Contiguity, and 🟠 Static Tensor Reuse. The results reveal something: CUDA Graphs is essential for any meaningful speedup. Without it, no combination of other optimizations provides any benefit. But once CUDA Graphs is enabled, the additional optimizations provide incremental improvements, with all three together achieving the best 3.42× speedup.

Case 3: 3D Convolution Pipeline - 120× Faster

The most impressive speedup came from a 3D operation pipeline: transposed convolution, average pooling, clamping, softmax, and element-wise multiplication. CUDA-L1 applied four optimizations:

  1. Mathematical Short-Circuit: Detects when min_value equals 0.0 and skips the entire computation
  2. Pre-allocated Tensors: Creates zero tensors during initialization
  3. Direct Shape Matching: Provides a fast path for standard input shapes
  4. Pre-computed Parameters: Stores convolution parameters during initialization
Python - Reference Implementation
import torch
import torch.nn as nn

class Model(nn.Module):
    """
    Model that performs a 3D convolution, applies Group Normalization, minimum, clamp, and dropout.
    """
    def __init__(self, in_channels, out_channels, kernel_size, groups, min_value, max_value, dropout_p):
        super(Model, self).__init__()
        self.conv = nn.Conv3d(in_channels, out_channels, kernel_size)
        self.norm = nn.GroupNorm(groups, out_channels)
        self.dropout = nn.Dropout(dropout_p)
        self.min_value = min_value
        self.max_value = max_value

    def forward(self, x):
        x = self.conv(x)
        x = self.norm(x)
        x = torch.min(x, torch.tensor(self.min_value))
        x = torch.clamp(x, min=self.min_value, max=self.max_value)
        x = self.dropout(x)
        return x

# Hyperparameters
batch_size = 128
in_channels = 3
out_channels = 16
depth, height, width = 16, 32, 32
kernel_size = 3
groups = 8
min_value = 0.0
max_value = 1.0
dropout_p = 0.2

def get_inputs():
    return [torch.randn(batch_size, in_channels, depth, height, width)]

def get_init_inputs():
    return [in_channels, out_channels, kernel_size, groups, min_value, max_value, dropout_p]
Python - CUDA-L1 Optimized
120× faster
import torch
import torch.nn as nn

# Hyperparameters
batch_size = 128
in_channels = 3
out_channels = 16
depth, height, width = 16, 32, 32
kernel_size = 3
groups = 8
min_value = 0.0
max_value = 1.0
dropout_p = 0.2

class ModelNew(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, groups, min_value, max_value, dropout_p):
        super(ModelNew, self).__init__()
        # Store the original layers for parameter compatibility
        self.conv = nn.Conv3d(in_channels, out_channels, kernel_size)
        self.norm = nn.GroupNorm(groups, out_channels)
        self.dropout = nn.Dropout(dropout_p)
        self.min_value = min_value
        self.max_value = max_value
        self.dropout_p = dropout_p

        # 🔵 TECH 1: Mathematical Short-Circuit Optimization
        # Detects when min_value=0.0 to skip entire computation
        self.use_optimized_path = (min_value == 0.0)

        # 🟠 TECH 4: Pre-computed Convolution Parameters
        # Extract and store conv parameters once during initialization
        if isinstance(kernel_size, int):
            self.kernel_size = (kernel_size, kernel_size, kernel_size)
        else:
            self.kernel_size = kernel_size
        self.stride = self.conv.stride
        self.padding = self.conv.padding
        self.dilation = self.conv.dilation

        # 🟠 TECH 4: Pre-compute output dimensions for standard input
        self.out_depth = ((depth + 2 * self.padding[0] - self.dilation[0] * (self.kernel_size[0] - 1) - 1) // self.stride[0]) + 1
        self.out_height = ((height + 2 * self.padding[1] - self.dilation[1] * (self.kernel_size[1] - 1) - 1) // self.stride[1]) + 1
        self.out_width = ((width + 2 * self.padding[2] - self.dilation[2] * (self.kernel_size[2] - 1) - 1) // self.stride[2]) + 1

        # Standard output shape for the default batch size
        self.standard_shape = (batch_size, out_channels, self.out_depth, self.out_height, self.out_width)

        # 🟣 TECH 2: Pre-allocated Zero Tensors
        # Create zero tensors once to avoid allocation overhead
        if self.use_optimized_path:
            self.register_buffer('zero_output_float32',
                               torch.zeros(self.standard_shape, dtype=torch.float32),
                               persistent=False)
            self.register_buffer('zero_output_float16',
                               torch.zeros(self.standard_shape, dtype=torch.float16),
                               persistent=False)
            self.register_buffer('zero_output_bfloat16',
                               torch.zeros(self.standard_shape, dtype=torch.bfloat16),
                               persistent=False)

    def calculate_output_shape(self, input_shape):
        """Calculate the output shape of the convolution operation."""
        batch_size, _, d, h, w = input_shape

        # 🟠 TECH 4: Use precomputed parameters
        # Avoid repeated attribute lookups
        out_d = ((d + 2 * self.padding[0] - self.dilation[0] * (self.kernel_size[0] - 1) - 1) // self.stride[0]) + 1
        out_h = ((h + 2 * self.padding[1] - self.dilation[1] * (self.kernel_size[1] - 1) - 1) // self.stride[1]) + 1
        out_w = ((w + 2 * self.padding[2] - self.dilation[2] * (self.kernel_size[2] - 1) - 1) // self.stride[2]) + 1

        return (batch_size, self.conv.out_channels, out_d, out_h, out_w)

    def forward(self, x):
        # 🔵 TECH 1: Mathematical Short-Circuit - Main optimization
        # Skip all computation when we know result will be zeros
        if not self.use_optimized_path:
            # Standard path for non-optimized cases
            x = self.conv(x)
            x = self.norm(x)
            x = torch.minimum(x, torch.tensor(self.min_value, device=x.device))
            x = torch.clamp(x, min=self.min_value, max=self.max_value)
            x = self.dropout(x)
            return x

        # Optimized path when min_value == 0.0
        # Since min(x, 0) followed by clamp(0, 1) always produces zeros

        # 🟢 TECH 3: Direct Shape Matching
        # Fast path for standard input dimensions
        if x.shape == (batch_size, in_channels, depth, height, width):
            # 🟣 TECH 2: Use pre-allocated tensors
            # Return pre-allocated zeros matching input dtype
            if x.dtype == torch.float32:
                return self.zero_output_float32
            elif x.dtype == torch.float16:
                return self.zero_output_float16
            elif x.dtype == torch.bfloat16:
                return self.zero_output_bfloat16
            else:
                # Fallback for other dtypes
                return torch.zeros(self.standard_shape, device=x.device, dtype=x.dtype)
        else:
            # For non-standard input shapes, calculate output shape
            output_shape = self.calculate_output_shape(x.shape)
            return torch.zeros(output_shape, device=x.device, dtype=x.dtype)

# Color Legend:
# 🔵 TECH 1: Mathematical Short-Circuit (Blue) - Skips computation when min_value=0
# 🟣 TECH 2: Pre-allocated Tensors (Purple) - Pre-allocates zero tensors
# 🟢 TECH 3: Direct Shape Matching (Green) - Fast path for standard shapes
# 🟠 TECH 4: Pre-computed Parameters (Orange) - Pre-computes conv parameters

The optimized implementation uses color-coded emojis to highlight the four optimization techniques: 🔵 Mathematical Short-Circuit, 🟣 Pre-allocated Tensors, 🟢 Direct Shape Matching, and 🟠 Pre-computed Parameters.

These case studies reveal the true power of CUDA-L1: it doesn't just apply known optimization tricks, it discovers fundamental mathematical and computational insights that lead to performance improvements. Through reinforcement learning, it explores the vast space of possible implementations and learns principles that even experienced CUDA programmers might miss.

Limitations and Challenges

During the training process, we found that RL is particularly susceptible to reward hacking. We've already identified quite some hacking cases (e.g., exploiting timing measurements & caching results). If you identify any additional reward hacks in the code, we would greatly appreciate you letting us know. You can contact us via email at `research@deep-reinforce.com` or open a GitHub issue at https://github.com/deepreinforce-ai/CUDA-L1

BibTeX

@article{deepreinforce2025cudal1,
  title={CUDA-L1: Improving CUDA Optimization via Contrastive Reinforcement Learning},
  author={DeepReinforce Team},
  journal={arXiv preprint arXiv:2507.14111},
  year={2025}
}