-
Notifications
You must be signed in to change notification settings - Fork 79
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor/Fix: WassersteinSolver constructor now throws TypeError when an unrecognized argument is given #579
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #579 +/- ##
==========================================
- Coverage 87.83% 87.81% -0.03%
==========================================
Files 73 73
Lines 7826 7845 +19
Branches 1127 1133 +6
==========================================
+ Hits 6874 6889 +15
- Misses 799 801 +2
- Partials 153 155 +2
|
thanks @selmanozleyen for the PR! i will defer to @michalk8 on this, but it feels that if we implement this for this particular solver, we would need to implement it for all solvers, no? What was the use case that revealed the problem? |
For linear solvers there is no need as their base class In moscot we don't want to ignore any unrecognized arguments since there are many arguments, and with some typo etc. it can lead to some well hidden bugs. Here is the PR for it:theislab/moscot#748 We use many (if not all) solvers in our case and from my tests this PR should be enough to cover the constructors for linear and quadratic solvers. I am not sure about other methods such as |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@selmanozleyen in the comment above, I think we should rather explicitly pass the linear_ot_solver
to WassersteinSolver
instead.
Lmk if you prefer to do it or not. I can also take a look at it, as there might be many places in tests/
that need a change.
.github/workflows/lint.yml
Outdated
@@ -33,8 +33,11 @@ jobs: | |||
if: ${{ matrix.lint-kind == 'code' }} | |||
with: | |||
path: ~/.cache/pre-commit | |||
key: pre-commit-${{ env.pythonLocation }}-${{ hashFiles('**/.pre-commit-config.yaml') }} | |||
|
|||
key: pre-commit-${{ runner.os }}-python-${{ env.pythonLocation }}-${{ hashFiles('**/.pre-commit-config.yaml') }} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this is necessary, as the cache key will be search on the PR's target branch if it's not on the feature branch.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
did this as a temporary solution as ci's failed. the reformatting was also because of ci for some reason. will undo this and the reformatting
@@ -11,6 +11,7 @@ | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
import inspect |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In general, I think there's a slightly better solution rather than inspecting the signature of the linear solvers,
I'd rather make linear_ot_solver
a required argument and remove the construction of the solver in __init__
altogether - this will require some changes, esp. in tests, in ott/solvers/quadratic/_solve.py
, etc.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
then this means we can also remove kwargs
right? I'd also prefer this to inspect
but didn't want to change the interface in case you had other plans.
tests/solvers/quadratic/fgw_test.py
Outdated
@@ -60,7 +59,12 @@ def test_gradient_marginals_fgw_solver(self, jit: bool): | |||
|
|||
def reg_gw(a: jnp.ndarray, b: jnp.ndarray, implicit: bool): | |||
prob = quadratic_problem.QuadraticProblem( | |||
geom_x, geom_y, geom_xy, fused_penalty=self.fused_penalty, a=a, b=b |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure why it was reformatted, but would prefer to undo.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done!
@michalk8 since the interface is going to change I think it would be better if you did it. I already resolved other pre-commit and formatting issues you mentioned |
Ok, thanks! I will then close this PR and open tomorrow a new one. |
hi @michalk8, just wanted to remind you on this. I think many test cases and stuff might have to change since the API also changes. So maybe I can help a bit |
hi,
A user can give an argument by typo or any other misunderstanding and the solver class would work without them noticing. To prevent such cases I made some modifications. I also added tests that asserts that the raises are thrown properly.
Note: I am not sure about why the linting fails, it
tox -e lint-code
passes locally for me. Note: I also modified the caching in CI's because it didn't work on my pr for some reasonRelated: theislab/moscot#748
ping: @MUCDK