Skip to content

Implementation of the following parameter-efficient fine-tuning methods on GPT-2 for summarization: Soft prompt tuning by optimizing prefix embeddings, LoRA, and fine-tuning only the last classifier layer while keeping the rest of the model frozen

Notifications You must be signed in to change notification settings

GnanaPrakashSG2004/Parameter_Efficient_Fine_Tuning

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

11 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

File Structure:

.
├── ANLP_Assignment_3.pdf
├── README.md
├── Report.pdf
└── src
    ├── dataset
    │   ├── dataset.py
    │   └── __init__.py
    ├── __init__.py
    └── models
        ├── classifier_ft
        │   ├── __init__.py
        │   ├── model.py
        │   ├── test.py
        │   └── train.py
        ├── __init__.py
        ├── lora
        │   ├── __init__.py
        │   ├── model.py
        │   ├── test.py
        │   └── train.py
        └── soft_prompts
            ├── __init__.py
            ├── model.py
            ├── test.py
            └── train.py

Instructions to run the code:

  • All scripts can be run in the current (root) directory itself.

Soft Prompts:

Fine-tuning the model:

  • Run the following command to fine-tune the model:
python -m src.models.soft_prompts.train
  • This will fine-tune the model and save the model in the weights directory with the file name soft_prompts.pt.
  • To change the hyperparameters, you can pass them as arguments to the above command. For example:
python -m src.models.soft_prompts.train --epochs 10 --batch_size 4
  • To see all the hyperparameters, you can run:
python -m src.models.soft_prompts.train --help

Restoring the model:

  • The following python code snippet can be used to restore the model:
from src.models.soft_prompts.model import GPTSoftPromptTuning

model = GPTSoftPromptTuning(hard_prompt, device=device, cache_dir=cache_dir)
model.load_state_dict(torch.load(model_path, map_location=device)["model_state_dict"])

Testing the model:

  • Run the following command to test the model:
python -m src.models.soft_prompts.test

Model checkpoint link:

  • The model checkpoint can be downloaded from here

LoRA:

Fine-tuning the model:

  • Run the following command to fine-tune the model:
python -m src.models.lora.train
  • This will fine-tune the model and save the model in the weights directory with the file name lora.pt.
  • This script also saves only the lora adapter weights in the lora_weights subdirectory.
  • To change the hyperparameters, you can pass them as arguments to the above command. For example:
python -m src.models.lora.train --epochs 10 --batch_size 4
  • To see all the hyperparameters, you can run:
python -m src.models.lora.train --help

Restoring the model:

  • The following python code snippet can be used to restore the model:
from src.models.lora.model import GPTLoRA

model = GPTLoRA(lora_config=lora_config, device=device, cache_dir=cache_dir)
model.load_lora_weights(model_path)

Testing the model:

  • Run the following command to test the model:
python -m src.models.lora.test

Model checkpoint link:

  • The model with the loaded lora adapter weights can be downloaded from here
  • To download only the lora adapter weights, click here

Classifier Fine-tuning:

Fine-tuning the model:

  • Run the following command to fine-tune the model:
python -m src.models.classifier_ft.train
  • This will fine-tune the model and save the model in the weights directory with the file name classifier_ft.pt.
  • To change the hyperparameters, you can pass them as arguments to the above command. For example:
python -m src.models.classifier_ft.train --epochs 10 --batch_size 4
  • To see all the hyperparameters, you can run:
python -m src.models.classifier_ft.train --help

Restoring the model:

  • The following python code snippet can be used to restore the model:
from src.models.classifier_ft.model import GPTFineTune

model = GPTFineTune(device=device, cache_dir=cache_dir)
model.load_state_dict(torch.load(model_path, map_location=device)["model_state_dict"])

Testing the model:

  • Run the following command to test the model:
python -m src.models.classifier_ft.test

Model checkpoint link:

  • The model checkpoint can be downloaded from here

All checkpoints:

  • All the weights can be downloaded from here

About

Implementation of the following parameter-efficient fine-tuning methods on GPT-2 for summarization: Soft prompt tuning by optimizing prefix embeddings, LoRA, and fine-tuning only the last classifier layer while keeping the rest of the model frozen

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages