CRINN: Contrastive Reinforcement Learning for Approximate Nearest Neighbor Search

DeepReinforce Team
August 3, 2025
CUDA-L1 Speedup Results

🥳 News & Updates

[August 3, 2025] The repository is created and we release the first version of CRINN!

📋 Upcoming Features

  • Add RL version based on ParlayANN
  • Incorporate both Euclidean and angular distances as RL training rewards
  • Introduction

    Why Vector Search is So Slow ?

    If you've used ChatGPT with web search, asked an AI agent to find something in your documents, or experienced any modern AI application that "remembers" things, you've benefited from vector search. Behind the scenes, these systems convert your text into high-dimensional vectors and search through millions or billions of similar vectors to find relevant information. The problem is that vector search is computationally expensive—really expensive. Finding exact nearest neighbors in high dimensions is so slow it's practically unusable for real applications. That's why everyone uses approximate nearest neighbor search (ANNS), trading a bit of accuracy for massive speedups.

    The Manual Optimization Nightmare

    Optimizing ANNS algorithms is like tuning a Formula 1 car—there are hundreds of knobs to turn, and each adjustment affects everything else. Engineers spend months profiling cache misses and memory access patterns, hand-tuning graph construction parameters, experimenting with prefetching strategies, and balancing between search accuracy and speed. And here's the kicker: what works great on one dataset might be terrible on another. What runs fast on your laptop might crawl on a server. It's an endless optimization treadmill that requires deep expertise in computer architecture, parallel programming, and the mathematical properties of ANNS algorithms.

    Introduce CRINN

    What if we could teach an AI to do these optimizations for us? That's exactly what CRINN does. Instead of manually tweaking code, we let a reinforcement learning model generate thousands of variations, test them all, and learn what makes vector search fast. The secret sauce is "contrastive reinforcement learning"—we show the AI multiple code implementations along with their actual speed measurements. It's like having a coach that says "Look, implementation A is 2x faster than B because it does X differently." The AI learns these patterns and generates even better code.

    Summary of Main Experimental Results

    CRINN beats state-of-the-art implementations with remarkable consistency. On the GIST-960 dataset containing image descriptors, CRINN achieves a stunning 134% speedup. For MNIST-784 handwritten digits, the improvement reaches 108%. Overall, CRINN delivers best-in-class performance on three out of six standard benchmarks and matches state-of-the-art results on two others. The best part is that CRINN learned all these optimizations automatically, discovering techniques that human experts use plus some novel ones we hadn't seen before.

    How CRINN Works

    The Challenge: Speed vs. Accuracy Trade-offs

    Unlike other optimization where you have a single metric to improve, ANNS is tricky. You're balancing two competing goals: how fast can you search (queries per second) versus how accurate are your results (recall). Change one parameter and both metrics shift in opposite directions. This is why ANNS papers always show those curvy graphs--there's no single "best" algorithm, just different points on the speed-accuracy curve.

    CRINN Training Recipe

    We break down ANNS optimization into three sequential stages, each building on the previous. This modular approach allows CRINN to focus on specific optimization opportunities while maintaining the overall algorithm structure.

    Background: Hierarchical Navigable Small World (HNSW)

    CRINN optimizes HNSW (Hierarchical Navigable Small World), a state-of-the-art graph-based ANNS algorithm. The key insight is that HNSW naturally decomposes into three distinct modules: graph construction (building the multi-layer structure), search (navigating through the layers), and refinement (fine-tuning results with techniques like quantization). By treating each module independently, CRINN can apply targeted optimizations without breaking the overall algorithm.

    Stage 1: Graph Construction

    In HNSW, vectors are inserted incrementally, with each new vector assigned to multiple layers based on an exponentially decaying probability. CRINN learns to adjust search effort based on target quality through adaptive EF scaling, implement smart memory prefetching to reduce cache misses, and use multiple entry points for parallel exploration. These optimizations form the foundation of fast vector search, as a well-constructed graph enables efficient navigation regardless of the search algorithm used.

    Stage 2: Search

    Once you have the graph, you need to search through it efficiently. The search starts from the top layer and greedily descends through each layer until reaching the bottom, where it switches to a more exhaustive exploration. The search phase is where most of the computation happens, so even small improvements here translate to significant overall speedups.

    Stage 3: Refinement

    The last stage fine-tunes the results. While less dramatic than the first two stages, refinement still contributes meaningful improvements through techniques like quantized preliminary search (using int4/int8 representations), adaptive memory prefetching with lookahead, and pre-computed edge metadata for faster access.

    The Magic: Contrastive Reinforcement Learning

    In this paper, we propose to use contrastive supervisions as rewards. This contrastive approach enables the LLM to explicitly reason about why certain ANNS implementations outperform others, learning to identify performance-critical patterns through direct comparison. The prompt has four key components: task description, previous implementations with speed scores, generation protocol, and critical requirements.

    CRINN's Contrastive Learning Prompt Structure:

    1. Task Description: "You're an ANNS optimization expert. Create an accelerated version that maintains identical functionality."

    2. Previous Implementations with Speed:

    • Implementation v1: Score 1.42 (baseline GLASS)
    • Implementation v2: Score 1.85 (+30% faster, used adaptive EF)

    3. Critical Requirements: Search quality must match reference implementation exactly. Same interface, deterministic results.

    The Reward Challenge: Solving the QPS-Recall Dilemma

    Remember that tricky speed vs. accuracy trade-off? CRINN needs a single reward number to guide its learning, but ANNS gives you a curve, not a point. Our solution: we run each implementation with different ef values (which control the exploration budget), collect multiple (QPS, recall) points, then calculate the area under the curve in the "sweet spot" range of 85-95% recall. This range captures where most real-world applications operate--high enough accuracy to be useful, but not so high that performance becomes impractical.

    This reward design is crucial because it prevents CRINN from gaming the system. Without it, the model might just optimize for raw speed by returning random results (infinite QPS, zero recall) or maximize recall by doing exhaustive search (perfect accuracy, terrible speed). The area-under-curve reward forces CRINN to find implementations that perform well across the entire operating range.

    Training with GRPO

    CRINN uses Group Relative Policy Optimization (GRPO) for training. For each input prompt q augmented with selected demonstrations, we generate G code completions from the current policy πold, represented as {d1, d2, ..., dG} (typically 4-8 variants). Each variant receives a reward score based on its execution performance.

    Speed Reward Calculation

    The reward signal is constructed by evaluating each implementation across different ef values (the number of neighbors explored during search). This produces a set of (QPS, recall) points. We filter these points to retain only those within the recall range of [0.85, 0.95] and compute the area under the curve (AUC) formed by these points. This AUC serves as our scalar reward ri for implementation i.

    Reward Normalization

    To ensure training stability, rewards undergo normalization within each group:

    i = (ri - mean(r)) / std(r)

    where r = (r1, r2, ..., rG) represents the reward scores for all generated variants.

    GRPO Training Objective

    The GRPO training objective maximizes the following loss function:

    LGRPO(θ) = 𝔼q~P(q), {di}Gi=1θold(d|q) [ 1G Gi=1 1|di| |di|t=1 min ( πθ(di,t|q, di,<t) πθold(di,t|q, di,<t) · i ,

    clip ( πθ(di,t|q, di,<t) πθold(di,t|q, di,<t) , 1−ε, 1+ε ) · i ) βDKLθ‖πref] ]

    where:

    • πθ represents the policy network under optimization
    • πθold denotes the policy from the preceding training step
    • ε controls the clipping range for policy updates (preventing drastic changes)
    • β is a regularization parameter balancing exploration and adherence to the reference policy
    • DKL measures the Kullback-Leibler divergence between current and reference distributions

    This relative comparison is more stable than absolute rewards and helps the model focus on what makes one implementation faster than another, rather than getting distracted by dataset-specific performance characteristics. The contrastive learning ensures that each iteration builds on the lessons learned from previous attempts, creating a virtuous cycle of continuous improvement.

    We start with GLASS as our baseline and optimize each module sequentially. The advantage of this approach is that CRINN isn't limited to GLASS—it can start with any ANNS implementation and evolve it for better performance. Despite training exclusively on SIFT-128 (Euclidean distance), the optimized code generalizes well to angular-distance datasets, demonstrating the robustness of the learned optimization patterns.

    Results

    Performance Across Different Datasets

    We tested CRINN on six standard benchmarks used by the vector search community. These datasets span different dimensions, distance metrics, and real-world applications, providing a comprehensive evaluation of CRINN's optimization capabilities. The diversity of these benchmarks ensures that our results aren't just tuned to specific scenarios but demonstrate genuine algorithmic improvements.

    CRINN Performance vs Best Baselines at Different Recall Levels
    Dataset Recall CRINN QPS Best Baseline Baseline QPS Improvement
    Euclidean Distance
    SIFT-128 0.900 36,876 ParlayANN 29,368 +25.57%
    0.950 27,499 ParlayANN 23,057 +19.26%
    0.990 13,014 ParlayANN 11,808 +10.21%
    0.999 5,158 ParlayANN 4,996 +3.25%
    GIST-960 0.900 4,288 ParlayANN 3,788 +13.20%
    0.950 2,925 ParlayANN 2,348 +24.59%
    0.990 1,149 ParlayANN 666 +72.68%
    MNIST-784 0.900 24,826 ParlayANN 19,324 +28.47%
    0.950 22,008 ParlayANN 17,293 +27.26%
    0.990 17,457 ParlayANN 11,728 +48.85%
    0.999 10,600 ParlayANN 5,722 +85.25%
    Angular Distance
    GloVe-100 0.900 5,947 Vearch 5,768 +3.09%
    0.950 3,024 ParlayANN 3,212 -5.84%
    GloVe-25 0.900 37,474 Glass 31,611 +18.55%
    0.950 28,909 Glass 21,899 +32.01%
    0.990 13,574 Glass 11,804 +14.99%
    0.999 4,588 Glass 4,549 +0.87%
    NYTimes-256 0.900 1,623 ParlayANN 9,459 -82.85%

    The above presents the QPS (Queries Per Second) performance of CRINN against the best performing baselines across six benchmark datasets at various recall levels (0.9, 0.99, 0.999). For cases where performances are absent for specific recall levels, it indicates that none of the tested methods could reach the target recall threshold.

    The results demonstrate that CRINN consistently outperforms state-of-the-art methods across most configurations, with improvements ranging from modest gains of 3.09% to substantial speedups of 85.25%. The SIFT-128 dataset, which was used for training the RL agent, shows consistent improvements across all recall levels, with gains decreasing as recall requirements become more stringent. Among angular distance datasets, GloVe-25 exhibits significant improvements of up to 32.01%, while GloVe-100 shows mixed results, including a slight degradation of 5.84% at 0.95 recall. As mentioned above, CRINN achieves poor performance on NYTimes-256, where CRINN underperforms the best baseline by 82.85% for the 0.9 recall setup.

    Progressive Improvements: How Each Stage Contributes

    One of the most insightful aspects of CRINN is understanding how each optimization stage contributes to the overall speedup. This breakdown helps us understand where the major performance gains come from and validates our three-stage optimization approach.

    Stage-by-Stage Performance Improvements
    Dataset Graph Construction Search Refinement
    Individual Cumulative Individual Cumulative Individual Total
    SIFT-128 +30.12% +30.12% +25.86% +55.98% +11.19% +67.17%
    GIST-960 +58.26% +58.26% +46.43% +104.69% +29.63% +134.32%
    MNIST-784 +45.85% +45.85% +44.49% +90.34% +18.30% +108.64%
    GloVe-100 +13.19% +13.19% +19.03% +32.22% +5.86% +38.08%
    GloVe-25 +6.94% +6.94% +6.52% +13.46% +2.70% +16.16%
    NYTimes-256 -21.68% -21.68% -32.54% -54.22% -9.56% -63.78%
    Overall Average +22.11% +22.11% +18.30% +40.41% +9.69% +50.10%

    CRINN achieves performance improvements through three optimization stages with diminishing returns. The graph construction module provides the largest gains (22.11% average improvement), particularly on high-dimensional datasets like gist-960-euclidean (58.26%) and mnist-784-euclidean (45.85%). The search optimization module adds 18.30% improvement on average, while the refinement module contributes a more modest 9.69%.

    One notable exception is the nytimes-256-angular dataset, which shows performance degradation across all stages, suggesting that the optimization techniques may need dataset-specific tuning for certain angular distance computations. Overall, the results validate the effectiveness of the progressive optimization strategy, with five out of six datasets showing substantial cumulative improvements ranging from 16% to 134%.

    Case Studies: What CRINN Actually Optimizes

    Let's examine specific optimizations CRINN discovered. These examples showcase how AI can find both well-known techniques and novel improvements that human developers might overlook. Each optimization represents hours or days of manual tuning compressed into automated discovery.

    Graph Construction Optimizations

    1. Adaptive Search with Dynamic EF Scaling

    CRINN learned that not all searches need the same effort. When you need high recall, search harder. When approximate results suffice, save computation. This dynamic adjustment represents a fundamental shift from the traditional fixed-parameter approach. CRINN discovered that the relationship between recall requirements and search effort isn't linear—there's a critical threshold where additional effort yields diminishing returns.

    C++ - Original Fixed Budget
    // Old: Fixed search budget
    size_t ef = ef_construction; // Always constant
    C++ - CRINN Adaptive Budget
    30% faster
    // New: Adaptive search budget based on recall needs
    if (target_recall > critical_threshold)
        dynamic_ef = ef_search * (1.0 + recall_excess * 14.5);
    else
        dynamic_ef = ef_search;

    This optimization is brilliant in its simplicity. CRINN discovered that the magic number 14.5 provides the best scaling factor, something that would take humans extensive experimentation to find.

    2. Zero-Overhead Multi-Level Prefetching

    Modern CPUs are fast, but waiting for memory is slow. CRINN learned to predict which data it'll need next and fetch it early, implementing a sophisticated multi-level prefetching strategy that adapts to the search pattern. This optimization showcases CRINN's ability to understand hardware-level performance characteristics and translate them into algorithmic improvements.

    C++ - Fixed Prefetch Window
    // Old: Fixed prefetch window
    for (int j = 0; j < min(5, size); ++j)
        computer.prefetch(neighbors[j], 1);
    C++ - CRINN Adaptive Multi-Level Prefetch
    25% faster
    // New: Adaptive multi-level prefetching
    prefetch_depth = min(adaptive_depth, size); // 24-48 based on performance
    for (int j = 0; j < prefetch_depth; ++j)
        computer.prefetch(neighbors[j], 3); // L1 cache
    if (high_recall_needed)
        // Additional L2 prefetch for more neighbors

    3. Multi-Entry Point Search Architecture

    Instead of starting from a single point in the graph, CRINN discovered that maintaining multiple diverse entry points for parallel exploration significantly improves recall. This strategy explores diverse graph regions simultaneously, finding better neighbors faster.

    C++ - Single Entry Point
    // Old: Single entry point
    start_node = enterpoint_node;
    results = search(start_node, query);
    C++ - CRINN Multiple Entry Points
    18% better recall
    // New: Multiple diverse entry points (up to 9)
    for (node : strategic_entrypoints) {
        if (distance_to_others(node) > threshold)
            entry_points.add(node);
    }
    // Search from multiple starting points
    for (ep : entry_points)
        results.merge(search(ep, query));

    Search Optimizations

    1. Multi-Tier Entry Point Selection

    CRINN replaces single entry point initialization with a sophisticated multi-tier system that selects from primary, secondary, and tertiary entry points based on search budget. This strategy improves search quality by starting from diverse, high-quality nodes.

    C++ - Basic Entry Selection
    // Old: Single entry point
    initialize_search(single_entry_point)
    C++ - CRINN Multi-Tier Selection
    22% faster convergence
    // New: Multi-tier entry selection
    add_entry(primary_entry_point)
    if (search_budget > threshold_1)
        add_entry(secondary_entry_point)
    if (search_budget > threshold_2)
        add_entry(tertiary_entry_point)

    2. Batch Processing with Adaptive Prefetching

    CRINN optimizes neighbor processing by collecting edges into batches and using enhanced prefetch strategies. This reduces random memory access and improves cache utilization, particularly important for high-dimensional datasets.

    C++ - Basic
    // Old: Fixed prefetching
    for i in range(prefetch_count):
    	prefetch(neighbor[i])
    C++ - CRINN Adaptive Batch Prefetching
    30% better cache utilization
    # New: Adaptive batch prefetching
    prefetch_size = prefetch_count * batch_factor
    for i in range(adaptive_prefetch_size):
        prefetch(neighbor[i])
        if processing_node[j]:
            prefetch(neighbor[j + prefetch_size]) # Look ahead

    3. Intelligent Early Termination with Convergence Detection

    Why keep searching when you've already found the best results? CRINN learned to detect when search has converged and stop early. This optimization requires understanding the statistical properties of the search process—recognizing when additional exploration is unlikely to yield better results.

    C++ - Basic
    // Old: Explore until pool exhausted
    while has_candidates():
    	process_neighbor()
    C++ - CRINN Smart Termination
    15% fewer iterations
    # New: Smart termination
    no_improvement_count = 0
    while has_candidates():
        improvements = process_neighbor()
        if improvements == 0:
            no_improvement_count++
            if check_convergence(no_improvement_count):
                break # Early termination

    This optimization showcases CRINN's ability to balance aggressiveness with safety. The convergence detection provides the sweet spot between early termination benefits and maintaining search quality.

    Refinement Optimizations

    1. Adaptive Memory Prefetching with Lookahead

    CRINN replaces basic hierarchical search with an intelligent prefetching system that adapts based on edge patterns and node characteristics. This strategy significantly reduces memory latency during the refinement process.

    C++ - Basic Traversal
    // Old: Basic traversal without prefetching
    for each edge v in node_edges:
        if distance(v) < best_distance:
            update best_node
    C++ - CRINN Adaptive Prefetching
    20% lower latency
    // New: Adaptive prefetching with lookahead
    if (should_prefetch)
        prefetch(edges[0])
    for (i, edge v in node_edges) {
        prefetch(edges[i + lookahead]) // Prefetch future edges
        if (distance(v) < best_distance)
            update best_node
    }

    2. Pre-computed Edge Metadata with Pattern Recognition

    CRINN enhances the refiner by pre-computing and storing edge counts for each node level. This eliminates redundant computations and enables pattern-based optimizations during refinement. The system learns to recognize common access patterns and optimizes accordingly.

    C++ - Runtime Edge Counting
    // Old: Runtime edge counting
    count = 0
    for each edge in node:
        if (edge != -1)
            count++
    C++ - CRINN Pre-computed Metadata
    35% fewer computations
    // New: Pre-computed metadata access
    metadata = get_precomputed_metadata(level, node)
    edge_count = metadata.count
    pattern_score = metadata.intelligence_score
    // Use metadata for optimization decisions
    if (pattern_score > threshold)
        apply_pattern_optimization()

    These refinement optimizations demonstrate CRINN's attention to detail—even when the major optimizations are complete, there's still room for incremental improvements that add up to measurable performance gains. The combination of adaptive prefetching and pre-computed metadata shows how CRINN learns to optimize both memory access patterns and computational efficiency simultaneously.

    What This Means for the Future

    CRINN's success has implications far beyond just making vector search faster. It demonstrates that AI can now tackle optimization problems that traditionally required deep human expertise. Consider the possibilities: database query optimization could benefit from AI that learns optimal index strategies from execution patterns. Network protocol design might see AI discovering better congestion control algorithms by experimenting with different approaches. Compiler optimizations could adapt to specific codebases, learning which transformations yield the best performance. Even operating system schedulers could use reinforcement learning to adapt to workload patterns in real-time, providing better resource allocation than static policies.

    BibTeX

    @article{deepreinforce2025crinn,
      title={CRINN: Contrastive Reinforcement Learning for Approximate Nearest Neighbor Search},
      author={DeepReinforce Team},
      journal={arXiv preprint},
      year={2025}
    }