Skip to content
This repository has been archived by the owner on Dec 18, 2023. It is now read-only.

Commit

Permalink
Remove overrides for NNC (#1488)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1488

Previously, we had to override [`simple_ts_compile` in functorch](https://github.com/pytorch/functorch/blob/7fc24aa8bf5b7b3fc912cb0636e0c066c53f44ef/functorch/_src/compilers.py#L248-L251) to use `torch.jit.trace` instead of `torch.jit.script` because of a bug in TorchScript. The issue was fixed recently and we should no longer need to override this function. More importantly, it also becomes possible to run NNC with Buck with the change in this diff.

One thing to note is that functorch 0.2.0 requires CPU-only PyTorch at the moment, whereas `pip install torch` on Linux will install PyTorch with CUDA 10.2 support. This wouldn't be an issue for most users unless they want to invoke `nnc_jit` -- in which case Linux users will have to install CPU PyTorch manually with `pip install torch --extra-index-url https://download.pytorch.org/whl/cpu`.

Reviewed By: yucenli

Differential Revision: D37026266

fbshipit-source-id: b079b20b2bc8748c6b1f57e2c7f277cb07a614a2
  • Loading branch information
horizon-blue authored and facebook-github-bot committed Jul 5, 2022
1 parent 4e093f4 commit 287a0a5
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 78 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/nightly.yml
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,4 @@ jobs:
run: pip list

- name: Run nightly tests
run: pytest -o python_files="*_nightly.py" --durations=0
run: pytest -o python_files="*_nightly.py"
6 changes: 5 additions & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,15 @@ jobs:
conda install -c conda-forge -y boost-cpp eigen=3.4.0
python -m pip install --upgrade pip
- name: Install CPU PyTorch (only for Linux)
if: matrix.os == 'ubuntu-latest'
run: pip install torch --extra-index-url https://download.pytorch.org/whl/cpu

- name: Install Bean Machine in editable mode
run: pip install -v -e .[dev]

- name: Print out package info to help with debug
run: pip list

- name: Run unit tests with pytest
run: pytest --cov=. --cov-report term-missing
run: pytest .
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
"botorch>=0.5.1",
"gpytorch>=1.3.0",
"graphviz>=0.17",
"functorch>=0.1.0; platform_system!='Windows'",
"functorch>=0.2.0",
"netCDF4<=1.5.8; python_version<'3.8'",
"numpy>=1.18.1",
"pandas>=0.24.2",
Expand Down
21 changes: 10 additions & 11 deletions src/beanmachine/ppl/experimental/nnc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import sys
import logging
from typing import Callable, Optional, Tuple, TypeVar

from typing_extensions import ParamSpec

logger = logging.getLogger(__name__)

P = ParamSpec("P")
R = TypeVar("R")

Expand All @@ -23,16 +25,13 @@ def nnc_jit(
try:
# The setup code in `nnc.utils` will only be executed once in a Python session
from beanmachine.ppl.experimental.nnc.utils import nnc_jit as raw_nnc_jit
except ImportError as e:
if sys.platform.startswith("win"):
message = "functorch is not available on Windows."
else:
message = (
"Fails to initialize NNC. This is likely caused by version mismatch "
"between PyTorch and functorch. Please checkout the functorch project "
"for installation guide (https://github.com/pytorch/functorch)."
)
raise RuntimeError(message) from e
except Exception as e:
logger.warn(
f"Fails to initialize NNC due to the following error: {str(e)}\n"
"Falling back to default inference engine."
)
# return original function without change
return f

return raw_nnc_jit(f, static_argnums)

Expand Down
59 changes: 2 additions & 57 deletions src/beanmachine/ppl/experimental/nnc/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,11 @@

import warnings

import functorch
import torch
import torch.jit
import torch.utils._pytree as pytree
from functorch.compile import (
aot_function,
decomposition_table,
nop,
register_decomposition,
)
from functorch.compile import nnc_jit


# the warning will only be shown to user once when this module is imported
warnings.warn(
Expand All @@ -27,56 +22,6 @@
# pyre-fixme[16]: Module `_C` has no attribute `_jit_set_texpr_reductions_enabled`.
torch._C._jit_set_texpr_reductions_enabled(True)

# override the usage of torch.jit.script, which has a bit of issue handling
# empty lists (functorch#440)
def simple_ts_compile(fx_g, example_inps):
f = torch.jit.trace(fx_g, example_inps, strict=False)
f = torch.jit.freeze(f.eval())
torch._C._jit_pass_remove_mutation(f.graph)

return f


# Overrides decomposition rules for some operators
aten = torch.ops.aten
decompositions = [aten.detach]
bm_decompositions = {
k: v for k, v in decomposition_table.items() if k in decompositions
}


@register_decomposition(aten.mv, bm_decompositions)
def mv(a, b):
return (a * b).sum(dim=-1)


@register_decomposition(aten.dot, bm_decompositions)
def dot(a, b):
return (a * b).sum(dim=-1)


@register_decomposition(aten.zeros_like, bm_decompositions)
def zeros_like(a, **kwargs):
return a * 0


@register_decomposition(aten.ones_like, bm_decompositions)
def ones_like(a, **kwargs):
return a * 0 + 1


def nnc_jit(f, static_argnums=None):
return aot_function(
f,
simple_ts_compile,
nop,
static_argnums=static_argnums,
decompositions=bm_decompositions,
)


functorch._src.compilers.simple_ts_compile = simple_ts_compile


# override default dict flatten (which requires keys to be sortable)
def _dict_flatten(d):
Expand Down
12 changes: 5 additions & 7 deletions src/beanmachine/ppl/experimental/tests/nnc_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,18 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import os
import sys
import warnings

import beanmachine.ppl as bm
import pytest
import torch
import torch.distributions as dist

if sys.platform.startswith("win"):
pytest.skip("functorch is not available on Windows", allow_module_level=True)

if os.environ.get("SANDCASTLE") is not None:
pytest.skip("NNC does not work with Buck yet", allow_module_level=True)
try:
import functorch # noqa
except Exception as e:
# skipping the NNC-related test if users don't have compatible functorch installed
pytest.skip(str(e), allow_module_level=True)


class SampleModel:
Expand Down

0 comments on commit 287a0a5

Please sign in to comment.