-
Notifications
You must be signed in to change notification settings - Fork 452
/
rebuild_data.py
103 lines (83 loc) · 3.1 KB
/
rebuild_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import os
import sys
import json
import re
data_dir=sys.argv[1]
def unify_ent2id(ent2id, method='max'):
id2ent = {}
for k, v in ent2id.items():
if v in id2ent:
if method == 'min':
id2ent[v] = k if len(k) < len(id2ent[v]) else id2ent[v]
else:
id2ent[v] = k if len(k) > len(id2ent[v]) else id2ent[v]
else:
id2ent[v] = k
ent2id = {v:k for k, v in id2ent.items()}
return ent2id, id2ent
def sort_triples(triples, text):
sorted_triples = sorted(triples, key=lambda x:text.find(x['chemical']))
return sorted_triples
def build_target_seq_svo(relations, id2chem, id2disease):
answer = ""
for z in relations:
chemical = id2chem[z["chemical"]]
disease = id2disease[z["disease"]]
answer += f"{chemical} correlates with {disease}; "
return answer[:-2] + "."
def build_target_seq_relis(relations, id2chem, id2disease):
answer = ""
for z in relations:
chemical = id2chem[z["chemical"]]
disease = id2disease[z["disease"]]
answer += f"the relation between {chemical} and {disease} exists; "
return answer[:-2] + "."
def loader(fname, fn):
ret = []
null_cnt = 0
suc_cnt = 0
null_flag = False
with open(fname, "r", encoding="utf8") as fr:
data = json.load(fr)
for pmid, v in data.items():
if re.search(r"\W$", v["title"]):
content = v["title"] + " " + v["abstract"]
else:
content = v["title"] + ". " + v["abstract"]
content = content.lower()
if v["relations"] is None or len(v["relations"]) == 0:
if not null_flag:
print(f"Following PMID in {fname} has no extracted relations:")
null_flag = True
print(f"{pmid} ", end="")
null_cnt += 1
else:
chemical2id = v["chemical2id"]
disease2id = v["disease2id"]
unified_chemical2id, id2chemical = unify_ent2id(chemical2id, method='max')
unified_disease2id, id2disease = unify_ent2id(disease2id, method='max')
answer = fn(v["relations"], id2chemical, id2disease)
ret.append((pmid, content, answer))
suc_cnt += 1
if null_flag:
print("")
print(f"{len(data)} samples in {fname} has been processed with {null_cnt} samples has no relations extracted.")
return ret
def dumper(content_list, prefix):
fw_pmid = open(prefix + ".pmid", "w")
fw_content = open(prefix + ".x", "w")
fw_label = open(prefix + ".y", "w")
for ele in content_list:
print(ele[0], file=fw_pmid)
print(ele[1], file=fw_content)
print(ele[2], file=fw_label)
fw_pmid.close()
fw_content.close()
fw_label.close()
def worker(fname, prefix, fn):
ret = loader(fname, fn)
dumper(ret, prefix)
for split in ['train', 'valid', 'test']:
worker(os.path.join(f"{data_dir}", f"{split}.json"), os.path.join(f"{data_dir}", f"relis_{split}"), build_target_seq_relis)