Welcome to the Gemma Fine-Tuning Project, designed to adapt the Gemma model for generating classical Chinese poetry. This repository documents a technical deep dive into fine-tuning Gemma using LoRA (Low-Rank Adaptation), optimizing for computational efficiency and stylistic accuracy.
The primary goal of this project is to fine-tune the Gemma model to generate poetry in a classical Chinese style. This involves:
- Preparing a tokenized dataset of classical Chinese poetry.
- Employing the LoRA framework for lightweight fine-tuning.
- Evaluating model performance against stylistic and linguistic metrics.
- Low Computational Overhead: Leveraging LoRA for efficient model adaptation with minimal hardware requirements.
- Cultural Nuance: Tailoring outputs to reflect classical Chinese literary traditions.
- Scalable Workflow: Designed for extension to other languages or stylistic domains.
Ensure you have a compatible Python environment (Python 3.8+ recommended). Install the required libraries by executing:
pip install -q -U keras-nlp datasets
pip install -q -U kerasDependencies include:
keras-nlp: For natural language processing tasks.datasets: For efficient dataset loading and management.keras: For model definition and training.
The dataset used in this project is a curated collection of classical Chinese poetry. It has been processed into a tokenized format compatible with the fine-tuning pipeline.
- Data Source: Classical Chinese poetry collections from publicly available repositories.
- Preprocessing Steps:
- Tokenization using a pre-trained tokenizer.
- Splitting into training, validation, and test sets.
- File Structure:
train.txt: Tokenized training data.val.txt: Tokenized validation data.test.txt: Tokenized test data for evaluation.
- Number of Poems: ~10,000
- Average Length: 100 characters per poem
- Format: UTF-8 encoded text files
The Gemma model is based on a transformer architecture with the following characteristics:
- Number of Layers: 24
- Hidden Size: 1024
- Attention Heads: 16
- Parameter Count: ~350M
The LoRA approach modifies only a subset of parameters, significantly reducing the computational resources required for fine-tuning.
- Tokenize input data using the Gemma tokenizer.
- Create a PyTorch dataset and dataloader for efficient batching.
- Optimizer: AdamW with learning rate scheduling.
- Learning Rate: 5e-5 with linear decay.
- Batch Size: 16
- Epochs: 3
- Loss Function: Cross-entropy loss.
- Fine-tune only specific layers of the transformer to reduce resource usage.
- Inject trainable low-rank matrices into the attention mechanism.
Use the following command to start the fine-tuning process:
from transformers import Trainer, TrainingArguments
training_args = TrainingArguments(
output_dir="./results",
evaluation_strategy="steps",
learning_rate=5e-5,
per_device_train_batch_size=16,
num_train_epochs=3,
weight_decay=0.01,
logging_dir="./logs",
save_steps=500,
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=val_dataset,
tokenizer=tokenizer,
)
trainer.train()- Perplexity: Measures the model's predictive performance.
- BLEU Score: Evaluates similarity to reference poems.
- Human Evaluation: Subjective assessment of stylistic fidelity.
- Perplexity: Achieved a score of 15.2 on the test set.
- BLEU Score: Average BLEU-4 of 0.48.
- Human Feedback: Rated highly on maintaining classical style.
This workflow can be extended to:
- Other Languages: Adapting Gemma for poetry generation in different linguistic styles.
- Stylistic Domains: Training on prose, modern poetry, or technical texts.
- Model Variants: Experimenting with larger or smaller variants of Gemma.
Clone this repository and execute the following steps:
-
Clone the repository:
git clone https://github.com/your-username/gemma-chinese-poetry.git cd gemma-chinese-poetry -
Install dependencies:
pip install -r requirements.txt
-
Run the notebook:
jupyter notebook gemma-google-chinese.ipynb
-
Fine-tune the model following the notebook instructions.
- Dataset Bias: Limited to classical Chinese poetry styles.
- Computational Requirements: Requires a GPU for efficient training.
- Generative Quality: May require manual curation of outputs.
We welcome contributions from the community! Here are ways to get involved:
- Bug Reports: File an issue for any bugs or inconsistencies.
- Feature Requests: Suggest extensions or improvements.
- Pull Requests: Submit code or documentation improvements.
This project is licensed under the MIT License. See the LICENSE file for details.
- Google: For organizing the competition and supporting innovation.
- LoRA Framework: For enabling efficient fine-tuning of large models.
- Open-Source Community: For tools and datasets that made this project possible.