The Perceiver Resampler: A Deep Dive into Visual Information Bottlenecks
Understanding how modern multimodal models compress visual information for language models
Deep Learning
Computer Vision
Multimodal Models
Attention Mechanisms
Author
Rishabh Mondal
Published
February 11, 2025
Welcome to the World of Multimodal Architecture!
Have you ever wondered how models like GPT-4V or Google Gemini can “see” images and have conversations about them? The secret lies in a beautiful architectural component called the Perceiver Resampler.
What you’ll learn in this blog:
Why we need special architectures to bridge vision and language
The elegant solution: Learned latent queries
How cross-attention creates information bottlenecks
Temporal encodings for video understanding
Building your intuition from simple concepts to advanced implementation
Prerequisites: Basic understanding of neural networks and attention mechanisms. Don’t worry if you’re rusty - we’ll build up from the basics!
Part 1: The Problem - Bridging Two Worlds
The Vision-Language Marriage
Modern AI has achieved something remarkable: we can now have natural conversations with AI about both text and images. But this marriage of vision and language faces a fundamental challenge:
Vision models produce a LOT of data. A single image encoded by a modern vision encoder might produce thousands of tokens. A video? That’s tens of thousands of tokens, easily.
Language models expect a fixed input size. GPT-style models are designed to process sequences of text, where each token is processed through attention layers. The computational cost grows quadratically with sequence length.
The Quadratic Problem
In transformers, the attention mechanism computes relationships between every pair of tokens. If you have \(n\) tokens, you need to compute \(n^2\) relationships.
Accept variable-length visual input - 1 image, 8 images, 30 images… it should work with any number
Produce fixed-size output - The language model needs a consistent representation
Preserve important information - We can’t just randomly sample pixels!
Be learnable - The model should figure out WHAT information matters
This is where the Perceiver Resampler enters the story.
Part 2: The Intuition - A Smart Summarizer
The Library Analogy
Imagine you’re researching a topic and have access to a library with millions of books. You don’t have time to read everything, but you need enough information to write a comprehensive summary.
The Perceiver Resampler is like hiring a smart research assistant who:
Has a set of pre-written research questions (learned latent queries)
Searches through ALL the books (visual features)
Extracts only the most relevant answers
Returns a fixed-size summary regardless of library size
Learned Latent Queries: The Secret Sauce
The key insight is that we don’t process visual features directly. Instead, we create a set of “learned queries” - vectors that start random but are trained to ask the right questions.
Think of it like this:
import numpy as np# Imagine we have 64 "questions" our model can ask# Each question is a vector of, say, 1024 dimensionsnum_latents =64latent_dim =1024# At the start of training, these are randomlearned_latents = np.random.randn(num_latents, latent_dim)print(f"Shape of learned latents: {learned_latents.shape}")# Output: (64, 1024) - 64 questions, each with 1024 dimensions
Shape of learned latents: (64, 1024)
These “questions” aren’t text - they’re abstract vectors in a high-dimensional space. Through training, they learn to extract useful information from images.
Key Insight: Not Computed, But Learned!
The learned latents are NOT computed from the input images. They are:
Initialized randomly at the start of training
Learned through backpropagation as the model trains
Used as queries to extract information from any input
This is why the Stack Exchange answer emphasized that latent vectors are “learned via gradient descent” - not pre-computed features!
Part 3: Cross-Attention - The Information Filter
Understanding Cross-Attention
Before diving deeper, let’s understand cross-attention, the mechanism that powers the Perceiver Resampler.
Standard Self-Attention: Each token attends to every other token in the SAME sequence.
Cross-Attention: Tokens from one sequence attend to tokens from a DIFFERENT sequence.
In the Perceiver Resampler: - Queries: Our learned latents (fixed size, e.g., 64 tokens) - Keys & Values: Visual features from the image encoder (variable size)
Spatial information is encoded in the feature vectors themselves (CNNs preserve locality)
Temporal information is explicitly added (see next section)
The flattening is consistent, so the model learns to handle this specific ordering
Step 3: Temporal Position Encodings
class TemporalEncoding(nn.Module):"""Learned temporal position encodings for video frames"""def__init__(self, max_frames: int=8, feature_dim: int=1024):super().__init__()# Learnable embedding for each time positionself.temporal_embeds = nn.Parameter(torch.randn(max_frames, feature_dim))def forward(self, features: torch.Tensor) -> torch.Tensor:""" Args: features: [T, S, d] tensor of visual features Returns: features with temporal encoding added: [T, S, d] """ T = features.shape[0]# Add temporal encoding to each frame# For each frame t, add self.temporal_embeds[t] to all S spatial tokens encoded = []for t inrange(T): encoded.append(features[t] +self.temporal_embeds[t])return torch.stack(encoded)# Example usagetemporal_encoder = TemporalEncoding(max_frames=8, feature_dim=1024)encoded_features = temporal_encoder(visual_features)print(f"Encoded features shape: {encoded_features.shape}")
Encoded features shape: torch.Size([8, 289, 1024])
Interpolation at Inference Time
Here’s something fascinating: Flamingo was trained with 8 frames, but at inference, it processes 30 frames!
How? By interpolating the temporal embeddings:
# Trained with 8 frames, need 30trained_embeds = temporal_encoder.temporal_embeds # Shape: [8, 1024]# Linear interpolate to get 30 embeddingsfrom torch.nn.functional import interpolateinference_embeds = interpolate( trained_embeds.T.unsqueeze(0), # [1, 1024, 8] size=30, mode='linear', align_corners=False).squeeze(0).T # [1024, 30]
This remarkable property allows the model to generalize to different frame counts!
You might wonder: What do these learned latents actually represent?
The Perceiver paper describes this beautifully: the latents act like cluster centers in a soft clustering algorithm. Each latent specializes in extracting a certain type of information.
import matplotlib.pyplot as pltimport numpy as npdef visualize_latent_specialization():"""Conceptual visualization of how latents might specialize""" fig, axes = plt.subplots(2, 4, figsize=(14, 7)) fig.suptitle('How Learned Latents Might Specialize', fontsize=14, weight='bold')# Conceptual specializations specializations = [ ("Objects", "focuses on discrete objects"), ("Actions", "captures motion and activity"), ("Spatial", "understands spatial relationships"), ("Colors", "attends to color patterns"), ("Textures", "focuses on surface properties"), ("Context", "understands scene context"), ("Temporal", "tracks changes over time"), ("Global", "integrates holistic understanding") ]for idx, (ax, (name, desc)) inenumerate(zip(axes.flat, specializations)):# Create a conceptual attention pattern attn_pattern = np.random.rand(16, 16) attn_pattern = attn_pattern / attn_pattern.sum() im = ax.imshow(attn_pattern, cmap='viridis') ax.set_title(f'Latent {idx}: {name}\n{desc}', fontsize=9) ax.axis('off') plt.tight_layout() plt.show()visualize_latent_specialization()
Key insight from the Stack Exchange discussion:
“The latent vectors in this case act as queries which extract information from the input data and need to be aligned in such a way that they extract the necessary information to perform the prediction task.”
Through training, gradients flow back and adjust these latents so they ask “better questions” - questions that extract information useful for the task at hand.
Not Traditional Clustering!
The Perceiver authors call this “end-to-end clustering” but it’s important to understand:
NOT K-means or any traditional clustering algorithm
NOT pre-computed from features
IS a soft, differentiable clustering that emerges from gradient descent
IS task-dependent - latents learn to cluster information relevant to the task
The “clusters” are implicit in the attention patterns that develop during training!
Part 6: The Flamingo Integration
How Flamingo Uses the Perceiver Resampler
The Perceiver Resampler was popularized by DeepMind’s Flamingo model. Let’s see how it fits into the complete architecture:
After the Perceiver Resampler produces fixed-size visual tokens, Flamingo uses gated cross-attention to inject this information into the language model:
class GatedCrossAttention(nn.Module):""" Cross-attention with gating for controlled visual information injection. The gate starts closed (alpha=0) and gradually opens during training. """def__init__(self, visual_dim: int, lang_dim: int, num_heads: int=8):super().__init__()self.cross_attn = nn.MultiheadAttention( embed_dim=lang_dim, num_heads=num_heads, batch_first=True )# Learnable gate parameterself.alpha = nn.Parameter(torch.zeros(1))def forward(self, lang_tokens: torch.Tensor, visual_tokens: torch.Tensor) -> torch.Tensor:""" Args: lang_tokens: [batch, seq_len, lang_dim] - language model tokens visual_tokens: [batch, num_visual, visual_dim] - resampled visual tokens Returns: Updated lang_tokens with visual information injected """# Cross-attention: language queries, visual keys/values attn_out, _ =self.cross_attn( query=lang_tokens, key=visual_tokens, value=visual_tokens )# Tanh gate: controls how much visual information flows through gate = torch.tanh(self.alpha) gated_output = gate * attn_out# Skip connection: original language tokensreturn lang_tokens + gated_output# Training progressiondef simulate_training_progression():"""Show how the gate opens during training""" fig, ax = plt.subplots(figsize=(10, 4))# Simulate gate values during training steps = np.arange(0, 10000)# Gate gradually opens from 0 towards 1 gate_values = np.tanh(steps /2000) # Smooth transition ax.plot(steps, gate_values, linewidth=2, color='purple') ax.set_xlabel('Training Steps') ax.set_ylabel('Gate Value (tanh(α))') ax.set_title('Gated Cross-Attention: Progressive Information Injection') ax.grid(True, alpha=0.3) ax.set_ylim(-0.1, 1.1)# Add annotations ax.annotate('Start: Gate Closed\nα≈0, tanh(α)≈0\nNo visual info', xy=(0, 0), xytext=(1000, 0.2), arrowprops=dict(arrowstyle='->', color='red'), fontsize=9, ha='center') ax.annotate('Mid: Gate Opening\nGradual\nvisual injection', xy=(3000, 0.8), xytext=(3500, 0.5), arrowprops=dict(arrowstyle='->', color='orange'), fontsize=9, ha='center') ax.annotate('End: Gate Open\nα large, tanh(α)≈1\nFull visual integration', xy=(8000, 1), xytext=(6500, 0.8), arrowprops=dict(arrowstyle='->', color='green'), fontsize=9, ha='center') plt.tight_layout() plt.show()simulate_training_progression()
Why the Gradual Gate?
When training Flamingo, the cross-attention gate starts closed (α=0 → tanh(α)=0). This is crucial:
Protects the pre-trained LLM: The language model already knows how to process text
Allows gradual adaptation: Visual information is slowly integrated
Prevents catastrophic forgetting: The LLM doesn’t suddenly “forget” how to handle text
Stabilizes training: The model learns to balance textual and visual information
It’s like introducing a new team member gradually, not throwing them into the deep end!
Part 7: Advanced Topics
Masking Strategies in Flamingo
Flamingo uses an interesting masking strategy for cross-attention: text tokens only attend to the most recent image.
def create_flamingo_mask( text_len: int, image_positions: list, # Positions where images appear num_visual_tokens: int# Tokens per image (R from resampler)) -> torch.Tensor:""" Create a mask where text tokens can only attend to the most recent image. Args: text_len: Length of text sequence image_positions: List of positions where <image> tokens appear num_visual_tokens: Number of visual tokens per image Returns: Boolean mask for cross-attention [text_len, num_images * num_visual_tokens] """ num_images =len(image_positions) total_visual = num_images * num_visual_tokens mask = torch.zeros(text_len, total_visual, dtype=torch.bool)for text_idx inrange(text_len):# Find the most recent image before this text token recent_image_idx =-1for img_idx, img_pos inenumerate(image_positions):if img_pos < text_idx: recent_image_idx = img_idx# Only allow attending to that imageif recent_image_idx >=0: start = recent_image_idx * num_visual_tokens end = start + num_visual_tokens mask[text_idx, start:end] =False# False = can attendelse:# No image before this text mask[text_idx, :] =True# True = maskedreturn mask# Exampletext_len =20image_positions = [2, 10, 15] # <image> tokens at positions 2, 10, 15num_visual_tokens =64mask = create_flamingo_mask(text_len, image_positions, num_visual_tokens)# Visualizefig, ax = plt.subplots(figsize=(10, 6))im = ax.imshow(~mask.numpy(), cmap='Blues', aspect='auto')ax.set_xlabel('Visual Tokens', weight='bold')ax.set_ylabel('Text Tokens', weight='bold')ax.set_title('Flamingo Cross-Attention Mask\n(Blue = Can Attend, White = Masked)')# Add image boundariesfor i inrange(1, len(image_positions)): ax.axvline(x=i * num_visual_tokens -0.5, color='red', linestyle='--', alpha=0.5) ax.text(i * num_visual_tokens -32, text_len +0.5, f'Image {i}', ha='center', fontsize=8, color='red')# Add image position markersfor pos in image_positions: ax.axhline(y=pos +0.5, color='green', linestyle=':', alpha=0.7) ax.text(-1, pos, '<img>', ha='right', va='center', fontsize=7, color='green')plt.tight_layout()plt.show()
Spatial vs. Temporal Encodings
A common question: Why only temporal encodings? What about position?
def explain_encoding_choices():"""Explain why Flamingo uses temporal but not spatial encodings""" fig, axes = plt.subplots(1, 2, figsize=(12, 5))# Spatial: Already in features ax = axes[0] ax.text(0.5, 0.9, 'Spatial Information', ha='center', weight='bold', transform=ax.transAxes) reasons = ["CNNs (like NFNet) inherently","preserve spatial relationships","in their feature maps.","","Early layers detect edges, textures","at specific locations.","","Deeper layers build spatial","relationships from these.","","No explicit encoding needed!" ]for i, reason inenumerate(reasons): color ='darkgreen'if'No explicit'in reason else'black' weight ='bold'if'No explicit'in reason else'normal' ax.text(0.1, 0.75- i *0.08, reason, transform=ax.transAxes, fontsize=10, color=color, weight=weight) ax.axis('off') ax.add_patch(plt.Rectangle((0.05, 0.05), 0.9, 0.9, fill=False, edgecolor='green', linewidth=2, transform=ax.transAxes))# Temporal: Needs explicit encoding ax = axes[1] ax.text(0.5, 0.9, 'Temporal Information', ha='center', weight='bold', transform=ax.transAxes) reasons = ["Attention shuffles tokens,","losing temporal order.","","Frames at t=0 and t=7 become","indistinguishable after attention!","","Solution: Add learned","temporal embeddings.","","Each frame gets a unique","time vector added to it." ]for i, reason inenumerate(reasons): color ='darkred'if'Solution'in reason or'time vector'in reason else'black' weight ='bold'if'Solution'in reason or'time vector'in reason else'normal' ax.text(0.1, 0.75- i *0.08, reason, transform=ax.transAxes, fontsize=10, color=color, weight=weight) ax.axis('off') ax.add_patch(plt.Rectangle((0.05, 0.05), 0.9, 0.9, fill=False, edgecolor='red', linewidth=2, transform=ax.transAxes)) plt.suptitle('Why Temporal but Not Spatial Encodings?', fontsize=13, weight='bold', y=1.02) plt.tight_layout() plt.show()explain_encoding_choices()
“Note that we only use temporal encodings and no explicit spatial grid position encodings; we did not observe improvements from the latter. This rationale behind is likely that CNNs, such as our NFNet encoder, are known to implicitly include spatial information.” — The Flamingo Paper
Part 8: Practical Implementation
A Minimal Working Example
import torchimport torch.nn as nnclass MinimalPerceiverResampler(nn.Module):"""Simplified Perceiver Resampler for educational purposes"""def__init__(self, dim=512, num_latents=32, num_layers=2):super().__init__()# Learnable latentsself.latents = nn.Parameter(torch.randn(num_latents, dim))# Temporal embeddings (for up to 8 frames)self.temporal_embeds = nn.Parameter(torch.randn(8, dim))# Cross-attention layersself.cross_attn_layers = nn.ModuleList([ nn.MultiheadAttention(dim, num_heads=8, batch_first=True)for _ inrange(num_layers) ])# Feed-forward networksself.ffns = nn.ModuleList([ nn.Sequential(nn.Linear(dim, dim *4), nn.GELU(), nn.Linear(dim *4, dim))for _ inrange(num_layers) ])# Layer normsself.norms1 = nn.ModuleList([nn.LayerNorm(dim) for _ inrange(num_layers)])self.norms2 = nn.ModuleList([nn.LayerNorm(dim) for _ inrange(num_layers)])def forward(self, x):""" Args: x: [B, T, S, D] - batch of video features Returns: [B, R, D] - resampled features """ B, T, S, D = x.shape# Add temporal encodingfor t inrange(T): x[:, t] = x[:, t] +self.temporal_embeds[t]# Flatten x = x.reshape(B, T * S, D)# Expand latents for batch latents =self.latents.unsqueeze(0).expand(B, -1, -1)# Process through layersfor cross_attn, ffn, norm1, norm2 inzip(self.cross_attn_layers, self.ffns, self.norms1, self.norms2 ):# Cross-attention attn_out, _ = cross_attn(latents, x, x) latents = latents + norm1(attn_out)# Feed-forward ff_out = ffn(latents) latents = latents + norm2(ff_out)return latents# Test the modelmodel = MinimalPerceiverResampler(dim=512, num_latents=32, num_layers=2)# Input: 2 videos, 4 frames each, 49 spatial tokens, 512 dimensionsinput_features = torch.randn(2, 4, 49, 512)# Forward passoutput = model(input_features)print(f"Input shape: {input_features.shape}")print(f"Output shape: {output.shape}")print(f"Compression ratio: {input_features.shape[1] * input_features.shape[2]} → {output.shape[1]}")
import pandas as pd# Create a hyperparameter reference tablehyperparams = {'Hyperparameter': ['num_latents (R)','num_layers','num_heads','feature_dim (d)','max_frames' ],'Typical Values': ['64 (Flamingo)','4-6','8','1024','8 (trainable)' ],'Effect': ['More latents → more info preserved, but slower','More layers → more processing, diminishing returns','More heads → can attend to more things in parallel','Larger → more expressive, but more compute','Higher → need more memory, but interpolation helps' ],'Trade-offs': ['32-128 is common sweet spot','2-8 layers typical','4-16 heads typical','512-2048 common','8 is standard, interpolate at inference' ]}df = pd.DataFrame(hyperparams)df
Hyperparameter
Typical Values
Effect
Trade-offs
0
num_latents (R)
64 (Flamingo)
More latents → more info preserved, but slower
32-128 is common sweet spot
1
num_layers
4-6
More layers → more processing, diminishing ret...
2-8 layers typical
2
num_heads
8
More heads → can attend to more things in para...
4-16 heads typical
3
feature_dim (d)
1024
Larger → more expressive, but more compute
512-2048 common
4
max_frames
8 (trainable)
Higher → need more memory, but interpolation h...
8 is standard, interpolate at inference
Part 9: Theoretical Insights
Why This Works: Theoretical Perspectives
1. Information Bottleneck Perspective
The Perceiver Resampler creates an information bottleneck:
The insight: By forcing information through a narrow bottleneck, the model learns to preserve only the most task-relevant information. This is a form of representation learning - the latents learn to capture the essence of the visual input.
2. The Set Function Perspective
The Perceiver Resampler operates on sets, not sequences:
Visual features form a set (order doesn’t matter after encoding)
Latents form another set
Cross-attention operates on set-to-set relationships
This makes the architecture permutation invariant - it doesn’t depend on the specific ordering of spatial tokens.
3. The “Soft Database” Perspective
From the Stack Exchange discussion:
“You can think of attention, as used in Transformers or Perceiver as a soft differentiable database.”
Keys: Indices in the database
Values: Stored information
Queries: What we’re looking for
Attention weights: Soft retrieval (not just 0 or 1, but weighted)
This “soft” nature allows gradients to flow through, making the entire system learnable!
Part 10: Common Questions and Gotchas
Q: Why not just use pooling (average/max)?
A: Pooling is a fixed operation - it can’t learn WHAT to preserve. The Perceiver Resampler learns to ask task-specific questions.
Imagine you’re writing a paper about birds: - Average pooling = reading every sentence and taking the average - Perceiver = having specific questions like “What color is it?”, “Where does it live?”, etc.
Q: Are the latents really just random at initialization?
A: Yes! They’re randomly initialized, just like any other neural network weights. The magic is that gradients from the task teach them what to ask for.
This is why the Stack Exchange emphasized “learned via gradient descent” - they’re not computed from features, they’re learned parameters!
Q: Can I use the Perceiver Resampler for other modalities?
A: Absolutely! The original Perceiver paper showed it works for: - Images - Audio - Point clouds - Video - Video + Audio
Any modality that can be encoded into feature vectors can be resampled!
Q: What’s the difference between Perceiver and Perceiver IO?
A: Perceiver IO (the follow-up paper) adds cross-attention on the OUTPUT side too, allowing the model to produce structured outputs (like segmentation masks) while keeping the input bottleneck.
This blog is based on insights from: - The Perceiver paper by Jaegle et al. (DeepMind, 2021) - The Flamingo paper by Alayrac et al. (DeepMind, 2022) - The excellent Stack Exchange discussion - The Flamingo explainer by Daniel Warfield
Happy Learning! May your attention mechanisms always attend to what matters!