TurbaNet is a lightweight and user-friendly API wrapper for the JAX library, designed to simplify and accelerate the setup of swarm-based training, evaluation, and simulation of small neural networks. Based on the work presented by Will Whitney in his blog post from 2021.1
- Simplified API: Provides an intuitive interface for configuring and managing swarm-based neural network tasks.
- Efficiency: Leverages JAX's capabilities to offer accelerated computation for training and evaluation processes.
- Flexibility: Supports various configurations, allowing users to tailor the swarm behavior to specific needs.
To install TurbaNet, ensure that you have Python and pip installed. Then, run:
pip install turbanet
TurbaNet train states require models and optimizers from Flax and Optax which can be installed with:
pip install flax optax
TurbaNet is built on top of JAX and fully supports GPU acceleration out of the box. To take advantage of GPU-based training, ensure your environment is configured correctly with the appropriate CUDA and cuDNN versions supported by JAX.
Refer to the official JAX installation guide for detailed instructions and compatibility tables, including supported CUDA versions and platform-specific recommendations.
If JAX detects a supported GPU and the correct environment configuration, TurbaNet will automatically execute on the GPU without any additional setup.
You can verify whether GPU support is enabled by running:
from jax.extend.backend import get_backend
print(get_backend().platform) # Should return 'gpu' if GPU is activeHere's a basic example demonstrating how to initialize and use TurbaNet:
import matplotlib.pyplot as plt
import numpy as np
import optax
from flax import linen as nn
from turbanet import TurbaTrainState, mse
# Set numpy random seed for reproducable results
np.random.seed(0)
# Sample input
X_data = np.random.randint(0, 2, (10, 10)).astype(float)
y_data = np.random.randint(0, 2, (10, 1)).astype(float)
# Define model for the swarm
class MyModel(nn.Module):
hidden_dim: int = 32
@nn.compact
def __call__(self, x):
x = nn.Dense(self.hidden_dim)(x)
x = nn.relu(x)
x = nn.Dense(1)(x)
x = nn.sigmoid(x)
return x
# Define optimizer
optimizer = optax.adam(learning_rate=0.01)
# Define the size of the swarm
swarm_size = 10
# Initialize the swarm with desired parameters
swarm = TurbaTrainState.swarm(MyModel(), optimizer, swarm_size, X_data[0].reshape(1, -1))
# Train the swarm on your dataset
epochs = 100
losses = np.zeros((epochs, swarm_size))
for epoch in range(epochs):
X = np.expand_dims(X_data, 0).repeat(len(swarm), axis=0)
y = np.expand_dims(y_data, 0).repeat(len(swarm), axis=0)
swarm, losses[epoch], predictions = swarm.train(X, y, mse)
# Plot the loss curves from training
plt.plot(losses)
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.show()For more detailed tutorials and examples, please refer to the documentation.
We welcome contributions to TurbaNet! If you'd like to contribute, please follow these steps:
Fork the repository: Click the "Fork" button at the top right of the GitHub page.
Clone your fork:
`git clone https://github.com/your-username/TurbaNet.git`
- Create a new branch:
git checkout -b feature/your-feature-name
- Make your changes: Implement your feature or fix the identified issue. 5. Commit your changes:
git commit -m "Description of your changes"
- Push to your fork:
git push origin feature/your-feature-name
- Submit a Pull Request: Navigate to the original repository and click on "New Pull Request" to submit your changes for review.
This project is licensed under the MIT License. See the LICENSE file for more details.
Footnotes
-
Whitney, W. (2021). Parallelizing neural networks on one GPU with JAX. Will Whitney's Blog. https://willwhitney.com/parallel-training-jax.html ↩