Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug in neuraldual.py with respect to pos_weights #460

Open
lucaeyring opened this issue Nov 9, 2023 · 0 comments
Open

Bug in neuraldual.py with respect to pos_weights #460

lucaeyring opened this issue Nov 9, 2023 · 0 comments
Labels
bug Something isn't working

Comments

@lucaeyring
Copy link
Contributor

In neuraldual.py, when pos_weights=False the weights of network f should be clipped and the weights of network g should be penalized in a loss. For clipping, this behaves correctly but for the loss it adds a penalization if pos_weights=True (https://github.com/ott-jax/ott/blob/main/src/ott/solvers/nn/neuraldual.py#L459), which I think should be changed to if not self.pos_weights.

Additionally, in the current implementation

  • f depicts the forward transport, and g the inverse
  • f (forward) is clipped and optimized in the outer iteration while g is optimized in the inner iteration and penalized.

This is in contrast to the original Makkuva paper (https://arxiv.org/pdf/1908.10962.pdf), where the forward transport is optimized in the inner iteration, and the weights of this network are penalized while the inverse transport is optimized in the outer iteration and the weights are clipped.

From my experience, the network trained in the inner iteration, where weights are penalized, is learned better compared to the one optimized in the outer iteration. Since most often one is mainly interested in the forward map, maybe this should be changed to how it is described in the Makkuva paper.

@michalk8 michalk8 added the bug Something isn't working label Nov 20, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants