-
Notifications
You must be signed in to change notification settings - Fork 3
/
lm.py
45 lines (33 loc) · 1.25 KB
/
lm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
class LanguageModel:
'''
Abstraction for prefix language models
p(sentence) = p(next token|prefix) p(prefix)
'''
def __init__(self, vocabulary, SOS_ind, EOS_ind):
self.SOS_ind = SOS_ind
self.EOS_ind = EOS_ind
# Mapping from indices to tokens
self.vocabulary = vocabulary
# Inverse map from tokens to indices
self.vocabulary_index = {token: ind for ind, token in enumerate(self.vocabulary)}
self.vocabulary_size = len(self.vocabulary)
def p_next_token(self, prefix):
'''Returns the distribution over the next token given the prefix
represented by a list of indices.'''
raise NotImplementedError()
def perplexity(self, sentence):
'''Returns -log p(sentence)'''
raise NotImplementedError()
class NgramLanguageModel(LanguageModel):
def __init__(self, vocabulary, SOS_ind, EOS_ind):
super().__init__(vocabulary, SOS_ind, EOS_ind)
def p_next_token(self, prefix):
pass
class StatefulLanguageModel(LanguageModel):
def __init__(self, vocabulary, SOS_ind, EOS_ind):
self.prefix_to_state = {}
class GptLanguageModel(StatefulLanguageModel):
'''GPT-2 as a language model.'''
pass
if __name__ == '__main__':
pass