Geneformer Fine-Tuning for Cell Annotation Application¶
Note 1: this notebook requires a specific python version (<=3.10, 3.11 is not compatible) and torch version (torch 2 is not OK!). You can create a virtual env and install the packages by the following:
conda create -n py310 python=3.10 anaconda
conda activate py310
conda install pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=11.3 -c pytorch
Note 2: There is something wrong with conda when install scanpy with the given instruct: conda install -c bioconda scanpy, instead, we should use this: conda install -c bioconda -c conda-forge scanpy
Note 3: Do check your matplotlib version since if it is >3.7, it is not compatible with scanpy. Use pip install 'matplotlib == 3.6'
Note 4: Please also run the following code to install necessary dependencies: pip install transformers accelerate, pip install git+https://github.com/huggingface/accelerate, pip install transformers==4.28.0.
import os
GPU_NUMBER = [i for i in range(1)] # please!!!!!!!! don't change this !!!!!!!! seems that parallel is not supported in hpc3.
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(s) for s in GPU_NUMBER])
print(os.environ["CUDA_VISIBLE_DEVICES"])
os.environ["NCCL_DEBUG"] = "INFO"
os.environ['CUDA_LAUNCH_BLOCKING'] ='1'
os.environ["NCCL_P2P_DISABLE"] = '1'
0
If your python version is not acceptable, just run the following code and the device count will be 0.
import torch
print(torch.cuda.device_count())
print(torch.__version__)
1 2.0.1
# imports
from collections import Counter
import datetime
import pickle
import subprocess
import seaborn as sns; sns.set()
from datasets import load_from_disk
from sklearn.metrics import accuracy_score, f1_score
from transformers import BertForSequenceClassification
from transformers import Trainer
from transformers.training_args import TrainingArguments
from geneformer import DataCollatorForCellClassification
Prepare training and evaluation datasets¶
# load cell type dataset (includes all tissues)
train_dataset=load_from_disk("../output/pbmc.dataset/")
The following 1 code block is optional. If you want to finetune your model please run it, otherwise just skip it.
x = train_dataset.train_test_split(test_size=0.8, seed=42)
train_dataset = x["train"] # use 10% of data to train
test_dataset = x["test"]
test_dataset.save_to_disk("./test_dataset.80.dataset")
# print("OK")
Saving the dataset (0/1 shards): 0%| | 0/134407 [00:00<?, ? examples/s]
The following codeblock is the dataset preprocessing process. If you want to train on your own dataset, please finish the following 3 tasks:
please make sure that there is one column called 'organ_major' in your .loom data.
please make sure that you've already transformed the data to .dataset format with
tokenizing_scRNAseq_data.ipynb.please make sure that you've already changed all
major.typein the following codeblock toyour own preferred feature.
Then, just run all of the code below to get your own model! It will be located at ../models/.
dataset_list = []
evalset_list = []
organ_list = []
target_dict_list = []
for organ in Counter(train_dataset["organ_major"]).keys():
# collect list of tissues for fine-tuning (immune and bone marrow are included together)
if organ in ["bone_marrow"]:
continue
elif organ=="immune":
organ_ids = ["immune","bone_marrow"]
organ_list += ["immune"]
else:
organ_ids = [organ]
organ_list += [organ]
print(organ)
# filter datasets for given organ
def if_organ(example):
return example["organ_major"] in organ_ids
trainset_organ = train_dataset.filter(if_organ, num_proc=16)
# per scDeepsort published method, drop cell types representing <0.5% of cells
celltype_counter = Counter(trainset_organ["condition"])
total_cells = sum(celltype_counter.values())
cells_to_keep = [k for k,v in celltype_counter.items() if v>(0.005*total_cells)]
def if_not_rare_celltype(example):
return example["condition"] in cells_to_keep
trainset_organ_subset = trainset_organ.filter(if_not_rare_celltype, num_proc=16)
# shuffle datasets and rename columns
trainset_organ_shuffled = trainset_organ_subset.shuffle(seed=42)
trainset_organ_shuffled = trainset_organ_shuffled.rename_column("condition","label")
trainset_organ_shuffled = trainset_organ_shuffled.remove_columns("organ_major")
# create dictionary of cell types : label ids
target_names = list(Counter(trainset_organ_shuffled["label"]).keys())
target_name_id_dict = dict(zip(target_names,[i for i in range(len(target_names))]))
target_dict_list += [target_name_id_dict]
# change labels to numerical ids
def classes_to_ids(example):
example["label"] = target_name_id_dict[example["label"]]
return example
labeled_trainset = trainset_organ_shuffled.map(classes_to_ids, num_proc=16)
# create 80/20 train/eval splits
labeled_train_split = labeled_trainset.select([i for i in range(0,round(len(labeled_trainset)*0.8))])
labeled_eval_split = labeled_trainset.select([i for i in range(round(len(labeled_trainset)*0.8),len(labeled_trainset))])
# filter dataset for cell types in corresponding training set
trained_labels = list(Counter(labeled_train_split["label"]).keys())
def if_trained_label(example):
return example["label"] in trained_labels
labeled_eval_split_subset = labeled_eval_split.filter(if_trained_label, num_proc=16)
dataset_list += [labeled_train_split]
evalset_list += [labeled_eval_split_subset]
A
trainset_dict = dict(zip(organ_list,dataset_list))
traintargetdict_dict = dict(zip(organ_list,target_dict_list))
evalset_dict = dict(zip(organ_list,evalset_list))
Fine-Tune With Cell Classification Learning Objective and Quantify Predictive Performance¶
def compute_metrics(pred):
labels = pred.label_ids
preds = pred.predictions.argmax(-1)
# calculate accuracy and macro f1 using sklearn's function
acc = accuracy_score(labels, preds)
macro_f1 = f1_score(labels, preds, average='macro')
return {
'accuracy': acc,
'macro_f1': macro_f1
}
Please note that, as usual with deep learning models, we highly recommend tuning learning hyperparameters for all fine-tuning applications as this can significantly improve model performance. Example hyperparameters are defined below, but please see the "hyperparam_optimiz_for_disease_classifier" script for an example of how to tune hyperparameters for downstream applications.¶
# set model parameters
# max input size
max_input_size = 2 ** 11 # 2048
# set training hyperparameters
# max learning rate
max_lr = 5e-5
# how many pretrained layers to freeze
freeze_layers = 0
# number gpus
num_gpus = 1
# number cpu cores
num_proc = 16
# batch size for training and eval
geneformer_batch_size = 5 #for 3090, 5 is maximum.
# learning schedule
lr_schedule_fn = "linear"
# warmup steps
warmup_steps = 500
# number of epochs
epochs = 10
# optimizer
optimizer = "adamw"
# os.environ['CUDA_LAUNCH_BLOCKING'] ='1'
for organ in organ_list:
print(organ)
organ_trainset = trainset_dict[organ]
organ_evalset = evalset_dict[organ]
organ_label_dict = traintargetdict_dict[organ]
# set logging steps
logging_steps = round(len(organ_trainset)/geneformer_batch_size/10)
# reload pretrained model
model = BertForSequenceClassification.from_pretrained("../geneformer-12L-30M/",
# num_labels=len(organ_label_dict.keys()),
num_labels=2,
output_attentions = False,
output_hidden_states = False).to("cuda")
# define output directory path
current_date = datetime.datetime.now()
datestamp = f"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}"
output_dir = f"../models/{datestamp}_geneformer_CellClassifier_{organ}_L{max_input_size}_B{geneformer_batch_size}_LR{max_lr}_LS{lr_schedule_fn}_WU{warmup_steps}_E{epochs}_O{optimizer}_F{freeze_layers}/"
# ensure not overwriting previously saved model
saved_model_test = os.path.join(output_dir, f"pytorch_model.bin")
if os.path.isfile(saved_model_test) == True:
raise Exception("Model already saved to this directory.")
# make output directory
subprocess.call(f'mkdir {output_dir}', shell=True)
# set training arguments
training_args = {
"learning_rate": max_lr,
"do_train": True,
"do_eval": True,
"evaluation_strategy": "epoch",
"save_strategy": "epoch",
"logging_steps": logging_steps,
"group_by_length": True,
"length_column_name": "length",
"disable_tqdm": False,
"lr_scheduler_type": lr_schedule_fn,
"warmup_steps": warmup_steps,
"weight_decay": 0.001,
"per_device_train_batch_size": geneformer_batch_size,
"per_device_eval_batch_size": geneformer_batch_size,
"num_train_epochs": epochs,
"load_best_model_at_end": True,
"output_dir": output_dir,
}
training_args_init = TrainingArguments(**training_args)
# create the trainer
trainer = Trainer(
model=model,
args=training_args_init,
data_collator=DataCollatorForCellClassification(),
train_dataset=organ_trainset,
eval_dataset=organ_evalset,
compute_metrics=compute_metrics
)
# train the cell type classifier
trainer.train()
predictions = trainer.predict(organ_evalset)
with open(f"{output_dir}predictions.pickle", "wb") as fp:
pickle.dump(predictions, fp)
trainer.save_metrics("eval",predictions.metrics)
trainer.save_model(output_dir)
A
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at ../geneformer-12L-30M/ and are newly initialized: ['bert.pooler.dense.weight', 'classifier.weight', 'classifier.bias', 'bert.pooler.dense.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
mkdir: cannot create directory ‘../models/230907_geneformer_CellClassifier_A_L2048_B5_LR5e-05_LSlinear_WU500_E10_Oadamw_F0/’: File exists
/mnt/disk/mwang/conda3/lib/python3.11/site-packages/geneformer/collator_for_classification.py:581: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
| Epoch | Training Loss | Validation Loss | Accuracy | Macro F1 |
|---|---|---|---|---|
| 1 | 0.478000 | 0.446698 | 0.802143 | 0.797032 |
| 2 | 0.360400 | 0.520274 | 0.831269 | 0.830752 |
| 3 | 0.335900 | 0.584502 | 0.841647 | 0.841416 |
| 4 | 0.240000 | 0.727501 | 0.847673 | 0.847564 |
| 5 | 0.128000 | 0.736757 | 0.873786 | 0.872867 |
| 6 | 0.110300 | 0.900706 | 0.873452 | 0.871618 |
| 7 | 0.057700 | 1.003406 | 0.876130 | 0.874294 |
| 8 | 0.012700 | 1.131091 | 0.878473 | 0.877133 |
| 9 | 0.011200 | 1.234100 | 0.876465 | 0.874432 |
| 10 | 0.008600 | 1.168336 | 0.878473 | 0.877268 |
/mnt/disk/mwang/conda3/lib/python3.11/site-packages/geneformer/collator_for_classification.py:581: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
/mnt/disk/mwang/conda3/lib/python3.11/site-packages/geneformer/collator_for_classification.py:581: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
/mnt/disk/mwang/conda3/lib/python3.11/site-packages/geneformer/collator_for_classification.py:581: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
/mnt/disk/mwang/conda3/lib/python3.11/site-packages/geneformer/collator_for_classification.py:581: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
/mnt/disk/mwang/conda3/lib/python3.11/site-packages/geneformer/collator_for_classification.py:581: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
/mnt/disk/mwang/conda3/lib/python3.11/site-packages/geneformer/collator_for_classification.py:581: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
/mnt/disk/mwang/conda3/lib/python3.11/site-packages/geneformer/collator_for_classification.py:581: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
/mnt/disk/mwang/conda3/lib/python3.11/site-packages/geneformer/collator_for_classification.py:581: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
/mnt/disk/mwang/conda3/lib/python3.11/site-packages/geneformer/collator_for_classification.py:581: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
/mnt/disk/mwang/conda3/lib/python3.11/site-packages/geneformer/collator_for_classification.py:581: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
Useless¶
# GPU_NUMBER = [i for i in range(2)] # please!!!!!!!! don't change this !!!!!!!! seems that parallel is not supported in hpc3.
# os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(s) for s in GPU_NUMBER])
# epochs=4
# for organ in organ_list:
# print(organ)
# organ_trainset = trainset_dict[organ]
# organ_evalset = evalset_dict[organ]
# organ_label_dict = traintargetdict_dict[organ]
# # set logging steps
# logging_steps = round(len(organ_trainset)/geneformer_batch_size/10)
# # reload pretrained model
# model = BertForSequenceClassification.from_pretrained("../geneformer-12L-30M/",
# # num_labels=len(organ_label_dict.keys()),
# num_labels=2,
# output_attentions = False,
# output_hidden_states = False).to("cuda")
# # define output directory path
# current_date = datetime.datetime.now()
# datestamp = f"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}"
# output_dir = f"../models/{datestamp}_geneformer_CellClassifier_{organ}_L{max_input_size}_B{geneformer_batch_size}_LR{max_lr}_LS{lr_schedule_fn}_WU{warmup_steps}_E{epochs}_O{optimizer}_F{freeze_layers}/"
# # ensure not overwriting previously saved model
# saved_model_test = os.path.join(output_dir, f"pytorch_model.bin")
# if os.path.isfile(saved_model_test) == True:
# raise Exception("Model already saved to this directory.")
# # make output directory
# subprocess.call(f'mkdir {output_dir}', shell=True)
# # set training arguments
# training_args = {
# "learning_rate": max_lr,
# "do_train": True,
# "do_eval": True,
# "evaluation_strategy": "epoch",
# "save_strategy": "epoch",
# "logging_steps": logging_steps,
# "group_by_length": True,
# "length_column_name": "length",
# "disable_tqdm": False,
# "lr_scheduler_type": lr_schedule_fn,
# "warmup_steps": warmup_steps,
# "weight_decay": 0.001,
# "per_device_train_batch_size": geneformer_batch_size,
# "per_device_eval_batch_size": geneformer_batch_size,
# "num_train_epochs": epochs,
# "load_best_model_at_end": True,
# "output_dir": output_dir,
# }
# training_args_init = TrainingArguments(**training_args)
# # create the trainer
# trainer = Trainer(
# model=model,
# args=training_args_init,
# data_collator=DataCollatorForCellClassification(),
# train_dataset=organ_trainset,
# eval_dataset=organ_evalset,
# compute_metrics=compute_metrics
# )
# # train the cell type classifier
# trainer.train()
# predictions = trainer.predict(organ_evalset)
# with open(f"{output_dir}predictions.pickle", "wb") as fp:
# pickle.dump(predictions, fp)
# trainer.save_metrics("eval",predictions.metrics)
# trainer.save_model(output_dir)




