(Experimental) Integrate Metal PjRt plugin #1504
Draft
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This integrates the PjRt plugin from the
jax-metal
for running on the Apple GPU. To test it, one can setclient: :mps
on EXLA backend/compiler. Since the plugin is loaded as a separate dynamic library, it can be tested without any changes to XLA (just make sure to remove thecache/
directory).Certain computations can already be run, but the plugin is still very much incomplete. This PR is a room for experimentation and is meant to track the plugin progress. I reported a number of issues upstream, comments in the code point to those. In a few places I applied workarounds as temporary solutions or just to avoid VM crashes (segfaults), those are marked with a TODO.
Issues
For tracking purposes, here is a list of the Metal plugin issues reported upstream:
Crucial
jax.lax.while_loop
jax-ml/jax#21552Not implemented
Edge cases
jax.lax.dynamic_update_slice
jax-ml/jax#21392All issues: link.
Note: this PR is against the jk-s32 branch, which changes the default integer precision to 32 bits. This is a planned change (#1491), but it's not integrated yet to avoid conflicts with other work in progress.