diff --git a/SentEval/senteval/sts.py b/SentEval/senteval/sts.py index a4df049..8c0e415 100644 --- a/SentEval/senteval/sts.py +++ b/SentEval/senteval/sts.py @@ -39,8 +39,8 @@ def loadFile(self, fpath): not_empty_idx = raw_scores != '' gs_scores = [float(x) for x in raw_scores[not_empty_idx]] - sent1 = np.array([s.split() for s in sent1])[not_empty_idx] - sent2 = np.array([s.split() for s in sent2])[not_empty_idx] + sent1 = np.array([s.split() for s in sent1], dtype=object)[not_empty_idx] + sent2 = np.array([s.split() for s in sent2], dtype=object)[not_empty_idx] # sort data by length to minimize padding in batcher sorted_data = sorted(zip(sent1, sent2, gs_scores), key=lambda z: (len(z[0]), len(z[1]), z[2]))