Geneformer | 基因分类预测

Gene Classification

上一篇文章写了用Geneformer如何做细胞分类,这一次记录用Genefomer做基因分类的过程,例如预测基因是否为药物敏感性TF。

首先,下载基因分类相关数据

cd Genecorpus-30M/example_input_files/cell_classification/disease_classifiction/human_dcm_hcm_nf.dataset
wget https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/resolve/main/example_input_files/cell_classification/disease_classification/human_dcm_hcm_nf.dataset/dataset.arrow
 
wget https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/resolve/main/example_input_files/cell_classification/disease_classification/human_dcm_hcm_nf.dataset/dataset_info.json

wget https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/resolve/main/example_input_files/cell_classification/disease_classification/human_dcm_hcm_nf.dataset/state.json

作者提供一组心肌炎相关的scRNA-seq数据,其中包含来自non-failing (nf), hypertrophic, and dilated样本的数据,以及是否为对药物敏感的转录因子的gene list。根据这些数据进行微调,随后判断基因是否为对药物敏感的转录因子。

微调数据:sc-RNA-seq data and gene labels;

下游任务:判断TFs的药物敏感性。

Modules import

import os
GPU_NUMBER = [0]
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(s) for s in GPU_NUMBER])
os.environ["NCCL_DEBUG"] = "INFO"
# imports
import datetime
import subprocess
import math
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from datasets import load_from_disk
from sklearn import preprocessing
from sklearn.metrics import accuracy_score, auc, confusion_matrix, ConfusionMatrixDisplay, roc_curve
from sklearn.model_selection import StratifiedKFold
import torch
from transformers import BertForTokenClassification
from transformers import Trainer
from transformers.training_args import TrainingArguments
from tqdm.notebook import tqdm

from geneformer import DataCollatorForGeneClassification
from geneformer.pretrainer import token_dictionary
e:\miniconda3\envs\geneformer\lib\site-packages\loompy\bus_file.py:68: NumbaDeprecationWarning: �[1mThe 'nopython' keyword argument was not supplied to the 'numba.jit' decorator. The implicit default value for this argument is currently False, but it will be changed to True in Numba 0.59.0. See https://numba.readthedocs.io/en/stable/reference/deprecation.html#deprecation-of-object-mode-fall-back-behaviour-when-using-jit for details.�[0m
  def twobit_to_dna(twobit: int, size: int) -> str:
e:\miniconda3\envs\geneformer\lib\site-packages\loompy\bus_file.py:85: NumbaDeprecationWarning: �[1mThe 'nopython' keyword argument was not supplied to the 'numba.jit' decorator. The implicit default value for this argument is currently False, but it will be changed to True in Numba 0.59.0. See https://numba.readthedocs.io/en/stable/reference/deprecation.html#deprecation-of-object-mode-fall-back-behaviour-when-using-jit for details.�[0m
  def dna_to_twobit(dna: str) -> int:
e:\miniconda3\envs\geneformer\lib\site-packages\loompy\bus_file.py:102: NumbaDeprecationWarning: �[1mThe 'nopython' keyword argument was not supplied to the 'numba.jit' decorator. The implicit default value for this argument is currently False, but it will be changed to True in Numba 0.59.0. See https://numba.readthedocs.io/en/stable/reference/deprecation.html#deprecation-of-object-mode-fall-back-behaviour-when-using-jit for details.�[0m
  def twobit_1hamming(twobit: int, size: int) -> List[int]:

Load Gene Attribute Information

读入作者提供的基因信息表格,包括了ensembl id, gene name和gene type信息。再将这些信息分别封装到三个字典中(gene_id_type_dict, gene_name_id_dict, gene_id_name_dict).

# table of corresponding Ensembl IDs, gene names, and gene types (e.g. coding, miRNA, etc.)
gene_info = pd.read_csv("D:/jupyterNote/Geneformer/Genecorpus-30M/example_input_files/gene_info_table.csv", index_col=0)

# create dictionaries for corresponding attributes
gene_id_type_dict = dict(zip(gene_info["ensembl_id"],gene_info["gene_type"]))
gene_name_id_dict = dict(zip(gene_info["gene_name"],gene_info["ensembl_id"]))
gene_id_name_dict = {v: k for k,v in gene_name_id_dict.items()}

# first 5 key:value pairs
{k: gene_id_name_dict[k] for k in list(gene_id_name_dict)[:5]}
{'ENSG00000000003': 'TSPAN6',
 'ENSG00000000005': 'TNMD',
 'ENSG00000000419': 'DPM1',
 'ENSG00000000457': 'SCYL3',
 'ENSG00000000460': 'C1orf112'}

Load Training Data and Class Labels

接下来,读入微调训练相关数据集,包括心肌炎相关的scRNA-seq数据 ("human_dcm_hcm_nf.dataset")和是否为对药物敏感的转录因子的gene list ("dosage_sens_tf_labels.csv")

为了处理读入的dosage_sens_tf_labels,这里定义函数prep_inputs将输入的基因id转换为token id,并生成genegroup1genegroup2相应长度的labels(group1记为0, group2记为1).

token_dictionary中定义了ensembl id和token的对应关系。

# function for preparing targets and labels
def prep_inputs(genegroup1, genegroup2, id_type):
    if id_type == "gene_name":
        targets1 = [gene_name_id_dict[gene] for gene in genegroup1 if gene_name_id_dict.get(gene) in token_dictionary]
        targets2 = [gene_name_id_dict[gene] for gene in genegroup2 if gene_name_id_dict.get(gene) in token_dictionary]
    elif id_type == "ensembl_id":
        targets1 = [gene for gene in genegroup1 if gene in token_dictionary]
        targets2 = [gene for gene in genegroup2 if gene in token_dictionary]
            
    targets1_id = [token_dictionary[gene] for gene in targets1]
    targets2_id = [token_dictionary[gene] for gene in targets2]
    
    targets = np.array(targets1_id + targets2_id)
    labels = np.array([0]*len(targets1_id) + [1]*len(targets2_id))
    nsplits = min(5, min(len(targets1_id), len(targets2_id))-1)
    assert nsplits > 2
    print(f"# targets1: {len(targets1_id)}\n# targets2: {len(targets2_id)}\n# splits: {nsplits}")
    return targets, labels, nsplits
{k: token_dictionary[k] for k in list(token_dictionary)[:5]}
{'<pad>': 0,
 '<mask>': 1,
 'ENSG00000000003': 2,
 'ENSG00000000005': 3,
 'ENSG00000000419': 4}

读入作者提供的dosage sensitive tfs list,其中包含122 dosage sensitive tfs (0),和368个insensitive tfs (1). 使用prep_inputs将tfs的基因id转换为token,并划分为5个splits,做后续的5-fold cross-validation

from collections import Counter

# preparing targets and labels for dosage sensitive vs insensitive TFs
dosage_tfs = pd.read_csv("D:/jupyterNote/Geneformer/Genecorpus-30M/example_input_files/gene_classification/dosage_sensitive_tfs/dosage_sens_tf_labels.csv", header=0)
sensitive = dosage_tfs["dosage_sensitive"].dropna()
insensitive = dosage_tfs["dosage_insensitive"].dropna()
targets, labels, nsplits = prep_inputs(sensitive, insensitive, "ensembl_id")
print(targets[0:5])
print(Counter(labels))
# targets1: 122
# targets2: 368
# splits: 5
[208 223 275 295 487]
Counter({1: 368, 0: 122})

读入作者提供的心肌炎scRNA-seq进行微调(fine-tune),其中包含579,159个细胞,21种celltypes;

3种亚型分组:1. NF (Non-failing), 2. HCM (hypertrophic cardiomyopathy), and 3. DCM (dilated cardiomyopathy).

在打乱细胞标签后,随机抽取了50,000个细胞作为training set.

# load training dataset
train_dataset=load_from_disk("D:/jupyterNote/Geneformer/Genecorpus-30M/example_input_files/cell_classification/disease_classification/human_dcm_hcm_nf.dataset")
shuffled_train_dataset = train_dataset.shuffle(seed=42)
subsampled_train_dataset = shuffled_train_dataset.select([i for i in range(50_000)])

Loading cached shuffled indices for dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\disease_classification\human_dcm_hcm_nf.dataset\cache-54b519f110fa07f1.arrow
#import pandas as pd
print(train_dataset)
print("\nCelltype: ")
print(Counter(train_dataset['cell_type']))
print("\nSubgroups: ")
print(Counter(train_dataset['disease']))

print(subsampled_train_dataset)
Dataset({
    features: ['input_ids', 'length', 'cell_type', 'individual', 'age', 'sex', 'disease', 'lvef'],
    num_rows: 579159
})

Celltype: 
Counter({'Fibroblast1': 141725, 'Cardiomyocyte1': 136167, 'Endothelial1': 78375, 'Pericyte1': 67600, 'Macrophage': 54714, 'Endothelial2': 18394, 'VSMC': 18137, 'Lymphocyte': 16246, 'Endocardial': 6489, 'Cardiomyocyte2': 5445, 'Adipocyte': 5298, 'ActivatedFibroblast': 5210, 'LymphaticEndothelial': 5181, 'Endothelial3': 4538, 'MastCell': 4465, 'Neuronal': 4292, 'Cardiomyocyte3': 3350, 'Pericyte2': 1704, 'ProliferatingMacrophage': 1276, 'Fibroblast2': 284, 'Epicardial': 269})

Subgroups: 
Counter({'hcm': 230652, 'nf': 182317, 'dcm': 166190})

Define Functions for Training and Cross-Validating Classifier

Geneformer将细胞基因表达量转为rank value encoding,且每个细胞的rank encoding长度不一样,而后续模型要求input tensors的长度一致。因此,这里定义函数preprocess_classifier_batch将不同长度的input都添加<pad> token到统一长度。

classifier_predict将input dataset 划分为forward_batch_size大小的batch利用fine-tuned的模型进行prediction,预测基因属于dosage sensitive or insensitive. 同时,根据预测labels与真实labels计算相应evaluation metrics (e.g., FPR, TPR)。

注意,如果使用GPU训练,且GPU内存太小,需要相应降低forward_batch_size,这里我调整至forward_batch_size=20

def preprocess_classifier_batch(cell_batch, max_len):
    if max_len == None:
        max_len = max([len(i) for i in cell_batch["input_ids"]])
    def pad_label_example(example):
        example["labels"] = np.pad(example["labels"], 
                                   (0, max_len-len(example["input_ids"])), 
                                   mode='constant', constant_values=-100)
        example["input_ids"] = np.pad(example["input_ids"], 
                                      (0, max_len-len(example["input_ids"])), 
                                      mode='constant', constant_values=token_dictionary.get("<pad>"))
        example["attention_mask"] = (example["input_ids"] != token_dictionary.get("<pad>")).astype(int)
        return example
    padded_batch = cell_batch.map(pad_label_example)
    return padded_batch

# forward batch size is batch size for model inference (e.g. 200)
def classifier_predict(model, evalset, forward_batch_size, mean_fpr):
    predict_logits = []
    predict_labels = []
    model.eval()
    
    # ensure there is at least 2 examples in each batch to avoid incorrect tensor dims
    evalset_len = len(evalset)
    max_divisible = find_largest_div(evalset_len, forward_batch_size)
    if len(evalset) - max_divisible == 1:
        evalset_len = max_divisible
    
    max_evalset_len = max(evalset.select([i for i in range(evalset_len)])["length"])
    
    for i in range(0, evalset_len, forward_batch_size):
        max_range = min(i+forward_batch_size, evalset_len)
        batch_evalset = evalset.select([i for i in range(i, max_range)])
        padded_batch = preprocess_classifier_batch(batch_evalset, max_evalset_len)
        padded_batch.set_format(type="torch")
        
        input_data_batch = padded_batch["input_ids"]
        attn_msk_batch = padded_batch["attention_mask"]
        label_batch = padded_batch["labels"]
        with torch.no_grad():
            outputs = model(
                input_ids = input_data_batch.to("cuda"), 
                attention_mask = attn_msk_batch.to("cuda"), 
                labels = label_batch.to("cuda"), 
            )
            predict_logits += [torch.squeeze(outputs.logits.to("cpu"))]
            predict_labels += [torch.squeeze(label_batch.to("cpu"))]
            
    logits_by_cell = torch.cat(predict_logits)
    all_logits = logits_by_cell.reshape(-1, logits_by_cell.shape[2])
    labels_by_cell = torch.cat(predict_labels)
    all_labels = torch.flatten(labels_by_cell)
    logit_label_paired = [item for item in list(zip(all_logits.tolist(), all_labels.tolist())) if item[1]!=-100]
    y_pred = [vote(item[0]) for item in logit_label_paired]
    y_true = [item[1] for item in logit_label_paired]
    logits_list = [item[0] for item in logit_label_paired]
    # probability of class 1
    y_score = [py_softmax(item)[1] for item in logits_list]
    conf_mat = confusion_matrix(y_true, y_pred)
    fpr, tpr, _ = roc_curve(y_true, y_score)
    # plot roc_curve for this split
    plt.plot(fpr, tpr)
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('ROC')
    plt.show()
    # interpolate to graph
    interp_tpr = np.interp(mean_fpr, fpr, tpr)
    interp_tpr[0] = 0.0
    return fpr, tpr, interp_tpr, conf_mat 

def vote(logit_pair):
    a, b = logit_pair
    if a > b:
        return 0
    elif b > a:
        return 1
    elif a == b:
        return "tie"
    
def py_softmax(vector):
    e = np.exp(vector)
    return e / e.sum()
    
# get cross-validated mean and sd metrics
def get_cross_valid_metrics(all_tpr, all_roc_auc, all_tpr_wt):
    wts = [count/sum(all_tpr_wt) for count in all_tpr_wt]
    print(wts)
    all_weighted_tpr = [a*b for a,b in zip(all_tpr, wts)]
    mean_tpr = np.sum(all_weighted_tpr, axis=0)
    mean_tpr[-1] = 1.0
    all_weighted_roc_auc = [a*b for a,b in zip(all_roc_auc, wts)]
    roc_auc = np.sum(all_weighted_roc_auc)
    roc_auc_sd = math.sqrt(np.average((all_roc_auc-roc_auc)**2, weights=wts))
    return mean_tpr, roc_auc, roc_auc_sd

# Function to find the largest number smaller
# than or equal to N that is divisible by k
def find_largest_div(N, K):
    rem = N % K
    if(rem == 0):
        return N
    else:
        return N - rem

定义函数cross_validate封装模型数据切分(80% training set, 10% evaluation set, 10% hold-out evaluation set)、训练和预测过程。

其中,读入预训练模型这部分需要改为本地Geneformer或是hugging face上库的名字 ("ctheodoris/Geneformer")

        # load model
        model = BertForTokenClassification.from_pretrained(
            "D:/jupyterNote/Geneformer", # change to local path to the model
            num_labels=2,
            output_attentions = False,
            output_hidden_states = False
        )

接下来,这部分代码根据定义的参数微调模型

        # add output directory to training args and initiate
        training_args["output_dir"] = ksplit_output_dir
        training_args_init = TrainingArguments(**training_args)
        
        # create the trainer
        trainer = Trainer(
            model=model,
            args=training_args_init,
            data_collator=DataCollatorForGeneClassification(),
            train_dataset=trainset_labeled,
            eval_dataset=evalset_train_labeled
        )

        # train the gene classifier
        trainer.train()

这部分代码使用微调模型在 out-of-sample dataset (evalset_oos_labeled) 进行预测及评估。

注意调整这里forward_batch_size以适应电脑配置。

        # evaluate model
        fpr, tpr, interp_tpr, conf_mat = classifier_predict(trainer.model, evalset_oos_labeled, 20, mean_fpr) # forward_batch_size: 20
        
        # append to tpr and roc lists
        confusion = confusion + conf_mat
        all_tpr.append(interp_tpr)
        all_roc_auc.append(auc(fpr, tpr))
        # append number of eval examples by which to weight tpr in averaged graphs
        all_tpr_wt.append(len(tpr))
  
# cross-validate gene classifier
def cross_validate(data, targets, labels, nsplits, subsample_size, training_args, freeze_layers, output_dir, num_proc):
    # check if output directory already written to
    # ensure not overwriting previously saved model
    model_dir_test = os.path.join(output_dir, "ksplit0/models/pytorch_model.bin")
    if os.path.isfile(model_dir_test) == True:
        raise Exception("Model already saved to this directory.")
    
    # initiate eval metrics to return
    num_classes = len(set(labels))
    mean_fpr = np.linspace(0, 1, 100)
    all_tpr = []
    all_roc_auc = []
    all_tpr_wt = []
    label_dicts = []
    confusion = np.zeros((num_classes,num_classes))
    
    # set up cross-validation splits
    skf = StratifiedKFold(n_splits=nsplits, random_state=0, shuffle=True)
    # train and evaluate
    iteration_num = 0
    for train_index, eval_index in tqdm(skf.split(targets, labels)):
        if len(labels) > 500:
            print("early stopping activated due to large # of training examples")
            nsplits = 3
            if iteration_num == 3:
                break
        print(f"****** Crossval split: {iteration_num}/{nsplits-1} ******\n")
        # generate cross-validation splits
        targets_train, targets_eval = targets[train_index], targets[eval_index]
        labels_train, labels_eval = labels[train_index], labels[eval_index]
        label_dict_train = dict(zip(targets_train, labels_train))
        label_dict_eval = dict(zip(targets_eval, labels_eval))
        label_dicts += (iteration_num, targets_train, targets_eval, labels_train, labels_eval)
        
        # function to filter by whether contains train or eval labels
        def if_contains_train_label(example):
            a = label_dict_train.keys()
            b = example['input_ids']
            return not set(a).isdisjoint(b)

        def if_contains_eval_label(example):
            a = label_dict_eval.keys()
            b = example['input_ids']
            return not set(a).isdisjoint(b)
        
        # filter dataset for examples containing classes for this split
        print(f"Filtering training data")
        trainset = data.filter(if_contains_train_label, num_proc=num_proc)
        print(f"Filtered {round((1-len(trainset)/len(data))*100)}%; {len(trainset)} remain\n")
        print(f"Filtering evalation data")
        evalset = data.filter(if_contains_eval_label, num_proc=num_proc)
        print(f"Filtered {round((1-len(evalset)/len(data))*100)}%; {len(evalset)} remain\n")

        # minimize to smaller training sample
        training_size = min(subsample_size, len(trainset))
        trainset_min = trainset.select([i for i in range(training_size)])
        eval_size = min(training_size, len(evalset))
        half_training_size = round(eval_size/2)
        evalset_train_min = evalset.select([i for i in range(half_training_size)])
        evalset_oos_min = evalset.select([i for i in range(half_training_size, eval_size)])
        
        # label conversion functions
        def generate_train_labels(example):
            example["labels"] = [label_dict_train.get(token_id, -100) for token_id in example["input_ids"]]
            return example

        def generate_eval_labels(example):
            example["labels"] = [label_dict_eval.get(token_id, -100) for token_id in example["input_ids"]]
            return example
        
        # label datasets 
        print(f"Labeling training data")
        trainset_labeled = trainset_min.map(generate_train_labels)
        print(f"Labeling evaluation data")
        evalset_train_labeled = evalset_train_min.map(generate_eval_labels)
        print(f"Labeling evaluation OOS data")
        evalset_oos_labeled = evalset_oos_min.map(generate_eval_labels)
        
        # create output directories

        ksplit_output_dir = os.path.join(output_dir, f"ksplit{iteration_num}")
        ksplit_model_dir = os.path.join(ksplit_output_dir, "models/") 
        
        # ensure not overwriting previously saved model
        model_output_file = os.path.join(ksplit_model_dir, "pytorch_model.bin")
        if os.path.isfile(model_output_file) == True:
            raise Exception("Model already saved to this directory.")

        # make training and model output directories
        subprocess.call(f'mkdir {ksplit_output_dir}', shell=True)
        subprocess.call(f'mkdir {ksplit_model_dir}', shell=True)
        
        # load model
        model = BertForTokenClassification.from_pretrained(
            "D:/jupyterNote/Geneformer", # change as the path to the model
            num_labels=2,
            output_attentions = False,
            output_hidden_states = False
        )
        if freeze_layers is not None:
            modules_to_freeze = model.bert.encoder.layer[:freeze_layers]
            for module in modules_to_freeze:
                for param in module.parameters():
                    param.requires_grad = False
                
        model = model.to("cuda:0")
        
        # add output directory to training args and initiate
        training_args["output_dir"] = ksplit_output_dir
        training_args_init = TrainingArguments(**training_args)
        
        # create the trainer
        trainer = Trainer(
            model=model,
            args=training_args_init,
            data_collator=DataCollatorForGeneClassification(),
            train_dataset=trainset_labeled,
            eval_dataset=evalset_train_labeled
        )

        # train the gene classifier
        trainer.train()
        
        # save model
        trainer.save_model(ksplit_model_dir)
        
        # evaluate model
        fpr, tpr, interp_tpr, conf_mat = classifier_predict(trainer.model, evalset_oos_labeled, 20, mean_fpr) # forward_batch_size: 20
        
        # append to tpr and roc lists
        confusion = confusion + conf_mat
        all_tpr.append(interp_tpr)
        all_roc_auc.append(auc(fpr, tpr))
        # append number of eval examples by which to weight tpr in averaged graphs
        all_tpr_wt.append(len(tpr))
        
        iteration_num = iteration_num + 1
        
    # get overall metrics for cross-validation
    mean_tpr, roc_auc, roc_auc_sd = get_cross_valid_metrics(all_tpr, all_roc_auc, all_tpr_wt)
    return all_roc_auc, roc_auc, roc_auc_sd, mean_fpr, mean_tpr, confusion, label_dicts

Define Functions for Plotting Results

定义一个画ROC曲线的函数plot_ROC和画混淆矩阵的函数plot_confusion_matrix

# plot ROC curve
def plot_ROC(bundled_data, title):
    plt.figure()
    lw = 2
    for roc_auc, roc_auc_sd, mean_fpr, mean_tpr, sample, color in bundled_data:
        plt.plot(mean_fpr, mean_tpr, color=color,
                 lw=lw, label="{0} (AUC {1:0.2f} $\pm$ {2:0.2f})".format(sample, roc_auc, roc_auc_sd))
    plt.plot([0, 1], [0, 1], color='black', lw=lw, linestyle='--')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title(title)
    plt.legend(loc="lower right")
    plt.show()
    
# plot confusion matrix
def plot_confusion_matrix(classes_list, conf_mat, title):
    display_labels = []
    i = 0
    for label in classes_list:
        display_labels += ["{0}\nn={1:.0f}".format(label, sum(conf_mat[:,i]))]
        i = i + 1
    display = ConfusionMatrixDisplay(confusion_matrix=preprocessing.normalize(conf_mat, norm="l1"), 
                                     display_labels=display_labels)
    display.plot(cmap="Blues",values_format=".2g")
    plt.title(title)

Fine-Tune With Gene Classification Learning Objective and Quantify Predictive Performance

定义模型微调的参数,同样的根据电脑配置调整num_gpus, num_proc, geneformer_batch_size.其余的超参延用预设的值,理论上超参也可以继续优化。

关于freeze_layers的选择,作者说下游任务和pretrain越相似的时候freeze_layers可以越大,即“记住”更多pretrain的weights (?).

Generally, in our experience, applications that are more relevant to the pretraining objective benefit from more layers being frozen to prevent overfitting to the limited task-specific data, whereas applications that are more distant from the pretraining objective benefit from fine-tuning of more layers to optimize performance on the new task.

# set model parameters
# max input size
max_input_size = 2 ** 11  # 2048

# set training hyperparameters
# max learning rate
max_lr = 5e-5
# how many pretrained layers to freeze
freeze_layers = 4
# number gpus
num_gpus = 1
# number cpu cores
num_proc = 6
# batch size for training and eval
geneformer_batch_size = 2
# learning schedule
lr_schedule_fn = "linear"
# warmup steps
warmup_steps = 500
# number of epochs
epochs = 1
# optimizer
optimizer = "adamw"
# set training arguments
subsample_size = 10_000
training_args = {
    "learning_rate": max_lr,
    "do_train": True,
    "evaluation_strategy": "no",
    "save_strategy": "epoch",
    "logging_steps": 100,
    "group_by_length": True,
    "length_column_name": "length",
    "disable_tqdm": False,
    "lr_scheduler_type": lr_schedule_fn,
    "warmup_steps": warmup_steps,
    "weight_decay": 0.001,
    "per_device_train_batch_size": geneformer_batch_size,
    "per_device_eval_batch_size": geneformer_batch_size,
    "num_train_epochs": epochs,
}
# define output directory path
current_date = datetime.datetime.now()
datestamp = f"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}"
training_output_dir = f"D:\\jupyterNote\\Geneformer\\examples\\gene_class_test\\{datestamp}_geneformer_GeneClassifier_dosageTF_L{max_input_size}_B{geneformer_batch_size}_LR{max_lr}_LS{lr_schedule_fn}_WU{warmup_steps}_E{epochs}_O{optimizer}_n{subsample_size}_F{freeze_layers}\\"

# ensure not overwriting previously saved model
ksplit_model_test = os.path.join(training_output_dir, "ksplit0/models/pytorch_model.bin")
if os.path.isfile(ksplit_model_test) == True:
    raise Exception("Model already saved to this directory.")

# make output directory
subprocess.call(f'mkdir {training_output_dir}', shell=True)
0
# clear GPU memory after pytorch training 
import torch
torch.cuda.empty_cache()
# not work
#!set 'PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:512' # Limit each allocation split to 500 MB

我们使用subsampled_train_dataset进行微调,其中包含50,000个细胞,每次抽取10,000个细胞做CV,一共做5次(nsplits=5).同样,将输入的targets和labels划分为80% training set (n = 392), 和 20% evaluation set (n = 98),这里采取的是stratified split,即不同split之间会有同样的数据。

这些划分的target和label存储在label_dicts中,其中每五个元素为一组,包括iteration_num, targets_train, targets_eval, labels_train, labels_eval.

cross_validate会打印每个split training相关的信息,包括training loss, learning_rate, epoch, ROC curve.

# cross-validate gene classifier
all_roc_auc, roc_auc, roc_auc_sd, mean_fpr, mean_tpr, confusion, label_dicts \
    = cross_validate(subsampled_train_dataset, targets, labels, nsplits, subsample_size, training_args, freeze_layers, training_output_dir, 1)
0it [00:00, ?it/s]


Loading cached processed dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\disease_classification\human_dcm_hcm_nf.dataset\cache-509acb05b140c462.arrow


****** Crossval split: 0/4 ******

Filtering training data
Filtered 0%; 49994 remain

Filtering evalation data
Split 0 training info...
****** Crossval split: 1/4 ******

Filtering training data
Filtered 0%; 49992 remain

Filtering evalation data
Filtered 4%; 47913 remain

Labeling training data
Split 1 training info...
****** Crossval split: 2/4 ******

Filtering training data
Filtered 0%; 49993 remain

Filtering evalation data
Filtered 4%; 47886 remain

Labeling training data
Split 2 training info...
****** Crossval split: 3/4 ******

Filtering training data
Filtered 0%; 49991 remain

Filtering evalation data
Filtered 4%; 48025 remain

Labeling training data
Split 3 training info...
****** Crossval split: 4/4 ******

Filtering training data
Filtered 0%; 49977 remain

Filtering evalation data
Filtered 2%; 48951 remain

Labeling training data
Split 4 training info...
[0.25172310458495656, 0.18719408650484468, 0.1628708420737189, 0.2369393666966337, 0.16127260013984618]
# bundle data for plotting
bundled_data = []
bundled_data += [(roc_auc, roc_auc_sd, mean_fpr, mean_tpr, "Geneformer", "red")]
# plot ROC curve
plot_ROC(bundled_data, 'Dosage Sensitive vs Insensitive TFs')
# plot confusion matrix
classes_list = ["Dosage Sensitive", "Dosage Insensitive"]
plot_confusion_matrix(classes_list, confusion, "Geneformer")

以上是5-fold CV的结果,我们接下来尝试用其中10,000个细胞微调的模型在其相应的out-of-sample evaluation set上进行gene classification.

我们首先读入第一个split的fine-tuned model,并将其转换到GPU上。该模型out_features=2即进行二分类预测。

# reload fine-tuned model
ft_model = BertForTokenClassification.from_pretrained("gene_class_test/230724_geneformer_GeneClassifier_dosageTF_L2048_B2_LR5e-05_LSlinear_WU500_E1_Oadamw_n10000_F4/ksplit0/models/")

ft_model.to('cuda:0')
print(ft_model)
BertForTokenClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(25426, 256, padding_idx=0)
      (position_embeddings): Embedding(2048, 256)
      (token_type_embeddings): Embedding(2, 256)
      (LayerNorm): LayerNorm((256,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.02, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-5): 6 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=256, out_features=256, bias=True)
              (key): Linear(in_features=256, out_features=256, bias=True)
              (value): Linear(in_features=256, out_features=256, bias=True)
              (dropout): Dropout(p=0.02, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=256, out_features=256, bias=True)
              (LayerNorm): LayerNorm((256,), eps=1e-12, elementwise_affine=True)
              (dropout): Dropout(p=0.02, inplace=False)
            )
          )
          (intermediate): BertIntermediate(
            (dense): Linear(in_features=256, out_features=512, bias=True)
            (intermediate_act_fn): ReLU()
          )
          (output): BertOutput(
            (dense): Linear(in_features=512, out_features=256, bias=True)
            (LayerNorm): LayerNorm((256,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.02, inplace=False)
          )
        )
      )
    )
  )
  (dropout): Dropout(p=0.02, inplace=False)
  (classifier): Linear(in_features=256, out_features=2, bias=True)
)

我们取出第一个split对应的evaluation targets and labels,并抽取出相应的evaluation set (evalset_oos_labeled)。

# out-of-sample evaluation set
# for set 0
label_dict_eval = dict(zip(label_dicts[2], label_dicts[4]))

def if_contains_eval_label(example, label_dict):
    a = label_dict.keys()
    b = example['input_ids']
    return not set(a).isdisjoint(b)

evalset0 = subsampled_train_dataset.filter(if_contains_eval_label, num_proc=2, fn_kwargs={"label_dict": label_dict_eval})
eval_size0 = min(10000, len(evalset0))
half_training_size = round(eval_size0/2)
evalset_oos_min = evalset0.select([i for i in range(half_training_size, eval_size0)])

def generate_eval_labels(example, label_dict):
    example["labels"] = [label_dict.get(token_id, -100) for token_id in example["input_ids"]]
    return example

evalset_oos_labeled = evalset_oos_min.map(generate_eval_labels, fn_kwargs={"label_dict": label_dict_eval})
evalset_oos_labeled
Dataset({
    features: ['input_ids', 'length', 'cell_type', 'individual', 'age', 'sex', 'disease', 'lvef', 'labels'],
    num_rows: 5000
})

这里我们修改一下原本的classifier_predict让其输出微调模型预测的label (y_pred), 真实label (y_true), 模型的预测值 (logits_list), 细胞ID (cell_id)和转录因子的token (token_id_dict).

# return prediction results
def get_classifier_predict(model, evalset, forward_batch_size):
    predict_logits = []# return prediction results
def get_classifier_predict(model, evalset, forward_batch_size):
    predict_logits = []
    predict_labels = []
    model.eval()
    cell_id = []
    token_id_dict = {}
    
    # ensure there is at least 2 examples in each batch to avoid incorrect tensor dims
    evalset_len = len(evalset)
    max_divisible = find_largest_div(evalset_len, forward_batch_size)
    if len(evalset) - max_divisible == 1:
        evalset_len = max_divisible
    
    max_evalset_len = max(evalset.select([i for i in range(evalset_len)])["length"])
    
    for i in range(0, evalset_len, forward_batch_size):
        max_range = min(i+forward_batch_size, evalset_len)
        batch_evalset = evalset.select([i for i in range(i, max_range)])
        padded_batch = preprocess_classifier_batch(batch_evalset, max_evalset_len)
        padded_batch.set_format(type="torch")
        
        # cell id
        cell_id += [i for i in range(i, max_range)]
        # store token id by cell j
        for j, tokens in enumerate(batch_evalset['input_ids']):
            cell_idx = range(i, max_range)[j]
            token_id_dict[cell_idx] = [tki for k, tki in enumerate(tokens) if batch_evalset['labels'][j][k] > -1]
        
        input_data_batch = padded_batch["input_ids"]
        attn_msk_batch = padded_batch["attention_mask"]
        label_batch = padded_batch["labels"]
        with torch.no_grad():
            outputs = model(
                input_ids = input_data_batch.to("cuda"), 
                attention_mask = attn_msk_batch.to("cuda"), 
                labels = label_batch.to("cuda"), 
            )
            predict_logits += [torch.squeeze(outputs.logits.to("cpu"))]
            predict_labels += [torch.squeeze(label_batch.to("cpu"))]
            
    logits_by_cell = torch.cat(predict_logits)
    all_logits = logits_by_cell.reshape(-1, logits_by_cell.shape[2])
    labels_by_cell = torch.cat(predict_labels)
    all_labels = torch.flatten(labels_by_cell)
    logit_label_paired = [item for item in list(zip(all_logits.tolist(), all_labels.tolist())) if item[1]!=-100]
    y_pred = [vote(item[0]) for item in logit_label_paired]
    y_true = [item[1] for item in logit_label_paired]
    logits_list = [item[0] for item in logit_label_paired]
    return y_pred, y_true, logits_list, cell_id, token_id_dict 
eval_pred, eval_label, eval_logits, cell_id, token_id = get_classifier_predict(model=ft_model, evalset=evalset_oos_labeled, forward_batch_size=20)
Model prediction info...

该模型输出两个分类的预测值,根据最大值来判断该基因的label。这里对每个细胞中的tf都进行了预测 (n = 27,939) .

print(Counter(eval_pred))
print(Counter(eval_label))

print(eval_logits[0:3])
print(eval_pred[0:3])
Counter({1: 16492, 0: 11447})
Counter({0: 14673, 1: 13266})
[[4.6540117263793945, -4.643155574798584], [5.055752277374268, -4.894111156463623], [0.701909065246582, -0.6132677793502808]]
[0, 0, 0]

接下来,我们统计各个转录因子出现的频率。

# # numbers of tfs (genes with 0/1 label) in out-of-sample evaluation set
# tf_num = [len([v for v in i if v >= 0]) for i in evalset_oos_labeled['labels']]
# sum(tf_num)

# frequencies of tokens
token_freq = Counter()

for tks in token_id.values():
    token_freq.update(tks)

token_freq
Counter({1636: 2636,
         9061: 2755,
         6754: 475,
         16718: 204,
         275: 1445,
         15866: 600,
         5084: 805,
         3361: 272,
         2410: 108,
         1757: 550,
         18597: 82,
         10422: 305,
         14481: 197,
         8218: 766,
         16619: 138,
         4071: 434,
         6931: 1052,
         14023: 468,
         7445: 699,
         4445: 157,
         17672: 983,
         3982: 547,
         5944: 552,
         5357: 359,
         20144: 237,
         6257: 137,
         6456: 185,
         16597: 437,
         2774: 216,
         15781: 553,
         20018: 386,
         23967: 427,
         21561: 218,
         12006: 116,
         20989: 339,
         15753: 199,
         487: 387,
         16016: 530,
         998: 496,
         8972: 382,
         6492: 269,
         14410: 180,
         14286: 228,
         12961: 228,
         8725: 26,
         2707: 82,
         17085: 262,
         15375: 72,
         13606: 313,
         10804: 317,
         12959: 527,
         12435: 202,
         16713: 359,
         12674: 184,
         20959: 88,
         16535: 348,
         21035: 131,
         11880: 34,
         23100: 347,
         21079: 114,
         20581: 284,
         15553: 249,
         14677: 63,
         954: 171,
         17147: 47,
         12995: 51,
         20962: 74,
         12165: 46,
         17092: 66,
         15717: 54,
         9024: 118,
         16555: 67,
         7705: 78,
         13722: 44,
         18778: 100,
         9831: 41,
         5789: 40,
         14124: 59,
         13954: 31,
         10534: 50,
         16425: 6,
         20787: 3,
         9367: 44,
         14578: 1,
         15180: 1,
         12243: 4,
         11443: 1,
         13066: 1})

这里我们随机看两个基因预测分类是否正确,其中gene 9061被预测准确,为药物敏感基因。而gene 16425预测值与标签值不匹配。

# append all tokens into one list
token_id_list = [tk for tks in token_id.values() for tk in tks]

# successed prediction
# get prediction of gene (token = 9061)
target_pred1 = [eval_pred[i] for i in token_id_list if i == 9061]
print("Predicted label of gene 9061: ")
print(Counter(target_pred1))
target_label1 = [eval_label[i] for i in token_id_list if i == 9061]
print("True label of gene 9061: ")
print(Counter(target_label1))

# failed prediction
# get prediction of gene (token = 16425)
target_pred2 = [eval_pred[i] for i in token_id_list if i == 16425]
print("Predicted label of gene 16425: ")
print(Counter(target_pred2))
target_label2 = [eval_label[i] for i in token_id_list if i == 16425]
print("True label of gene 16425: ")
print(Counter(target_label2))
Predicted label of gene 9061: 
Counter({1: 2755})
True label of gene 9061: 
Counter({1: 2755})
Predicted label of gene 16425: 
Counter({0: 6})
True label of gene 16425: 
Counter({1: 6})

总结

对于基因分类的微调,我们需要:

  1. 获取相应微调的数据集,并且有基因的label信息,例如某个TF是否为药物靶点之类的;

    关于数据集大小,从作者提供的例子来看,最少的情况是884个细胞,但其余下游任务都超过10k细胞

  2. BertForTokenClassification的方式读入预训练模型,并设置num_labels为分类数目;

  3. 根据微调的数据集训练,加上最后的输出层(task-specific transformer layer),并对微调模型预测性能进行评估;

  4. 在新的数据集上应用微调模型进行预测。

另外,作者最近更新上传了心肌炎单细胞数据微调的模型 (https://huggingface.co/ctheodoris/Geneformer/tree/main/fine_tuned_models/geneformer-6L-30M_CellClassifier_cardiomyopathies_220224)。大家也可以直接下载该模型使用。

Ref:

https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/tree/main/example_input_files/cell_classification/disease_classification/human_dcm_hcm_nf.dataset

Transfer learning enables predictions in network biology: https://doi.org/10.1038/s41586-023-06139-9

©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 206,723评论 6 481
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 88,485评论 2 382
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 152,998评论 0 344
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 55,323评论 1 279
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 64,355评论 5 374
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 49,079评论 1 285
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 38,389评论 3 400
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 37,019评论 0 259
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 43,519评论 1 300
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 35,971评论 2 325
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 38,100评论 1 333
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 33,738评论 4 324
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 39,293评论 3 307
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 30,289评论 0 19
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 31,517评论 1 262
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 45,547评论 2 354
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 42,834评论 2 345

推荐阅读更多精彩内容