Geneformer Fine-Tuning for Cell Annotation Application¶
Note 1: this notebook requires a specific python version (<=3.10, 3.11 is not compatible) and torch version (torch 2 is not OK!). You can create a virtual env and install the packages by the following:
conda create -n py310 python=3.10 anaconda
conda activate py310
conda install pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=11.3 -c pytorch
Note 2: There is something wrong with conda when install scanpy
with the given instruct: conda install -c bioconda scanpy
, instead, we should use this: conda install -c bioconda -c conda-forge scanpy
Note 3: Do check your matplotlib version since if it is >3.7, it is not compatible with scanpy. Use pip install 'matplotlib == 3.6'
Note 4: Please also run the following code to install necessary dependencies: pip install transformers accelerate
, pip install git+https://github.com/huggingface/accelerate
, pip install transformers==4.28.0
.
import os
GPU_NUMBER = [i for i in range(1)] # please!!!!!!!! don't change this !!!!!!!! seems that parallel is not supported in hpc3.
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(s) for s in GPU_NUMBER])
print(os.environ["CUDA_VISIBLE_DEVICES"])
os.environ["NCCL_DEBUG"] = "INFO"
os.environ['CUDA_LAUNCH_BLOCKING'] ='1'
os.environ["NCCL_P2P_DISABLE"] = '1'
0
If your python version is not acceptable, just run the following code and the device count will be 0.
import torch
print(torch.cuda.device_count())
print(torch.__version__)
1 2.0.1
# imports
from collections import Counter
import datetime
import pickle
import subprocess
import seaborn as sns; sns.set()
from datasets import load_from_disk
from sklearn.metrics import accuracy_score, f1_score
from transformers import BertForSequenceClassification
from transformers import Trainer
from transformers.training_args import TrainingArguments
from geneformer import DataCollatorForCellClassification
Prepare training and evaluation datasets¶
# load cell type dataset (includes all tissues)
train_dataset=load_from_disk("../output/pbmc.dataset/")
The following 1 code block is optional. If you want to finetune your model please run it, otherwise just skip it.
x = train_dataset.train_test_split(test_size=0.8, seed=42)
train_dataset = x["train"] # use 10% of data to train
test_dataset = x["test"]
test_dataset.save_to_disk("./test_dataset.80.dataset")
# print("OK")
Saving the dataset (0/1 shards): 0%| | 0/134407 [00:00<?, ? examples/s]
The following codeblock is the dataset preprocessing process. If you want to train on your own dataset, please finish the following 3 tasks:
please make sure that there is one column called 'organ_major' in your .loom data.
please make sure that you've already transformed the data to .dataset format with
tokenizing_scRNAseq_data.ipynb
.please make sure that you've already changed all
major.type
in the following codeblock toyour own preferred feature
.
Then, just run all of the code below to get your own model! It will be located at ../models/
.
dataset_list = []
evalset_list = []
organ_list = []
target_dict_list = []
for organ in Counter(train_dataset["organ_major"]).keys():
# collect list of tissues for fine-tuning (immune and bone marrow are included together)
if organ in ["bone_marrow"]:
continue
elif organ=="immune":
organ_ids = ["immune","bone_marrow"]
organ_list += ["immune"]
else:
organ_ids = [organ]
organ_list += [organ]
print(organ)
# filter datasets for given organ
def if_organ(example):
return example["organ_major"] in organ_ids
trainset_organ = train_dataset.filter(if_organ, num_proc=16)
# per scDeepsort published method, drop cell types representing <0.5% of cells
celltype_counter = Counter(trainset_organ["condition"])
total_cells = sum(celltype_counter.values())
cells_to_keep = [k for k,v in celltype_counter.items() if v>(0.005*total_cells)]
def if_not_rare_celltype(example):
return example["condition"] in cells_to_keep
trainset_organ_subset = trainset_organ.filter(if_not_rare_celltype, num_proc=16)
# shuffle datasets and rename columns
trainset_organ_shuffled = trainset_organ_subset.shuffle(seed=42)
trainset_organ_shuffled = trainset_organ_shuffled.rename_column("condition","label")
trainset_organ_shuffled = trainset_organ_shuffled.remove_columns("organ_major")
# create dictionary of cell types : label ids
target_names = list(Counter(trainset_organ_shuffled["label"]).keys())
target_name_id_dict = dict(zip(target_names,[i for i in range(len(target_names))]))
target_dict_list += [target_name_id_dict]
# change labels to numerical ids
def classes_to_ids(example):
example["label"] = target_name_id_dict[example["label"]]
return example
labeled_trainset = trainset_organ_shuffled.map(classes_to_ids, num_proc=16)
# create 80/20 train/eval splits
labeled_train_split = labeled_trainset.select([i for i in range(0,round(len(labeled_trainset)*0.8))])
labeled_eval_split = labeled_trainset.select([i for i in range(round(len(labeled_trainset)*0.8),len(labeled_trainset))])
# filter dataset for cell types in corresponding training set
trained_labels = list(Counter(labeled_train_split["label"]).keys())
def if_trained_label(example):
return example["label"] in trained_labels
labeled_eval_split_subset = labeled_eval_split.filter(if_trained_label, num_proc=16)
dataset_list += [labeled_train_split]
evalset_list += [labeled_eval_split_subset]
A
trainset_dict = dict(zip(organ_list,dataset_list))
traintargetdict_dict = dict(zip(organ_list,target_dict_list))
evalset_dict = dict(zip(organ_list,evalset_list))
Fine-Tune With Cell Classification Learning Objective and Quantify Predictive Performance¶
def compute_metrics(pred):
labels = pred.label_ids
preds = pred.predictions.argmax(-1)
# calculate accuracy and macro f1 using sklearn's function
acc = accuracy_score(labels, preds)
macro_f1 = f1_score(labels, preds, average='macro')
return {
'accuracy': acc,
'macro_f1': macro_f1
}
Please note that, as usual with deep learning models, we highly recommend tuning learning hyperparameters for all fine-tuning applications as this can significantly improve model performance. Example hyperparameters are defined below, but please see the "hyperparam_optimiz_for_disease_classifier" script for an example of how to tune hyperparameters for downstream applications.¶
# set model parameters
# max input size
max_input_size = 2 ** 11 # 2048
# set training hyperparameters
# max learning rate
max_lr = 5e-5
# how many pretrained layers to freeze
freeze_layers = 0
# number gpus
num_gpus = 1
# number cpu cores
num_proc = 16
# batch size for training and eval
geneformer_batch_size = 5 #for 3090, 5 is maximum.
# learning schedule
lr_schedule_fn = "linear"
# warmup steps
warmup_steps = 500
# number of epochs
epochs = 10
# optimizer
optimizer = "adamw"
# os.environ['CUDA_LAUNCH_BLOCKING'] ='1'
for organ in organ_list:
print(organ)
organ_trainset = trainset_dict[organ]
organ_evalset = evalset_dict[organ]
organ_label_dict = traintargetdict_dict[organ]
# set logging steps
logging_steps = round(len(organ_trainset)/geneformer_batch_size/10)
# reload pretrained model
model = BertForSequenceClassification.from_pretrained("../geneformer-12L-30M/",
# num_labels=len(organ_label_dict.keys()),
num_labels=2,
output_attentions = False,
output_hidden_states = False).to("cuda")
# define output directory path
current_date = datetime.datetime.now()
datestamp = f"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}"
output_dir = f"../models/{datestamp}_geneformer_CellClassifier_{organ}_L{max_input_size}_B{geneformer_batch_size}_LR{max_lr}_LS{lr_schedule_fn}_WU{warmup_steps}_E{epochs}_O{optimizer}_F{freeze_layers}/"
# ensure not overwriting previously saved model
saved_model_test = os.path.join(output_dir, f"pytorch_model.bin")
if os.path.isfile(saved_model_test) == True:
raise Exception("Model already saved to this directory.")
# make output directory
subprocess.call(f'mkdir {output_dir}', shell=True)
# set training arguments
training_args = {
"learning_rate": max_lr,
"do_train": True,
"do_eval": True,
"evaluation_strategy": "epoch",
"save_strategy": "epoch",
"logging_steps": logging_steps,
"group_by_length": True,
"length_column_name": "length",
"disable_tqdm": False,
"lr_scheduler_type": lr_schedule_fn,
"warmup_steps": warmup_steps,
"weight_decay": 0.001,
"per_device_train_batch_size": geneformer_batch_size,
"per_device_eval_batch_size": geneformer_batch_size,
"num_train_epochs": epochs,
"load_best_model_at_end": True,
"output_dir": output_dir,
}
training_args_init = TrainingArguments(**training_args)
# create the trainer
trainer = Trainer(
model=model,
args=training_args_init,
data_collator=DataCollatorForCellClassification(),
train_dataset=organ_trainset,
eval_dataset=organ_evalset,
compute_metrics=compute_metrics
)
# train the cell type classifier
trainer.train()
predictions = trainer.predict(organ_evalset)
with open(f"{output_dir}predictions.pickle", "wb") as fp:
pickle.dump(predictions, fp)
trainer.save_metrics("eval",predictions.metrics)
trainer.save_model(output_dir)
A
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at ../geneformer-12L-30M/ and are newly initialized: ['bert.pooler.dense.weight', 'classifier.weight', 'classifier.bias', 'bert.pooler.dense.bias'] You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference. mkdir: cannot create directory ‘../models/230907_geneformer_CellClassifier_A_L2048_B5_LR5e-05_LSlinear_WU500_E10_Oadamw_F0/’: File exists /mnt/disk/mwang/conda3/lib/python3.11/site-packages/geneformer/collator_for_classification.py:581: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor). batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
Epoch | Training Loss | Validation Loss | Accuracy | Macro F1 |
---|---|---|---|---|
1 | 0.478000 | 0.446698 | 0.802143 | 0.797032 |
2 | 0.360400 | 0.520274 | 0.831269 | 0.830752 |
3 | 0.335900 | 0.584502 | 0.841647 | 0.841416 |
4 | 0.240000 | 0.727501 | 0.847673 | 0.847564 |
5 | 0.128000 | 0.736757 | 0.873786 | 0.872867 |
6 | 0.110300 | 0.900706 | 0.873452 | 0.871618 |
7 | 0.057700 | 1.003406 | 0.876130 | 0.874294 |
8 | 0.012700 | 1.131091 | 0.878473 | 0.877133 |
9 | 0.011200 | 1.234100 | 0.876465 | 0.874432 |
10 | 0.008600 | 1.168336 | 0.878473 | 0.877268 |
/mnt/disk/mwang/conda3/lib/python3.11/site-packages/geneformer/collator_for_classification.py:581: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor). batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()} /mnt/disk/mwang/conda3/lib/python3.11/site-packages/geneformer/collator_for_classification.py:581: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor). batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()} /mnt/disk/mwang/conda3/lib/python3.11/site-packages/geneformer/collator_for_classification.py:581: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor). batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()} /mnt/disk/mwang/conda3/lib/python3.11/site-packages/geneformer/collator_for_classification.py:581: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor). batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()} /mnt/disk/mwang/conda3/lib/python3.11/site-packages/geneformer/collator_for_classification.py:581: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor). batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()} /mnt/disk/mwang/conda3/lib/python3.11/site-packages/geneformer/collator_for_classification.py:581: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor). batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()} /mnt/disk/mwang/conda3/lib/python3.11/site-packages/geneformer/collator_for_classification.py:581: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor). batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()} /mnt/disk/mwang/conda3/lib/python3.11/site-packages/geneformer/collator_for_classification.py:581: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor). batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()} /mnt/disk/mwang/conda3/lib/python3.11/site-packages/geneformer/collator_for_classification.py:581: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor). batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()} /mnt/disk/mwang/conda3/lib/python3.11/site-packages/geneformer/collator_for_classification.py:581: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor). batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
Useless¶
# GPU_NUMBER = [i for i in range(2)] # please!!!!!!!! don't change this !!!!!!!! seems that parallel is not supported in hpc3.
# os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(s) for s in GPU_NUMBER])
# epochs=4
# for organ in organ_list:
# print(organ)
# organ_trainset = trainset_dict[organ]
# organ_evalset = evalset_dict[organ]
# organ_label_dict = traintargetdict_dict[organ]
# # set logging steps
# logging_steps = round(len(organ_trainset)/geneformer_batch_size/10)
# # reload pretrained model
# model = BertForSequenceClassification.from_pretrained("../geneformer-12L-30M/",
# # num_labels=len(organ_label_dict.keys()),
# num_labels=2,
# output_attentions = False,
# output_hidden_states = False).to("cuda")
# # define output directory path
# current_date = datetime.datetime.now()
# datestamp = f"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}"
# output_dir = f"../models/{datestamp}_geneformer_CellClassifier_{organ}_L{max_input_size}_B{geneformer_batch_size}_LR{max_lr}_LS{lr_schedule_fn}_WU{warmup_steps}_E{epochs}_O{optimizer}_F{freeze_layers}/"
# # ensure not overwriting previously saved model
# saved_model_test = os.path.join(output_dir, f"pytorch_model.bin")
# if os.path.isfile(saved_model_test) == True:
# raise Exception("Model already saved to this directory.")
# # make output directory
# subprocess.call(f'mkdir {output_dir}', shell=True)
# # set training arguments
# training_args = {
# "learning_rate": max_lr,
# "do_train": True,
# "do_eval": True,
# "evaluation_strategy": "epoch",
# "save_strategy": "epoch",
# "logging_steps": logging_steps,
# "group_by_length": True,
# "length_column_name": "length",
# "disable_tqdm": False,
# "lr_scheduler_type": lr_schedule_fn,
# "warmup_steps": warmup_steps,
# "weight_decay": 0.001,
# "per_device_train_batch_size": geneformer_batch_size,
# "per_device_eval_batch_size": geneformer_batch_size,
# "num_train_epochs": epochs,
# "load_best_model_at_end": True,
# "output_dir": output_dir,
# }
# training_args_init = TrainingArguments(**training_args)
# # create the trainer
# trainer = Trainer(
# model=model,
# args=training_args_init,
# data_collator=DataCollatorForCellClassification(),
# train_dataset=organ_trainset,
# eval_dataset=organ_evalset,
# compute_metrics=compute_metrics
# )
# # train the cell type classifier
# trainer.train()
# predictions = trainer.predict(organ_evalset)
# with open(f"{output_dir}predictions.pickle", "wb") as fp:
# pickle.dump(predictions, fp)
# trainer.save_metrics("eval",predictions.metrics)
# trainer.save_model(output_dir)