-
Notifications
You must be signed in to change notification settings - Fork 2
/
correlation_across_storage.py
88 lines (76 loc) · 2.82 KB
/
correlation_across_storage.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
def _melt_table(
abundance_table: pd.DataFrame,
metadata: pd.DataFrame
) -> pd.DataFrame:
# melt the table so that we can merge it with metadata later
melt = abundance_table.unstack()
melt_table = pd.DataFrame(melt)
melt_table.reset_index(inplace=True)
melt_table.rename(columns={
melt_table.columns[0]:'sample',
melt_table.columns[1]:'taxa',
melt_table.columns[2]:'counts'}, inplace=True)
# merge metadata and abundance data
return pd.merge(metadata, melt_table, left_index=True, right_on='sample')
def _validate_columns(
table: pd.DataFrame,
taxa_column: str,
subject_column: str,
pivot_column: str,
) -> None:
if 'correlation_column' in table.columns:
raise Exception("table cannot already has 'correlation_column'")
if taxa_column not in table.columns:
raise Exception('table does not have {c}'.format(c=taxa_column))
if subject_column not in table.columns:
raise Exception('table does not have {c}'.format(c=subject_column))
if pivot_column not in table.columns:
raise Exception('table does not have {c}'.format(c=pivot_column))
def correlation_plot(
abundance_table: pd.DataFrame,
metadata: pd.DataFrame,
taxa_column: str,
subject_column: str,
pivot_column: str,
base_value: str,
save_loc: str,
) -> None:
table = _melt_table(abundance_table, metadata)
_validate_columns(
table, taxa_column, subject_column, pivot_column)
table['correlation_column'] = table[taxa_column] + table[subject_column]
pivot_table = table.pivot(
index='correlation_column', columns=pivot_column, values='counts')
correlation_table = pivot_table.corr()
correlation_table = correlation_table[[base_value]].reset_index()
correlation_table = correlation_table[
correlation_table[pivot_column] != base_value]
# add error bars
pivot_table['subject'] = pivot_table.index.str[-2:]
pearsons = []
for subject in pivot_table.subject.unique():
data = pivot_table[pivot_table['subject'] == subject].corr()
data['subject'] = subject
pearsons.append(data)
pearsons_table = pd.concat(pearsons)
pearsons_table.reset_index(inplace=True)
pearsons_table = pearsons_table[[pivot_column, base_value]]
pearsons_table = pearsons_table[pearsons_table[pivot_column] != base_value]
sns.set_context('talk')
fig, ax = plt.subplots()
fig.set_size_inches(5,5)
sns.barplot(x=pivot_column, y=base_value, data=pearsons_table, ax=ax,
errwidth=2, capsize=.2, edgecolor='black')
ax.set_ylabel('Pearson correlation to frozen swab', fontsize=16)
ax.set_ylim(0,1)
ax.set_xlabel('collection protocol')
ax.set_xticklabels(ax.get_xticklabels(), rotation=90)
fig.savefig(save_loc)
return correlation_table
# example how to run
#meta = pandas df metadata
#at = abundance table
#correlation_plot(at, meta, 'taxa', 'host_subject_id', 'volume_ml', 'swbFroz', 'correlation_fig.png')