-
Notifications
You must be signed in to change notification settings - Fork 162
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
refactor: Upgrade the models to use keras 3.0 #1138
Changes from all commits
318b0b0
af9f275
4491b97
70c8d85
25861c8
f2f93cf
3467d62
15ac395
cd32b7c
d5667d7
799cfe4
5db5118
b1edcec
a88593a
5916460
30c8207
062355e
4adc8e0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
global-exclude .DS_Store | ||
global-exclude */__pycache__/* | ||
|
||
include *.txt | ||
include CODEOWNERS | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -237,7 +237,8 @@ def _construct_model(self) -> None: | |
model_loc = self._parameters["model_path"] | ||
|
||
self._model: tf.keras.Model = tf.keras.models.load_model(model_loc) | ||
softmax_output_layer_name = self._model.outputs[0].name.split("/")[0] | ||
self._model = tf.keras.Model(self._model.inputs, self._model.outputs) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Required for the function to have |
||
softmax_output_layer_name = self._model.output_names[0] | ||
softmax_layer_ind = cast( | ||
int, | ||
labeler_utils.get_tf_layer_index_from_name( | ||
|
@@ -252,21 +253,28 @@ def _construct_model(self) -> None: | |
num_labels, activation="softmax", name="softmax_output" | ||
)(self._model.layers[softmax_layer_ind - 1].output) | ||
|
||
# Output the model into a .pb file for TensorFlow | ||
argmax_layer = tf.keras.backend.argmax(new_softmax_layer) | ||
# Add argmax layer to get labels directly as an output | ||
argmax_layer = tf.keras.ops.argmax(new_softmax_layer, axis=2) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. keras v3 method |
||
|
||
argmax_outputs = [new_softmax_layer, argmax_layer] | ||
self._model = tf.keras.Model(self._model.inputs, argmax_outputs) | ||
self._model = tf.keras.Model(self._model.inputs, self._model.outputs) | ||
|
||
# Compile the model w/ metrics | ||
softmax_output_layer_name = self._model.outputs[0].name.split("/")[0] | ||
softmax_output_layer_name = self._model.output_names[0] | ||
losses = {softmax_output_layer_name: "categorical_crossentropy"} | ||
|
||
# use f1 score metric | ||
f1_score_training = labeler_utils.F1Score( | ||
num_classes=num_labels, average="micro" | ||
) | ||
metrics = {softmax_output_layer_name: ["acc", f1_score_training]} | ||
metrics = { | ||
softmax_output_layer_name: [ | ||
"categorical_crossentropy", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. keras v3 requires specification of loss while v2 did not |
||
"acc", | ||
f1_score_training, | ||
] | ||
} | ||
|
||
self._model.compile(loss=losses, optimizer="adam", metrics=metrics) | ||
|
||
|
@@ -294,30 +302,33 @@ def _reconstruct_model(self) -> None: | |
num_labels = self.num_labels | ||
default_ind = self.label_mapping[self._parameters["default_label"]] | ||
|
||
# Remove the 2 output layers ('softmax', 'tf_op_layer_ArgMax') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. popping does nothing in v3 |
||
for _ in range(2): | ||
self._model.layers.pop() | ||
|
||
# Add the final Softmax layer to the previous spot | ||
# self._model.layers[-2] to skip: original softmax | ||
final_softmax_layer = tf.keras.layers.Dense( | ||
num_labels, activation="softmax", name="softmax_output" | ||
)(self._model.layers[-4].output) | ||
)(self._model.layers[-2].output) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. argmax ops does not show as a layer anymore |
||
|
||
# Output the model into a .pb file for TensorFlow | ||
argmax_layer = tf.keras.backend.argmax(final_softmax_layer) | ||
# Add argmax layer to get labels directly as an output | ||
argmax_layer = tf.keras.ops.argmax(final_softmax_layer, axis=2) | ||
|
||
argmax_outputs = [final_softmax_layer, argmax_layer] | ||
self._model = tf.keras.Model(self._model.inputs, argmax_outputs) | ||
|
||
# Compile the model | ||
softmax_output_layer_name = self._model.outputs[0].name.split("/")[0] | ||
softmax_output_layer_name = self._model.output_names[0] | ||
losses = {softmax_output_layer_name: "categorical_crossentropy"} | ||
|
||
# use f1 score metric | ||
f1_score_training = labeler_utils.F1Score( | ||
num_classes=num_labels, average="micro" | ||
) | ||
metrics = {softmax_output_layer_name: ["acc", f1_score_training]} | ||
metrics = { | ||
softmax_output_layer_name: [ | ||
"categorical_crossentropy", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. keras v3 requires specification of loss while v2 did not |
||
"acc", | ||
f1_score_training, | ||
] | ||
} | ||
|
||
self._model.compile(loss=losses, optimizer="adam", metrics=metrics) | ||
|
||
|
@@ -370,7 +381,7 @@ def fit( | |
f1_report: dict = {} | ||
|
||
self._model.reset_metrics() | ||
softmax_output_layer_name = self._model.outputs[0].name.split("/")[0] | ||
softmax_output_layer_name = self._model.output_names[0] | ||
|
||
start_time = time.time() | ||
batch_id = 0 | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice -- and post merge of this, #1090 can add 3.11