You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
In neuraldual.py, when pos_weights=False the weights of network f should be clipped and the weights of network g should be penalized in a loss. For clipping, this behaves correctly but for the loss it adds a penalization if pos_weights=True (https://github.com/ott-jax/ott/blob/main/src/ott/solvers/nn/neuraldual.py#L459), which I think should be changed to if not self.pos_weights.
Additionally, in the current implementation
f depicts the forward transport, and g the inverse
f (forward) is clipped and optimized in the outer iteration while g is optimized in the inner iteration and penalized.
This is in contrast to the original Makkuva paper (https://arxiv.org/pdf/1908.10962.pdf), where the forward transport is optimized in the inner iteration, and the weights of this network are penalized while the inverse transport is optimized in the outer iteration and the weights are clipped.
From my experience, the network trained in the inner iteration, where weights are penalized, is learned better compared to the one optimized in the outer iteration. Since most often one is mainly interested in the forward map, maybe this should be changed to how it is described in the Makkuva paper.
The text was updated successfully, but these errors were encountered:
In
neuraldual.py
, whenpos_weights=False
the weights of networkf
should be clipped and the weights of networkg
should be penalized in a loss. For clipping, this behaves correctly but for the loss it adds a penalization ifpos_weights=True
(https://github.com/ott-jax/ott/blob/main/src/ott/solvers/nn/neuraldual.py#L459), which I think should be changed toif not self.pos_weights
.Additionally, in the current implementation
f
depicts the forward transport, andg
the inversef
(forward) is clipped and optimized in the outer iteration whileg
is optimized in the inner iteration and penalized.This is in contrast to the original Makkuva paper (https://arxiv.org/pdf/1908.10962.pdf), where the forward transport is optimized in the inner iteration, and the weights of this network are penalized while the inverse transport is optimized in the outer iteration and the weights are clipped.
From my experience, the network trained in the inner iteration, where weights are penalized, is learned better compared to the one optimized in the outer iteration. Since most often one is mainly interested in the forward map, maybe this should be changed to how it is described in the Makkuva paper.
The text was updated successfully, but these errors were encountered: