Skip to content

Commit

Permalink
Updated random state
Browse files Browse the repository at this point in the history
  • Loading branch information
dmmiller612 committed Oct 25, 2019
1 parent 7295415 commit 47a402c
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 10 deletions.
10 changes: 6 additions & 4 deletions summarizer/ClusterFeatures.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ def __init__(
self,
features: ndarray,
algorithm: str = 'kmeans',
pca_k: int = None
pca_k: int = None,
random_state: int = 12345
):

if pca_k:
Expand All @@ -22,11 +23,12 @@ def __init__(

self.algorithm = algorithm
self.pca_k = pca_k
self.random_state = random_state

def __get_model(self, k: int, random_state=12345):
def __get_model(self, k: int):
if self.algorithm == 'gmm':
return GaussianMixture(n_components=k, random_state=random_state)
return KMeans(n_clusters=k, random_state=random_state)
return GaussianMixture(n_components=k, random_state=self.random_state)
return KMeans(n_clusters=k, random_state=self.random_state)

def __get_centroids(self, model):
if self.algorithm == 'gmm':
Expand Down
19 changes: 13 additions & 6 deletions summarizer/model_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from abc import abstractmethod
import neuralcoref
from spacy.lang.en import English
import numpy as np


class ModelProcessor(object):
Expand All @@ -14,12 +15,15 @@ def __init__(
hidden: int=-2,
reduce_option: str = 'mean',
greedyness: float=0.45,
language=English
language=English,
random_state: int = 12345
):
np.random.seed(random_state)
self.model = BertParent(model)
self.hidden = hidden
self.reduce_option = reduce_option
self.nlp = language()
self.random_state = random_state
self.nlp.add_pipe(self.nlp.create_pipe('sentencizer'))
neuralcoref.add_to_pipe(self.nlp, greedyness=greedyness)

Expand Down Expand Up @@ -64,13 +68,15 @@ def __init__(
hidden: int=-2,
reduce_option: str = 'mean',
greedyness: float=0.45,
language=English
language=English,
random_state: int=12345
):
super(SingleModel, self).__init__(model, hidden, reduce_option, greedyness, language=language)
super(SingleModel, self).__init__(model, hidden, reduce_option,
greedyness, language=language, random_state=random_state)

def run_clusters(self, content: List[str], ratio=0.2, algorithm='kmeans', use_first: bool= True) -> List[str]:
hidden = self.model(content, self.hidden, self.reduce_option)
hidden_args = ClusterFeatures(hidden, algorithm).cluster(ratio)
hidden_args = ClusterFeatures(hidden, algorithm, random_state=self.random_state).cluster(ratio)

if use_first:
if hidden_args[0] != 0:
Expand All @@ -87,6 +93,7 @@ def __init__(
hidden: int=-2,
reduce_option: str = 'mean',
greedyness: float=0.45,
language=English
language=English,
random_state: int=12345
):
super(Summarizer, self).__init__(model, hidden, reduce_option, greedyness, language)
super(Summarizer, self).__init__(model, hidden, reduce_option, greedyness, language, random_state)

0 comments on commit 47a402c

Please sign in to comment.