A Vision Transformer (ViT) implementation for pneumonia detection in chest X-ray images with attention visualization capabilities.
Model: ViT-Small trained on 5,856 chest X-ray images
- Accuracy: 90.38%
- F1 Score: 92.02%
- AUC-ROC: 96.78%
- Sensitivity: 88.72%
- Specificity: 93.16%
- Total Test Samples: 624
- Correct Predictions: 572
- Incorrect Predictions: 52
- Error Rate: 0.0833
- False Positives: 26
- False Negatives: 26
- Average Inference Time: 0.0549 seconds per image
- Optimization Metric: f1
- Best Threshold: 0.150
- Best Score: 0.9333
ROC Curve
|
Confusion Matrix
|
Precision-Recall Curve
|
Threshold Analysis
|
- Vision Transformer Models: ViT-Small, ViT-Base, and ViT-Large configurations
- Attention Visualization: Multi-head attention maps and attention rollout
- Medical AI Optimizations: Class balancing and medical-specific augmentations
- Real-time Inference: Optimized for clinical deployment
python -m venv vit_env
vit_env\Scripts\activate # Windows
pip install -r requirements.txtThis project uses the Chest X-Ray Images (Pneumonia) dataset for binary classification:
- Dataset Source: Kaggle - Chest X-Ray Images (Pneumonia)
- Total Images: 5,856 chest X-ray images
- Classes: NORMAL and PNEUMONIA
- Image Format: JPEG files
- Original Split: Pre-divided into train/validation/test sets
Option 1: Download from Kaggle
- Visit: https://www.kaggle.com/datasets/paultimothymooney/chest-xray-pneumonia
- Click "Download" (requires Kaggle account)
- Extract the zip file to your project directory
Option 2: Using Kaggle API
# Install Kaggle API
pip install kaggle
# Download dataset (requires kaggle.json in ~/.kaggle/)
kaggle datasets download -d paultimothymooney/chest-xray-pneumonia
# Extract the dataset
unzip chest-xray-pneumonia.zipAfter downloading, organize your dataset as follows:
data/chest_xray_pneumonia/
├── chest_xray/
│ ├── train/
│ │ ├── NORMAL/ # 1,341 normal X-ray images
│ │ └── PNEUMONIA/ # 3,875 pneumonia X-ray images
│ ├── val/
│ │ ├── NORMAL/ # 8 normal X-ray images
│ │ └── PNEUMONIA/ # 8 pneumonia X-ray images
│ └── test/
│ ├── NORMAL/ # 234 normal X-ray images
│ └── PNEUMONIA/ # 390 pneumonia X-ray images
Important Notes:
- The validation set is very small (16 images total) - the model may use train/test split internally
- Images are in JPEG format with varying resolutions
- The dataset is imbalanced (more pneumonia cases than normal)
- All images are pediatric chest X-rays (ages 1-5 years)
# Quick training (2 epochs)
python main.py --epochs 2 --batch_size 16
# With attention visualization
python main.py --epochs 2 --generate_attention
# Different model sizes
python main.py --model_size small --epochs 2 --batch_size 32# Evaluate existing model
python main.py --eval_only --checkpoint checkpoints/best_checkpoint.pth
# Generate attention maps
python main.py --eval_only --checkpoint checkpoints/best_checkpoint.pth --generate_attention| Model | Parameters | Batch Size | Training Time |
|---|---|---|---|
| ViT-Small | 22M | 32 | Fast |
| ViT-Base | 86M | 16 | Medium |
| ViT-Large | 307M | 8 | Slow |
pneumonia_vit/
├── src/ # Source code
│ ├── model.py # ViT model implementation
│ ├── trainer.py # Training logic
│ ├── evaluator.py # Evaluation and metrics
│ └── attention_visualization.py
├── configs/ # Model configurations
├── main.py # Main script
└── requirements.txt # Dependencies
training:
learning_rate: 3e-4
batch_size: 16
num_epochs: 50 # Use 2 for quick testing
warmup_epochs: 5
weight_decay: 0.3The model provides interpretable attention maps showing:
- Attention Rollout: How the model focuses on lung regions
- Multi-Head Attention: Different attention patterns across heads
- Layer-wise Evolution: Attention development through the network
Attention Rollout![]() Model focus on pneumonia-affected regions |
Multi-Head Attention![]() Different attention patterns across heads |
![]() Additional attention rollout example |
![]() Diverse attention head patterns |
- Explainable AI: Attention maps for radiologist confidence
- Real-time Processing: Fast inference for clinical workflows
- Global Context: Captures relationships between distant anatomical regions
- Robust Performance: Less sensitive to local artifacts
# Recommended training test
python main.py --model_size base --epochs 10 --batch_size 16
# Fast training test
python main.py --model_size small --epochs 2 --batch_size 32
# Full training with visualization
python main.py --epochs 10 --generate_attention
# Custom configuration
python main.py --config configs/vit_base_config.yaml
# Evaluation only
python main.py --eval_only --checkpoint path/to/checkpoint.pthCUDA Out of Memory: Reduce batch size or use ViT-Small
python main.py --model_size small --batch_size 8Missing Dependencies: Install required packages
pip install timm>=0.9.0 matplotlib seabornThis implementation is part of a Master of Computer Science/Cybersecurity capstone project, demonstrating advanced machine learning techniques for medical AI applications.
Note: This system is for research and educational purposes only. Not intended for actual clinical diagnosis without proper validation and medical oversight.







