Skip to content

Commit

Permalink
Merge pull request #1265 from dmorrill10:fix-rcfr-for-keras-3
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 665410450
Change-Id: I03d29c0f92429bc0a9329ad869bd4668d27713e2
  • Loading branch information
lanctot committed Aug 20, 2024
2 parents 0be2e9a + 9164f75 commit 5a1f76f
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 19 deletions.
3 changes: 1 addition & 2 deletions open_spiel/python/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -311,8 +311,7 @@ if (OPEN_SPIEL_ENABLE_TENSORFLOW)
algorithms/nfsp_test.py
algorithms/policy_gradient_test.py
algorithms/psro_v2/strategy_selectors_test.py
# Broken in Python 3.12. Must port to Keras 3. https://github.com/google-deepmind/open_spiel/issues/1207.
# algorithms/rcfr_test.py
algorithms/rcfr_test.py
)
if (OPEN_SPIEL_ENABLE_PYTHON_MISC)
set(PYTHON_TESTS ${PYTHON_TESTS}
Expand Down
41 changes: 28 additions & 13 deletions open_spiel/python/algorithms/rcfr_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ def _new_model():
_GAME,
num_hidden_layers=1,
num_hidden_units=13,
num_hidden_factors=1,
use_skip_connections=True)


Expand Down Expand Up @@ -476,12 +475,18 @@ def test_rcfr_functions(self):
data = data.batch(12)
data = data.repeat(num_epochs)

optimizer = tf.keras.optimizers.Adam(lr=0.005, amsgrad=True)
optimizer = tf.keras.optimizers.Adam(learning_rate=0.005, amsgrad=True)

model = models[regret_player]
for x, y in data:
optimizer.minimize(
lambda: tf.losses.huber_loss(y, models[regret_player](x)), # pylint: disable=cell-var-from-loop
models[regret_player].trainable_variables)
with tf.GradientTape() as tape:
loss = tf.losses.huber_loss(y, model(x))
optimizer.apply_gradients(
zip(
tape.gradient(loss, model.trainable_variables),
model.trainable_variables,
)
)

regret_player = reach_weights_player

Expand All @@ -504,12 +509,17 @@ def _train(model, data):
data = data.batch(12)
data = data.repeat(num_epochs)

optimizer = tf.keras.optimizers.Adam(lr=0.005, amsgrad=True)
optimizer = tf.keras.optimizers.Adam(learning_rate=0.005, amsgrad=True)

for x, y in data:
optimizer.minimize(
lambda: tf.losses.huber_loss(y, model(x)), # pylint: disable=cell-var-from-loop
model.trainable_variables)
with tf.GradientTape() as tape:
loss = tf.losses.huber_loss(y, model(x))
optimizer.apply_gradients(
zip(
tape.gradient(loss, model.trainable_variables),
model.trainable_variables,
)
)

average_policy = patient.average_policy()
self.assertGreater(pyspiel.nash_conv(_GAME, average_policy), 0.91)
Expand Down Expand Up @@ -565,12 +575,17 @@ def _train(model, data):
data = data.batch(12)
data = data.repeat(num_epochs)

optimizer = tf.keras.optimizers.Adam(lr=0.005, amsgrad=True)
optimizer = tf.keras.optimizers.Adam(learning_rate=0.005, amsgrad=True)

for x, y in data:
optimizer.minimize(
lambda: tf.losses.huber_loss(y, model(x)), # pylint: disable=cell-var-from-loop
model.trainable_variables)
with tf.GradientTape() as tape:
loss = tf.losses.huber_loss(y, model(x))
optimizer.apply_gradients(
zip(
tape.gradient(loss, model.trainable_variables),
model.trainable_variables,
)
)

average_policy = patient.average_policy()
self.assertGreater(pyspiel.nash_conv(_GAME, average_policy), 0.91)
Expand Down
15 changes: 11 additions & 4 deletions open_spiel/python/examples/rcfr_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,14 +87,21 @@ def _train_fn(model, data):
data = data.batch(FLAGS.batch_size)
data = data.repeat(FLAGS.num_epochs)

optimizer = tf.keras.optimizers.Adam(lr=FLAGS.step_size, amsgrad=True)
optimizer = tf.keras.optimizers.Adam(
learning_rate=FLAGS.step_size, amsgrad=True
)

@tf.function
def _train():
for x, y in data:
optimizer.minimize(
lambda: tf.losses.huber_loss(y, model(x), delta=0.01), # pylint: disable=cell-var-from-loop
model.trainable_variables)
with tf.GradientTape() as tape:
loss = tf.losses.huber_loss(y, model(x), delta=0.01)
optimizer.apply_gradients(
zip(
tape.gradient(loss, model.trainable_variables),
model.trainable_variables,
)
)

_train()

Expand Down

0 comments on commit 5a1f76f

Please sign in to comment.