diff --git a/docs/nncase_studio.md b/docs/nncase_studio.md new file mode 100644 index 0000000000..1c24b1b649 --- /dev/null +++ b/docs/nncase_studio.md @@ -0,0 +1,129 @@ +# 下载 + +nncase studio的版本和nncase的版本是一致的,后续studio的更新都会放在每个版本的release的asset中,如果需要下载最新的可以通过这里看到最新的release版本,在assets中找到对应的下载链接 + +以2.6.0为例 + +![Untitled](nncase_studio_asset/Untitled.png) + +你可以在这个链接找到最新的release的asset + +https://github.com/kendryte/nncase/releases + +下载压缩包后进行解压, 找到文件夹中的Nncase.Studio文件,双击或者在当前目录下的命令行输入./Nncase.Studio即可启动 + +![Untitled](nncase_studio_asset/Untitled%201.png) + +# 编译配置 + +启动nncase studio后,会看到如下初始界面。 + +![Untitled](nncase_studio_asset/Untitled%202.png) + +界面中有三个主要的按钮 + +1. 当第一次打开的时候需要直接选择**常规编译** +2. 如果已经有了kmodel和输入文件,那么只需要**直接推理**即可 +3. 如果之前下方的导出编译配置导出过配置文件,那么可以通过**导入配置**按钮加载以往的配置。整个过程中随时可以通过下方的导出编译配置保存配置文件,下次即可在这里直接通过配置文件加载所有已修改的条目。 + +这里我们是第一次编译,因此单击常规编译,进入到下一个页面 + +![Untitled](nncase_studio_asset/Untitled%203.png) + +目前nncase支持ncnn,onnx,tflite三种模型的格式,我们在这里点击import,即可打开窗口选择对应的模型,点击Open即可导入对应的模型文件 + +![Untitled](nncase_studio_asset/Untitled%204.png) + +之后会自动进入到一些基础编译选项的配置页面。这个页面以及之后页面的配置填写的方式几乎都与命令行相同,比如说字符串的格式,或者字段具体的含义后面不再每次都重复说明,如有疑问可以参考以下链接 + +[https://github.com/kendryte/nncase/blob/master/docs/USAGE_v2.md](https://github.com/kendryte/nncase/blob/master/docs/USAGE_v2.md) + +![Untitled](nncase_studio_asset/Untitled%205.png) + +如果要配置dump选项,那么单击红框中的按钮,在显示的下拉菜单中选中所需要dump的信息,可以进行多选,不同的选项包含了不同阶段生成的信息,根据情况选择,如果你不需要查看生成的信息,也可以忽略这个选项,不影响正常的流程。 + +![Untitled](nncase_studio_asset/Untitled%206.png) + +dump路径默认会设置为当前可执行程序所在的位置下的nncase_dump文件夹,比如我这里的可执行程序的路径在/home/homura/Downloads/nncase-studio/,对应的dump路径为/home/homura/Downloads/nncase-studio/nncase-dump,可自行修改为自己想要的路径 + +根据需求设置开启 前后处理 / 量化 / ShapeBucket,这几个选项会在后面显示不同的页面,如果不开启选项则会跳过对应页面。设置好之后点击下一步进入下一个页面。 + +# 前后处理 + +如果你开启了前后处理,那么就会进入到以下这个界面 + +![Untitled](nncase_studio_asset/Untitled%207.png) + +layout和命令行的输入是一致的,支持NCHW的格式,也同样数字0231的格式。 + +InputShape则是用逗号隔开的维度信息。 + +配置好对应参数后即可点击下一步切换到下一个页面。 + +关于前处理的顺序,可以通过单击显示前处理顺序来看到整个的流程,目前只支持固定的前后处理。 + +![Untitled](nncase_studio_asset/Untitled%208.png) + +# 量化 + +## 非混合量化 + +如果开启量化,那么前后处理后下一个界面是量化的配置。如果不使用混合量化,则会进入当前这个界面。 + +![Untitled](nncase_studio_asset/Untitled%209.png) + +首先是数选择数据集所在的文件夹。输入的数据需要有datatype和shape的信息,因此数据集中目前仅支持npy格式的输入文件。 + +另外放入数据集中的矫正集需要按照特定的文件名格式,用于正确的解析文件。 + +格式为”第几组数据集_第几个输入_文件名.bin“的格式,多组文件都放在同一个文件夹之中,例如下图中 + +![Untitled](nncase_studio_asset/Untitled%2010.png) + +## 混合量化 + +如果使用混合量化,那么会进入这样的界面。 + +![Untitled](nncase_studio_asset/Untitled%2011.png) + +混合量化的配置基本上是通过nncase导出的,但也有来自内部其他工具导出的配置,这里只需要选择对应的json文件即可。 + +量化相关的配置设置好后单击下一步。 + +# ShapeBucket + +![Untitled](nncase_studio_asset/Untitled%2012.png) + +在ShapeBucket中和在命令行中填写的信息类似,只不过这里要转成字符串的形式。FixVarMap和VarRangeInfo都通过逗号隔开多个参数,其中每一个参数的具体写法如图中的浅色字体所示。 + +当一切都填写好后即可点击下一步。 + +# Compile + +![Untitled](nncase_studio_asset/Untitled%2013.png) + +默认的kmodel路径会在之前填写的dump目录下,也可以自行输入一个自己想要的路径,设置好kmodel路径后就可以开始编译了,开始编译后如果有什么问题随时可以通过停止按钮立刻结束编译的过程。如果没什么问题,在编译完成后会弹出窗口提示已完成会显示最后生成的kmodel路径,可选中进行复制对应的路径,另外会自动跳转到下一个页面。 + +![Untitled](nncase_studio_asset/Untitled%2014.png) + +# Simulate + +![Untitled](nncase_studio_asset/Untitled%2015.png) + +进行推理之前需要配置输入文件,以及结果存储的位置(默认会存储到dump的文件夹下)。选择输入文件的时候需要配置文件名的格式以编号下划线开头,否则会导致无法对应到正确的输入中。 + +点击这个按钮可以看到当前已经添加的输入文件 + +![Untitled](nncase_studio_asset/Untitled%2016.png) + +设置好所有的路径后即可开始推理,目前只支持在simulator上进行推理。点击开始推理后可能需要等一会,simulator会模拟硬件的行为,不可能像上板一样秒出结果,在推理的过程会有滚动的进度条以及计时。 + +![Untitled](nncase_studio_asset/Untitled%2017.png) + +推理成功后会弹出这样的一个窗口,此时所有的输出已经写入到对应的结果保存的路径中,文件名为nncase_result_*.npy + +![Untitled](nncase_studio_asset/Untitled%2018.png) + +对应的dump文件夹中即可看到所有产生的文件,包括dump的文件,kmodel以及生成的结果文件。 + +![Untitled](nncase_studio_asset/Untitled%2019.png) diff --git a/docs/nncase_studio_asset/Untitled 1.png b/docs/nncase_studio_asset/Untitled 1.png new file mode 100644 index 0000000000..e4609d85b3 Binary files /dev/null and b/docs/nncase_studio_asset/Untitled 1.png differ diff --git a/docs/nncase_studio_asset/Untitled 10.png b/docs/nncase_studio_asset/Untitled 10.png new file mode 100644 index 0000000000..5ad3b88ef9 Binary files /dev/null and b/docs/nncase_studio_asset/Untitled 10.png differ diff --git a/docs/nncase_studio_asset/Untitled 11.png b/docs/nncase_studio_asset/Untitled 11.png new file mode 100644 index 0000000000..f44cb75636 Binary files /dev/null and b/docs/nncase_studio_asset/Untitled 11.png differ diff --git a/docs/nncase_studio_asset/Untitled 12.png b/docs/nncase_studio_asset/Untitled 12.png new file mode 100644 index 0000000000..a364689cd4 Binary files /dev/null and b/docs/nncase_studio_asset/Untitled 12.png differ diff --git a/docs/nncase_studio_asset/Untitled 13.png b/docs/nncase_studio_asset/Untitled 13.png new file mode 100644 index 0000000000..ba35c6ae78 Binary files /dev/null and b/docs/nncase_studio_asset/Untitled 13.png differ diff --git a/docs/nncase_studio_asset/Untitled 14.png b/docs/nncase_studio_asset/Untitled 14.png new file mode 100644 index 0000000000..3bf1055475 Binary files /dev/null and b/docs/nncase_studio_asset/Untitled 14.png differ diff --git a/docs/nncase_studio_asset/Untitled 15.png b/docs/nncase_studio_asset/Untitled 15.png new file mode 100644 index 0000000000..0312d86186 Binary files /dev/null and b/docs/nncase_studio_asset/Untitled 15.png differ diff --git a/docs/nncase_studio_asset/Untitled 16.png b/docs/nncase_studio_asset/Untitled 16.png new file mode 100644 index 0000000000..e4317981ce Binary files /dev/null and b/docs/nncase_studio_asset/Untitled 16.png differ diff --git a/docs/nncase_studio_asset/Untitled 17.png b/docs/nncase_studio_asset/Untitled 17.png new file mode 100644 index 0000000000..37f2ed7928 Binary files /dev/null and b/docs/nncase_studio_asset/Untitled 17.png differ diff --git a/docs/nncase_studio_asset/Untitled 18.png b/docs/nncase_studio_asset/Untitled 18.png new file mode 100644 index 0000000000..ba7a52d0a7 Binary files /dev/null and b/docs/nncase_studio_asset/Untitled 18.png differ diff --git a/docs/nncase_studio_asset/Untitled 19.png b/docs/nncase_studio_asset/Untitled 19.png new file mode 100644 index 0000000000..407dbac628 Binary files /dev/null and b/docs/nncase_studio_asset/Untitled 19.png differ diff --git a/docs/nncase_studio_asset/Untitled 2.png b/docs/nncase_studio_asset/Untitled 2.png new file mode 100644 index 0000000000..acc629af83 Binary files /dev/null and b/docs/nncase_studio_asset/Untitled 2.png differ diff --git a/docs/nncase_studio_asset/Untitled 3.png b/docs/nncase_studio_asset/Untitled 3.png new file mode 100644 index 0000000000..8219915d0e Binary files /dev/null and b/docs/nncase_studio_asset/Untitled 3.png differ diff --git a/docs/nncase_studio_asset/Untitled 4.png b/docs/nncase_studio_asset/Untitled 4.png new file mode 100644 index 0000000000..b2e3517d8f Binary files /dev/null and b/docs/nncase_studio_asset/Untitled 4.png differ diff --git a/docs/nncase_studio_asset/Untitled 5.png b/docs/nncase_studio_asset/Untitled 5.png new file mode 100644 index 0000000000..30a3e26c45 Binary files /dev/null and b/docs/nncase_studio_asset/Untitled 5.png differ diff --git a/docs/nncase_studio_asset/Untitled 6.png b/docs/nncase_studio_asset/Untitled 6.png new file mode 100644 index 0000000000..01ac3398f7 Binary files /dev/null and b/docs/nncase_studio_asset/Untitled 6.png differ diff --git a/docs/nncase_studio_asset/Untitled 7.png b/docs/nncase_studio_asset/Untitled 7.png new file mode 100644 index 0000000000..28c0984dbe Binary files /dev/null and b/docs/nncase_studio_asset/Untitled 7.png differ diff --git a/docs/nncase_studio_asset/Untitled 8.png b/docs/nncase_studio_asset/Untitled 8.png new file mode 100644 index 0000000000..23909db05c Binary files /dev/null and b/docs/nncase_studio_asset/Untitled 8.png differ diff --git a/docs/nncase_studio_asset/Untitled 9.png b/docs/nncase_studio_asset/Untitled 9.png new file mode 100644 index 0000000000..1403169e3d Binary files /dev/null and b/docs/nncase_studio_asset/Untitled 9.png differ diff --git a/docs/nncase_studio_asset/Untitled.png b/docs/nncase_studio_asset/Untitled.png new file mode 100644 index 0000000000..407b169aee Binary files /dev/null and b/docs/nncase_studio_asset/Untitled.png differ diff --git a/src/Native/include/nncase/compiler.h b/src/Native/include/nncase/compiler.h index 7339d0b546..1ef12f990d 100644 --- a/src/Native/include/nncase/compiler.h +++ b/src/Native/include/nncase/compiler.h @@ -109,6 +109,7 @@ typedef struct { clr_object_handle_t (*calibration_dataset_provider_create)( clr_object_handle_t dataset, size_t samplesCount, clr_object_handle_t fn_params); + void (*handle_dispose)(clr_object_handle_t handle); void (*handle_free)(clr_object_handle_t handle); clr_object_handle_t (*compile_options_create)(); void (*compile_options_set_input_file)(clr_object_handle_t compile_options, @@ -478,6 +479,8 @@ class cstream : public clr_object_base { cstream(const nncase_stream_mt_t *mt, void *handle) { obj_ = nncase_clr_api()->stream_create(mt, handle); } + + ~cstream() { nncase_clr_api()->handle_dispose(obj_.get()); } }; class compile_options : public clr_object_base { diff --git a/src/Native/include/nncase/runtime/allocator.h b/src/Native/include/nncase/runtime/allocator.h index 295c170597..b81353a72b 100644 --- a/src/Native/include/nncase/runtime/allocator.h +++ b/src/Native/include/nncase/runtime/allocator.h @@ -43,6 +43,7 @@ class NNCASE_API buffer_allocator { const buffer_attach_options &options) = 0; static buffer_allocator &host(); + virtual void shrink_memory_pool() = 0; }; END_NS_NNCASE_RUNTIME diff --git a/src/Native/include/nncase/runtime/util.h b/src/Native/include/nncase/runtime/util.h index 68bd808025..056b8fa98a 100644 --- a/src/Native/include/nncase/runtime/util.h +++ b/src/Native/include/nncase/runtime/util.h @@ -587,4 +587,8 @@ inline dims_t to_4d(dims_t in_a_shape) { return in_a_shape; } +inline void shrink_memory_pool() { + buffer_allocator::host().shrink_memory_pool(); +} + END_NS_NNCASE_RUNTIME \ No newline at end of file diff --git a/src/Native/src/runtime/host_allocator.cpp b/src/Native/src/runtime/host_allocator.cpp index 7ab12092da..00468afd0d 100644 --- a/src/Native/src/runtime/host_allocator.cpp +++ b/src/Native/src/runtime/host_allocator.cpp @@ -122,6 +122,8 @@ class host_buffer_allocator : public buffer_allocator { []([[maybe_unused]] gsl::byte *p) {}, paddr, *this, host_sync_status_t::valid)); } + + void shrink_memory_pool() override {} }; host_buffer_allocator host_allocator; diff --git a/src/Nncase.Compiler/Interop/CApi.cs b/src/Nncase.Compiler/Interop/CApi.cs index 69bacc8397..2c0451d14e 100644 --- a/src/Nncase.Compiler/Interop/CApi.cs +++ b/src/Nncase.Compiler/Interop/CApi.cs @@ -41,6 +41,7 @@ public unsafe struct CApiMT public delegate* unmanaged ArrayGetItemPtr; public delegate* unmanaged ArrayGetLengthPtr; public delegate* unmanaged CalibrationDatasetProviderCreatePtr; + public delegate* unmanaged ClrHandleDisposePtr; public delegate* unmanaged ClrHandleFreePtr; public delegate* unmanaged CompileOptionsCreatePtr; public delegate* unmanaged CompileOptionsSetInputFilePtr; @@ -112,6 +113,7 @@ public static void Initialize(CApiMT* mt) mt->ArrayGetItemPtr = &ArrayGetItem; mt->ArrayGetLengthPtr = &ArrayGetLength; mt->CalibrationDatasetProviderCreatePtr = &CalibrationDatasetProviderCreate; + mt->ClrHandleDisposePtr = &ClrHandleDispose; mt->ClrHandleFreePtr = &ClrHandleFree; mt->CompileOptionsCreatePtr = &CompileOptionsCreate; mt->CompileOptionsSetInputFilePtr = &CompileOptionsSetInputFile; @@ -225,6 +227,12 @@ private static IntPtr CalibrationDatasetProviderCreate(IntPtr datasetHandle, nui return GCHandle.ToIntPtr(GCHandle.Alloc(new CCalibrationDatasetProvider(samples, (int)samplesCount))); } + [UnmanagedCallersOnly] + private static void ClrHandleDispose(IntPtr handle) + { + Get(handle).Dispose(); + } + [UnmanagedCallersOnly] private static void ClrHandleFree(IntPtr handle) { diff --git a/src/Nncase.Core/CompilerServices.cs b/src/Nncase.Core/CompilerServices.cs index 9c666bcd21..2f731a7fdf 100644 --- a/src/Nncase.Core/CompilerServices.cs +++ b/src/Nncase.Core/CompilerServices.cs @@ -208,6 +208,15 @@ public interface ICompilerServicesProvider /// Options. /// Rewrited expression. Expr ERewrite(Expr expr, IEnumerable rules, RunPassContext options); + + /// + /// Using EGraph rewrite expression. + /// + /// Expression. + /// Rewrite rules. + /// Options. + /// Rewrited expression. + IEGraph ERewrite(IEGraph expr, IEnumerable rules, RunPassContext options); } internal interface ICompilerServicesProviderInternal @@ -409,6 +418,18 @@ public static Expr ERewrite(Expr expr, IEnumerable rules, RunPassC return Provider.ERewrite(expr, rules, options); } + /// + /// Using EGraph rewrite expression. + /// + /// Expression. + /// Rewrite rules. + /// Options. + /// Rewrited expression. + public static IEGraph ERewrite(IEGraph graph, IEnumerable rules, RunPassContext options) + { + return Provider.ERewrite(graph, rules, options); + } + /// /// Match enodes as root. /// @@ -677,4 +698,9 @@ public Expr ERewrite(Expr expr, IEnumerable rules, RunPassContext { return _eGraphrewriteProvider.ERewrite(expr, rules, options); } + + public IEGraph ERewrite(IEGraph graph, IEnumerable rules, RunPassContext options) + { + return _eGraphrewriteProvider.ERewrite(graph, rules, options); + } } diff --git a/src/Nncase.Core/Enum/BinaryOp.cs b/src/Nncase.Core/Enum/BinaryOp.cs index 45afba68cc..fa4092fa48 100644 --- a/src/Nncase.Core/Enum/BinaryOp.cs +++ b/src/Nncase.Core/Enum/BinaryOp.cs @@ -93,4 +93,14 @@ public enum BinaryOp : byte /// Right Shift. /// RightShift, + + /// + /// Floor Div. + /// + FloorDiv, + + /// + /// Ceil Div. + /// + CeilDiv, } diff --git a/src/Nncase.Core/IR/Buffers/Allocate.cs b/src/Nncase.Core/IR/Buffers/Allocate.cs index 14b44010ec..ff7bdd13c0 100644 --- a/src/Nncase.Core/IR/Buffers/Allocate.cs +++ b/src/Nncase.Core/IR/Buffers/Allocate.cs @@ -13,5 +13,20 @@ namespace Nncase.IR.Buffers; /// public sealed partial class Allocate : Op { - public TensorType ElemType { get; } + /// + /// Get the input parameter. + /// + public static readonly ParameterInfo Size = new(typeof(Allocate), 0, "size", TypePatternUtility.IsIntegralScalar()); + + /// + /// Gets the alloacted buffer type. + /// + public DataType ElemType { get; } + + public TIR.MemoryLocation Location { get; } + + /// + public override bool CanFoldConstCall => false; + + public override string DisplayProperty() => $"{ElemType}, {Location}"; } diff --git a/src/Nncase.Core/IR/Buffers/Functional.cs b/src/Nncase.Core/IR/Buffers/Functional.cs index a2e3507a5f..463c4f1e2c 100644 --- a/src/Nncase.Core/IR/Buffers/Functional.cs +++ b/src/Nncase.Core/IR/Buffers/Functional.cs @@ -42,4 +42,6 @@ public static Call BaseMentOf(Expr input) => /// create the uninitialized buffer. /// public static Call Uninitialized(DataType dataType, TIR.MemoryLocation memoryLocation, Expr shape) => new Call(new Uninitialized(dataType, memoryLocation), shape); + + public static Call Allocate(Expr size, DataType dataType, TIR.MemoryLocation location) => new Call(new Allocate(dataType, location), size); } diff --git a/src/Nncase.Core/IR/TensorConst.cs b/src/Nncase.Core/IR/TensorConst.cs index 9e651978ed..07dccfbcc3 100644 --- a/src/Nncase.Core/IR/TensorConst.cs +++ b/src/Nncase.Core/IR/TensorConst.cs @@ -20,12 +20,18 @@ public TensorConst(Tensor tensor) Value = tensor; } + public TensorConst(Tensor tensor, IRArray ndsbp, Placement placement) + : base(new DistributedType(new TensorType(tensor.ElementType, tensor.Shape), ndsbp, placement)) + { + Value = tensor; + } + public Tensor Value { get; } /// /// Gets value type. /// - public new TensorType ValueType => (TensorType)base.ValueType; + public new IRType ValueType => base.ValueType; /// /// Create TensorConstant from a . @@ -122,25 +128,43 @@ public TensorConst(Tensor tensor) public static bool operator !=(TensorConst? left, TensorConst? right) => !(left == right); /// - public override string ToString() => ValueType switch + public override string ToString() { - var x when x.IsScalar => - x.DType switch - { - var dtype when DataTypes.IsIntegral(dtype) => Value.ToScalar().ToString(), - var dtype when DataTypes.IsFloat(dtype) => Value.ToScalar().ToString(), - var dtype when DataTypes.IsPointer(dtype) => Value.ToScalar().ToString(), - var dtype when dtype == DataTypes.Boolean => Value.ToScalar().ToString(), - _ => $"{x.DType.GetDisplayName()} {x.Shape}", - }, - _ => $"{ValueType.DType.GetDisplayName()} {ValueType.Shape}", - }; + var type = ValueType switch + { + DistributedType dt => dt.TensorType, + TensorType tt => tt, + _ => throw new NotSupportedException("Not supported const type: " + ValueType), + }; + + return type switch + { + var x when x.IsScalar => + x.DType switch + { + var dtype when DataTypes.IsIntegral(dtype) => Value.ToScalar().ToString(), + var dtype when DataTypes.IsFloat(dtype) => Value.ToScalar().ToString(), + var dtype when DataTypes.IsPointer(dtype) => Value.ToScalar().ToString(), + var dtype when dtype == DataTypes.Boolean => Value.ToScalar().ToString(), + _ => $"{x.DType.GetDisplayName()} {x.Shape}", + }, + _ => $"{type.DType.GetDisplayName()} {type.Shape}", + }; + } /// public override TExprResult Accept(ExprFunctor functor, TContext context) => functor.VisitTensorConst(this, context); - public TensorConst With(Tensor? value = null) => new TensorConst(value ?? Value); + public TensorConst With(Tensor? value = null) + { + if (value is null && ValueType is DistributedType dt) + { + return new TensorConst(Value, dt.NdSBP, dt.Placement); + } + + return new TensorConst(value ?? Value); + } /// public override bool Equals(object? obj) => Equals(obj as TensorConst); diff --git a/src/Nncase.Core/IR/Tensors/Where.cs b/src/Nncase.Core/IR/Tensors/Where.cs index 5506b6afea..d64e2472ad 100644 --- a/src/Nncase.Core/IR/Tensors/Where.cs +++ b/src/Nncase.Core/IR/Tensors/Where.cs @@ -21,17 +21,22 @@ public sealed partial class Where : Op /// /// Gets condition. /// - public static readonly ParameterInfo Cond = new(typeof(Where), 0, "cond"); + public static readonly ParameterInfo Cond = new(typeof(Where), 0, "cond", ParameterKind.Input); /// /// Gets x. /// - public static readonly ParameterInfo X = new(typeof(Where), 1, "x"); + public static readonly ParameterInfo X = new(typeof(Where), 1, "x", ParameterKind.Input); /// /// Gets y. /// - public static readonly ParameterInfo Y = new(typeof(Where), 2, "y"); + public static readonly ParameterInfo Y = new(typeof(Where), 2, "y", ParameterKind.Input); public bool IsTfWhere { get; } + + public override string DisplayProperty() + { + return $"IsTfWhere: {IsTfWhere}"; + } } diff --git a/src/Nncase.Core/Passes/Mutators/UnFoldBlock.cs b/src/Nncase.Core/Passes/Mutators/UnFoldBlock.cs index 7559954a49..1a64b55917 100644 --- a/src/Nncase.Core/Passes/Mutators/UnFoldBlock.cs +++ b/src/Nncase.Core/Passes/Mutators/UnFoldBlock.cs @@ -23,6 +23,37 @@ protected override Expr RewriteLeafBlock(Block expr) { if (predicate) { + if (expr.AllocBuffers.Length > 0) + { + var lets = expr.AllocBuffers.ToArray().Select(b => (T.Let(out var v, b.MemSpan.Start, b.Name + "_ptr"), v)).ToArray(); + for (int i = 0; i < lets.Length - 1; i++) + { + lets[i].Item1.Body(lets[i + 1].Item1); + } + + var map = new Dictionary(ReferenceEqualityComparer.Instance); + for (int i = 0; i < expr.AllocBuffers.Length; i++) + { + map.Add(expr.AllocBuffers[i].MemSpan.Start, lets[i].v); + } + + var mutator = new Substitutor(e => + { + if (map.TryGetValue(e, out var r)) + { + return r; + } + + return null; + }); + + var initBody = mutator.Visit(expr.InitBody, Unit.Default); + var body = mutator.Visit(expr.Body, Unit.Default); + + lets[^1].Item1.Body(initBody, body); + return lets[0].Item1.Build(); + } + return T.Sequential(expr.InitBody, expr.Body); } else diff --git a/src/Nncase.Core/PatternMatch/ConstPattern.cs b/src/Nncase.Core/PatternMatch/ConstPattern.cs index cb9a52aebe..bca5ec63a1 100644 --- a/src/Nncase.Core/PatternMatch/ConstPattern.cs +++ b/src/Nncase.Core/PatternMatch/ConstPattern.cs @@ -66,9 +66,9 @@ public static partial class Utility public static TensorConstPattern IsConst(string? name, Func cond) => new( x => { - if (DataTypes.IsFloat(x.ValueType.DType)) + if (DataTypes.IsFloat(x.CheckedDataType)) { - if (x.ValueType.IsScalar) + if (x.CheckedShape.IsScalar) { return cond(x.Value.ToScalar()); } @@ -93,9 +93,9 @@ public static partial class Utility public static TensorConstPattern IsConst(string? name, Func cond) => new( x => { - if (DataTypes.IsIntegral(x.ValueType.DType)) + if (DataTypes.IsIntegral(x.CheckedDataType)) { - if (x.ValueType.IsScalar) + if (x.CheckedShape.IsScalar) { return cond(x.Value.ToScalar()); } diff --git a/src/Nncase.Core/TIR/Builders/NestBodyExprBuilder.cs b/src/Nncase.Core/TIR/Builders/NestBodyExprBuilder.cs index 60d79a6b89..2add795b71 100644 --- a/src/Nncase.Core/TIR/Builders/NestBodyExprBuilder.cs +++ b/src/Nncase.Core/TIR/Builders/NestBodyExprBuilder.cs @@ -51,7 +51,7 @@ public T Build() public ISequentialBuilder InsertBody(int index, params object[] exprOrBuilders) { - _subBuilders[_subBuilders.Length - 1].InsertBody(index, exprOrBuilders); + _subBuilders[index < 0 ? _subBuilders.Length + index : index].Body(exprOrBuilders); return this; } } diff --git a/src/Nncase.Core/TIR/Script.cs b/src/Nncase.Core/TIR/Script.cs index 28740e43ab..ffda2527df 100644 --- a/src/Nncase.Core/TIR/Script.cs +++ b/src/Nncase.Core/TIR/Script.cs @@ -134,8 +134,16 @@ public static ISequentialBuilder Grid(out Var[] loopVars, LoopMode loopMode { string[] names = { "i", "j", "k", "l" }; var newLoopVars = loopVars = new Var[ranges.Length]; - return new NestBodyExprBuilder(ranges.Select((rg, i) => - T.ForLoop(out newLoopVars[i], rg, loopMode, names[i % 4] + (i / 4 == 0 ? string.Empty : (i / 4).ToString())).Body()).ToArray()); + var newLoops = ranges.Select((rg, i) => T.ForLoop(out newLoopVars[i], rg, loopMode, names[i % 4] + (i / 4 == 0 ? string.Empty : (i / 4).ToString())).Body()).ToArray(); + return new NestBodyExprBuilder(newLoops); + } + + public static ISequentialBuilder Grid(out Var[] loopVars, out ISequentialBuilder[] loops, LoopMode loopMode, params TIR.Range[] ranges) + { + string[] names = { "i", "j", "k", "l" }; + var newLoopVars = loopVars = new Var[ranges.Length]; + var newLoops = loops = ranges.Select((rg, i) => T.ForLoop(out newLoopVars[i], rg, loopMode, names[i % 4] + (i / 4 == 0 ? string.Empty : (i / 4).ToString())).Body()).ToArray(); + return new NestBodyExprBuilder(loops); } /// @@ -223,6 +231,49 @@ public static Buffer CreateBuffer(TensorType tensorType, MemoryLocation location return buffer; } + /// + /// create the buffer by expressions. + /// + public static Buffer CreateBuffer(DataType dataType, Expr[] dimensions, MemoryLocation location, out Buffer buffer, [CallerArgumentExpression("buffer")] string name = "") + { + if (name.StartsWith("var ")) + { + name = name[4..]; + } + + var strides = TensorUtilities.GetStrides(dimensions); + var size = TensorUtilities.GetProduct(dimensions.ToArray()) * dataType.SizeInBytes; + var memspan = new MemSpan(size, location); + buffer = new Buffer(name, dataType, memspan, dimensions, strides); + return buffer; + } + + public static Buffer CreateBuffer(DataType dataType, Expr[] dimensions, Expr[] strides, MemSpan memSpan, out Buffer buffer, [CallerArgumentExpression("buffer")] string name = "") + { + if (name.StartsWith("var ")) + { + name = name[4..]; + } + + buffer = new Buffer(name, dataType, memSpan, dimensions, strides); + return buffer; + } + + public static Buffer AttachBuffer(Expr start, TensorType tensorType, MemoryLocation location, out Buffer buffer, [CallerArgumentExpression("buffer")] string name = "") + { + if (name.StartsWith("var ")) + { + name = name[4..]; + } + + var dimensions = tensorType.Shape.ToValueArray(); + var strides = TensorUtilities.GetStrides(dimensions); + var size = (int)TensorUtilities.GetProduct(dimensions.ToArray()) * tensorType.DType.SizeInBytes; + var memspan = new MemSpan(start, size, location); + buffer = new Buffer(name, tensorType.DType, memspan, dimensions.Select(i => (Expr)i).ToArray(), strides.Select(i => (Expr)i).ToArray()); + return buffer; + } + /// /// create buffer by const. /// @@ -233,11 +284,11 @@ public static Buffer AttachBuffer(TensorConst @const, out Buffer buffer, [Caller name = name[4..]; } - var dimensions = @const.ValueType.Shape.ToValueArray(); + var dimensions = @const.CheckedShape.ToValueArray(); var strides = TensorUtilities.GetStrides(dimensions); - var size = (int)TensorUtilities.GetProduct(dimensions.ToArray()) * @const.ValueType.DType.SizeInBytes; + var size = (int)TensorUtilities.GetProduct(dimensions.ToArray()) * @const.CheckedDataType.SizeInBytes; var memspan = new MemSpan(IR.F.Buffer.DDrOf(@const), size, MemoryLocation.Rdata); - buffer = new Buffer(name, @const.ValueType.DType, memspan, dimensions.Select(i => (Expr)i).ToArray(), strides.Select(i => (Expr)i).ToArray()); + buffer = new Buffer(name, @const.CheckedDataType, memspan, dimensions.Select(i => (Expr)i).ToArray(), strides.Select(i => (Expr)i).ToArray()); return buffer; } diff --git a/src/Nncase.Core/TensorUtilities.cs b/src/Nncase.Core/TensorUtilities.cs index 79f658aefa..146e5c6cfa 100644 --- a/src/Nncase.Core/TensorUtilities.cs +++ b/src/Nncase.Core/TensorUtilities.cs @@ -69,7 +69,7 @@ public static Expr GetProduct(ReadOnlySpan dimensions, int startIndex = 0) for (int i = startIndex; i < dimensions.Length; i++) { var dimension = dimensions[i]; - product *= IR.F.Math.Require(dimension >= 0, dimension, "Dimension is out of range."); + product *= dimension; } return product; diff --git a/src/Nncase.Core/Utilities/DistributedUtility.cs b/src/Nncase.Core/Utilities/DistributedUtility.cs index 2061a40958..eb4a84be0d 100644 --- a/src/Nncase.Core/Utilities/DistributedUtility.cs +++ b/src/Nncase.Core/Utilities/DistributedUtility.cs @@ -2,6 +2,7 @@ // Licensed under the Apache license. See LICENSE file in the project root for full license information. using System.Diagnostics.CodeAnalysis; +using NetFabric.Hyperlinq; using Nncase.IR; namespace Nncase.Utilities; @@ -26,11 +27,7 @@ public static IReadOnlyList> GetLeafCandidateNDSBPs(TensorType tens ndsbps.Add(ndsbp); } - return ndsbps.CartesianProduct(). - Select(ndsbp => ndsbp.ToArray()). - Where(ndsbp => IsDistributable(tensorType, ndsbp, placement)). - Select(ndsbp => new IRArray(ndsbp)). - ToArray(); + return ndsbps.CartesianProduct().Select(ndsbp => ndsbp.ToArray()).Where(ndsbp => IsDistributable(tensorType, ndsbp, placement)).Select(ndsbp => new IRArray(ndsbp)).ToArray(); } public static IReadOnlyList> GetPartialCandidateNDSBPs(DistributedType distributedType) @@ -65,11 +62,7 @@ public static IReadOnlyList> GetPartialCandidateNDSBPs(DistributedT } } - return candidateNdsbps.CartesianProduct(). - Select(ndsbp => ndsbp.ToArray()). - Where(ndsbp => IsDistributable(tensorType, ndsbp, placement)). - Select(ndsbp => new IRArray(ndsbp)). - ToArray(); + return candidateNdsbps.CartesianProduct().Select(ndsbp => ndsbp.ToArray()).Where(ndsbp => IsDistributable(tensorType, ndsbp, placement)).Select(ndsbp => new IRArray(ndsbp)).ToArray(); } public static bool IsDistributable(TensorType tensorType, ReadOnlySpan ndsbp, Placement placement) @@ -131,24 +124,74 @@ public static Expr[] TryGetNonUniformDividedShape(DistributedType distributedTyp } return hierarchies.Select((divs, axis) => + { + Expr dim; + if (divs.Any()) { - Expr dim; - if (divs.Any()) + var divsor = (int)TensorUtilities.GetProduct(divs.Select(h => distributedType.Placement.Hierarchy[h]).ToArray()); + var (res, rem) = Math.DivRem(shape[axis], divsor); + if (rem == 0) { - var divsor = (int)TensorUtilities.GetProduct(divs.Select(h => distributedType.Placement.Hierarchy[h]).ToArray()); - var (res, rem) = Math.DivRem(shape[axis], divsor); - dim = IR.F.Math.Select( - TensorUtilities.GetIndex(hierarchyStrides.TakeLast(divs.Count).Select(s => (Expr)s).ToArray(), divs.Select(h => ids[h]).ToArray()) < (divsor - 1), - res, - res + rem); + return res; } - else + + dim = IR.F.Math.Select( + TensorUtilities.GetIndex(hierarchyStrides.TakeLast(divs.Count).Select(s => (Expr)s).ToArray(), divs.Select(h => ids[h]).ToArray()) < (divsor - 1), + res, + res + rem); + } + else + { + dim = distributedType.TensorType.Shape[axis].FixedValue; + } + + return dim; + }).ToArray(); + } + + public static List TryGetNonUniformDividedSlice(DistributedType distributedType) + { + var shape = distributedType.TensorType.Shape.ToValueArray(); + var hierarchies = Enumerable.Range(0, shape.Length).Select(i => new List()).ToArray(); + for (int i = 0; i < distributedType.NdSBP.Count; i++) + { + if (distributedType.NdSBP[i] is SBPSplit { Axis: int axis }) + { + hierarchies[axis].Add(i); + } + } + + var spliList = hierarchies.Select, int[]>((divs, axis) => + { + int[] dim; + if (divs.Any()) + { + var divsor = (int)TensorUtilities.GetProduct(divs.Select(h => distributedType.Placement.Hierarchy[h]).ToArray()); + var (res, rem) = Math.DivRem(shape[axis], divsor); + if (rem == 0) { - dim = distributedType.TensorType.Shape[axis].FixedValue; + return new[] { res }; } - return dim; - }).ToArray(); + dim = new[] { res, res + rem }; + } + else + { + dim = distributedType.TensorType.Shape.ToValueArray().Skip(axis).Take(1).ToArray(); + } + + return dim; + }).ToList(); + + IEnumerable ret = new[] { Array.Empty() }; + foreach (int[] array in spliList) + { + ret = from seq in ret + from item in array + select seq.Concat(new[] { item }).ToArray(); + } + + return ret.ToList(); } public static bool IsDivideBy(int input, int divisor) @@ -174,17 +217,14 @@ public static bool IsDivideExactly(int input, int divisor) public static float GetDividedTensorEfficiency(DistributedType distributedType, int burstLength) { var (tiles, shape) = GetDividedTile(distributedType); - return Enumerable.Range(0, tiles.Count). - Select(i => tiles[i].Ranges(0, shape[i])). - CartesianProduct(). - Select(rgs => - { - var slice = rgs.ToArray(); - var iscontiguous = TensorUtilities.IsContiguousSlice(shape.ToArray(), slice, out var contiguousStart); - var size = TensorUtilities.GetProduct(tiles.ToArray(), contiguousStart) * distributedType.TensorType.DType.SizeInBytes; - var (div, rem) = Math.DivRem(size, burstLength); - return ((div * 1.0f) + ((float)rem / burstLength)) / (div + 1); - }).Average(); + return Enumerable.Range(0, tiles.Count).Select(i => tiles[i].Ranges(0, shape[i])).CartesianProduct().Select(rgs => + { + var slice = rgs.ToArray(); + var iscontiguous = TensorUtilities.IsContiguousSlice(shape.ToArray(), slice, out var contiguousStart); + var size = TensorUtilities.GetProduct(tiles.ToArray(), contiguousStart) * distributedType.TensorType.DType.SizeInBytes; + var (div, rem) = Math.DivRem(size, burstLength); + return ((div * 1.0f) + ((float)rem / burstLength)) / (div + 1); + }).Average(); } public static TensorType GetDividedTensorType(DistributedType distributedType) diff --git a/src/Nncase.Diagnostics/Diagnostics/ILPrintVisitor.cs b/src/Nncase.Diagnostics/Diagnostics/ILPrintVisitor.cs index 4a447073b8..079dafc907 100644 --- a/src/Nncase.Diagnostics/Diagnostics/ILPrintVisitor.cs +++ b/src/Nncase.Diagnostics/Diagnostics/ILPrintVisitor.cs @@ -248,6 +248,8 @@ public ILPrintVisitor(TextWriter textWriter, bool display_callable, int indent_l _scope = new(textWriter, indent_level); } + public override string DefaultVisitType(IRType type) => type.ToString(); + /// public override string VisitType(AnyType type) => "any"; diff --git a/src/Nncase.EGraph/Passes/EGraphExtractors/SatExtractor.cs b/src/Nncase.EGraph/Passes/EGraphExtractors/SatExtractor.cs index 038873173b..dbb9d54a9a 100644 --- a/src/Nncase.EGraph/Passes/EGraphExtractors/SatExtractor.cs +++ b/src/Nncase.EGraph/Passes/EGraphExtractors/SatExtractor.cs @@ -77,7 +77,7 @@ public Expr Extract(EClass root, IEGraph eGraph, out IReadOnlyDictionary /// public IRType Visit(ITypeInferenceContext context, Allocate target) { - return TensorType.Pointer(target.ElemType.DType); + return TensorType.Pointer(target.ElemType); } } diff --git a/src/Nncase.Evaluator/Buffers/DDrOf.cs b/src/Nncase.Evaluator/Buffers/DDrOf.cs index 86c53e04b7..b329ee1787 100644 --- a/src/Nncase.Evaluator/Buffers/DDrOf.cs +++ b/src/Nncase.Evaluator/Buffers/DDrOf.cs @@ -12,8 +12,13 @@ namespace Nncase.Evaluator.Buffers; [TypeInferGenerator] public partial class DDrOfEvaluator : ITypeInferencer { - private IRType Visit(TensorType input) + private IRType Visit(IRType input) { - return TensorType.Pointer(input.DType); + return input switch + { + DistributedType d => TensorType.Pointer(d.TensorType.DType), + TensorType t => TensorType.Pointer(t.DType), + _ => new InvalidType(input.GetType().Name), + }; } } diff --git a/src/Nncase.Evaluator/Imaging/ResizeImage.cs b/src/Nncase.Evaluator/Imaging/ResizeImage.cs index e25db7b8c0..0e9f1c2e09 100644 --- a/src/Nncase.Evaluator/Imaging/ResizeImage.cs +++ b/src/Nncase.Evaluator/Imaging/ResizeImage.cs @@ -163,6 +163,7 @@ public Cost Visit(ICostEvaluateContext context, ResizeImage target) { [CostFactorNames.MemoryLoad] = CostUtility.GetMemoryAccess(inputType), [CostFactorNames.MemoryStore] = CostUtility.GetMemoryAccess(returnType), + [CostFactorNames.CPUCycles] = CostUtility.GetCPUCycles(returnType, 4), }; } diff --git a/src/Nncase.Evaluator/Math/Binary.cs b/src/Nncase.Evaluator/Math/Binary.cs index 1ac424b0be..4e8ab2e493 100755 --- a/src/Nncase.Evaluator/Math/Binary.cs +++ b/src/Nncase.Evaluator/Math/Binary.cs @@ -214,6 +214,8 @@ private IRType Visit(Binary target, DistributedType a, DistributedType b) BinaryOp.Sub => a - b, BinaryOp.Mul => a * b, BinaryOp.Div => a / b, + BinaryOp.FloorDiv => (int)System.Math.Floor((float)a / b), + BinaryOp.CeilDiv => (int)System.Math.Ceiling((float)a / b), BinaryOp.Mod => a % b, BinaryOp.Min => System.Math.Min(a, b), BinaryOp.Max => System.Math.Max(a, b), @@ -227,6 +229,8 @@ private IRType Visit(Binary target, DistributedType a, DistributedType b) BinaryOp.Sub => a - b, BinaryOp.Mul => a * b, BinaryOp.Div => a / b, + BinaryOp.FloorDiv => (uint)System.Math.Floor((float)a / b), + BinaryOp.CeilDiv => (uint)System.Math.Ceiling((float)a / b), BinaryOp.Mod => a % b, BinaryOp.Min => System.Math.Min(a, b), BinaryOp.Max => System.Math.Max(a, b), @@ -242,6 +246,8 @@ private IRType Visit(Binary target, DistributedType a, DistributedType b) BinaryOp.Sub => a - b, BinaryOp.Mul => a * b, BinaryOp.Div => a / b, + BinaryOp.FloorDiv => (ulong)System.Math.Floor((float)a / b), + BinaryOp.CeilDiv => (ulong)System.Math.Ceiling((float)a / b), BinaryOp.Mod => a % b, BinaryOp.Min => System.Math.Min(a, b), BinaryOp.Max => System.Math.Max(a, b), @@ -262,6 +268,8 @@ private IRType Visit(Binary target, DistributedType a, DistributedType b) BinaryOp.Sub => a - b, BinaryOp.Mul => a * b, BinaryOp.Div => a / b, + BinaryOp.FloorDiv => (long)System.Math.Floor((float)a / b), + BinaryOp.CeilDiv => (long)System.Math.Ceiling((float)a / b), BinaryOp.Mod => a % b, BinaryOp.Min => System.Math.Min(a, b), BinaryOp.Max => System.Math.Max(a, b), @@ -298,6 +306,8 @@ static OrtKISharp.Tensor Mod(OrtKISharp.Tensor a, OrtKISharp.Tensor b) BinaryOp.Sub => a - b, BinaryOp.Mul => a * b, BinaryOp.Div => a / b, + BinaryOp.FloorDiv => OrtKI.Floor(a.Cast(OrtDataType.Float) / b.Cast(OrtDataType.Float)).Cast(a.DataType), + BinaryOp.CeilDiv => OrtKI.Ceil(a.Cast(OrtDataType.Float) / b.Cast(OrtDataType.Float)).Cast(a.DataType), BinaryOp.Mod => Mod(a, b), BinaryOp.Min => OrtKI.Min(new[] { a, b }), BinaryOp.Max => OrtKI.Max(new[] { a, b }), diff --git a/src/Nncase.Evaluator/Math/MatMul.cs b/src/Nncase.Evaluator/Math/MatMul.cs index 1f19b64388..4642a4f8d5 100644 --- a/src/Nncase.Evaluator/Math/MatMul.cs +++ b/src/Nncase.Evaluator/Math/MatMul.cs @@ -23,7 +23,7 @@ public static IRType VisitDistributedType(DistributedType a, DistributedType b) { if (VisitTensorType(a.TensorType, b.TensorType) is not TensorType outType) { - return new InvalidType(string.Empty); + return new InvalidType($"{a.TensorType} {b.TensorType} not support"); } if (a.Placement != b.Placement) @@ -162,7 +162,7 @@ public IRType Visit(ITypeInferenceContext context, MatMul target) { (DistributedType a, DistributedType b) => VisitDistributedType(a, b), (TensorType a, TensorType b) => VisitTensorType(a, b), - _ => new InvalidType(string.Empty), + _ => new InvalidType($"{lhs} {rhs} not support"), }; } diff --git a/src/Nncase.Evaluator/Tensors/Reshape.cs b/src/Nncase.Evaluator/Tensors/Reshape.cs index 7488739f1e..4d0b9245f6 100644 --- a/src/Nncase.Evaluator/Tensors/Reshape.cs +++ b/src/Nncase.Evaluator/Tensors/Reshape.cs @@ -40,7 +40,8 @@ public IRType Visit(ITypeInferenceContext context, Reshape target) TensorType tensorType => Visit(context, target, tensorType), DistributedType distributedType => Visit(context, target, distributedType), AnyType => AnyType.Default, - _ => throw new NotImplementedException(), + InvalidType => input, + _ => new InvalidType($"Not Support Input Type {input.GetType().Name}"), }; } diff --git a/src/Nncase.Evaluator/Tensors/Where.cs b/src/Nncase.Evaluator/Tensors/Where.cs index db614bc6d9..f45d41b737 100644 --- a/src/Nncase.Evaluator/Tensors/Where.cs +++ b/src/Nncase.Evaluator/Tensors/Where.cs @@ -4,6 +4,7 @@ using System; using System.Collections.Generic; using System.Linq; +using System.Text.RegularExpressions; using NetFabric.Hyperlinq; using Nncase.CostModel; using Nncase.IR; @@ -45,9 +46,20 @@ public IValue Visit(IEvaluateContext context, Where where) /// public IRType Visit(ITypeInferenceContext context, Where target) { - var cond = context.CheckArgumentType(target, Where.Cond); - var x = context.CheckArgumentType(target, Where.X); - var y = context.CheckArgumentType(target, Where.Y); + var cond = context.CheckArgumentType(target, Where.Cond); + var x = context.CheckArgumentType(target, Where.X); + var y = context.CheckArgumentType(target, Where.Y); + + return (cond, x, y) switch + { + (DistributedType a, DistributedType b, DistributedType c) => Visit(a, b, c, target), + (TensorType a, TensorType b, TensorType c) => Visit(a, b, c, target), + _ => new InvalidType(cond.GetType().ToString()), + }; + } + + public IRType Visit(TensorType cond, TensorType x, TensorType y, Where target) + { if (target.IsTfWhere) { return new TensorType(DataTypes.Int64, new Shape(Dimension.Unknown, cond.Shape.Rank)); @@ -56,12 +68,60 @@ public IRType Visit(ITypeInferenceContext context, Where target) return TypeInference.BroadcastType(x.DType, cond, x, y); } + public IRType Visit(DistributedType cond, DistributedType x, DistributedType y, Where target) + { + var invalid = new InvalidType($"{cond}, {x}, {y} not support"); + if (cond.Placement != x.Placement || x.Placement != y.Placement) + { + return invalid; + } + + if (target.IsTfWhere) + { + return invalid; + } + + var targetType = (TensorType)TypeInference.BroadcastType(x.TensorType.DType, cond.TensorType, x.TensorType, y.TensorType); + if (cond.TensorType.Shape != targetType.Shape) + { + return invalid; + } + + var ndsbp = new SBP[cond.Placement.Rank]; + + for (int i = 0; i < cond.Placement.Rank; i++) + { + switch (cond.NdSBP[i], x.NdSBP[i], y.NdSBP[i]) + { + case (SBPSplit { Axis: int ic }, SBPSplit { Axis: int }, SBPSplit { Axis: int }): + ndsbp[i] = SBP.S(ic); + break; + case (SBPSplit { Axis: int ic }, SBPBroadCast, SBPSplit { Axis: int }): + ndsbp[i] = SBP.S(ic); + break; + case (SBPSplit { Axis: int ic }, SBPSplit { Axis: int }, SBPBroadCast): + ndsbp[i] = SBP.S(ic); + break; + case (SBPSplit { Axis: int ic }, SBPBroadCast, SBPBroadCast): + ndsbp[i] = SBP.S(ic); + break; + case (SBPBroadCast, SBPBroadCast, SBPBroadCast): + ndsbp[i] = SBP.B; + break; + default: + return invalid; + } + } + + return new DistributedType(targetType, ndsbp, cond.Placement); + } + public Cost Visit(ICostEvaluateContext context, Where target) { - var cond = context.GetArgumentType(target, Where.Cond); - var x = context.GetArgumentType(target, Where.X); - var y = context.GetArgumentType(target, Where.Y); - var ret = context.GetReturnType(); + var cond = context.GetArgumentType(target, Where.Cond); + var x = context.GetArgumentType(target, Where.X); + var y = context.GetArgumentType(target, Where.Y); + var ret = context.GetReturnType(); return new() { [CostFactorNames.MemoryLoad] = CostUtility.GetMemoryAccess(cond, x, y), diff --git a/src/Nncase.Importer/Onnx/QLinearConv.cs b/src/Nncase.Importer/Onnx/QLinearConv.cs index 0ab9ebc99c..3e1bb7e07a 100644 --- a/src/Nncase.Importer/Onnx/QLinearConv.cs +++ b/src/Nncase.Importer/Onnx/QLinearConv.cs @@ -29,13 +29,13 @@ private Expr VisitQLinearConv(in NodeProto op) var group = GetIntAttribute(op, "group", 1); var strides = GetStrideAttribute(op); - int? stridesValueLen = ((TensorConst)strides).ValueType.Shape[0].Value; + int? stridesValueLen = ((TensorConst)strides).CheckedShape[0].Value; for (var i = 0; i < stridesValueLen; i++) { System.Diagnostics.Trace.Assert(((TensorConst)strides).Value.Cast()[i] <= (long)int.MaxValue); } - int? dilationValueLen = ((TensorConst)dilation).ValueType.Shape[0].Value; + int? dilationValueLen = ((TensorConst)dilation).CheckedShape[0].Value; for (var i = 0; i < dilationValueLen; i++) { System.Diagnostics.Trace.Assert(((TensorConst)dilation).Value.Cast()[i] <= (long)int.MaxValue); @@ -63,16 +63,16 @@ private Expr VisitQLinearConv(in NodeProto op) if (bias == null) { - int? ocNumber = ((TensorConst)weights).ValueType.Shape[0].Value; + int? ocNumber = ((TensorConst)weights).CheckedShape[0].Value; var zeroBias = new TensorConst(new int[ocNumber == null ? default(int) : ocNumber.Value]); var conv = F.NN.Conv2D(inputDeq, weightsDeq, zeroBias, strideConst, pads, dilationConst, PadMode.Constant, group); - return Quantize(conv, new QuantParam(((TensorConst)yZeroPoint).Value.ToScalar(), ((TensorConst)yScale).Value.ToScalar()), ((TensorConst)yZeroPoint).ValueType.DType); + return Quantize(conv, new QuantParam(((TensorConst)yZeroPoint).Value.ToScalar(), ((TensorConst)yScale).Value.ToScalar()), ((TensorConst)yZeroPoint).CheckedDataType); } else { var biasDeq = Dequantize(bias, new QuantParam(0, ((TensorConst)xScale).Value.ToScalar() * ((TensorConst)wScale).Value.ToScalar()), DataTypes.Float32); var conv = F.NN.Conv2D(inputDeq, weightsDeq, biasDeq, strideConst, pads, dilationConst, PadMode.Constant, group); - return Quantize(conv, new QuantParam(((TensorConst)yZeroPoint).Value.ToScalar(), ((TensorConst)yScale).Value.ToScalar()), ((TensorConst)yZeroPoint).ValueType.DType); + return Quantize(conv, new QuantParam(((TensorConst)yZeroPoint).Value.ToScalar(), ((TensorConst)yScale).Value.ToScalar()), ((TensorConst)yZeroPoint).CheckedDataType); } } } diff --git a/src/Nncase.Importer/Onnx/QLinearMatmul.cs b/src/Nncase.Importer/Onnx/QLinearMatmul.cs index 5ab2ad73d8..892a6b6c2d 100644 --- a/src/Nncase.Importer/Onnx/QLinearMatmul.cs +++ b/src/Nncase.Importer/Onnx/QLinearMatmul.cs @@ -25,7 +25,7 @@ private Expr VisitQLinearMatMul(in NodeProto op) var aDeq = Dequantize(input_a, new QuantParam(((TensorConst)aZeroPoint).Value.ToScalar(), ((TensorConst)aScale).Value.ToScalar()), DataTypes.Float32); var bDeq = Dequantize(input_b, new QuantParam(((TensorConst)bZeroPoint).Value.ToScalar(), ((TensorConst)bScale).Value.ToScalar()), DataTypes.Float32); var matmul = F.Tensors.MatMul(aDeq, bDeq); - return Quantize(matmul, new QuantParam(((TensorConst)yZeroPoint).Value.ToScalar(), ((TensorConst)yScale).Value.ToScalar()), ((TensorConst)yZeroPoint).ValueType.DType); + return Quantize(matmul, new QuantParam(((TensorConst)yZeroPoint).Value.ToScalar(), ((TensorConst)yScale).Value.ToScalar()), ((TensorConst)yZeroPoint).CheckedDataType); } } } diff --git a/src/Nncase.Importer/Onnx/Quantize.cs b/src/Nncase.Importer/Onnx/Quantize.cs index cc33583711..6a4c771cb4 100644 --- a/src/Nncase.Importer/Onnx/Quantize.cs +++ b/src/Nncase.Importer/Onnx/Quantize.cs @@ -23,7 +23,7 @@ private Expr VisitQuantizeLinear(in NodeProto op) new QuantParam( biasConst.Value.ToScalar(), scaleConst.Value.ToScalar()), - ((TensorConst)bias).ValueType.DType); + ((TensorConst)bias).CheckedDataType); } throw new NotImplementedException("Onnx importer not impl for dynamic scale and bias"); diff --git a/src/Nncase.Passes/DDrBufferSchdeulePass.cs b/src/Nncase.Passes/DDrBufferSchdeulePass.cs index 80aebda267..15a6505686 100644 --- a/src/Nncase.Passes/DDrBufferSchdeulePass.cs +++ b/src/Nncase.Passes/DDrBufferSchdeulePass.cs @@ -134,7 +134,7 @@ protected override Expr RewriteLeafBuffer(TIR.Buffer expr) protected override TIR.MemSpan RewriteLeafMemSpan(TIR.MemSpan memSpan) { - if (memSpan is { Location: MemoryLocation.Rdata, Start: Call { Target: IR.Buffers.DDrOf, Arguments: var arg } } && arg[0] is Const { ValueType: TensorType constType } @const) + if (memSpan is { Location: MemoryLocation.Rdata, Start: Call { Target: IR.Buffers.DDrOf, Arguments: var arg } } && arg[0] is Const @const) { if (!ModuleRdataMaps.TryGetValue(Entry.ModuleKind, out var moduleRdataMap)) { @@ -163,7 +163,7 @@ protected override TIR.MemSpan RewriteLeafMemSpan(TIR.MemSpan memSpan) Changed = true; } - return memSpan.With(new TensorConst(Tensor.FromPointer((ulong)memRange.Min, constType.DType)), memRange.Max - memRange.Min); + return memSpan.With(new TensorConst(Tensor.FromPointer((ulong)memRange.Min, @const.CheckedDataType)), memRange.Max - memRange.Min); } return memSpan; diff --git a/src/Nncase.Passes/ModulePass.cs b/src/Nncase.Passes/ModulePass.cs index dc35cd0843..25f8edd98f 100644 --- a/src/Nncase.Passes/ModulePass.cs +++ b/src/Nncase.Passes/ModulePass.cs @@ -30,7 +30,7 @@ protected override Task OnPassStartAsync(IRModule input, RunPassContext context) { foreach (var func in input.Functions) { - DumpScope.Current.DumpIR(func, func.Name, "Start"); + DumpScope.Current.DumpIR(func, string.Empty, "Start"); } } @@ -44,7 +44,7 @@ protected override Task OnPassEndAsync(IRModule post, RunPassContext context) { foreach (var func in post.Functions) { - DumpScope.Current.DumpIR(func, func.Name, "End"); + DumpScope.Current.DumpIR(func, string.Empty, "End"); } } diff --git a/src/Nncase.Passes/Rules/Neutral/CombineQuantize.cs b/src/Nncase.Passes/Rules/Neutral/CombineQuantize.cs index 9519b011c7..6d80c05be1 100644 --- a/src/Nncase.Passes/Rules/Neutral/CombineQuantize.cs +++ b/src/Nncase.Passes/Rules/Neutral/CombineQuantize.cs @@ -36,10 +36,10 @@ public sealed partial class CombineQuantizeConcat : RewriteRule IsConcat( "concat", _ => true, - IsTuple(IsVArgsRepeat("tupleInputs", () => IsWildcard()))), + IsTuple("tuple", IsVArgsRepeat("tupleInputs", () => IsWildcard()))), IsWildcard("quantParam")); - private Expr? GetReplace(Quantize quantize, IReadOnlyList tupleInputs, IR.Tensors.Concat concat, Expr quantParam, RunPassContext options) + private Expr? GetReplace(Quantize quantize, IReadOnlyList tupleInputs, IR.Tensors.Concat concat, Expr quantParam, RunPassContext options, Expr tuple) { if (options.Driver is DataflowPass) { @@ -50,7 +50,31 @@ public sealed partial class CombineQuantizeConcat : RewriteRule { if (userAnalysis[e].Count() > 1) { - return null; + foreach (var user in userAnalysis[e]) + { + if (user is Call { Target: Nncase.IR.Math.Quantize } userCall) + { + var quantUser = userCall.Arguments[Nncase.IR.Math.Quantize.QuantParam.Index]; + if (quantUser != quantParam) + { + return null; + } + } + else + { + if (user is not Nncase.IR.Tuple) + { + return null; + } + else + { + if (user != tuple) + { + return null; + } + } + } + } } } } diff --git a/src/Nncase.Passes/Rules/Neutral/FoldPrePostReshapeSoftmax.cs b/src/Nncase.Passes/Rules/Neutral/FoldPrePostReshapeSoftmax.cs deleted file mode 100644 index 83edfe5e5e..0000000000 --- a/src/Nncase.Passes/Rules/Neutral/FoldPrePostReshapeSoftmax.cs +++ /dev/null @@ -1,38 +0,0 @@ -// Copyright (c) Canaan Inc. All rights reserved. -// Licensed under the Apache license. See LICENSE file in the project root for full license information. - -using System; -using System.Collections.Generic; -using System.Linq; -using Nncase.IR; -using Nncase.IR.NN; -using Nncase.PatternMatch; -using static Nncase.IR.F.NN; -using static Nncase.IR.F.Tensors; -using static Nncase.IR.TypePatternUtility; -using static Nncase.PatternMatch.F.Math; -using static Nncase.PatternMatch.F.NN; -using static Nncase.PatternMatch.F.Tensors; -using static Nncase.PatternMatch.Utility; - -namespace Nncase.Passes.Rules.Neutral; - -/// -/// Fold nop . -/// -[RuleGenerator] -public sealed partial class FoldPrePostReshapeSoftmax : IRewriteRule -{ - /// - public IPattern Pattern { get; } = IsReshape( - "reshape", - "reshapeCall", - _ => true, - IsSoftmax("softmax", IsReshape("rehsape2", "reshapeCall2", _ => true, IsWildcard("input"), IsTensorConst("shape2"))), - IsTensorConst("shape1")); - - private Expr? GetReplace(Expr input) - { - return Softmax(input, 3); - } -} diff --git a/src/Nncase.Quantization/Quantization/PytestCalibrationDatasetProvider.cs b/src/Nncase.Quantization/Quantization/PytestCalibrationDatasetProvider.cs index 66b040ec0f..6aef026d82 100644 --- a/src/Nncase.Quantization/Quantization/PytestCalibrationDatasetProvider.cs +++ b/src/Nncase.Quantization/Quantization/PytestCalibrationDatasetProvider.cs @@ -42,7 +42,7 @@ public PytestCalibrationDatasetProvider(IReadOnlyList vars, string dataset) } // group by the samples - _sampleSets = sampleItems.GroupBy(item => item.Number).Select(g => g.OrderBy(item => item.InputIndex).ToArray()).ToArray(); + _sampleSets = sampleItems.GroupBy(item => item.Number).OrderBy(x => x.Key).Select(g => g.OrderBy(item => item.InputIndex).ToArray()).ToArray(); Count = _sampleSets.Length; Samples = _sampleSets.Select(samples => diff --git a/src/Nncase.Quantization/Quantization/QuantUtility.cs b/src/Nncase.Quantization/Quantization/QuantUtility.cs index e05e347492..3a5dcef39d 100644 --- a/src/Nncase.Quantization/Quantization/QuantUtility.cs +++ b/src/Nncase.Quantization/Quantization/QuantUtility.cs @@ -4,10 +4,12 @@ using System; using System.Collections.Generic; using System.Data; +using System.Diagnostics; using System.Linq; using System.Runtime.InteropServices; using System.Text; using System.Threading.Tasks; +using Nncase.Evaluator; using Nncase.IR; using OrtKISharp; using SMath = System.Math; @@ -38,14 +40,12 @@ public static Tensor SquantWeights(Tensor inputWeights, Tensor (long)x.FixedValue).ToArray(); OrtKISharp.Tensor x, delta, zeroPoint; if (inputWeightsShape.Length == 4) { var outChannel = inputWeightsShape[0]; - var inChannel = inputWeightsShape[1]; - var filterH = inputWeightsShape[2]; - var filterW = inputWeightsShape[3]; - x = OrtKISharp.Tensor.MakeTensor(inputWeights.PinBuffer(), OrtDataType.Float, new long[] { outChannel, inChannel, filterH, filterW }); + x = inputWeights.ToOrtTensor(); if (isByChannel) { @@ -53,7 +53,7 @@ public static Tensor SquantWeights(Tensor inputWeights, Tensor { var xMin = inputWeightsRanges[c, 0]; var xMax = inputWeightsRanges[c, 1]; @@ -64,10 +64,9 @@ public static Tensor SquantWeights(Tensor inputWeights, Tensor SquantWeights(Tensor inputWeights, Tensor { var xMin = inputWeightsRanges[c, 0]; var xMax = inputWeightsRanges[c, 1]; var deltaTmp = (xMax - xMin) / (qMax - qMin); var zeroPointTmp = System.Math.Round(((xMax * qMin) - (xMin * qMax)) / (xMax - xMin)); + for (int i = 0; i < eachChannelSize; i++) { deltaArr[(c * eachChannelSize) + i] = deltaTmp; zeroPointArr[(c * eachChannelSize) + i] = (float)zeroPointTmp; } - } + }); - delta = OrtKISharp.Tensor.MakeTensor(deltaArr, new long[] { outChannel, inChannel }); - zeroPoint = OrtKISharp.Tensor.MakeTensor(zeroPointArr, new long[] { outChannel, inChannel }); + delta = OrtKISharp.Tensor.MakeTensor(deltaArr, inWShape); + zeroPoint = OrtKISharp.Tensor.MakeTensor(zeroPointArr, inWShape); } else { @@ -111,151 +112,102 @@ public static Tensor SquantWeights(Tensor inputWeights, Tensor(qMin), OrtKISharp.Tensor.FromScalar(qMax)); var xDequant = (xQuant - zeroPoint) * delta; - - return Tensor.From(xDequant.ToArray(), inputWeights.Shape); + var res = Tensor.From(xDequant.ToArray(), inputWeights.Shape); + quantTensor.Dispose(); + xQuant.Dispose(); + xDequant.Dispose(); + return res; } - private static void RoundingForward(float roundingErrorSum, OrtKISharp.Tensor roundingNumber, OrtKISharp.Tensor roundingError, OrtKISharp.Tensor number, OrtKISharp.Tensor error, OrtKISharp.Tensor priority, OrtKISharp.Tensor order, OrtKISharp.Tensor priority1) + private static void RoundingForward(float roundingErrorSum, Span roundingNumberMem, Span roundingErrorMem, Span numberMem, Span errorMem, Span priorityMem, Span orderMem, Span priority1Mem) { - var roundingNumberMem = MemoryMarshal.Cast(roundingNumber.BytesBuffer); - var roundingErrorMem = MemoryMarshal.Cast(roundingError.BytesBuffer); - var priorityMem = MemoryMarshal.Cast(priority.BytesBuffer); - var priority1Mem = MemoryMarshal.Cast(priority1.BytesBuffer); int topK = (int)System.Math.Round(System.Math.Abs(roundingErrorSum)); bool overSquant = topK >= System.Math.Abs(roundingErrorSum); if (topK > 0) { - var starts = OrtKISharp.Tensor.MakeTensor(new long[] { 0 }, new long[] { 1 }); - var ends = OrtKISharp.Tensor.MakeTensor(new long[] { topK }, new long[] { 1 }); - var axes = OrtKISharp.Tensor.MakeTensor(new long[] { 0 }, new long[] { 1 }); - var steps = OrtKISharp.Tensor.MakeTensor(new long[] { 1 }, new long[] { 1 }); - - var orderTmp = OrtKI.Slice(order, starts, ends, axes, steps); - - var orderTmpArr = orderTmp.ToArray(); - var orderArr = order.ToArray(); - var errorArr = error.ToArray(); - var numberArr = number.ToArray(); - for (int i = 0; i < orderTmp.Length; i++) + var orderTmpArr = orderMem.Slice(0, topK); + + for (int i = 0; i < orderTmpArr.Length; i++) { - var index = orderTmpArr[i]; - roundingErrorMem[(int)index] = errorArr[index]; - roundingNumberMem[(int)index] = numberArr[index]; + var index = (int)orderTmpArr[i]; + roundingErrorMem[index] = errorMem[index]; + roundingNumberMem[index] = numberMem[index]; } if (overSquant) { - var index = orderArr[topK - 1]; + var index = (int)orderMem[topK - 1]; priority1Mem[index] = System.Math.Abs(roundingErrorMem[index]); } else { - var index = orderArr[topK]; + var index = (int)orderMem[topK]; priorityMem[index] = System.Math.Abs(roundingErrorMem[index]); } } } - private static void SQuantFunc(OrtKISharp.Tensor roundingErrorSum, OrtKISharp.Tensor roundingNumber, OrtKISharp.Tensor roundingError, OrtKISharp.Tensor upNumber, OrtKISharp.Tensor upError, OrtKISharp.Tensor upPriority, OrtKISharp.Tensor upOrder, OrtKISharp.Tensor downNumber, OrtKISharp.Tensor downError, OrtKISharp.Tensor downPriority, OrtKISharp.Tensor downOrder) + private static void SQuantFunc(OrtKISharp.Tensor roundingErrorSum, OrtKISharp.Tensor roundingNumber, OrtKISharp.Tensor roundingError, OrtKISharp.Tensor upNumber, OrtKISharp.Tensor upError, OrtKISharp.Tensor upPriority, OrtKISharp.Tensor upOrder, OrtKISharp.Tensor downNumber, OrtKISharp.Tensor downError, OrtKISharp.Tensor downPriority, OrtKISharp.Tensor downOrder, bool getNumberOnly) { - var roundingNumberShape = roundingNumber.Shape; - var batches = roundingNumberShape[0]; - var inputChannel = roundingNumberShape[1]; - long totalSize = 1; - for (int i = 0; i < roundingNumberShape.Length; i++) + var roundingNumberShape = roundingNumber.Shape.Select(x => (int)x).ToArray(); + if (roundingNumberShape.Length != 3) { - totalSize *= roundingNumberShape[i]; + throw new InvalidOperationException("Error"); } - var oneBatchSize = totalSize / batches; - var oneInputChannelSize = oneBatchSize / inputChannel; - - var roundingNumberMem = MemoryMarshal.Cast(roundingNumber.BytesBuffer); - var roundingErrorMem = MemoryMarshal.Cast(roundingError.BytesBuffer); - var upPriorityMem = MemoryMarshal.Cast(upPriority.BytesBuffer); - var downPriorityMem = MemoryMarshal.Cast(downPriority.BytesBuffer); - + var batches = roundingNumberShape[0]; + var inputChannel = roundingNumberShape[1]; + var sizePreChannel = roundingNumberShape[2]; var roundingErrorSumArr = roundingErrorSum.ToArray(); - - for (var n = 0; n < batches; n++) + var loopSize = (long)batches * inputChannel; + Parallel.For(0, loopSize, currentIndex => { - for (var c = 0; c < inputChannel; c++) + var n = currentIndex / inputChannel; + var c = currentIndex % inputChannel; + using var starts = OrtKISharp.Tensor.MakeTensor(new long[] { n, c }, new long[] { 2 }); + using var ends = OrtKISharp.Tensor.MakeTensor(new long[] { n + 1, c + 1 }, new long[] { 2 }); + using var axes = OrtKISharp.Tensor.MakeTensor(new long[] { 0, 1 }, new long[] { 2 }); + using var steps = OrtKISharp.Tensor.MakeTensor(new long[] { 1, 1 }, new long[] { 2 }); + + Span Sl(OrtKISharp.Tensor tensor) { - var starts = OrtKISharp.Tensor.MakeTensor(new long[] { n, c }, new long[] { 2 }); - var ends = OrtKISharp.Tensor.MakeTensor(new long[] { n + 1, c + 1 }, new long[] { 2 }); - var axes = OrtKISharp.Tensor.MakeTensor(new long[] { 0, 1 }, new long[] { 2 }); - var steps = OrtKISharp.Tensor.MakeTensor(new long[] { 1, 1 }, new long[] { 2 }); - var roundingNumberTmp = OrtKI.Squeeze(OrtKI.Slice(roundingNumber, starts, ends, axes, steps), axes); - var roundingErrorTmp = OrtKI.Squeeze(OrtKI.Slice(roundingError, starts, ends, axes, steps), axes); - var upNumberSlice = OrtKI.Squeeze(OrtKI.Slice(upNumber, starts, ends, axes, steps), axes); - var upErrorSlice = OrtKI.Squeeze(OrtKI.Slice(upError, starts, ends, axes, steps), axes); - var upOrderSlice = OrtKI.Squeeze(OrtKI.Slice(upOrder, starts, ends, axes, steps), axes); - var downNumberSlice = OrtKI.Squeeze(OrtKI.Slice(downNumber, starts, ends, axes, steps), axes); - var downErrorSlice = OrtKI.Squeeze(OrtKI.Slice(downError, starts, ends, axes, steps), axes); - var downOrderSlice = OrtKI.Squeeze(OrtKI.Slice(downOrder, starts, ends, axes, steps), axes); - - if (roundingErrorSumArr[(n * inputChannel) + c] < 0) - { - var priorityTmp = OrtKI.Squeeze(OrtKI.Slice(upPriority, starts, ends, axes, steps), axes); - var priority1Tmp = OrtKI.Squeeze(OrtKI.Slice(downPriority, starts, ends, axes, steps), axes); - RoundingForward(roundingErrorSumArr[(n * inputChannel) + c], roundingNumberTmp, roundingErrorTmp, upNumberSlice, upErrorSlice, priorityTmp, upOrderSlice, priority1Tmp); - - var roundingNumberTmpArr = roundingNumberTmp.ToArray(); - var roundingErrorTmpArr = roundingErrorTmp.ToArray(); - var priorityTmpArr = priorityTmp.ToArray(); - var priority1TmpArr = priority1Tmp.ToArray(); - for (int i = 0; i < roundingNumberTmp.Length; i++) - { - roundingNumberMem[(n * (int)oneBatchSize) + (c * (int)oneInputChannelSize) + i] = roundingNumberTmpArr[i]; - } - - for (int i = 0; i < roundingErrorTmp.Length; i++) - { - roundingErrorMem[(n * (int)oneBatchSize) + (c * (int)oneInputChannelSize) + i] = roundingErrorTmpArr[i]; - } - - for (int i = 0; i < priorityTmp.Length; i++) - { - upPriorityMem[(n * (int)oneBatchSize) + (c * (int)oneInputChannelSize) + i] = priorityTmpArr[i]; - } + var span = MemoryMarshal.Cast(tensor.BytesBuffer); + var begin = currentIndex * sizePreChannel; + return span.Slice((int)begin, sizePreChannel); + } - for (int i = 0; i < priority1Tmp.Length; i++) - { - downPriorityMem[(n * (int)oneBatchSize) + (c * (int)oneInputChannelSize) + i] = priority1TmpArr[i]; - } - } - else - { - var priorityTmp = OrtKI.Squeeze(OrtKI.Slice(downPriority, starts, ends, axes, steps), axes); - var priority1Tmp = OrtKI.Squeeze(OrtKI.Slice(upPriority, starts, ends, axes, steps), axes); - RoundingForward(roundingErrorSumArr[(n * inputChannel) + c], roundingNumberTmp, roundingErrorTmp, downNumberSlice, downErrorSlice, priorityTmp, downOrderSlice, priority1Tmp); - - var roundingNumberTmpArr = roundingNumberTmp.ToArray(); - var roundingErrorTmpArr = roundingErrorTmp.ToArray(); - var priorityTmpArr = priorityTmp.ToArray(); - var priority1TmpArr = priority1Tmp.ToArray(); - for (int i = 0; i < roundingNumberTmp.Length; i++) - { - roundingNumberMem[(n * (int)oneBatchSize) + (c * (int)oneInputChannelSize) + i] = roundingNumberTmpArr[i]; - } + Span SlInt(OrtKISharp.Tensor tensor) + { + var span = MemoryMarshal.Cast(tensor.BytesBuffer); + var begin = currentIndex * sizePreChannel; + return span.Slice((int)begin, sizePreChannel); + } - for (int i = 0; i < roundingErrorTmp.Length; i++) - { - roundingErrorMem[(n * (int)oneBatchSize) + (c * (int)oneInputChannelSize) + i] = roundingErrorTmpArr[i]; - } + var roundingNumberTmp = Sl(roundingNumber); + var roundingErrorTmp = Sl(roundingError); - for (int i = 0; i < priorityTmp.Length; i++) - { - downPriorityMem[(n * (int)oneBatchSize) + (c * (int)oneInputChannelSize) + i] = priorityTmpArr[i]; - } + var upNumberSlice = Sl(upNumber); + var upErrorSlice = Sl(upError); + var upOrderSlice = SlInt(upOrder); + var downNumberSlice = Sl(downNumber); + var downErrorSlice = Sl(downError); + var downOrderSlice = SlInt(downOrder); - for (int i = 0; i < priority1Tmp.Length; i++) - { - upPriorityMem[(n * (int)oneBatchSize) + (c * (int)oneInputChannelSize) + i] = priority1TmpArr[i]; - } - } + Span priorityTmp; + Span priority1Tmp; + if (roundingErrorSumArr[currentIndex] < 0) + { + priorityTmp = Sl(upPriority); + priority1Tmp = Sl(downPriority); + RoundingForward(roundingErrorSumArr[currentIndex], roundingNumberTmp, roundingErrorTmp, upNumberSlice, upErrorSlice, priorityTmp, upOrderSlice, priority1Tmp); } - } + else + { + priorityTmp = Sl(downPriority); + priority1Tmp = Sl(upPriority); + RoundingForward(roundingErrorSumArr[currentIndex], roundingNumberTmp, roundingErrorTmp, downNumberSlice, downErrorSlice, priorityTmp, downOrderSlice, priority1Tmp); + } + }); } private static OrtKISharp.Tensor AdaptiveRound(OrtKISharp.Tensor x, float tMin, float tMax) @@ -295,11 +247,13 @@ private static OrtKISharp.Tensor AdaptiveRound(OrtKISharp.Tensor x, float tMin, if (squantK) { var roundingErrorSum = OrtKI.ReduceSum(OrtKI.Reshape(roundingError, converShape, 0), new long[] { -1 }, 0, 0); - var upPriorityK = OrtKI.Reshape(upPriority, converShape, 0).Shape[OrtKI.Reshape(upPriority, converShape, 0).Shape.Length - 1]; - var sortRet = OrtKI.TopK(OrtKI.Reshape(upPriority, converShape, 0), OrtKISharp.Tensor.MakeTensor(new long[] { upPriorityK }, new long[] { 1 }), -1, 1, 1); + var reshapeUpPriority = OrtKI.Reshape(upPriority, converShape, 0); + var upPriorityK = reshapeUpPriority.Shape[^1]; + var sortRet = OrtKI.TopK(reshapeUpPriority, OrtKISharp.Tensor.MakeTensor(new long[] { upPriorityK }, new long[] { 1 }), -1, 1, 1); var upOrder = sortRet[1]; - var downPriorityK = OrtKI.Reshape(downPriority, converShape, 0).Shape[OrtKI.Reshape(downPriority, converShape, 0).Shape.Length - 1]; - sortRet = OrtKI.TopK(OrtKI.Reshape(downPriority, converShape, 0), OrtKISharp.Tensor.MakeTensor(new long[] { downPriorityK }, new long[] { 1 }), -1, 1, 1); + var reshapeDownPriority = OrtKI.Reshape(downPriority, converShape, 0); + var downPriorityK = reshapeDownPriority.Shape[^1]; + sortRet = OrtKI.TopK(reshapeDownPriority, OrtKISharp.Tensor.MakeTensor(new long[] { downPriorityK }, new long[] { 1 }), -1, 1, 1); var downOrder = sortRet[1]; upPriority *= 0.0f; downPriority *= 0.0f; @@ -312,7 +266,7 @@ private static OrtKISharp.Tensor AdaptiveRound(OrtKISharp.Tensor x, float tMin, downNumber = OrtKI.Reshape(downNumber, converShape, 0); downError = OrtKI.Reshape(downError, converShape, 0); downPriority = OrtKI.Reshape(downPriority, converShape, 0); - SQuantFunc(roundingErrorSum, roundingNumber, roundingError, upNumber, upError, upPriority, upOrder, downNumber, downError, downPriority, downOrder); + SQuantFunc(roundingErrorSum, roundingNumber, roundingError, upNumber, upError, upPriority, upOrder, downNumber, downError, downPriority, downOrder, false); roundingNumber = OrtKI.Reshape(roundingNumber, x.Shape, 0); roundingError = OrtKI.Reshape(roundingError, x.Shape, 0); upPriority = OrtKI.Reshape(upPriority, x.Shape, 0); @@ -323,11 +277,13 @@ private static OrtKISharp.Tensor AdaptiveRound(OrtKISharp.Tensor x, float tMin, { converShape = new long[] { 1, x.Shape[0], -1 }; var roundingErrorSum = OrtKI.ReduceSum(OrtKI.Reshape(roundingError, converShape, 0), new long[] { -1 }, 0, 0); - var upPriorityK = OrtKI.Reshape(upPriority, converShape, 0).Shape[OrtKI.Reshape(upPriority, converShape, 0).Shape.Length - 1]; - var sortRet = OrtKI.TopK(OrtKI.Reshape(upPriority, converShape, 0), OrtKISharp.Tensor.MakeTensor(new long[] { upPriorityK }, new long[] { 1 }), -1, 1, 1); + var reshapePriority = OrtKI.Reshape(upPriority, converShape, 0); + var upPriorityK = reshapePriority.Shape[^1]; + var sortRet = OrtKI.TopK(reshapePriority, OrtKISharp.Tensor.MakeTensor(new long[] { upPriorityK }, new long[] { 1 }), -1, 1, 1); var upOrder = sortRet[1]; - var downPriorityK = OrtKI.Reshape(downPriority, converShape, 0).Shape[OrtKI.Reshape(downPriority, converShape, 0).Shape.Length - 1]; - sortRet = OrtKI.TopK(OrtKI.Reshape(downPriority, converShape, 0), OrtKISharp.Tensor.MakeTensor(new long[] { downPriorityK }, new long[] { 1 }), -1, 1, 1); + var reshapeDownPriority = OrtKI.Reshape(downPriority, converShape, 0); + var downPriorityK = reshapeDownPriority.Shape[^1]; + sortRet = OrtKI.TopK(reshapeDownPriority, OrtKISharp.Tensor.MakeTensor(new long[] { downPriorityK }, new long[] { 1 }), -1, 1, 1); var downOrder = sortRet[1]; roundingNumber = OrtKI.Reshape(roundingNumber, converShape, 0); @@ -338,13 +294,10 @@ private static OrtKISharp.Tensor AdaptiveRound(OrtKISharp.Tensor x, float tMin, downNumber = OrtKI.Reshape(downNumber, converShape, 0); downError = OrtKI.Reshape(downError, converShape, 0); downPriority = OrtKI.Reshape(downPriority, converShape, 0); - SQuantFunc(roundingErrorSum, roundingNumber, roundingError, upNumber, upError, upPriority, upOrder, downNumber, downError, downPriority, downOrder); + SQuantFunc(roundingErrorSum, roundingNumber, roundingError, upNumber, upError, upPriority, upOrder, downNumber, downError, downPriority, downOrder, true); } roundingNumber = OrtKI.Reshape(roundingNumber, x.Shape, 0); - _ = OrtKI.Reshape(roundingError, x.Shape, 0); - _ = OrtKI.Reshape(upPriority, x.Shape, 0); - _ = OrtKI.Reshape(downPriority, x.Shape, 0); return roundingNumber!; } diff --git a/src/Nncase.Quantization/Quantization/Quantizer.cs b/src/Nncase.Quantization/Quantization/Quantizer.cs index c9b5fa9daa..257ee001eb 100644 --- a/src/Nncase.Quantization/Quantization/Quantizer.cs +++ b/src/Nncase.Quantization/Quantization/Quantizer.cs @@ -325,25 +325,21 @@ public async Task RunAsync(RunPassContext options) _graph.Rebuild(); } - private async Task RunPassAsync(ICalibrationDatasetProvider calibrationDataset, Action, IReadOnlyDictionary> func) + private async Task RunForHistogramsAsync(ICalibrationDatasetProvider calibrationDataset, Action> func) { await foreach (var sample in calibrationDataset.Samples) { - IReadOnlyDictionary values, childrenValues; - using (var dumpScope = new DumpScope("ep1")) - { - var evaluator = new CalibrationEvaluator(sample, _rangeOfs); - values = evaluator.Evaluate(); - } - + IReadOnlyDictionary childrenValues; using (var dumpScope2 = new DumpScope("ep2")) { var childrenEvaluator = new CalibrationEvaluator(sample, _childrenOfRangeOfs); - childrenValues = childrenEvaluator.Evaluate(); + var tmpChildrenValues = childrenEvaluator.Evaluate().ToList(); + childrenValues = tmpChildrenValues.Zip(_rangeOfs).ToDictionary(pair => pair.Second, pair => pair.First.Value); } // values are children op range values(only two scalars for each value: Min and Max), childrenValues are children op tensor values. - func(values, childrenValues); + func(childrenValues); + GC.Collect(); } } @@ -355,6 +351,7 @@ private async Task RunPassAsync(ICalibrationDatasetProvider calibrationDataset, var evaluator = new CalibrationEvaluator(sample, _rangeOfs); var values = evaluator.Evaluate(); func(values); + GC.Collect(); } } @@ -511,26 +508,23 @@ private void AssignDataTypeFromConfig(QuantScheme quantScheme) private async Task>> GetHistogramsAsync(ICalibrationDatasetProvider calibrationDataset, IDictionary> ranges, int srcBinSize, int dstBinSize) { var histograms = new Dictionary>(ReferenceEqualityComparer.Instance); - await RunPassAsync(calibrationDataset, (values, childrenValues) => + foreach (var (key, value) in ranges) { - var valuesList = values.ToList(); - var childrenValuesList = childrenValues.ToList(); - for (int i = 0; i < valuesList.Count; i++) + var initSrcBin = new List(new float[srcBinSize]); + histograms[key] = new QuantizeHistogram(initSrcBin, initSrcBin); + } + + await RunForHistogramsAsync(calibrationDataset, childrenValues => + { + foreach (var (key, value) in childrenValues) { - var r = ranges[valuesList[i].Key].Max - ranges[valuesList[i].Key].Min; + var r = ranges[key].Max - ranges[key].Min; var srcBinInterval = r / srcBinSize; - if (!histograms.TryGetValue(valuesList[i].Key, out var histogram)) - { - var initSrcBin = new List(new float[srcBinSize]); - var initDstBin = new List(new float[dstBinSize]); - histogram = new QuantizeHistogram(initSrcBin, initDstBin); - histograms.Add(valuesList[i].Key, histogram); - } - var childrenTensor = childrenValuesList[i].Value.Cast(); + var childrenTensor = value.Cast(); var childrenBuffer = childrenTensor.Buffer.Span; - var valueRange = ranges[valuesList[i].Key]; - + var valueRange = ranges[key]; + var histogram = histograms[key]; foreach (var buf in childrenBuffer) { var r_index = (buf - valueRange.Min) / srcBinInterval; diff --git a/src/Nncase.Studio/Assets/calib.png b/src/Nncase.Studio/Assets/calib.png new file mode 100644 index 0000000000..b1524fda7e Binary files /dev/null and b/src/Nncase.Studio/Assets/calib.png differ diff --git a/src/Nncase.Studio/Util/DataUtil.cs b/src/Nncase.Studio/Util/DataUtil.cs index 3144a5c518..c15ffcf2da 100644 --- a/src/Nncase.Studio/Util/DataUtil.cs +++ b/src/Nncase.Studio/Util/DataUtil.cs @@ -45,15 +45,17 @@ public static (string[] InputFiles, Tensor[] InputList) ReadMultiInputs(List ReadInput(string[] file) { return file .Where(f => Path.GetExtension(f) == ".npy") - .Select(f => - { - var tensor = np.load(f); - return Tensor.FromBytes(new TensorType(DataType.FromType(tensor.dtype), tensor.shape), tensor.ToByteArray()); - }).ToList(); + .Select(ReadNumpyAsTensor).ToList(); } public static DataType QuantTypeToDataType(QuantType qt) diff --git a/src/Nncase.Studio/Util/ViewModelContext.cs b/src/Nncase.Studio/Util/ViewModelContext.cs index 70e476dafe..c96587302e 100644 --- a/src/Nncase.Studio/Util/ViewModelContext.cs +++ b/src/Nncase.Studio/Util/ViewModelContext.cs @@ -32,7 +32,7 @@ public ViewModelContext(MainWindowViewModel windowViewModel) public ViewModelBase[] ViewModelBases { get; set; } = Array.Empty(); - public Function? Entry { get; set; } + public Var[] Params { get; set; } = Array.Empty(); public async Task> OpenFile(FilePickerOpenOptions options) { diff --git a/src/Nncase.Studio/ViewModels/CompileViewModel.cs b/src/Nncase.Studio/ViewModels/CompileViewModel.cs index 8977f73139..1d67483ea0 100644 --- a/src/Nncase.Studio/ViewModels/CompileViewModel.cs +++ b/src/Nncase.Studio/ViewModels/CompileViewModel.cs @@ -3,6 +3,7 @@ using System; using System.IO; +using System.Linq; using System.Threading; using System.Threading.Tasks; using System.Windows.Input; @@ -73,12 +74,13 @@ public async Task Compile() var compileSession = CompileSession.Create(target, options); var compiler = compileSession.Compiler; var module = await compiler.ImportModuleAsync(options.InputFormat, options.InputFile, options.IsBenchmarkOnly); - Context.Entry = (Function)module.Entry!; + Context.Params = ((Function)module.Entry!).Parameters.ToArray(); if (options.QuantizeOptions.ModelQuantMode != ModelQuantMode.NoQuant) { var calib = ((QuantizeViewModel)Context.ViewModelLookup(typeof(QuantizeViewModel))).LoadCalibFiles(); if (calib == null) { + Context.OpenDialog("矫正集为空"); return; } diff --git a/src/Nncase.Studio/ViewModels/QuantizeViewModel.cs b/src/Nncase.Studio/ViewModels/QuantizeViewModel.cs index 974e3b147d..130c23c80c 100644 --- a/src/Nncase.Studio/ViewModels/QuantizeViewModel.cs +++ b/src/Nncase.Studio/ViewModels/QuantizeViewModel.cs @@ -10,6 +10,7 @@ using Avalonia.Media.Fonts; using CommunityToolkit.Mvvm.ComponentModel; using CommunityToolkit.Mvvm.Input; +using NetFabric.Hyperlinq; using Nncase.IR; using Nncase.Quantization; using Nncase.Studio.Util; @@ -49,7 +50,7 @@ public partial class QuantizeViewModel : ViewModelBase [ObservableProperty] private string _calibDir = string.Empty; - private string[] _inputFiles = Array.Empty(); + private string[][] _multiInputFiles = Array.Empty(); public QuantizeViewModel(ViewModelContext context) { @@ -90,6 +91,23 @@ public async Task SelectCalibrationDataSet() } var inputFiles = Directory.GetFiles(path); + + try + { + var n = inputFiles.Where(f => Path.GetExtension(f) == ".npy").GroupBy(s => Path.GetFileName(s).Split("_")[0]); + _multiInputFiles = n.Select(group => + { + var value = group.ToArray(); + var one = value.OrderBy(s => int.Parse(Path.GetFileName(s).Split("_")[1])).ToArray(); + return one; + }).ToArray(); + } + catch (Exception e) + { + Context.OpenDialog($"文件夹中的文件解析失败,请检查文件名是否符合格式。\n{e.Message}"); + return; + } + if (inputFiles.Length == 0) { Context.OpenDialog("empty dir"); @@ -97,37 +115,39 @@ public async Task SelectCalibrationDataSet() } CalibDir = path; - _inputFiles = inputFiles; + } + + [RelayCommand] + public void ShowCalibFormat() + { + new QuantizeCalibWindow().Show(); } public ICalibrationDatasetProvider? LoadCalibFiles() { - Tensor[] input; try { - input = DataUtil.ReadInput(_inputFiles).ToArray(); + var samples = _multiInputFiles.Select(files => + { + var input = files.Select(DataUtil.ReadNumpyAsTensor).ToArray(); + var samples = Context.Params.Zip(input) + .ToDictionary(pair => pair.First, pair => (IValue)Value.FromTensor(pair.Second)); + return samples; + }).ToArray(); + + if (Context.Params.Length == 0) + { + Context.OpenDialog("Should Import Model first"); + return null; + } + + return new SelfInputCalibrationDatasetProvider(samples); } catch (Exception e) { Context.OpenDialog(e.Message); return null; } - - if (input.Length == 0) - { - Context.OpenDialog("no file is loaded, only support .npy"); - return null; - } - - if (Context.Entry == null) - { - Context.OpenDialog("Should Import Model first"); - return null; - } - - var samples = Context.Entry!.Parameters.ToArray().Zip(input) - .ToDictionary(pair => pair.First, pair => (IValue)Value.FromTensor(pair.Second)); - return new SelfInputCalibrationDatasetProvider(samples); } public override void UpdateViewModelCore(CompileConfig config) @@ -171,7 +191,7 @@ public override List CheckViewModel() { if (Directory.Exists(CalibDir)) { - if (_inputFiles.Length == 0) + if (_multiInputFiles.Length == 0) { list.Add("CalibDir don't exist any .npy file"); } @@ -200,9 +220,9 @@ public sealed class SelfInputCalibrationDatasetProvider : ICalibrationDatasetPro private readonly IAsyncEnumerable> _samples; - public SelfInputCalibrationDatasetProvider(IReadOnlyDictionary sample) + public SelfInputCalibrationDatasetProvider(IReadOnlyDictionary[] samples) { - _samples = new[] { sample }.ToAsyncEnumerable(); + _samples = samples.ToAsyncEnumerable(); } public int? Count => _count; diff --git a/src/Nncase.Studio/ViewModels/SimulateViewModel.cs b/src/Nncase.Studio/ViewModels/SimulateViewModel.cs index c7fc1a62d5..2f6b0521cb 100644 --- a/src/Nncase.Studio/ViewModels/SimulateViewModel.cs +++ b/src/Nncase.Studio/ViewModels/SimulateViewModel.cs @@ -60,6 +60,17 @@ public async Task SetRuntimeInput() try { (inputFiles, input) = DataUtil.ReadMultiInputs(path); + if (CanBeSort(inputFiles)) + { + var pairList = inputFiles.Zip(input) + .OrderBy(pair => int.Parse(Path.GetFileName(pair.First).Split("_")[0])); + inputFiles = pairList.Select(x => x.First).ToArray(); + input = pairList.Select(x => x.Second).ToArray(); + } + else + { + Context.OpenDialog("输入文件未排序,可能出现输入无法正确对应的情况"); + } } catch (Exception e) { @@ -68,6 +79,14 @@ public async Task SetRuntimeInput() } UpdateRuntimeInputUI(input, inputFiles); + + bool CanBeSort(string[] inputFilePathList) + { + var fileNames = inputFilePathList.Select(Path.GetFileName).ToArray(); + var canBeSort = fileNames.All(x => + x!.Contains("_", StringComparison.Ordinal) && int.TryParse(x.Split("_")[0], out int _)); + return canBeSort; + } } [RelayCommand] @@ -163,12 +182,12 @@ public override void UpdateConfig(CompileConfig config) private bool CheckInput() { - if (Context.Entry == null) + if (Context.Params.Length == 0) { return true; } - var paramList = Context.Entry!.Parameters.ToArray(); + var paramList = Context.Params!; foreach ((var tensor, var param) in RuntimeInput.Zip(paramList)) { var tt = (TensorType)param.TypeAnnotation; @@ -176,7 +195,7 @@ private bool CheckInput() { Context.OpenDialog($"{param.Name} input datatype mismatch"); { - return true; + return false; } } @@ -185,12 +204,12 @@ private bool CheckInput() { Context.OpenDialog($"{param.Name} input shape mismatch"); { - return true; + return false; } } } - return false; + return true; } [RelayCommand] diff --git a/src/Nncase.Studio/Views/QuantizeCalibWindow.axaml b/src/Nncase.Studio/Views/QuantizeCalibWindow.axaml new file mode 100644 index 0000000000..418f8a5f47 --- /dev/null +++ b/src/Nncase.Studio/Views/QuantizeCalibWindow.axaml @@ -0,0 +1,15 @@ + + + 数据集文件夹中仅会读取npy格式的文件 + 文件名需要以"第几组数据集_第几个输入_文件名.bin"的格式,例如下图中 + + + diff --git a/src/Nncase.Studio/Views/QuantizeCalibWindow.axaml.cs b/src/Nncase.Studio/Views/QuantizeCalibWindow.axaml.cs new file mode 100644 index 0000000000..44502dad93 --- /dev/null +++ b/src/Nncase.Studio/Views/QuantizeCalibWindow.axaml.cs @@ -0,0 +1,17 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using Avalonia; +using Avalonia.Controls; +using Avalonia.Markup.Xaml; + +namespace Nncase.Studio.Views +{ + public partial class QuantizeCalibWindow : Window + { + public QuantizeCalibWindow() + { + InitializeComponent(); + } + } +} diff --git a/src/Nncase.Studio/Views/QuantizeView.axaml b/src/Nncase.Studio/Views/QuantizeView.axaml index 68ac94469e..ba0a0d010e 100644 --- a/src/Nncase.Studio/Views/QuantizeView.axaml +++ b/src/Nncase.Studio/Views/QuantizeView.axaml @@ -42,23 +42,27 @@ Content="选择" x:CompileBindings="False" Command="{Binding SelectCalibrationDataSetCommand}"> - - - - - - - - + + - - - - + + - - + + + + + + + + + + + + + diff --git a/src/Nncase.Studio/Views/SimulateView.axaml b/src/Nncase.Studio/Views/SimulateView.axaml index 4e6f895ff6..b6eca40cf3 100644 --- a/src/Nncase.Studio/Views/SimulateView.axaml +++ b/src/Nncase.Studio/Views/SimulateView.axaml @@ -20,16 +20,19 @@ - - - - - - - - - + + + + + + + + + + + +