This repository contains a hybrid deep learning model that combines a pre-trained VGG19 convolutional backbone with Kolmogorov-Arnold Network (KAN) linear layers for flexible and powerful representation learning on image classification tasks.
This implementation is based on the approach proposed in the paper: “Refining Crop Pest Recognition Performance through Dynamically Adaptable Activation Patterns of Kolmogorov Arnold Networks” (IEEE Link).
The model leverages:
- VGG19: A pre-trained CNN for extracting high-level image features.
- KANLinear layers: Novel fully connected layers inspired by the Kolmogorov-Arnold representation, providing improved non-linear mapping and expressivity.
- Data augmentation: Techniques like random flips, rotations, color jitter, and random erasing for robust training.
The architecture is designed to replace traditional fully connected layers in standard CNNs with KANLinear layers for enhanced performance.
Python 3.x and the following packages:
torch
torchvision
scikit-learn
matplotlib
numpyInstall via pip:
pip install torch torchvision scikit-learn matplotlib numpyVGG19KAN:
-
Feature Extractor: Pre-trained VGG19 convolutional layers
-
Adaptive Average Pooling: Reduce feature maps to fixed size
(7x7) -
KANLinear Layers:
kan1: 25088 → 512kan2: 512 → 1024kan3: 1024 → Output
The model is trained with CrossEntropyLoss and optimized using AdamW.
The training loop includes:
- Training with backpropagation
- Validation phase for monitoring overfitting
- Test evaluation with accuracy, precision, recall, and F1-score
Metrics are tracked per epoch and can be visualized using matplotlib.
train_accuracies, val_accuracies, test_accuracies, train_losses, val_losses, test_losses, test_precisions, test_recalls, test_f1_scores = run(
model, criterion, optimizer, train_loader, val_loader, test_loader
)- Prepare your dataset with training, validation, and test splits in folders.
- Update dataset paths in the code:
train_dataset = datasets.ImageFolder(root='PATH_TO_TRAIN', transform=train_transform)
val_dataset = datasets.ImageFolder(root='PATH_TO_VAL', transform=val_transform)
test_dataset = datasets.ImageFolder(root='PATH_TO_TEST', transform=test_transform)- Run training:
python VGG19KAN.ipynb- Save the trained model:
torch.save(model.state_dict(), "vgg19_kan.pth")- VGG19 PyTorch Implementation
- Efficient-KAN: Kolmogorov-Arnold Network
- KAN: Kolmogorov-Arnold Networks
- Refining Crop Pest Recognition Performance through Dynamically Adaptable Activation Patterns of Kolmogorov Arnold Networks
- GPU is recommended for training.
- KANLinear layers replace standard fully connected layers for better representation learning.
- Data augmentation improves model robustness and generalization.