import scanpy as sc
import pandas as pd
import matplotlib.pyplot as pl
import seaborn as sns
import numpy as np
fontsize=15
params = {'legend.fontsize': fontsize,
'figure.figsize': (8, 8),
'figure.dpi': 150,
'axes.labelsize': fontsize,
'axes.titlesize':fontsize,
'xtick.labelsize':fontsize,
'ytick.labelsize':fontsize}
pl.rcParams.update(params)
#sns.set(style='white', rc={'figure.figsize':(5,5), 'figure.dpi':150})
def run_embedding(adata,path="./",n_neighbors=15, init=None, method='umap', resolution=1.0, min_dist=0.5):
sc.pp.neighbors(adata,use_rep='X_vipcca')
sc.tl.louvain(adata, directed=False,resolution =resolution)
if method in ("umap","all"):
sc.tl.umap(adata,min_dist=min_dist)
if method in ("tsne","all"):
sc.tl.tsne(adata, use_rep="X_vipcca",n_jobs=20)
def plotEmbedding(adata,path,group_by="batch",ncol=5, method="umap",legend_loc=None,frameon=True,legend_fontsize=6,title=""):
filename=path+"2dplot_"+group_by+"_"+method+".png"
if method in ("umap","all"):
sc.pl.umap(adata,frameon =frameon,color=group_by,show=False,use_raw=False,legend_loc=legend_loc,legend_fontsize=legend_fontsize)
pl.title(title)
pl.savefig(filename)
pl.close()
if method in ("tsne","all"):
sc.pl.tsne(adata,frameon =frameon,color=group_by,show=False,legend_fontsize=legend_fontsize)
pl.title("")
pl.legend(loc=3,fontsize=legend_fontsize,mode="expand",bbox_to_anchor=(0.0, 1.01, 1, 0.2),ncol=ncol)
pl.savefig(filename)
pl.close()
if method in ("location","all"):
sc.pl.scatter(adata, x='xcoord',y='ycoord',frameon =frameon,color=group_by,show=False,use_raw=False,legend_loc="right margin",legend_fontsize=legend_fontsize)
pl.title("")
pl.savefig(filename)
pl.close()
def plotDEG(adata,path,group_by="louvain",method="wilcoxon"):
current_path=path+"deg_"+group_by+"_"+method+"_"
sc.tl.pca(adata)
sc.tl.dendrogram(adata,groupby=group_by,use_rep ="X_pca")
sc.tl.rank_genes_groups(adata,groupby=group_by,use_raw=False,n_genes=100,method=method,rankby_abs=False,corr_method="benjamini-hochberg")
sc.tl.filter_rank_genes_groups(adata, min_fold_change=3)
sc.pl.rank_genes_groups(adata,show=False)
pl.savefig(current_path+"rank_genes_groups.png")
sc.pl.rank_genes_groups(adata, key='rank_genes_groups_filtered',show=False)
pl.savefig(current_path+"rank_genes_groups_filtered.png")
sc.pl.rank_genes_groups_dotplot(adata, key='rank_genes_groups_filtered')
pl.savefig(current_path+"rank_genes_groups_filtered_dotplot.png")
top_ranked_genes=pd.DataFrame(adata.uns['rank_genes_groups']['names'][range(1)])
top_ranked_genes_index = pd.Index(top_ranked_genes.values.flatten()).drop_duplicates(keep='first')
sc.pl.stacked_violin(adata,top_ranked_genes_index,groupby=group_by,use_raw=False,show=False)
pl.savefig(current_path+"stacked_violin.png")
sc.pl.heatmap(adata,top_ranked_genes_index,groupby=group_by,use_raw=False,show=False,swap_axes=False)
pl.savefig(current_path+"headmap.png")
pl.close()
def plotDEG2(adata,path,key_batch="batch",key_celltype="celltype",method="wilcoxon", mode=None):
celltypes=adata.obs[key_celltype].cat.categories.values
# sc.pp.scale(adata)
for ct in celltypes:
if np.sum(adata.obs[key_celltype]==ct)<10:
continue
adata_sub = adata[adata.obs[key_celltype]==ct]
sc.tl.rank_genes_groups(adata_sub, groupby=key_batch, use_raw=False, n_genes=adata.shape[1], rankby_abs=False)
catname=adata_sub.obs[key_batch].cat.categories.values[0]
pvals_adj=adata_sub.uns['rank_genes_groups']['pvals_adj'][catname]
logfc=adata_sub.uns['rank_genes_groups']['logfoldchanges'][catname]
t=np.isnan(logfc)
logfc=logfc[~t]
pvals_adj=pvals_adj[~t]
# f=np.absolute(logfc)<10
# logfc=logfc[f]
# pvals_adj=pvals_adj[f]
indicator_neg=np.logical_and(logfc<0, pvals_adj<1e-50)
indicator_pos=np.logical_and(logfc>0, pvals_adj<1e-50)
ndeg=np.sum(indicator_neg)
pdeg=np.sum(indicator_pos)
cv=np.repeat('k', len(logfc))
cv[indicator_pos]='r'
cv[indicator_neg]='b'
pl.scatter(logfc, -np.log10(pvals_adj+1e-300), c=cv, s=1)
pl.axvline(x=0, linewidth=0.5, color='c')
from adjustText import adjust_text
text1=pl.text(-7.5, 305, '%d significant genes'%np.int(ndeg))
text2=pl.text(2.5, 305, '%d significant genes'%np.int(pdeg))
adjust_text([text1,text2])
if mode is not None:
tt="{} ({})".format(mode,ct)
pl.title(tt)
pl.xlabel(r'$log_2(FC)\ (ctr/stim)$')
pl.ylabel(r'$-log_{10}(FDR\ adjusted\ p-value)$')
pl.savefig(path+"dge_%s.png"%ct.replace('+','p'))
pl.close()
def runGeoSketch(adata,N=10000,use_rep="X_pca"):
from geosketch import gs
sc.tl.pca(adata)
sketch_index = gs(adata.obsm[use_rep], N, replace=False)
adata.uns['geosketch']=adata.obs.index[sketch_index]
subdata = adata[adata.obs.index[sketch_index]]
return subdata
def plotQQdeg(adata,path,groupby="batch",method="wilcoxon"):
sc.tl.rank_genes_groups(adata, groupby=groupby, method=method,n_genes=adata.shape[1],rankby_abs=True,use_raw=False)
result = adata.uns['rank_genes_groups']
groups = result['names'].dtype.names
df = pd.DataFrame({group + '_' + key[:1]: result[key][group] for group in groups for key in ['names', 'pvals_adj']})
df.to_csv(path+"markers_"+groupby+"_"+method+".csv")
pvals=adata.uns['rank_genes_groups']['pvals_adj']['0']+1e-260
def plotQQdeg2(adata,path,groupby="batch",method="wilcoxon"):
for c in adata.obs.cell_type.cat.categories.values:
adatatemp= adata[adata.obs.cell_type==c,:]
patht=path+"pvals"+c.replace("/","_")
plotQQdeg(adatatemp,patht)
def plotPrediction(err,result_path):
err=err[err<1000]
x=range(len(err))
pl.scatter(x,err,c='r',s=1)
pl.savefig(result_path+"square_error.png")
pl.close()
[docs]def plotCorrelation(y,y_pred, save=True, result_path='./', show=True, rnum=1e4, lim=20):
"""\
Plot correlation between original data and corrected data
Parameters
----------
y: matrix or csr_matrix
The original data matrix.
y_pred: matrix or csr_matrix
The data matrix integrated by vipcca.
save: bool, optional (default: True)
If True, save the figure into result_path.
result_path: string, optional (default: './')
The path for saving the figure.
show: bool, optional (default: True)
If True, show the figure.
rnum: double, optional (default: 1e4)
The number of points you want to sample randomly in the matrix.
lim: int, optional (default: 20)
the right parameter of matplotlib.pyplot.xlim(left, right)
"""
from scipy.sparse import csr_matrix
if (isinstance(y, csr_matrix)):
y = y.toarray()
rx = np.random.choice(y.shape[0], np.int(rnum), replace=True)
ry = np.random.choice(y.shape[1], np.int(rnum), replace=True)
pl.rcParams['figure.figsize'] = (8, 7)
pl.scatter(y[rx, ry], y_pred[rx, ry], c='r', s=1)
pl.xlim(-1, lim)
pl.ylim(-1, lim)
pl.xlabel('uncorrected_x')
pl.ylabel('corrected_x')
if show:
pl.show()
if save:
pl.savefig(result_path+"correlation.png")