diff --git a/modules/Nncase.Modules.CPU/Passes/Rules/FusionMerger.cs b/modules/Nncase.Modules.CPU/Passes/Rules/FusionMerger.cs new file mode 100644 index 0000000000..0b1d9923ab --- /dev/null +++ b/modules/Nncase.Modules.CPU/Passes/Rules/FusionMerger.cs @@ -0,0 +1,81 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Net.Http.Headers; +using System.Reactive; +using System.Text; +using System.Threading.Tasks; +using DryIoc.ImTools; +using Google.OrTools.LinearSolver; +using Nncase.IR; +using Nncase.IR.CPU; +using Nncase.IR.Math; +using Nncase.IR.NN; +using Nncase.IR.Tensors; +using Nncase.Passes.Rules.Neutral; +using Nncase.PatternMatch; +using Nncase.Targets; +using static Nncase.PatternMatch.F.Math; +using static Nncase.PatternMatch.Utility; +using static Nncase.Utilities.ReplaceUtility; + +namespace Nncase.Passes.Rules; + +/// +/// Unet Merger for all. +/// +public sealed class FusionMerger : ExprCloner +{ + private readonly IReadOnlyDictionary _multiVarMap; + + public FusionMerger(IReadOnlyDictionary multiVarMap) + { + _multiVarMap = multiVarMap; + } + + protected override Expr VisitCall(Call expr, Unit context) + { + if (_multiVarMap.TryGetValue(expr, out var newVar)) + { + return newVar; + } + + return base.VisitCall(expr, context); + } + + protected override Expr VisitLeafCall(Call expr, Unit context) + { + var target = Clone(expr.Target, context); + var arguments = CloneArray(expr.Arguments, context); + if (target is Binary) + { + arguments = arguments.Select(e => e switch { TensorConst { Value: Tensor { Shape.IsScalar: true } } tc => Const.FromTensor(Tensor.FromBytes(tc.ValueType.DType, tc.Value.BytesBuffer.ToArray(), new[] { 1 })), _ => e }).ToArray(); + } + + if (target is Conv2D conv) + { + var bias = (TensorConst)arguments[2]; + var fusedClamp = ((TensorConst)arguments[7]).Value.ToArray(); + var newConv = IR.F.NN.Conv2D(arguments[0], arguments[1], Tensor.Zeros(bias.CheckedShape), arguments[3], arguments[4], arguments[5], conv.PadMode, arguments[6], new[] { float.NegativeInfinity, float.PositiveInfinity }); + var newBias = IR.F.Math.Add(newConv, Tensor.FromBytes(bias.CheckedDataType, bias.Value.BytesBuffer.ToArray(), new[] { bias.CheckedShape[0].FixedValue, 1, 1 })); + var newClamp = IR.F.Math.Clamp(newBias, fusedClamp[0], fusedClamp[1]); + return newClamp; + } + + return expr.With(target: target, arguments: arguments); + } + + protected override Expr VisitLeafVar(Var expr, Unit context) + { + if (_multiVarMap.TryGetValue(expr, out var newVar)) + { + return newVar; + } + + throw new InvalidOperationException(); + } +} + diff --git a/modules/Nncase.Modules.CPU/Passes/Rules/MHAFusion.cs b/modules/Nncase.Modules.CPU/Passes/Rules/MHAFusion.cs index 74ca6d7929..cd1d6040a7 100644 --- a/modules/Nncase.Modules.CPU/Passes/Rules/MHAFusion.cs +++ b/modules/Nncase.Modules.CPU/Passes/Rules/MHAFusion.cs @@ -23,51 +23,6 @@ namespace Nncase.Passes.Rules; -/// -/// MHA Merger for all. -/// -public sealed class MHAMerger : ExprCloner -{ - private readonly IReadOnlyDictionary _multiVarMap; - - public MHAMerger(IReadOnlyDictionary multiVarMap) - { - _multiVarMap = multiVarMap; - } - - protected override Expr VisitCall(Call expr, Unit context) - { - if (_multiVarMap.TryGetValue(expr, out var newVar)) - { - return newVar; - } - - return base.VisitCall(expr, context); - } - - protected override Expr VisitLeafCall(Call expr, Unit context) - { - var target = Clone(expr.Target, context); - var arguments = CloneArray(expr.Arguments, context); - if (target is Binary) - { - arguments = arguments.Select(e => e switch { TensorConst { Value: Tensor { Shape.IsScalar: true } } tc => Const.FromTensor(Tensor.FromBytes(tc.ValueType.DType, tc.Value.BytesBuffer.ToArray(), new[] { 1 })), _ => e }).ToArray(); - } - - return expr.With(target: target, arguments: arguments); - } - - protected override Expr VisitLeafVar(Var expr, Unit context) - { - if (_multiVarMap.TryGetValue(expr, out var newVar)) - { - return newVar; - } - - throw new InvalidOperationException(); - } -} - // pattern from BERT base [RuleGenerator] public sealed partial class FuseMHA1 : FusionMaker @@ -248,7 +203,7 @@ private static Pattern CreatePattern() { position_ids, (Var)newInputs[1] }, { attn_mask, (Var)newInputs[2] }, }; - var merger = new MHAMerger(multiVarMap); + var merger = new FusionMerger(multiVarMap); var clonedRoot = merger.Clone(root, default); var callFusion = new Call(new Fusion("MHALLaMA65B", $"{nameof(FuseMHA2)}_{Count++}", ModuleKind, clonedRoot, newInputs.OfType().ToArray()), hidden_in, position_ids, attn_mask); @@ -308,7 +263,7 @@ private static Pattern CreatePattern() { { input, (Var)newInputs[0] }, }; - var merger = new MHAMerger(multiVarMap); + var merger = new FusionMerger(multiVarMap); var clonedRoot = merger.Clone(root, default); var callFusion = new Call(new Fusion("SDTextEncoderMHA", $"{nameof(FuseMHA3)}_{Count++}", ModuleKind, clonedRoot, newInputs.OfType().ToArray()), input); @@ -345,7 +300,7 @@ private static Pattern CreatePattern() { { input, newInputs[0] }, }; - var merger = new MHAMerger(multiVarMap); + var merger = new FusionMerger(multiVarMap); var clonedRoot = merger.Clone(root, default); var callFusion = new Call(new Fusion("SDTextEncoderHeader", $"{nameof(FuseSDTextEncoderHeader)}_{Count++}", ModuleKind, clonedRoot, newInputs.ToArray()), input); @@ -389,7 +344,7 @@ private static Pattern CreatePattern() { input, newInputs[1] }, }; - var merger = new MHAMerger(multiVarMap); + var merger = new FusionMerger(multiVarMap); var clonedRoot = merger.Clone(root, default); var callFusion = new Call(new Fusion("SDTextEncoderTail", $"{nameof(FuseSDTextEncoderTail)}_{Count++}", ModuleKind, clonedRoot, newInputs), input_ids, input); diff --git a/modules/Nncase.Modules.CPU/Passes/Rules/UnetFusion.cs b/modules/Nncase.Modules.CPU/Passes/Rules/UnetFusion.cs index ae382f985b..ae30765b80 100644 --- a/modules/Nncase.Modules.CPU/Passes/Rules/UnetFusion.cs +++ b/modules/Nncase.Modules.CPU/Passes/Rules/UnetFusion.cs @@ -27,61 +27,6 @@ namespace Nncase.Passes.Rules; -/// -/// Unet Merger for all. -/// -public sealed class UnetMerger : ExprCloner -{ - private readonly IReadOnlyDictionary _multiVarMap; - - public UnetMerger(IReadOnlyDictionary multiVarMap) - { - _multiVarMap = multiVarMap; - } - - protected override Expr VisitCall(Call expr, Unit context) - { - if (_multiVarMap.TryGetValue(expr, out var newVar)) - { - return newVar; - } - - return base.VisitCall(expr, context); - } - - protected override Expr VisitLeafCall(Call expr, Unit context) - { - var target = Clone(expr.Target, context); - var arguments = CloneArray(expr.Arguments, context); - if (target is Binary) - { - arguments = arguments.Select(e => e switch { TensorConst { Value: Tensor { Shape.IsScalar: true } } tc => Const.FromTensor(Tensor.FromBytes(tc.ValueType.DType, tc.Value.BytesBuffer.ToArray(), new[] { 1 })), _ => e }).ToArray(); - } - - if (target is Conv2D conv) - { - var bias = (TensorConst)arguments[2]; - var fusedClamp = ((TensorConst)arguments[7]).Value.ToArray(); - var newConv = IR.F.NN.Conv2D(arguments[0], arguments[1], Tensor.Zeros(bias.CheckedShape), arguments[3], arguments[4], arguments[5], conv.PadMode, arguments[6], new[] { float.NegativeInfinity, float.PositiveInfinity }); - var newBias = IR.F.Math.Add(newConv, Tensor.FromBytes(bias.CheckedDataType, bias.Value.BytesBuffer.ToArray(), new[] { bias.CheckedShape[0].FixedValue, 1, 1 })); - var newClamp = IR.F.Math.Clamp(newBias, fusedClamp[0], fusedClamp[1]); - return newClamp; - } - - return expr.With(target: target, arguments: arguments); - } - - protected override Expr VisitLeafVar(Var expr, Unit context) - { - if (_multiVarMap.TryGetValue(expr, out var newVar)) - { - return newVar; - } - - throw new InvalidOperationException(); - } -} - /// /// stable-disffusion Unet spatial transformer. /// @@ -194,7 +139,7 @@ private static Pattern CreatePattern() { input, (Var)newInputs[0] }, { encoderHiddenStates, (Var)newInputs[1] }, }; - var merger = new UnetMerger(multiVarMap); + var merger = new FusionMerger(multiVarMap); var clonedRoot = merger.Clone(root, default); var callFusion = new Call(new Fusion("UnetSpatialTransformer", $"{nameof(FuseUnetSpatialTransformer)}_{Count++}", ModuleKind, clonedRoot, newInputs.OfType().ToArray()), input, encoderHiddenStates); @@ -282,7 +227,7 @@ private static Pattern CreatePattern() multiVarMap.Add(oldInputs[i], (Var)newInputs[i]); } - var merger = new UnetMerger(multiVarMap); + var merger = new FusionMerger(multiVarMap); var clonedRoot = merger.Clone(root, default); var callFusion = new Call(new Fusion("UnetResBlock", $"{nameof(FuseUnetResBlock)}_{Count++}", ModuleKind, clonedRoot, newInputs.OfType().ToArray()), oldInputs.ToArray()); @@ -336,7 +281,7 @@ private static Pattern CreatePattern() { { input, (Var)newInputs[0] }, }; - var merger = new UnetMerger(multiVarMap); + var merger = new FusionMerger(multiVarMap); var clonedRoot = merger.Clone(root, default); var callFusion = new Call(new Fusion("UnetTimeEmb", $"{nameof(FuseUnetTimeEmb)}_{Count++}", ModuleKind, clonedRoot, newInputs.OfType().ToArray()), input); diff --git a/modules/Nncase.Modules.CPU/Passes/Rules/VAEFusion.cs b/modules/Nncase.Modules.CPU/Passes/Rules/VAEFusion.cs index b43271a7b6..59e22d2587 100644 --- a/modules/Nncase.Modules.CPU/Passes/Rules/VAEFusion.cs +++ b/modules/Nncase.Modules.CPU/Passes/Rules/VAEFusion.cs @@ -24,51 +24,6 @@ namespace Nncase.Passes.Rules; -/// -/// VAE Merger for all. -/// -public sealed class VAEMerger : ExprCloner -{ - private readonly IReadOnlyDictionary _multiVarMap; - - public VAEMerger(IReadOnlyDictionary multiVarMap) - { - _multiVarMap = multiVarMap; - } - - protected override Expr VisitCall(Call expr, Unit context) - { - if (_multiVarMap.TryGetValue(expr, out var newVar)) - { - return newVar; - } - - return base.VisitCall(expr, context); - } - - protected override Expr VisitLeafCall(Call expr, Unit context) - { - var target = Clone(expr.Target, context); - var arguments = CloneArray(expr.Arguments, context); - if (target is Binary) - { - arguments = arguments.Select(e => e switch { TensorConst { Value: Tensor { Shape.IsScalar: true } } tc => Const.FromTensor(Tensor.FromBytes(tc.ValueType.DType, tc.Value.BytesBuffer.ToArray(), new[] { 1 })), _ => e }).ToArray(); - } - - return expr.With(target: target, arguments: arguments); - } - - protected override Expr VisitLeafVar(Var expr, Unit context) - { - if (_multiVarMap.TryGetValue(expr, out var newVar)) - { - return newVar; - } - - throw new InvalidOperationException(); - } -} - /// /// stable-disffusion VAE Decoder res-block. /// @@ -120,7 +75,7 @@ private static Pattern CreatePattern() { { input, (Var)newInputs[0] }, }; - var merger = new VAEMerger(multiVarMap); + var merger = new FusionMerger(multiVarMap); var clonedRoot = merger.Clone(root, default); var callFusion = new Call(new Fusion("VAEDecRes", $"{nameof(FuseVAEDecRes)}_{Count++}", ModuleKind, clonedRoot, newInputs.OfType().ToArray()), input); @@ -163,7 +118,7 @@ private static Pattern CreatePattern() { { input, (Var)newInputs[0] }, }; - var merger = new VAEMerger(multiVarMap); + var merger = new FusionMerger(multiVarMap); var clonedRoot = merger.Clone(root, default); var callFusion = new Call(new Fusion("VAEDecHead", $"{nameof(FuseVAEDecHead)}_{Count++}", ModuleKind, clonedRoot, newInputs.OfType().ToArray()), input); @@ -222,7 +177,7 @@ private static Pattern CreatePattern() { { input, (Var)newInputs[0] }, }; - var merger = new VAEMerger(multiVarMap); + var merger = new FusionMerger(multiVarMap); var clonedRoot = merger.Clone(root, default); var callFusion = new Call(new Fusion("VAEDecMHA", $"{nameof(FuseVAEDecMHA)}_{Count++}", ModuleKind, clonedRoot, newInputs.OfType().ToArray()), input);