This repository contains the code and description for AlterNet-LC, a deep learning model designed for Pneumonia detection using chest X-ray images. The primary codebase is provided within a Jupyter Notebook (Code & Description.ipynb).
The notebook includes:
- Complete training code for AlterNet-LC, encompassing data preprocessing, model definition, training, validation, and testing procedures.
Authors: Li Jiawei, Chen Mingfang, Yao Zehan
The code in this notebook is for reference only. It has not undergone strict data leakage prevention or logic optimization and should not be directly used in production environments. Results may vary due to device differences. Default configurations are provided for reference; adjust and debug according to your actual setup.
- Primary Training/Validation/Testing Dataset: PneumoniaMNIST (part of MedMNIST).
- The notebook expects this dataset to be available as
pneumoniamnist_224.npz.
- The notebook expects this dataset to be available as
- Generalization Testing Dataset: Chest X-Ray Images (Pneumonia) from Kaggle (
paultimothymooney/chest-xray-pneumonia).
- Clone the repository (or download the notebook).
- Install dependencies:
- MedMNIST: For accessing the PneumoniaMNIST dataset.
pip install medmnist
- Kaggle Hub: For downloading the Kaggle dataset.
pip install kagglehub
- Core Libraries: The notebook uses
numpy,torch,torchvision,scikit-learn,matplotlib,seaborn, andtqdm. Install them as needed (e.g., viapiporconda).pip install numpy torch torchvision scikit-learn matplotlib seaborn tqdm
- MedMNIST: For accessing the PneumoniaMNIST dataset.
- Download Datasets:
-
PneumoniaMNIST: The notebook can download this dataset if it's not found locally, using the
medmnistlibrary. The downloaded data will be processed intopneumoniamnist_224.npzor used directly.# Example snippet from the notebook for downloading import medmnist from medmnist import INFO import torchvision.transforms as transforms data_transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[.5], std=[.5]) ]) data_flag = 'pneumoniamnist' info = INFO[data_flag] DataClass = getattr(medmnist, info['python_class']) # This will download the data if not present train_dataset = DataClass(split='train', transform=data_transform, download=True, size=224, mmap_mode='r')
Note: Downloading MedMNIST datasets may require a VPN connection in some regions. The notebook is configured to load data from
pneumoniamnist_224.npz. -
Kaggle Chest X-Ray Dataset (for generalization testing): The notebook provides instructions to download this using
kagglehub.import kagglehub # May require login for the first time # kagglehub.login() path = kagglehub.dataset_download("paultimothymooney/chest-xray-pneumonia") print("Path to dataset files:", path)
-
AlterNet-LC is a hybrid model integrating convolutional neural networks (CNNs) with self-attention mechanisms. Key components include:
- Stem: Initial convolutional layers (
nn.Conv2d,nn.BatchNorm2d,nn.ReLU,nn.MaxPool2d) for low-level feature extraction. - Pre-activated Residual Blocks (
PreActResidualBlock): Used for building deeper CNN stages, featuring batch normalization and ReLU activation before convolution. - WindowAttention: A utility to partition feature maps into non-overlapping windows for localized self-attention, and to reverse this process.
- Multiple Self-attention Blocks (
MSABlock): These blocks perform self-attention within windows. They include:- Layer normalization (
nn.LayerNorm). - Multi-head self-attention (QKV computation, scaled dot-product attention).
- MLP layers with GELU activation and dropout.
- Residual connections.
- Layer normalization (
- Contrastive Self-attention Blocks (
ContrastiveMSABlock): An extension ofMSABlockthat incorporates a contrastive learning head.- A projection head maps features to a space for contrastive learning.
- Calculates a contrastive loss based on feature similarity and class labels during training.
- Hybrid Stages: The model architecture consists of an initial pure CNN stage, followed by stages that alternate
PreActResidualBlocks withContrastiveMSABlocks. This allows the model to learn both local and global features, enhanced by contrastive learning. - Classification Head: Global average pooling (
nn.AdaptiveAvgPool2d), dropout, and a final fully connected layer (nn.Linear) for classification.
The model is trained using a custom loss function named LMFLoss. This loss combines:
- Focal Loss: To address class imbalance by down-weighting the loss assigned to well-classified examples.
- Margin Loss: To enforce a margin between class probabilities, enhancing separability.
The LMFLoss also incorporates class weights and specific parameters (
gamma,margin,alpha) to fine-tune its behavior. An additional weighting is applied for misclassified negative samples.
The Jupyter Notebook is structured to provide a comprehensive workflow:
- Data Loading and Preprocessing (
load_and_preprocess_data):- Loads data from the
pneumoniamnist_224.npzfile. - Ensures images are 3-channel (RGB).
- Applies data augmentation to the training set (RandomHorizontalFlip, RandomRotation, ColorJitter).
- Normalizes images and converts them to PyTorch tensors.
- Creates
DataLoaderinstances for training, validation, and testing.
- Loads data from the
- Custom Dataset Class (
MedicalDataset):- A
torch.utils.data.Datasetsubclass to handle medical images and their labels, applying transformations.
- A
- Model Training (
train_model):- Implements the training and validation loop.
- Uses the AdamW optimizer and ReduceLROnPlateau learning rate scheduler.
- Calculates the combined loss:
LMF Loss + contrastive_weight * contrastive_loss. - Tracks and logs training/validation metrics (loss, accuracy, AUC).
- Implements early stopping based on validation loss.
- Saves the best performing model state during training.
- Model Evaluation (
evaluate_model):- Assesses the trained model on the test set.
- Calculates a comprehensive set of metrics: accuracy, precision, recall, specificity, F1-score, Negative Predictive Value (NPV), ROC AUC, and PR AUC (Average Precision).
- Generates and displays a confusion matrix and classification report.
- Saves detailed evaluation results to a
.txtfile in theevaluation_resultsdirectory.
- Visualization:
plot_training: Plots training and validation accuracy and loss curves over epochs.- The notebook also includes code to plot ROC and PR curves for the test set results.
- Ensure all prerequisites (dependencies, datasets) are met as described in the "Setup and Installation" section.
- The primary dataset file
pneumoniamnist_224.npzshould be in the location specified bydata_pathin the notebook (defaults to the same directory). - Open and run the cells in the
Code & Description.ipynbJupyter Notebook. - Key sections to execute:
- Data Loading: The cell calling
load_and_preprocess_data. - Model Initialization: The cell defining and instantiating
AlterNet_LC. - Training: The cell calling
train_model. This will train the model and save the best version (e.g.,alternet_contrastive_YYYYMMDD.pth). - Plotting Training History: The cell calling
plot_training. - Evaluation: The cell calling
evaluate_model. This will print metrics and save results to a text file. - Plotting Evaluation Curves: Cells for plotting ROC and PR curves.
- Saving Final Model: The final cell saves the model state (e.g.,
alternet_lc_YYYYMMDD.pth).
- Data Loading: The cell calling
- Trained Model Files:
alternet_contrastive_YYYYMMDD.pth(best model saved during training based on validation loss).alternet_lc_YYYYMMDD.pth(final model saved at the end of the notebook).
- Evaluation Results Directory (
evaluation_results/):- A text file (e.g.,
evaluation_results_YYYYMMDD_HHMMSS.txt) containing detailed performance metrics, confusion matrix, and classification report. - An
images/subdirectory is also created, though the notebook doesn't explicitly save plots there viaplt.savefig(), it is prepared for such use.
- A text file (e.g.,
- Plots:
- Training/Validation accuracy and loss curves.
- ROC curve with AUC.
- Precision-Recall (PR) curve with Average Precision (AP). (These are displayed within the notebook).
.
├── Code & Description.ipynb # Main Jupyter Notebook with all code and explanations