Concept Bottleneck Models: Making Neural Networks Speak Human

How interpretable intermediate concepts bridge the gap between black-box predictions and human understanding

Interpretable ML
Deep Learning
Explainable AI
Computer Vision
Author

Rishabh Mondal

Published

March 16, 2026

Interpretable Machine Learning

Concept Bottleneck Models: Making Neural Networks Speak Human

A comprehensive guide to building neural networks that explain their reasoning through human-understandable concepts.

Author

Rishabh Mondal

Published

March 16, 2026

Paper

Koh et al., ICML 2020

  • Interpretable ML
  • Deep Learning
  • Explainable AI
  • Computer Vision

Part 1: Thinking Lens - Why Interpretability Matters

Before we write any equations or code, let us ask a fundamental question: why do we need interpretable models at all?

Imagine a doctor using an AI system to diagnose skin cancer. The model looks at a dermoscopy image and outputs: “Malignant, 87% confidence.” The doctor asks:

“Why do you think it is malignant?”

The model cannot answer. It is a black box. All the reasoning is hidden inside millions of parameters that no human can understand.

Now imagine a different model that says:

“Malignant, 87% confidence. Reasoning: asymmetric shape (YES), irregular borders (YES), multiple colors (YES), diameter > 6mm (YES).”

This is much more useful. The doctor can verify each observation. If the model incorrectly detected “irregular borders” on a clearly smooth lesion, the doctor can mentally correct that and reconsider the diagnosis.

This shift from “what” to “why” is the core motivation behind Concept Bottleneck Models.

An intuitive analogy is a student showing their work on a math exam. Even if they get the final answer wrong, the teacher can see where they made a mistake. Without showing work, no feedback is possible.

Black-box model:
input → ??? → prediction

Concept Bottleneck Model:
input → concepts (interpretable) → prediction
        ↑
        Human can inspect and correct here

Think of a prediction task in your domain. What intermediate concepts would help a human verify the model’s reasoning?

  • Medical imaging: anatomical structures, lesion characteristics
  • Bird identification: wing color, beak shape, size
  • Loan approval: income stability, debt ratio, employment history

Part 2: The Core Idea - What Are Concept Bottleneck Models?

2.1 Definition

A Concept Bottleneck Model (CBM) is a neural network architecture where predictions must pass through an intermediate layer of human-interpretable concepts.

Instead of learning: \[ f: X \rightarrow Y \quad \text{(input directly to output)} \]

A CBM learns two functions: \[ g: X \rightarrow C \quad \text{(input to concepts)} \] \[ f: C \rightarrow Y \quad \text{(concepts to output)} \]

Where:

  • \(X\) is the input space (images, text, etc.)
  • \(C\) is a vector of interpretable concepts (e.g., “has yellow wings”, “is large”)
  • \(Y\) is the output space (class labels)

The key constraint: all information must flow through \(C\). The model cannot smuggle hidden information past the concepts.

2.2 What makes a concept “good”?

Not all intermediate representations are concepts. Good concepts must be:

Property Meaning Example
Meaningful Expert understands it “has yellow wings” ✓, “feature_47” ✗
Observable Can be determined from input “wing color” ✓, “bird’s mood” ✗
Predictive Helps determine output “beak shape” for bird species ✓

The fundamental difference: standard models hide their reasoning, while CBMs make intermediate concepts visible and correctable.

2.3 Code: Setting up our toy dataset

Let us create a simple dataset to understand CBMs concretely. We will simulate bird classification with 4 interpretable concepts.

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt

torch.manual_seed(42)
np.random.seed(42)

def create_bird_dataset(n_samples=500):
    """
    Toy bird classification dataset.

    Concepts (4 binary attributes):
      c0: has_yellow_wings
      c1: has_long_beak
      c2: is_large
      c3: has_forked_tail

    Species (3 classes):
      0: Sparrow - small, short beak, no yellow
      1: Warbler - yellow wings, small
      2: Hawk - large, long beak, forked tail
    """
    species = np.random.choice([0, 1, 2], n_samples, p=[0.4, 0.35, 0.25])
    concepts = np.zeros((n_samples, 4))

    for i, s in enumerate(species):
        if s == 0:  # Sparrow
            concepts[i] = [np.random.random() < 0.1,  # rarely yellow
                          np.random.random() < 0.2,   # short beak
                          np.random.random() < 0.1,   # small
                          np.random.random() < 0.2]   # no forked tail
        elif s == 1:  # Warbler
            concepts[i] = [np.random.random() < 0.9,  # usually yellow
                          np.random.random() < 0.3,   # short beak
                          np.random.random() < 0.2,   # small
                          np.random.random() < 0.3]   # sometimes forked
        else:  # Hawk
            concepts[i] = [np.random.random() < 0.15, # rarely yellow
                          np.random.random() < 0.85,  # long beak
                          np.random.random() < 0.9,   # large
                          np.random.random() < 0.8]   # forked tail

    # Simulate image features from concepts
    W = np.random.randn(4, 8)
    X = concepts @ W + np.random.randn(n_samples, 8) * 0.5

    return torch.FloatTensor(X), torch.FloatTensor(concepts), torch.LongTensor(species)

# Create dataset
X, C, Y = create_bird_dataset(500)
concept_names = ['Yellow Wings', 'Long Beak', 'Large Size', 'Forked Tail']
species_names = ['Sparrow', 'Warbler', 'Hawk']

print(f"Dataset: {len(X)} samples")
print(f"Features: {X.shape[1]} dimensions")
print(f"Concepts: {C.shape[1]} ({concept_names})")
print(f"Classes: {len(species_names)} ({species_names})")
Dataset: 500 samples
Features: 8 dimensions
Concepts: 4 (['Yellow Wings', 'Long Beak', 'Large Size', 'Forked Tail'])
Classes: 3 (['Sparrow', 'Warbler', 'Hawk'])

Part 3: The Black Box Problem

3.1 Thinking Lens

Before building a CBM, let us first see what is wrong with standard neural networks. We will train a black-box model and show that it cannot explain its predictions.

3.2 Definition: Standard Neural Network

A standard classifier learns a direct mapping:

\[ \hat{y} = f_\theta(x) = \text{softmax}(\text{MLP}(x)) \]

The intermediate layers are not constrained to be interpretable. They encode whatever features help minimize the loss, regardless of human understanding.

3.3 Code: Train a black-box model

class BlackBoxModel(nn.Module):
    """Standard neural network: X -> hidden -> Y"""
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(8, 16), nn.ReLU(),
            nn.Linear(16, 16), nn.ReLU(),
            nn.Linear(16, 3)
        )

    def forward(self, x):
        return self.net(x)

# Train
blackbox = BlackBoxModel()
optimizer = torch.optim.Adam(blackbox.parameters(), lr=0.01)

for epoch in range(100):
    optimizer.zero_grad()
    loss = F.cross_entropy(blackbox(X), Y)
    loss.backward()
    optimizer.step()

# Evaluate
with torch.no_grad():
    acc = (blackbox(X).argmax(1) == Y).float().mean()
print(f"Black-box accuracy: {acc:.1%}")
Black-box accuracy: 84.6%

3.4 The problem: No explanation

# Pick a sample and try to understand the prediction
idx = 42
sample_x = X[idx:idx+1]
true_concepts = C[idx]
true_species = Y[idx].item()

with torch.no_grad():
    logits = blackbox(sample_x)
    pred = logits.argmax().item()
    conf = F.softmax(logits, dim=1)[0, pred].item()

print(f"Sample {idx}:")
print(f"  True species: {species_names[true_species]}")
print(f"  Predicted: {species_names[pred]} ({conf:.1%} confidence)")
print(f"\n  True concepts: {dict(zip(concept_names, true_concepts.numpy().astype(int)))}")
print(f"\n  Model's reasoning: ??? (hidden in {sum(p.numel() for p in blackbox.parameters())} parameters)")
Sample 42:
  True species: Sparrow
  Predicted: Sparrow (96.2% confidence)

  True concepts: {'Yellow Wings': np.int64(0), 'Long Beak': np.int64(0), 'Large Size': np.int64(0), 'Forked Tail': np.int64(0)}

  Model's reasoning: ??? (hidden in 467 parameters)

The model makes a prediction but cannot tell us why. This is the black-box problem.

Part 4: The CBM Architecture

4.1 Thinking Lens

The key insight of CBMs is architectural: force all information to pass through a bottleneck of interpretable concepts. Think of it like a checkpoint where we can inspect what the model “sees” before it makes its final decision.

Standard:     Image → [hidden layers] → Prediction
                         ↓
                    (uninterpretable)

CBM:          Image → [concept predictor] → Concepts → [label predictor] → Prediction
                                               ↓
                                    Human can read these!

4.2 Definition: CBM Architecture

A Concept Bottleneck Model has two components:

1. Concept Predictor \(g_\theta: X \rightarrow [0,1]^k\)

Maps input to \(k\) concept probabilities: \[ \hat{c} = \sigma(W_c \cdot \text{encoder}(x) + b_c) \]

2. Label Predictor \(f_\phi: [0,1]^k \rightarrow \mathbb{R}^L\)

Maps concepts to class logits: \[ \hat{y} = W_y \cdot \hat{c} + b_y \]

The bottleneck constraint: \(f_\phi\) only receives \(\hat{c}\) as input. It cannot access \(x\) or any hidden features.

The CBM architecture forces all information through interpretable concepts before making the final prediction.

4.3 Code: Implement CBM

class ConceptBottleneckModel(nn.Module):
    """
    Concept Bottleneck Model: X -> Concepts -> Y
    """
    def __init__(self, input_dim=8, num_concepts=4, num_classes=3):
        super().__init__()

        # g(x): Input -> Concepts
        self.concept_predictor = nn.Sequential(
            nn.Linear(input_dim, 16),
            nn.ReLU(),
            nn.Linear(16, num_concepts),
            nn.Sigmoid()  # Concepts are probabilities
        )

        # f(c): Concepts -> Labels
        self.label_predictor = nn.Sequential(
            nn.Linear(num_concepts, 8),
            nn.ReLU(),
            nn.Linear(8, num_classes)
        )

    def forward(self, x):
        concepts = self.concept_predictor(x)   # Bottleneck!
        logits = self.label_predictor(concepts)
        return logits, concepts

    def predict_concepts(self, x):
        return self.concept_predictor(x)

# Create model
cbm = ConceptBottleneckModel()
print("CBM Architecture:")
print(f"  Concept predictor: 8 -> 16 -> 4 (with sigmoid)")
print(f"  Label predictor: 4 -> 8 -> 3")
print(f"  Bottleneck size: 4 concepts")
CBM Architecture:
  Concept predictor: 8 -> 16 -> 4 (with sigmoid)
  Label predictor: 4 -> 8 -> 3
  Bottleneck size: 4 concepts

4.4 Code: Trace the data flow

# See how data flows through the bottleneck
sample = X[0:1]

with torch.no_grad():
    # Step 1: Input -> Concepts
    concepts = cbm.predict_concepts(sample)
    print("Step 1: Input features -> Concept predictions")
    for i, name in enumerate(concept_names):
        print(f"  {name}: {concepts[0, i]:.2f}")

    # Step 2: Concepts -> Prediction
    logits, _ = cbm(sample)
    probs = F.softmax(logits, dim=1)[0]
    print(f"\nStep 2: Concepts -> Label prediction")
    for i, name in enumerate(species_names):
        print(f"  {name}: {probs[i]:.2f}")
Step 1: Input features -> Concept predictions
  Yellow Wings: 0.52
  Long Beak: 0.48
  Large Size: 0.48
  Forked Tail: 0.44

Step 2: Concepts -> Label prediction
  Sparrow: 0.48
  Warbler: 0.15
  Hawk: 0.37

Part 5: Training Strategies

5.1 Thinking Lens

How should we train the two components of a CBM? This is not as simple as training a single network. We have three choices, each with different trade-offs:

  1. Independent: Train each part separately
  2. Sequential: Train concept predictor first, then label predictor
  3. Joint: Train everything end-to-end together

Think of it like teaching a student:

  • Independent: Learn facts first, then learn to apply them (separately)
  • Sequential: Learn facts, then practice applying them
  • Joint: Learn both together, adjusting facts based on what helps applications

5.2 Definition: Three Training Strategies

Independent Training \[ \theta^* = \arg\min_\theta \mathcal{L}_{\text{concept}}(X, C; \theta) \] \[ \phi^* = \arg\min_\phi \mathcal{L}_{\text{task}}(C_{\text{true}}, Y; \phi) \]

The label predictor trains on ground truth concepts.

Sequential Training \[ \theta^* = \arg\min_\theta \mathcal{L}_{\text{concept}}(X, C; \theta) \] \[ \phi^* = \arg\min_\phi \mathcal{L}_{\text{task}}(\hat{C}, Y; \phi) \]

The label predictor trains on predicted concepts.

Joint Training \[ \theta^*, \phi^* = \arg\min_{\theta,\phi} \mathcal{L}_{\text{task}} + \lambda \mathcal{L}_{\text{concept}} \]

Both components optimize together.

Three training strategies with different trade-offs between concept accuracy and task performance.

5.3 Code: Implement all three strategies

# Split data
n_train = 400
X_train, X_test = X[:n_train], X[n_train:]
C_train, C_test = C[:n_train], C[n_train:]
Y_train, Y_test = Y[:n_train], Y[n_train:]

def train_independent(epochs=100):
    """Train concept and label predictors separately."""
    model = ConceptBottleneckModel()

    # Train concept predictor
    opt = torch.optim.Adam(model.concept_predictor.parameters(), lr=0.01)
    for _ in range(epochs):
        opt.zero_grad()
        c_pred = model.predict_concepts(X_train)
        loss = F.binary_cross_entropy(c_pred, C_train)
        loss.backward()
        opt.step()

    # Train label predictor on TRUE concepts
    opt = torch.optim.Adam(model.label_predictor.parameters(), lr=0.01)
    for _ in range(epochs):
        opt.zero_grad()
        logits = model.label_predictor(C_train)  # Ground truth!
        loss = F.cross_entropy(logits, Y_train)
        loss.backward()
        opt.step()

    return model

def train_sequential(epochs=100):
    """Train concept predictor first, then label predictor on predictions."""
    model = ConceptBottleneckModel()

    # Train concept predictor
    opt = torch.optim.Adam(model.concept_predictor.parameters(), lr=0.01)
    for _ in range(epochs):
        opt.zero_grad()
        c_pred = model.predict_concepts(X_train)
        loss = F.binary_cross_entropy(c_pred, C_train)
        loss.backward()
        opt.step()

    # Train label predictor on PREDICTED concepts
    opt = torch.optim.Adam(model.label_predictor.parameters(), lr=0.01)
    for _ in range(epochs):
        opt.zero_grad()
        with torch.no_grad():
            c_pred = model.predict_concepts(X_train)
        logits = model.label_predictor(c_pred)  # Predicted!
        loss = F.cross_entropy(logits, Y_train)
        loss.backward()
        opt.step()

    return model

def train_joint(epochs=100, lambd=1.0):
    """Train both components end-to-end."""
    model = ConceptBottleneckModel()
    opt = torch.optim.Adam(model.parameters(), lr=0.01)

    for _ in range(epochs):
        opt.zero_grad()
        logits, c_pred = model(X_train)

        concept_loss = F.binary_cross_entropy(c_pred, C_train)
        task_loss = F.cross_entropy(logits, Y_train)
        total_loss = task_loss + lambd * concept_loss

        total_loss.backward()
        opt.step()

    return model

# Train all three
cbm_ind = train_independent()
cbm_seq = train_sequential()
cbm_joint = train_joint()
print("Trained all three strategies!")
Trained all three strategies!

5.4 Code: Compare strategies

def evaluate(model, X, C, Y):
    with torch.no_grad():
        logits, c_pred = model(X)
        task_acc = (logits.argmax(1) == Y).float().mean().item()
        concept_acc = ((c_pred > 0.5) == C).float().mean().item()
    return task_acc, concept_acc

print("=" * 55)
print(f"{'Strategy':<15} {'Task Accuracy':>15} {'Concept Accuracy':>18}")
print("=" * 55)

for name, model in [('Independent', cbm_ind),
                    ('Sequential', cbm_seq),
                    ('Joint', cbm_joint)]:
    task_acc, concept_acc = evaluate(model, X_test, C_test, Y_test)
    print(f"{name:<15} {task_acc:>14.1%} {concept_acc:>17.1%}")

print("=" * 55)
=======================================================
Strategy          Task Accuracy   Concept Accuracy
=======================================================
Independent              78.0%             97.3%
Sequential               77.0%             97.8%
Joint                    78.0%             97.8%
=======================================================
  • Independent has high concept accuracy but may struggle at test time (label predictor never saw noisy concept predictions)
  • Sequential handles prediction noise better
  • Joint typically achieves the best task accuracy by allowing trade-offs

Part 6: The Intervention Mechanism

6.1 Thinking Lens

This is the killer feature of CBMs. Because concepts are interpretable, humans can correct them at test time!

Imagine a doctor reviewing the AI’s concept predictions:

AI predicts: Malignant (75%)
  - Asymmetric shape: YES (0.92)
  - Irregular borders: YES (0.78)  <- Doctor: "No, borders are smooth"
  - Multiple colors: YES (0.85)
  - Large diameter: NO (0.23)

The doctor corrects “irregular borders” to NO. The model re-computes:

After correction: Benign (68%)

This is human-AI collaboration. The model provides its best guesses, humans fix errors where they have expertise, and the combined system is better than either alone.

6.2 Definition: Intervention

An intervention replaces predicted concepts with ground truth values.

Let \(\hat{c} = g_\theta(x)\) be predicted concepts. Given an intervention set \(\mathcal{I}\) (indices to correct):

\[ \tilde{c}_j = \begin{cases} c_j^{\text{true}} & \text{if } j \in \mathcal{I} \\ \hat{c}_j & \text{otherwise} \end{cases} \]

The corrected prediction is: \[ \tilde{y} = f_\phi(\tilde{c}) \]

Humans can inspect concept predictions and correct errors, improving the final prediction.

6.3 Code: Perform an intervention

def inspect_and_intervene(model, x, c_true, y_true):
    """Show model reasoning and demonstrate intervention."""
    with torch.no_grad():
        logits, c_pred = model(x)
        pred = logits.argmax().item()

    print("MODEL REASONING")
    print("=" * 50)
    print(f"True species: {species_names[y_true]}")
    print(f"Predicted: {species_names[pred]}")
    print("\nConcept predictions:")

    wrong_concepts = []
    for i, name in enumerate(concept_names):
        true_val = c_true[i].item()
        pred_val = c_pred[0, i].item()
        correct = (pred_val > 0.5) == (true_val > 0.5)
        status = "✓" if correct else "✗ WRONG"
        if not correct:
            wrong_concepts.append(i)
        print(f"  {name:<15}: {pred_val:.2f} (true: {int(true_val)}) {status}")

    if wrong_concepts and pred != y_true:
        print(f"\n🔧 INTERVENTION: Correcting {len(wrong_concepts)} concept(s)...")

        # Correct the concepts
        c_corrected = c_pred.clone()
        for idx in wrong_concepts:
            c_corrected[0, idx] = c_true[idx]

        # Get new prediction
        with torch.no_grad():
            new_logits = model.label_predictor(c_corrected)
            new_pred = new_logits.argmax().item()

        print(f"  Before: {species_names[pred]}")
        print(f"  After:  {species_names[new_pred]}")
        if new_pred == y_true:
            print("  ✅ Intervention fixed the prediction!")

# Find a sample where intervention helps
for idx in range(len(X_test)):
    x, c, y = X_test[idx:idx+1], C_test[idx], Y_test[idx].item()
    with torch.no_grad():
        pred = cbm_joint(x)[0].argmax().item()
    if pred != y:
        inspect_and_intervene(cbm_joint, x, c, y)
        break
MODEL REASONING
==================================================
True species: Hawk
Predicted: Sparrow

Concept predictions:
  Yellow Wings   : 0.04 (true: 0) ✓
  Long Beak      : 0.14 (true: 1) ✗ WRONG
  Large Size     : 0.00 (true: 0) ✓
  Forked Tail    : 1.00 (true: 1) ✓

🔧 INTERVENTION: Correcting 1 concept(s)...
  Before: Sparrow
  After:  Hawk
  ✅ Intervention fixed the prediction!

6.4 Code: Intervention-accuracy curve

How much does accuracy improve as we correct more concepts?

def intervention_curve(model, X, C, Y):
    """Compute accuracy at different intervention levels."""
    n_concepts = C.shape[1]
    accuracies = []

    for n_correct in range(n_concepts + 1):
        correct = 0
        for i in range(len(X)):
            with torch.no_grad():
                _, c_pred = model(X[i:i+1])

                # Correct first n concepts
                c_use = c_pred[0].clone()
                for j in range(n_correct):
                    c_use[j] = C[i, j]

                logits = model.label_predictor(c_use.unsqueeze(0))
                pred = logits.argmax().item()

            if pred == Y[i].item():
                correct += 1

        accuracies.append(correct / len(X))

    return accuracies

# Plot
fig, ax = plt.subplots(figsize=(9, 5))

for name, model, color in [('Independent', cbm_ind, '#FF6B6B'),
                            ('Sequential', cbm_seq, '#4ECDC4'),
                            ('Joint', cbm_joint, '#45B7D1')]:
    accs = intervention_curve(model, X_test, C_test, Y_test)
    ax.plot(range(5), accs, 'o-', label=name, color=color, linewidth=2, markersize=8)

ax.set_xlabel('Number of Concepts Corrected', fontsize=12)
ax.set_ylabel('Task Accuracy', fontsize=12)
ax.set_title('How Much Do Human Corrections Help?', fontsize=14)
ax.set_xticks(range(5))
ax.set_xticklabels(['0\n(no help)'] + [str(i) for i in range(1, 5)])
ax.legend()
ax.grid(True, alpha=0.3)
ax.set_ylim(0, 1.05)
plt.tight_layout()
plt.show()

Key Insight

As we correct more concepts, accuracy improves toward 100%. This shows the value of human-AI collaboration with CBMs.

Part 7: Putting It All Together

7.1 Thinking Lens

Let us now compare black-box and CBM models side-by-side on the same sample. This will clearly show the interpretability advantage.

7.2 Code: Side-by-side comparison

# Pick a test sample
idx = 10
x = X_test[idx:idx+1]
c_true = C_test[idx]
y_true = Y_test[idx].item()

print("=" * 60)
print("COMPARISON: Black-Box vs Concept Bottleneck Model")
print("=" * 60)
print(f"\nTrue species: {species_names[y_true]}")
print(f"True concepts: {dict(zip(concept_names, c_true.numpy().astype(int)))}")

# Black-box
with torch.no_grad():
    bb_logits = blackbox(x)
    bb_pred = bb_logits.argmax().item()
    bb_conf = F.softmax(bb_logits, dim=1)[0, bb_pred].item()

print(f"\n🔲 BLACK-BOX MODEL")
print(f"   Prediction: {species_names[bb_pred]} ({bb_conf:.1%})")
print(f"   Reasoning: ??? (uninterpretable)")

# CBM
with torch.no_grad():
    cbm_logits, c_pred = cbm_joint(x)
    cbm_pred = cbm_logits.argmax().item()
    cbm_conf = F.softmax(cbm_logits, dim=1)[0, cbm_pred].item()

print(f"\n🔍 CONCEPT BOTTLENECK MODEL")
print(f"   Prediction: {species_names[cbm_pred]} ({cbm_conf:.1%})")
print(f"   Reasoning:")
for i, name in enumerate(concept_names):
    val = c_pred[0, i].item()
    bar = "█" * int(val * 10)
    print(f"      {name:<15}: {val:.2f} {bar}")
============================================================
COMPARISON: Black-Box vs Concept Bottleneck Model
============================================================

True species: Sparrow
True concepts: {'Yellow Wings': np.int64(0), 'Long Beak': np.int64(0), 'Large Size': np.int64(0), 'Forked Tail': np.int64(0)}

🔲 BLACK-BOX MODEL
   Prediction: Sparrow (86.1%)
   Reasoning: ??? (uninterpretable)

🔍 CONCEPT BOTTLENECK MODEL
   Prediction: Sparrow (83.4%)
   Reasoning:
      Yellow Wings   : 0.30 ██
      Long Beak      : 0.01 
      Large Size     : 0.25 ██
      Forked Tail    : 0.00 

7.3 Definition: Summary

Aspect Black-Box Concept Bottleneck Model
Architecture X → hidden → Y X → concepts → Y
Interpretability None Full (inspect concepts)
Intervention Not possible Correct concepts at test time
Requirements (X, Y) pairs (X, C, Y) triples
Accuracy Often highest Competitive (with joint training)

CBMs achieve competitive accuracy while maintaining full interpretability.

Part 8: Real-World Applications

8.1 Thinking Lens

Where are CBMs most valuable? In domains where:

  1. Mistakes are costly - healthcare, legal, financial decisions
  2. Trust is essential - users need to verify AI reasoning
  3. Experts can help - domain knowledge can correct errors
  4. Regulation requires explanations - GDPR, medical device approvals

8.2 Definition: Application Domains

Healthcare: Predict diagnoses through medical concepts (symptoms, biomarkers, imaging findings). Doctors verify concepts before accepting diagnoses.

Autonomous Systems: Predict actions through perception concepts (object detected, road conditions, traffic signals). Operators can verify scene understanding.

Scientific Discovery: Encode known scientific concepts and reveal which ones the model relies on for predictions.

In medical imaging, CBMs enable safe human-AI collaboration by making intermediate reasoning inspectable.

8.3 Code: Complete CBM pipeline

class CompleteCBM:
    """Production-ready Concept Bottleneck Model."""

    def __init__(self, input_dim, num_concepts, num_classes, concept_names):
        self.model = ConceptBottleneckModel(input_dim, num_concepts, num_classes)
        self.concept_names = concept_names

    def train(self, X, C, Y, epochs=100, strategy='joint'):
        if strategy == 'joint':
            opt = torch.optim.Adam(self.model.parameters(), lr=0.01)
            for _ in range(epochs):
                opt.zero_grad()
                logits, c_pred = self.model(X)
                loss = F.cross_entropy(logits, Y) + F.binary_cross_entropy(c_pred, C)
                loss.backward()
                opt.step()

    def predict(self, x):
        with torch.no_grad():
            logits, concepts = self.model(x)
            return logits.argmax(1), concepts

    def explain(self, x, class_names):
        pred, concepts = self.predict(x)

        lines = [f"Prediction: {class_names[pred[0]]}"]
        lines.append("Reasoning:")
        for i, name in enumerate(self.concept_names):
            val = concepts[0, i].item()
            status = "YES" if val > 0.5 else "NO"
            lines.append(f"  • {name}: {status} ({val:.0%})")

        return "\n".join(lines)

    def intervene(self, x, corrections):
        """corrections: dict of {concept_idx: new_value}"""
        with torch.no_grad():
            _, concepts = self.model(x)
            for idx, val in corrections.items():
                concepts[0, idx] = val
            logits = self.model.label_predictor(concepts)
            return logits.argmax(1)

# Demo
complete_cbm = CompleteCBM(8, 4, 3, concept_names)
complete_cbm.train(X_train, C_train, Y_train)

sample = X_test[0:1]
print(complete_cbm.explain(sample, species_names))
Prediction: Sparrow
Reasoning:
  • Yellow Wings: NO (7%)
  • Long Beak: NO (4%)
  • Large Size: NO (5%)
  • Forked Tail: NO (0%)

Part 9: Summary

Key Takeaways

What We Learned
  1. The Problem: Black-box models make predictions but cannot explain their reasoning

  2. The Solution: Force predictions through interpretable concepts (the “bottleneck”)

  3. The Architecture:

    • Concept predictor: \(X \rightarrow C\)
    • Label predictor: \(C \rightarrow Y\)
    • All information flows through \(C\)
  4. Training Options:

    • Independent: highest concept accuracy
    • Sequential: handles prediction noise
    • Joint: best task accuracy
  5. The Killer Feature: Humans can correct concept predictions at test time, improving accuracy through collaboration

  6. When to Use CBMs: High-stakes decisions where trust, transparency, and human oversight matter

References

  1. Koh, P. W., et al. (2020). Concept Bottleneck Models. ICML.

  2. Kim, B., et al. (2018). Interpretability Beyond Feature Attribution: TCAV. ICML.

  3. Zarlenga, M. E., et al. (2022). Concept Embedding Models. NeurIPS.