The Perceiver Resampler: A Deep Dive into Visual Information Bottlenecks

Understanding how modern multimodal models compress visual information for language models

Deep Learning
Computer Vision
Multimodal Models
Attention Mechanisms
Author

Rishabh Mondal

Published

February 11, 2025

Welcome to the World of Multimodal Architecture!

Have you ever wondered how models like GPT-4V or Google Gemini can “see” images and have conversations about them? The secret lies in a beautiful architectural component called the Perceiver Resampler.

What you’ll learn in this blog:

  • Why we need special architectures to bridge vision and language
  • The elegant solution: Learned latent queries
  • How cross-attention creates information bottlenecks
  • Temporal encodings for video understanding
  • Building your intuition from simple concepts to advanced implementation

Prerequisites: Basic understanding of neural networks and attention mechanisms. Don’t worry if you’re rusty - we’ll build up from the basics!


Part 1: The Problem - Bridging Two Worlds

The Vision-Language Marriage

Modern AI has achieved something remarkable: we can now have natural conversations with AI about both text and images. But this marriage of vision and language faces a fundamental challenge:

Vision models produce a LOT of data. A single image encoded by a modern vision encoder might produce thousands of tokens. A video? That’s tens of thousands of tokens, easily.

Language models expect a fixed input size. GPT-style models are designed to process sequences of text, where each token is processed through attention layers. The computational cost grows quadratically with sequence length.

The Quadratic Problem

In transformers, the attention mechanism computes relationships between every pair of tokens. If you have \(n\) tokens, you need to compute \(n^2\) relationships.

  • 1 image: ~2,000 tokens → ~4,000,000 computations
  • 8 images (video): ~16,000 tokens → ~256,000,000 computations

This quickly becomes computationally infeasible!

The Challenge: Variable Input, Fixed Output

We need a component that can:

  1. Accept variable-length visual input - 1 image, 8 images, 30 images… it should work with any number
  2. Produce fixed-size output - The language model needs a consistent representation
  3. Preserve important information - We can’t just randomly sample pixels!
  4. Be learnable - The model should figure out WHAT information matters

This is where the Perceiver Resampler enters the story.


Part 2: The Intuition - A Smart Summarizer

The Library Analogy

Imagine you’re researching a topic and have access to a library with millions of books. You don’t have time to read everything, but you need enough information to write a comprehensive summary.

The Perceiver Resampler is like hiring a smart research assistant who:

  1. Has a set of pre-written research questions (learned latent queries)
  2. Searches through ALL the books (visual features)
  3. Extracts only the most relevant answers
  4. Returns a fixed-size summary regardless of library size

Learned Latent Queries: The Secret Sauce

The key insight is that we don’t process visual features directly. Instead, we create a set of “learned queries” - vectors that start random but are trained to ask the right questions.

Think of it like this:

import numpy as np

# Imagine we have 64 "questions" our model can ask
# Each question is a vector of, say, 1024 dimensions
num_latents = 64
latent_dim = 1024

# At the start of training, these are random
learned_latents = np.random.randn(num_latents, latent_dim)

print(f"Shape of learned latents: {learned_latents.shape}")
# Output: (64, 1024) - 64 questions, each with 1024 dimensions
Shape of learned latents: (64, 1024)

These “questions” aren’t text - they’re abstract vectors in a high-dimensional space. Through training, they learn to extract useful information from images.

Key Insight: Not Computed, But Learned!

The learned latents are NOT computed from the input images. They are:

  1. Initialized randomly at the start of training
  2. Learned through backpropagation as the model trains
  3. Used as queries to extract information from any input

This is why the Stack Exchange answer emphasized that latent vectors are “learned via gradient descent” - not pre-computed features!


Part 3: Cross-Attention - The Information Filter

Understanding Cross-Attention

Before diving deeper, let’s understand cross-attention, the mechanism that powers the Perceiver Resampler.

Standard Self-Attention: Each token attends to every other token in the SAME sequence.

Cross-Attention: Tokens from one sequence attend to tokens from a DIFFERENT sequence.

In the Perceiver Resampler: - Queries: Our learned latents (fixed size, e.g., 64 tokens) - Keys & Values: Visual features from the image encoder (variable size)

The Database Query Analogy

Think of cross-attention as querying a database:

import matplotlib.pyplot as plt
import matplotlib.patches as mpatches

# Create a simple visualization
fig, ax = plt.subplots(figsize=(10, 4))

# Draw the database (visual features)
database = mpatches.FancyBboxPatch((0.2, 0.3), 0.4, 0.4, boxstyle="round,pad=0.05",
                                   edgecolor="blue", facecolor="lightblue", alpha=0.7)
ax.add_patch(database)
ax.text(0.4, 0.5, "Visual Features\n(Keys & Values)\nVariable Size",
        ha='center', va='center', fontsize=10, weight='bold')

# Draw the query
query = mpatches.FancyBboxPatch((0.6, 0.3), 0.25, 0.4, boxstyle="round,pad=0.05",
                                edgecolor="red", facecolor="lightcoral", alpha=0.7)
ax.add_patch(query)
ax.text(0.725, 0.5, "Learned\nQueries\nFixed Size",
        ha='center', va='center', fontsize=10, weight='bold')

# Draw the result
result = mpatches.FancyBboxPatch((0.6, -0.1), 0.25, 0.25, boxstyle="round,pad=0.05",
                                 edgecolor="green", facecolor="lightgreen", alpha=0.7)
ax.add_patch(result)
ax.text(0.725, 0.025, "Extracted\nInformation\nFixed Size",
        ha='center', va='center', fontsize=10, weight='bold')

# Draw arrows
ax.annotate('', xy=(0.45, 0.5), xytext=(0.6, 0.5),
            arrowprops=dict(arrowstyle='->', lw=2, color='purple'))
ax.text(0.525, 0.55, "Query", ha='center', fontsize=9, color='purple')

ax.annotate('', xy=(0.725, 0.05), xytext=(0.725, 0.3),
            arrowprops=dict(arrowstyle='->', lw=2, color='darkgreen'))

ax.set_xlim(0, 1)
ax.set_ylim(-0.2, 0.8)
ax.axis('off')
ax.set_title('Cross-Attention as a Database Query', fontsize=14, weight='bold')

plt.tight_layout()
plt.show()

The process:

  1. Query asks: “What information do I need?”
  2. Key responds: “Here’s what I have available”
  3. Value delivers: “Here’s the actual information”
  4. Attention weights determine: “How much each value contributes”

Part 4: Architecture Deep Dive

The Complete Perceiver Resampler

Let’s break down the full architecture, component by component:

import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.patches import FancyBboxPatch, FancyArrowPatch

def draw_perceiver_architecture():
    """Draw a detailed architecture diagram of the Perceiver Resampler"""
    fig, ax = plt.subplots(figsize=(14, 8))
    ax.set_xlim(0, 14)
    ax.set_ylim(0, 10)
    ax.axis('off')

    # Colors
    vision_color = '#3498db'      # Blue
    time_color = '#e74c3c'        # Red
    latent_color = '#9b59b6'      # Purple
    attn_color = '#f39c12'        # Orange
    output_color = '#27ae60'      # Green

    # Title
    ax.text(7, 9.5, 'Perceiver Resampler Architecture',
            ha='center', fontsize=16, weight='bold')

    # Step 1: Vision Encoder
    vision_box = FancyBboxPatch((0.5, 7), 2.5, 1.2, boxstyle="round,pad=0.05",
                                edgecolor=vision_color, facecolor=vision_color, alpha=0.3)
    ax.add_patch(vision_box)
    ax.text(1.75, 7.6, 'Vision Encoder\n(NFNet/CLIP)', ha='center', va='center',
            fontsize=9, weight='bold', color='darkblue')

    # Step 2: Temporal Encoding
    time_box = FancyBboxPatch((0.5, 5.5), 2.5, 0.8, boxstyle="round,pad=0.05",
                              edgecolor=time_color, facecolor=time_color, alpha=0.3)
    ax.add_patch(time_box)
    ax.text(1.75, 5.9, '+ Temporal Encoding', ha='center', va='center',
            fontsize=9, weight='bold', color='darkred')

    # Arrow from vision to time
    ax.annotate('', xy=(1.75, 6.3), xytext=(1.75, 7.0),
                arrowprops=dict(arrowstyle='->', lw=2, color='gray'))

    # Visual Features (flattened)
    features_box = FancyBboxPatch((0.5, 4), 2.5, 1, boxstyle="round,pad=0.05",
                                  edgecolor=vision_color, facecolor='lightblue', alpha=0.5)
    ax.add_patch(features_box)
    ax.text(1.75, 4.5, 'Visual Features\nFlattened [T*S, d]', ha='center', va='center',
            fontsize=8, style='italic')

    # Arrow from time to features
    ax.annotate('', xy=(1.75, 4.0), xytext=(1.75, 5.5),
                arrowprops=dict(arrowstyle='->', lw=2, color='gray'))

    # Learned Latents
    latent_box = FancyBboxPatch((5.5, 7), 2.5, 1, boxstyle="round,pad=0.05",
                                edgecolor=latent_color, facecolor=latent_color, alpha=0.3)
    ax.add_patch(latent_box)
    ax.text(6.75, 7.5, 'Learned Latents\n[R, d]', ha='center', va='center',
            fontsize=9, weight='bold', color='darkmagenta')

    # Cross-Attention Block
    attn_box = FancyBboxPatch((4, 2), 5.5, 1.5, boxstyle="round,pad=0.1",
                              edgecolor=attn_color, facecolor=attn_color, alpha=0.3)
    ax.add_patch(attn_box)
    ax.text(6.75, 2.75, 'Cross-Attention Layer', ha='center', va='center',
            fontsize=10, weight='bold', color='darkorange')
    ax.text(6.75, 2.35, 'Q: Latents | K,V: Visual Features', ha='center', va='center',
            fontsize=7)

    # Arrows to attention
    # From features
    ax.annotate('', xy=(4.5, 2.7), xytext=(3.0, 4.5),
                arrowprops=dict(arrowstyle='->', lw=2, color=vision_color, alpha=0.7))
    ax.text(3.3, 3.8, 'Keys, Values', fontsize=7, color=vision_color)

    # From latents
    ax.annotate('', xy=(5.5, 2.7), xytext=(6.75, 7.0),
                arrowprops=dict(arrowstyle='->', lw=2, color=latent_color, alpha=0.7))
    ax.text(6.5, 5.2, 'Query', fontsize=7, color='darkmagenta')

    # Skip connection
    skip_box = FancyBboxPatch((4, 0.8), 5.5, 0.5, boxstyle="round,pad=0.05",
                              edgecolor='gray', facecolor='lightgray', alpha=0.5)
    ax.add_patch(skip_box)
    ax.text(6.75, 1.05, 'Skip Connection (+)', ha='center', va='center',
            fontsize=8, style='italic')

    # Arrow from attention to skip
    ax.annotate('', xy=(6.75, 1.3), xytext=(6.75, 2.0),
                arrowprops=dict(arrowstyle='->', lw=2, color='gray'))

    # Feed Forward
    ff_box = FancyBboxPatch((4, -0.5), 5.5, 0.5, boxstyle="round,pad=0.05",
                           edgecolor='gray', facecolor='lightgray', alpha=0.5)
    ax.add_patch(ff_box)
    ax.text(6.75, -0.25, 'Feed Forward Network', ha='center', va='center',
            fontsize=8, style='italic')

    # Arrow from skip to FF
    ax.annotate('', xy=(6.75, 0.3), xytext=(6.75, 0.8),
                arrowprops=dict(arrowstyle='->', lw=2, color='gray'))

    # Output
    output_box = FancyBboxPatch((10.5, 4), 2.5, 1, boxstyle="round,pad=0.05",
                                edgecolor=output_color, facecolor=output_color, alpha=0.3)
    ax.add_patch(output_box)
    ax.text(11.75, 4.5, 'Output Tokens\n[R, d]', ha='center', va='center',
            fontsize=9, weight='bold', color='darkgreen')

    # Arrow from FF to output
    ax.annotate('', xy=(10.5, 4.5), xytext=(9.5, -0.25),
                arrowprops=dict(arrowstyle='->', lw=2, color='green'))

    # Iteration indicator
    ax.text(6.75, -1, '↑ Repeat for N layers ↑', ha='center', fontsize=8, style='italic')

    # Add labels for dimensions
    ax.text(1.75, 3.5, 'T × S tokens', ha='center', fontsize=7, color='blue')
    ax.text(6.75, 6.5, 'R tokens\n(learned)', ha='center', fontsize=7, color='purple')
    ax.text(11.75, 3.5, 'R tokens\n(fixed output)', ha='center', fontsize=7, color='green')

    plt.tight_layout()
    plt.show()

draw_perceiver_architecture()

Step-by-Step Breakdown

Step 1: Vision Encoding

import torch
import torch.nn as nn

# Vision encoder output shape
# T = number of temporal frames (images)
# S = spatial tokens per image
# d = feature dimension
T, S, d = 8, 289, 1024  # Example: 8 frames, 289 patches per frame, 1024-dim features

visual_features = torch.randn(T, S, d)
print(f"Visual features shape: {visual_features.shape}")
# Output: torch.Size([8, 289, 1024]) - 8 images, 289 tokens each
Visual features shape: torch.Size([8, 289, 1024])

Step 2: Flattening Space and Time

The Perceiver Resampler flattens the spatial and temporal dimensions into a single sequence:

# Flatten: [T, S, d] -> [T*S, d]
visual_features_flat = visual_features.reshape(T * S, d)
print(f"After flattening: {visual_features_flat.shape}")
# Output: torch.Size([2312, 1024]) - 2312 total tokens
After flattening: torch.Size([2312, 1024])

Why does this work?

  • Spatial information is encoded in the feature vectors themselves (CNNs preserve locality)
  • Temporal information is explicitly added (see next section)
  • The flattening is consistent, so the model learns to handle this specific ordering

Step 3: Temporal Position Encodings

class TemporalEncoding(nn.Module):
    """Learned temporal position encodings for video frames"""

    def __init__(self, max_frames: int = 8, feature_dim: int = 1024):
        super().__init__()
        # Learnable embedding for each time position
        self.temporal_embeds = nn.Parameter(torch.randn(max_frames, feature_dim))

    def forward(self, features: torch.Tensor) -> torch.Tensor:
        """
        Args:
            features: [T, S, d] tensor of visual features
        Returns:
            features with temporal encoding added: [T, S, d]
        """
        T = features.shape[0]

        # Add temporal encoding to each frame
        # For each frame t, add self.temporal_embeds[t] to all S spatial tokens
        encoded = []
        for t in range(T):
            encoded.append(features[t] + self.temporal_embeds[t])

        return torch.stack(encoded)

# Example usage
temporal_encoder = TemporalEncoding(max_frames=8, feature_dim=1024)
encoded_features = temporal_encoder(visual_features)
print(f"Encoded features shape: {encoded_features.shape}")
Encoded features shape: torch.Size([8, 289, 1024])
Interpolation at Inference Time

Here’s something fascinating: Flamingo was trained with 8 frames, but at inference, it processes 30 frames!

How? By interpolating the temporal embeddings:

# Trained with 8 frames, need 30
trained_embeds = temporal_encoder.temporal_embeds  # Shape: [8, 1024]

# Linear interpolate to get 30 embeddings
from torch.nn.functional import interpolate
inference_embeds = interpolate(
    trained_embeds.T.unsqueeze(0),  # [1, 1024, 8]
    size=30,
    mode='linear',
    align_corners=False
).squeeze(0).T  # [1024, 30]

This remarkable property allows the model to generalize to different frame counts!

Step 4: Cross-Attention Layer

class CrossAttentionLayer(nn.Module):
    """Cross-attention where latents query visual features"""

    def __init__(self, latent_dim: int = 1024, num_heads: int = 8):
        super().__init__()
        self.latent_dim = latent_dim
        self.num_heads = num_heads
        self.head_dim = latent_dim // num_heads

        # Query projection: from latents
        self.q_proj = nn.Linear(latent_dim, latent_dim)

        # Key and Value projections: from visual features
        self.k_proj = nn.Linear(latent_dim, latent_dim)
        self.v_proj = nn.Linear(latent_dim, latent_dim)

        # Output projection
        self.out_proj = nn.Linear(latent_dim, latent_dim)

        # Skip connection
        self.skip = nn.Identity()

    def forward(self, latents: torch.Tensor, visual_features: torch.Tensor) -> torch.Tensor:
        """
        Args:
            latents: [R, d] - learned latent queries
            visual_features: [T*S, d] - flattened visual features
        Returns:
            Updated latents: [R, d]
        """
        R = latents.shape[0]
        T_S = visual_features.shape[0]

        # Project to Q, K, V
        Q = self.q_proj(latents)      # [R, d]
        K = self.k_proj(visual_features)  # [T*S, d]
        V = self.v_proj(visual_features)  # [T*S, d]

        # Reshape for multi-head attention
        Q = Q.view(R, self.num_heads, self.head_dim).transpose(0, 1)  # [heads, R, head_dim]
        K = K.view(T_S, self.num_heads, self.head_dim).transpose(0, 1)  # [heads, T*S, head_dim]
        V = V.view(T_S, self.num_heads, self.head_dim).transpose(0, 1)  # [heads, T*S, head_dim]

        # Compute attention scores
        scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim ** 0.5)  # [heads, R, T*S]

        # Apply softmax to get attention weights
        attn_weights = torch.softmax(scores, dim=-1)  # [heads, R, T*S]

        # Apply attention to values
        attended = torch.matmul(attn_weights, V)  # [heads, R, head_dim]

        # Merge heads
        attended = attended.transpose(0, 1).contiguous().view(R, self.latent_dim)  # [R, d]

        # Output projection
        output = self.out_proj(attended)

        # Skip connection
        return output + self.skip(latents)

Step 5: The Complete Perceiver Resampler

class PerceiverResampler(nn.Module):
    """
    The Perceiver Resampler: compresses variable-length visual input
    into a fixed-size representation using learned latent queries.
    """

    def __init__(
        self,
        feature_dim: int = 1024,
        num_latents: int = 64,
        num_layers: int = 4,
        num_heads: int = 8,
        max_frames: int = 8
    ):
        super().__init__()

        # Learnable latent queries
        self.latents = nn.Parameter(torch.randn(num_latents, feature_dim))

        # Temporal encoding
        self.temporal_encoding = TemporalEncoding(max_frames, feature_dim)

        # Cross-attention layers
        self.layers = nn.ModuleList([
            CrossAttentionLayer(feature_dim, num_heads)
            for _ in range(num_layers)
        ])

        # Feed-forward networks
        self.ffns = nn.ModuleList([
            nn.Sequential(
                nn.Linear(feature_dim, feature_dim * 4),
                nn.GELU(),
                nn.Linear(feature_dim * 4, feature_dim)
            )
            for _ in range(num_layers)
        ])

        # Layer normalization
        self.layer_norms = nn.ModuleList([
            nn.LayerNorm(feature_dim)
            for _ in range(num_layers * 2)
        ])

    def forward(self, visual_features: torch.Tensor) -> torch.Tensor:
        """
        Args:
            visual_features: [T, S, d] - visual features from encoder
        Returns:
            resampled: [R, d] - fixed-size representation
        """
        # Add temporal encoding
        encoded = self.temporal_encoding(visual_features)  # [T, S, d]

        # Flatten spatial and temporal dimensions
        T, S, d = encoded.shape
        visual_flat = encoded.reshape(T * S, d)  # [T*S, d]

        # Initialize latents (learned parameter, not computed from input!)
        x = self.latents  # [R, d]

        # Process through layers
        for i, (attn_layer, ffn) in enumerate(zip(self.layers, self.ffns)):
            # Cross-attention with skip connection
            x = x + self.layer_norms[2*i](attn_layer(x, visual_flat))

            # Feed-forward with skip connection
            x = x + self.layer_norms[2*i+1](ffn(x))

        return x  # [R, d] - same size as latents!

# Example usage
resampler = PerceiverResampler(
    feature_dim=1024,
    num_latents=64,
    num_layers=4,
    num_heads=8
)

# Process 8 frames
visual_input = torch.randn(8, 289, 1024)  # T=8, S=289, d=1024
output = resampler(visual_input)

print(f"Input: {visual_input.shape[0]} frames × {visual_input.shape[1]} tokens = {visual_input.shape[0] * visual_input.shape[1]} tokens")
print(f"Output: {output.shape[0]} tokens (fixed size!)")
Input: 8 frames × 289 tokens = 2312 tokens
Output: 64 tokens (fixed size!)

Part 5: The “Soft Clustering” Intuition

Understanding What Latents Learn

You might wonder: What do these learned latents actually represent?

The Perceiver paper describes this beautifully: the latents act like cluster centers in a soft clustering algorithm. Each latent specializes in extracting a certain type of information.

import matplotlib.pyplot as plt
import numpy as np

def visualize_latent_specialization():
    """Conceptual visualization of how latents might specialize"""
    fig, axes = plt.subplots(2, 4, figsize=(14, 7))
    fig.suptitle('How Learned Latents Might Specialize', fontsize=14, weight='bold')

    # Conceptual specializations
    specializations = [
        ("Objects", "focuses on discrete objects"),
        ("Actions", "captures motion and activity"),
        ("Spatial", "understands spatial relationships"),
        ("Colors", "attends to color patterns"),
        ("Textures", "focuses on surface properties"),
        ("Context", "understands scene context"),
        ("Temporal", "tracks changes over time"),
        ("Global", "integrates holistic understanding")
    ]

    for idx, (ax, (name, desc)) in enumerate(zip(axes.flat, specializations)):
        # Create a conceptual attention pattern
        attn_pattern = np.random.rand(16, 16)
        attn_pattern = attn_pattern / attn_pattern.sum()

        im = ax.imshow(attn_pattern, cmap='viridis')
        ax.set_title(f'Latent {idx}: {name}\n{desc}', fontsize=9)
        ax.axis('off')

    plt.tight_layout()
    plt.show()

visualize_latent_specialization()

Key insight from the Stack Exchange discussion:

“The latent vectors in this case act as queries which extract information from the input data and need to be aligned in such a way that they extract the necessary information to perform the prediction task.”

Through training, gradients flow back and adjust these latents so they ask “better questions” - questions that extract information useful for the task at hand.

Not Traditional Clustering!

The Perceiver authors call this “end-to-end clustering” but it’s important to understand:

  • NOT K-means or any traditional clustering algorithm
  • NOT pre-computed from features
  • IS a soft, differentiable clustering that emerges from gradient descent
  • IS task-dependent - latents learn to cluster information relevant to the task

The “clusters” are implicit in the attention patterns that develop during training!


Part 6: The Flamingo Integration

How Flamingo Uses the Perceiver Resampler

The Perceiver Resampler was popularized by DeepMind’s Flamingo model. Let’s see how it fits into the complete architecture:

def draw_flamingo_architecture():
    """Draw the Flamingo architecture showing Perceiver Resampler integration"""
    fig, ax = plt.subplots(figsize=(16, 6))
    ax.set_xlim(0, 16)
    ax.set_ylim(0, 8)
    ax.axis('off')

    # Title
    ax.text(8, 7.5, 'Flamingo: Vision-Language Model with Perceiver Resampler',
            ha='center', fontsize=14, weight='bold')

    # Input section
    ax.add_patch(plt.Rectangle((0.5, 5.5), 3, 1, fill=True, color='lightblue', alpha=0.7))
    ax.text(2, 6, 'Images/Video\n(Variable Size)', ha='center', va='center', weight='bold')

    # Vision Encoder
    ax.add_patch(plt.Rectangle((4.5, 5.5), 2, 1, fill=True, color='blue', alpha=0.5))
    ax.text(5.5, 6, 'Vision\nEncoder', ha='center', va='center', weight='bold', color='white')

    # Perceiver Resampler
    ax.add_patch(plt.Rectangle((7, 5.5), 2.5, 1, fill=True, color='purple', alpha=0.5))
    ax.text(8.25, 6, 'Perceiver\nResampler', ha='center', va='center', weight='bold', color='white')

    # Fixed visual tokens
    ax.add_patch(plt.Rectangle((10, 5.5), 1.5, 1, fill=True, color='green', alpha=0.5))
    ax.text(10.75, 6, 'Fixed\nTokens', ha='center', va='center', weight='bold', color='white')

    # Text input
    ax.add_patch(plt.Rectangle((0.5, 3.5), 4, 1, fill=True, color='lightcoral', alpha=0.7))
    ax.text(2.5, 4, 'Text Input (with <image> tokens)', ha='center', va='center', weight='bold')

    # Gated Cross-Attention layers
    for i in range(4):
        y_pos = 2.5 - i * 0.6
        ax.add_patch(plt.Rectangle((5.5, y_pos), 5, 0.4, fill=True, color='orange', alpha=0.6))
        ax.text(8, y_pos + 0.2, f'Gated Cross-Attention Layer {i+1}',
                ha='center', va='center', fontsize=8, weight='bold', color='darkred')

    # LLM blocks
    for i in range(4):
        y_pos = 2.5 - i * 0.6
        ax.add_patch(plt.Rectangle((11, y_pos), 1, 0.4, fill=True, color='lightgray', alpha=0.7))
        ax.text(11.5, y_pos + 0.2, f'LLM', ha='center', va='center', fontsize=7)

    # Output
    ax.add_patch(plt.Rectangle((6, -0.2), 4, 0.6, fill=True, color='lightgreen', alpha=0.7))
    ax.text(8, 0.1, 'Generated Text Output', ha='center', va='center', weight='bold')

    # Arrows
    # Vision path
    ax.annotate('', xy=(4.5, 6), xytext=(3.5, 6), arrowprops=dict(arrowstyle='->', lw=2, color='blue'))
    ax.annotate('', xy=(7, 6), xytext=(6.5, 6), arrowprops=dict(arrowstyle='->', lw=2, color='purple'))
    ax.annotate('', xy=(10, 6), xytext=(9.5, 6), arrowprops=dict(arrowstyle='->', lw=2, color='green'))

    # Vision to cross-attention
    for i in range(4):
        y_pos = 2.7 - i * 0.6
        ax.annotate('', xy=(10.75, y_pos), xytext=(10.75, 5.5),
                    arrowprops=dict(arrowstyle='->', lw=1, color='green', alpha=0.5))

    # Text to LLM
    ax.annotate('', xy=(5.5, 2.5), xytext=(4.5, 4), arrowprops=dict(arrowstyle='->', lw=2, color='coral'))

    # LLM output
    ax.annotate('', xy=(8, 0.4), xytext=(11.5, 2.5), arrowprops=dict(arrowstyle='->', lw=2, color='gray'))

    plt.tight_layout()
    plt.show()

draw_flamingo_architecture()

Gated Cross-Attention: Mixing Vision and Language

After the Perceiver Resampler produces fixed-size visual tokens, Flamingo uses gated cross-attention to inject this information into the language model:

class GatedCrossAttention(nn.Module):
    """
    Cross-attention with gating for controlled visual information injection.
    The gate starts closed (alpha=0) and gradually opens during training.
    """

    def __init__(self, visual_dim: int, lang_dim: int, num_heads: int = 8):
        super().__init__()
        self.cross_attn = nn.MultiheadAttention(
            embed_dim=lang_dim,
            num_heads=num_heads,
            batch_first=True
        )

        # Learnable gate parameter
        self.alpha = nn.Parameter(torch.zeros(1))

    def forward(self, lang_tokens: torch.Tensor, visual_tokens: torch.Tensor) -> torch.Tensor:
        """
        Args:
            lang_tokens: [batch, seq_len, lang_dim] - language model tokens
            visual_tokens: [batch, num_visual, visual_dim] - resampled visual tokens
        Returns:
            Updated lang_tokens with visual information injected
        """
        # Cross-attention: language queries, visual keys/values
        attn_out, _ = self.cross_attn(
            query=lang_tokens,
            key=visual_tokens,
            value=visual_tokens
        )

        # Tanh gate: controls how much visual information flows through
        gate = torch.tanh(self.alpha)
        gated_output = gate * attn_out

        # Skip connection: original language tokens
        return lang_tokens + gated_output

# Training progression
def simulate_training_progression():
    """Show how the gate opens during training"""
    fig, ax = plt.subplots(figsize=(10, 4))

    # Simulate gate values during training
    steps = np.arange(0, 10000)
    # Gate gradually opens from 0 towards 1
    gate_values = np.tanh(steps / 2000)  # Smooth transition

    ax.plot(steps, gate_values, linewidth=2, color='purple')
    ax.set_xlabel('Training Steps')
    ax.set_ylabel('Gate Value (tanh(α))')
    ax.set_title('Gated Cross-Attention: Progressive Information Injection')
    ax.grid(True, alpha=0.3)
    ax.set_ylim(-0.1, 1.1)

    # Add annotations
    ax.annotate('Start: Gate Closed\nα≈0, tanh(α)≈0\nNo visual info',
                xy=(0, 0), xytext=(1000, 0.2),
                arrowprops=dict(arrowstyle='->', color='red'),
                fontsize=9, ha='center')

    ax.annotate('Mid: Gate Opening\nGradual\nvisual injection',
                xy=(3000, 0.8), xytext=(3500, 0.5),
                arrowprops=dict(arrowstyle='->', color='orange'),
                fontsize=9, ha='center')

    ax.annotate('End: Gate Open\nα large, tanh(α)≈1\nFull visual integration',
                xy=(8000, 1), xytext=(6500, 0.8),
                arrowprops=dict(arrowstyle='->', color='green'),
                fontsize=9, ha='center')

    plt.tight_layout()
    plt.show()

simulate_training_progression()

Why the Gradual Gate?

When training Flamingo, the cross-attention gate starts closed (α=0 → tanh(α)=0). This is crucial:

  1. Protects the pre-trained LLM: The language model already knows how to process text
  2. Allows gradual adaptation: Visual information is slowly integrated
  3. Prevents catastrophic forgetting: The LLM doesn’t suddenly “forget” how to handle text
  4. Stabilizes training: The model learns to balance textual and visual information

It’s like introducing a new team member gradually, not throwing them into the deep end!


Part 7: Advanced Topics

Masking Strategies in Flamingo

Flamingo uses an interesting masking strategy for cross-attention: text tokens only attend to the most recent image.

def create_flamingo_mask(
    text_len: int,
    image_positions: list,  # Positions where images appear
    num_visual_tokens: int  # Tokens per image (R from resampler)
) -> torch.Tensor:
    """
    Create a mask where text tokens can only attend to the most recent image.

    Args:
        text_len: Length of text sequence
        image_positions: List of positions where <image> tokens appear
        num_visual_tokens: Number of visual tokens per image

    Returns:
        Boolean mask for cross-attention [text_len, num_images * num_visual_tokens]
    """
    num_images = len(image_positions)
    total_visual = num_images * num_visual_tokens

    mask = torch.zeros(text_len, total_visual, dtype=torch.bool)

    for text_idx in range(text_len):
        # Find the most recent image before this text token
        recent_image_idx = -1
        for img_idx, img_pos in enumerate(image_positions):
            if img_pos < text_idx:
                recent_image_idx = img_idx

        # Only allow attending to that image
        if recent_image_idx >= 0:
            start = recent_image_idx * num_visual_tokens
            end = start + num_visual_tokens
            mask[text_idx, start:end] = False  # False = can attend
        else:
            # No image before this text
            mask[text_idx, :] = True  # True = masked

    return mask

# Example
text_len = 20
image_positions = [2, 10, 15]  # <image> tokens at positions 2, 10, 15
num_visual_tokens = 64

mask = create_flamingo_mask(text_len, image_positions, num_visual_tokens)

# Visualize
fig, ax = plt.subplots(figsize=(10, 6))
im = ax.imshow(~mask.numpy(), cmap='Blues', aspect='auto')

ax.set_xlabel('Visual Tokens', weight='bold')
ax.set_ylabel('Text Tokens', weight='bold')
ax.set_title('Flamingo Cross-Attention Mask\n(Blue = Can Attend, White = Masked)')

# Add image boundaries
for i in range(1, len(image_positions)):
    ax.axvline(x=i * num_visual_tokens - 0.5, color='red', linestyle='--', alpha=0.5)
    ax.text(i * num_visual_tokens - 32, text_len + 0.5, f'Image {i}',
            ha='center', fontsize=8, color='red')

# Add image position markers
for pos in image_positions:
    ax.axhline(y=pos + 0.5, color='green', linestyle=':', alpha=0.7)
    ax.text(-1, pos, '<img>', ha='right', va='center', fontsize=7, color='green')

plt.tight_layout()
plt.show()

Spatial vs. Temporal Encodings

A common question: Why only temporal encodings? What about position?

def explain_encoding_choices():
    """Explain why Flamingo uses temporal but not spatial encodings"""

    fig, axes = plt.subplots(1, 2, figsize=(12, 5))

    # Spatial: Already in features
    ax = axes[0]
    ax.text(0.5, 0.9, 'Spatial Information', ha='center', weight='bold', transform=ax.transAxes)

    reasons = [
        "CNNs (like NFNet) inherently",
        "preserve spatial relationships",
        "in their feature maps.",
        "",
        "Early layers detect edges, textures",
        "at specific locations.",
        "",
        "Deeper layers build spatial",
        "relationships from these.",
        "",
        "No explicit encoding needed!"
    ]

    for i, reason in enumerate(reasons):
        color = 'darkgreen' if 'No explicit' in reason else 'black'
        weight = 'bold' if 'No explicit' in reason else 'normal'
        ax.text(0.1, 0.75 - i * 0.08, reason, transform=ax.transAxes,
                fontsize=10, color=color, weight=weight)

    ax.axis('off')
    ax.add_patch(plt.Rectangle((0.05, 0.05), 0.9, 0.9, fill=False,
                               edgecolor='green', linewidth=2, transform=ax.transAxes))

    # Temporal: Needs explicit encoding
    ax = axes[1]
    ax.text(0.5, 0.9, 'Temporal Information', ha='center', weight='bold', transform=ax.transAxes)

    reasons = [
        "Attention shuffles tokens,",
        "losing temporal order.",
        "",
        "Frames at t=0 and t=7 become",
        "indistinguishable after attention!",
        "",
        "Solution: Add learned",
        "temporal embeddings.",
        "",
        "Each frame gets a unique",
        "time vector added to it."
    ]

    for i, reason in enumerate(reasons):
        color = 'darkred' if 'Solution' in reason or 'time vector' in reason else 'black'
        weight = 'bold' if 'Solution' in reason or 'time vector' in reason else 'normal'
        ax.text(0.1, 0.75 - i * 0.08, reason, transform=ax.transAxes,
                fontsize=10, color=color, weight=weight)

    ax.axis('off')
    ax.add_patch(plt.Rectangle((0.05, 0.05), 0.9, 0.9, fill=False,
                               edgecolor='red', linewidth=2, transform=ax.transAxes))

    plt.suptitle('Why Temporal but Not Spatial Encodings?', fontsize=13, weight='bold', y=1.02)
    plt.tight_layout()
    plt.show()

explain_encoding_choices()

“Note that we only use temporal encodings and no explicit spatial grid position encodings; we did not observe improvements from the latter. This rationale behind is likely that CNNs, such as our NFNet encoder, are known to implicitly include spatial information.” — The Flamingo Paper


Part 8: Practical Implementation

A Minimal Working Example

import torch
import torch.nn as nn

class MinimalPerceiverResampler(nn.Module):
    """Simplified Perceiver Resampler for educational purposes"""

    def __init__(self, dim=512, num_latents=32, num_layers=2):
        super().__init__()

        # Learnable latents
        self.latents = nn.Parameter(torch.randn(num_latents, dim))

        # Temporal embeddings (for up to 8 frames)
        self.temporal_embeds = nn.Parameter(torch.randn(8, dim))

        # Cross-attention layers
        self.cross_attn_layers = nn.ModuleList([
            nn.MultiheadAttention(dim, num_heads=8, batch_first=True)
            for _ in range(num_layers)
        ])

        # Feed-forward networks
        self.ffns = nn.ModuleList([
            nn.Sequential(nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim))
            for _ in range(num_layers)
        ])

        # Layer norms
        self.norms1 = nn.ModuleList([nn.LayerNorm(dim) for _ in range(num_layers)])
        self.norms2 = nn.ModuleList([nn.LayerNorm(dim) for _ in range(num_layers)])

    def forward(self, x):
        """
        Args:
            x: [B, T, S, D] - batch of video features
        Returns:
            [B, R, D] - resampled features
        """
        B, T, S, D = x.shape

        # Add temporal encoding
        for t in range(T):
            x[:, t] = x[:, t] + self.temporal_embeds[t]

        # Flatten
        x = x.reshape(B, T * S, D)

        # Expand latents for batch
        latents = self.latents.unsqueeze(0).expand(B, -1, -1)

        # Process through layers
        for cross_attn, ffn, norm1, norm2 in zip(
            self.cross_attn_layers, self.ffns, self.norms1, self.norms2
        ):
            # Cross-attention
            attn_out, _ = cross_attn(latents, x, x)
            latents = latents + norm1(attn_out)

            # Feed-forward
            ff_out = ffn(latents)
            latents = latents + norm2(ff_out)

        return latents

# Test the model
model = MinimalPerceiverResampler(dim=512, num_latents=32, num_layers=2)

# Input: 2 videos, 4 frames each, 49 spatial tokens, 512 dimensions
input_features = torch.randn(2, 4, 49, 512)

# Forward pass
output = model(input_features)

print(f"Input shape: {input_features.shape}")
print(f"Output shape: {output.shape}")
print(f"Compression ratio: {input_features.shape[1] * input_features.shape[2]}{output.shape[1]}")
Input shape: torch.Size([2, 4, 49, 512])
Output shape: torch.Size([2, 32, 512])
Compression ratio: 196 → 32

Key Hyperparameters and Their Effects

import pandas as pd

# Create a hyperparameter reference table
hyperparams = {
    'Hyperparameter': [
        'num_latents (R)',
        'num_layers',
        'num_heads',
        'feature_dim (d)',
        'max_frames'
    ],
    'Typical Values': [
        '64 (Flamingo)',
        '4-6',
        '8',
        '1024',
        '8 (trainable)'
    ],
    'Effect': [
        'More latents → more info preserved, but slower',
        'More layers → more processing, diminishing returns',
        'More heads → can attend to more things in parallel',
        'Larger → more expressive, but more compute',
        'Higher → need more memory, but interpolation helps'
    ],
    'Trade-offs': [
        '32-128 is common sweet spot',
        '2-8 layers typical',
        '4-16 heads typical',
        '512-2048 common',
        '8 is standard, interpolate at inference'
    ]
}

df = pd.DataFrame(hyperparams)
df
Hyperparameter Typical Values Effect Trade-offs
0 num_latents (R) 64 (Flamingo) More latents → more info preserved, but slower 32-128 is common sweet spot
1 num_layers 4-6 More layers → more processing, diminishing ret... 2-8 layers typical
2 num_heads 8 More heads → can attend to more things in para... 4-16 heads typical
3 feature_dim (d) 1024 Larger → more expressive, but more compute 512-2048 common
4 max_frames 8 (trainable) Higher → need more memory, but interpolation h... 8 is standard, interpolate at inference

Part 9: Theoretical Insights

Why This Works: Theoretical Perspectives

1. Information Bottleneck Perspective

The Perceiver Resampler creates an information bottleneck:

import matplotlib.pyplot as plt
import numpy as np

def plot_information_flow():
    """Visualize the information bottleneck"""
    fig, ax = plt.subplots(figsize=(12, 4))

    # Information content at each stage
    stages = ['Raw\nVideo', 'Vision\nEncoder', 'Perceiver\nInput', 'Perceiver\nOutput']
    info_content = [10000, 5000, 5000, 1000]  # Conceptual units

    # Plot bars
    bars = ax.bar(range(len(stages)), info_content,
                   color=['lightgray', 'blue', 'purple', 'green'], alpha=0.7, edgecolor='black')

    # Add labels
    for i, (bar, stage, info) in enumerate(zip(bars, stages, info_content)):
        height = bar.get_height()
        ax.text(bar.get_x() + bar.get_width()/2., height,
                f'{info:,}', ha='center', va='bottom', weight='bold')

    ax.set_xticks(range(len(stages)))
    ax.set_xticklabels(stages, fontsize=11)
    ax.set_ylabel('Information Content (arbitrary units)', fontsize=11)
    ax.set_title('The Information Bottleneck: Compression for Efficiency', fontsize=13, weight='bold')
    ax.set_ylim(0, 11000)

    # Add bottleneck annotation
    ax.annotate('Bottleneck\nHere!', xy=(2.5, 500), xytext=(2.5, 2000),
                arrowprops=dict(arrowstyle='->', lw=2, color='red'),
                ha='center', fontsize=10, color='red')

    plt.tight_layout()
    plt.show()

plot_information_flow()

The insight: By forcing information through a narrow bottleneck, the model learns to preserve only the most task-relevant information. This is a form of representation learning - the latents learn to capture the essence of the visual input.

2. The Set Function Perspective

The Perceiver Resampler operates on sets, not sequences:

  • Visual features form a set (order doesn’t matter after encoding)
  • Latents form another set
  • Cross-attention operates on set-to-set relationships

This makes the architecture permutation invariant - it doesn’t depend on the specific ordering of spatial tokens.

3. The “Soft Database” Perspective

From the Stack Exchange discussion:

“You can think of attention, as used in Transformers or Perceiver as a soft differentiable database.”

  • Keys: Indices in the database
  • Values: Stored information
  • Queries: What we’re looking for
  • Attention weights: Soft retrieval (not just 0 or 1, but weighted)

This “soft” nature allows gradients to flow through, making the entire system learnable!


Part 10: Common Questions and Gotchas

A: Pooling is a fixed operation - it can’t learn WHAT to preserve. The Perceiver Resampler learns to ask task-specific questions.

Imagine you’re writing a paper about birds: - Average pooling = reading every sentence and taking the average - Perceiver = having specific questions like “What color is it?”, “Where does it live?”, etc.

A: Yes! They’re randomly initialized, just like any other neural network weights. The magic is that gradients from the task teach them what to ask for.

This is why the Stack Exchange emphasized “learned via gradient descent” - they’re not computed from features, they’re learned parameters!

A: Absolutely! The original Perceiver paper showed it works for: - Images - Audio - Point clouds - Video - Video + Audio

Any modality that can be encoded into feature vectors can be resampled!

A: Perceiver IO (the follow-up paper) adds cross-attention on the OUTPUT side too, allowing the model to produce structured outputs (like segmentation masks) while keeping the input bottleneck.


Quick Reference Card

Core Concepts

Concept Description
Learned Latents Randomly initialized, task-specialized query vectors
Cross-Attention Latents query visual features (Q=latents, K,V=features)
Temporal Encoding Learned vectors added to each frame to preserve time
Information Bottleneck Variable input → Fixed output via learned compression

Architecture Flow

Images/Video
    ↓
Vision Encoder
    ↓
+ Temporal Encoding
    ↓
Flatten [T×S, d]
    ↓
Cross-Attention (Q: latents, K,V: features)
    ↓
Skip + LayerNorm
    ↓
Feed Forward
    ↓
Skip + LayerNorm
    ↓
Repeat N layers
    ↓
Output [R, d]

Key Formulas

# Cross-attention
Attention(Q, K, V) = softmax(QK^T / √d_k)V

# Temporal encoding
Encoded_t = Feature_t + TemporalEmbed_t

# Skip connection
Output = LayerNorm(Input + Sublayer(Input))

Congratulations! You Now Understand the Perceiver Resampler!

You’ve learned:

  • The problem of variable-length visual input
  • How learned latent queries create an information bottleneck
  • Cross-attention as a differentiable database query
  • Temporal encodings for video understanding
  • The Flamingo integration and gated cross-attention
  • Practical implementation details

Further Reading:


Acknowledgments

This blog is based on insights from: - The Perceiver paper by Jaegle et al. (DeepMind, 2021) - The Flamingo paper by Alayrac et al. (DeepMind, 2022) - The excellent Stack Exchange discussion - The Flamingo explainer by Daniel Warfield


Happy Learning! May your attention mechanisms always attend to what matters!