diff --git a/backend/src/packages/chaiNNer_onnx/__init__.py b/backend/src/packages/chaiNNer_onnx/__init__.py index e99d17407..a57ca7b83 100644 --- a/backend/src/packages/chaiNNer_onnx/__init__.py +++ b/backend/src/packages/chaiNNer_onnx/__init__.py @@ -2,7 +2,7 @@ from api import KB, MB, Dependency, add_package from gpu import nvidia -from system import is_arm_mac +from system import is_arm_mac, is_windows general = "ONNX uses .onnx models to upscale images." conversion = "It also helps to convert between PyTorch and NCNN." @@ -15,7 +15,7 @@ f"{general} {conversion} It is fastest when CUDA is supported. If TensorRT is" " installed on the system, it can also be configured to use that." ) - inst_hint = f"{general} It does not support AMD GPUs." + inst_hint = f"{general} It does not support AMD GPUs, in linux." def get_onnx_runtime(): @@ -28,6 +28,13 @@ def get_onnx_runtime(): import_name="onnxruntime", extra_index_url="https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/onnxruntime-cuda-12/pypi/simple/", ) + elif is_windows: + return Dependency( + display_name="ONNX Runtime (DirectMl)", + pypi_name="onnxruntime-directml", + version="1.17.1", + size_estimate=15 * MB, + ) else: return Dependency( display_name="ONNX Runtime",