Skip to content

Commit

Permalink
update fusion merger
Browse files Browse the repository at this point in the history
  • Loading branch information
xhuohai committed Oct 23, 2023
1 parent 33b3ebd commit aafa155
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 155 deletions.
81 changes: 81 additions & 0 deletions modules/Nncase.Modules.CPU/Passes/Rules/FusionMerger.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.

using System;
using System.Collections.Generic;
using System.Linq;
using System.Net.Http.Headers;
using System.Reactive;
using System.Text;
using System.Threading.Tasks;
using DryIoc.ImTools;
using Google.OrTools.LinearSolver;
using Nncase.IR;
using Nncase.IR.CPU;
using Nncase.IR.Math;
using Nncase.IR.NN;
using Nncase.IR.Tensors;
using Nncase.Passes.Rules.Neutral;
using Nncase.PatternMatch;
using Nncase.Targets;
using static Nncase.PatternMatch.F.Math;
using static Nncase.PatternMatch.Utility;
using static Nncase.Utilities.ReplaceUtility;

namespace Nncase.Passes.Rules;

/// <summary>
/// Unet Merger for all.
/// </summary>
public sealed class FusionMerger : ExprCloner<Unit>
{
private readonly IReadOnlyDictionary<Expr, Var> _multiVarMap;

public FusionMerger(IReadOnlyDictionary<Expr, Var> multiVarMap)
{
_multiVarMap = multiVarMap;
}

protected override Expr VisitCall(Call expr, Unit context)
{
if (_multiVarMap.TryGetValue(expr, out var newVar))
{
return newVar;
}

return base.VisitCall(expr, context);
}

protected override Expr VisitLeafCall(Call expr, Unit context)
{
var target = Clone(expr.Target, context);
var arguments = CloneArray(expr.Arguments, context);
if (target is Binary)
{
arguments = arguments.Select(e => e switch { TensorConst { Value: Tensor { Shape.IsScalar: true } } tc => Const.FromTensor(Tensor.FromBytes(tc.ValueType.DType, tc.Value.BytesBuffer.ToArray(), new[] { 1 })), _ => e }).ToArray();
}

if (target is Conv2D conv)
{
var bias = (TensorConst)arguments[2];
var fusedClamp = ((TensorConst)arguments[7]).Value.ToArray<float>();
var newConv = IR.F.NN.Conv2D(arguments[0], arguments[1], Tensor.Zeros<float>(bias.CheckedShape), arguments[3], arguments[4], arguments[5], conv.PadMode, arguments[6], new[] { float.NegativeInfinity, float.PositiveInfinity });
var newBias = IR.F.Math.Add(newConv, Tensor.FromBytes(bias.CheckedDataType, bias.Value.BytesBuffer.ToArray(), new[] { bias.CheckedShape[0].FixedValue, 1, 1 }));
var newClamp = IR.F.Math.Clamp(newBias, fusedClamp[0], fusedClamp[1]);
return newClamp;
}

return expr.With(target: target, arguments: arguments);
}

protected override Expr VisitLeafVar(Var expr, Unit context)
{
if (_multiVarMap.TryGetValue(expr, out var newVar))
{
return newVar;
}

throw new InvalidOperationException();
}
}

Check failure on line 80 in modules/Nncase.Modules.CPU/Passes/Rules/FusionMerger.cs

View workflow job for this annotation

GitHub Actions / build-x86_64-linux


53 changes: 4 additions & 49 deletions modules/Nncase.Modules.CPU/Passes/Rules/MHAFusion.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,51 +23,6 @@

namespace Nncase.Passes.Rules;

/// <summary>
/// MHA Merger for all.
/// </summary>
public sealed class MHAMerger : ExprCloner<Unit>
{
private readonly IReadOnlyDictionary<Expr, Var> _multiVarMap;

public MHAMerger(IReadOnlyDictionary<Expr, Var> multiVarMap)
{
_multiVarMap = multiVarMap;
}

protected override Expr VisitCall(Call expr, Unit context)
{
if (_multiVarMap.TryGetValue(expr, out var newVar))
{
return newVar;
}

return base.VisitCall(expr, context);
}

protected override Expr VisitLeafCall(Call expr, Unit context)
{
var target = Clone(expr.Target, context);
var arguments = CloneArray(expr.Arguments, context);
if (target is Binary)
{
arguments = arguments.Select(e => e switch { TensorConst { Value: Tensor { Shape.IsScalar: true } } tc => Const.FromTensor(Tensor.FromBytes(tc.ValueType.DType, tc.Value.BytesBuffer.ToArray(), new[] { 1 })), _ => e }).ToArray();
}

return expr.With(target: target, arguments: arguments);
}

protected override Expr VisitLeafVar(Var expr, Unit context)
{
if (_multiVarMap.TryGetValue(expr, out var newVar))
{
return newVar;
}

throw new InvalidOperationException();
}
}

// pattern from BERT base
[RuleGenerator]
public sealed partial class FuseMHA1 : FusionMaker
Expand Down Expand Up @@ -248,7 +203,7 @@ private static Pattern CreatePattern()
{ position_ids, (Var)newInputs[1] },
{ attn_mask, (Var)newInputs[2] },
};
var merger = new MHAMerger(multiVarMap);
var merger = new FusionMerger(multiVarMap);
var clonedRoot = merger.Clone(root, default);

var callFusion = new Call(new Fusion("MHALLaMA65B", $"{nameof(FuseMHA2)}_{Count++}", ModuleKind, clonedRoot, newInputs.OfType<Var>().ToArray()), hidden_in, position_ids, attn_mask);
Expand Down Expand Up @@ -308,7 +263,7 @@ private static Pattern CreatePattern()
{
{ input, (Var)newInputs[0] },
};
var merger = new MHAMerger(multiVarMap);
var merger = new FusionMerger(multiVarMap);
var clonedRoot = merger.Clone(root, default);

var callFusion = new Call(new Fusion("SDTextEncoderMHA", $"{nameof(FuseMHA3)}_{Count++}", ModuleKind, clonedRoot, newInputs.OfType<Var>().ToArray()), input);
Expand Down Expand Up @@ -345,7 +300,7 @@ private static Pattern CreatePattern()
{
{ input, newInputs[0] },
};
var merger = new MHAMerger(multiVarMap);
var merger = new FusionMerger(multiVarMap);
var clonedRoot = merger.Clone(root, default);

var callFusion = new Call(new Fusion("SDTextEncoderHeader", $"{nameof(FuseSDTextEncoderHeader)}_{Count++}", ModuleKind, clonedRoot, newInputs.ToArray()), input);
Expand Down Expand Up @@ -389,7 +344,7 @@ private static Pattern CreatePattern()
{ input, newInputs[1] },
};

var merger = new MHAMerger(multiVarMap);
var merger = new FusionMerger(multiVarMap);
var clonedRoot = merger.Clone(root, default);

var callFusion = new Call(new Fusion("SDTextEncoderTail", $"{nameof(FuseSDTextEncoderTail)}_{Count++}", ModuleKind, clonedRoot, newInputs), input_ids, input);
Expand Down
61 changes: 3 additions & 58 deletions modules/Nncase.Modules.CPU/Passes/Rules/UnetFusion.cs
Original file line number Diff line number Diff line change
Expand Up @@ -27,61 +27,6 @@

namespace Nncase.Passes.Rules;

/// <summary>
/// Unet Merger for all.
/// </summary>
public sealed class UnetMerger : ExprCloner<Unit>
{
private readonly IReadOnlyDictionary<Expr, Var> _multiVarMap;

public UnetMerger(IReadOnlyDictionary<Expr, Var> multiVarMap)
{
_multiVarMap = multiVarMap;
}

protected override Expr VisitCall(Call expr, Unit context)
{
if (_multiVarMap.TryGetValue(expr, out var newVar))
{
return newVar;
}

return base.VisitCall(expr, context);
}

protected override Expr VisitLeafCall(Call expr, Unit context)
{
var target = Clone(expr.Target, context);
var arguments = CloneArray(expr.Arguments, context);
if (target is Binary)
{
arguments = arguments.Select(e => e switch { TensorConst { Value: Tensor { Shape.IsScalar: true } } tc => Const.FromTensor(Tensor.FromBytes(tc.ValueType.DType, tc.Value.BytesBuffer.ToArray(), new[] { 1 })), _ => e }).ToArray();
}

if (target is Conv2D conv)
{
var bias = (TensorConst)arguments[2];
var fusedClamp = ((TensorConst)arguments[7]).Value.ToArray<float>();
var newConv = IR.F.NN.Conv2D(arguments[0], arguments[1], Tensor.Zeros<float>(bias.CheckedShape), arguments[3], arguments[4], arguments[5], conv.PadMode, arguments[6], new[] { float.NegativeInfinity, float.PositiveInfinity });
var newBias = IR.F.Math.Add(newConv, Tensor.FromBytes(bias.CheckedDataType, bias.Value.BytesBuffer.ToArray(), new[] { bias.CheckedShape[0].FixedValue, 1, 1 }));
var newClamp = IR.F.Math.Clamp(newBias, fusedClamp[0], fusedClamp[1]);
return newClamp;
}

return expr.With(target: target, arguments: arguments);
}

protected override Expr VisitLeafVar(Var expr, Unit context)
{
if (_multiVarMap.TryGetValue(expr, out var newVar))
{
return newVar;
}

throw new InvalidOperationException();
}
}

/// <summary>
/// stable-disffusion Unet spatial transformer.
/// </summary>
Expand Down Expand Up @@ -194,7 +139,7 @@ private static Pattern CreatePattern()
{ input, (Var)newInputs[0] },
{ encoderHiddenStates, (Var)newInputs[1] },
};
var merger = new UnetMerger(multiVarMap);
var merger = new FusionMerger(multiVarMap);
var clonedRoot = merger.Clone(root, default);

var callFusion = new Call(new Fusion("UnetSpatialTransformer", $"{nameof(FuseUnetSpatialTransformer)}_{Count++}", ModuleKind, clonedRoot, newInputs.OfType<Var>().ToArray()), input, encoderHiddenStates);
Expand Down Expand Up @@ -282,7 +227,7 @@ private static Pattern CreatePattern()
multiVarMap.Add(oldInputs[i], (Var)newInputs[i]);
}

var merger = new UnetMerger(multiVarMap);
var merger = new FusionMerger(multiVarMap);
var clonedRoot = merger.Clone(root, default);

var callFusion = new Call(new Fusion("UnetResBlock", $"{nameof(FuseUnetResBlock)}_{Count++}", ModuleKind, clonedRoot, newInputs.OfType<Var>().ToArray()), oldInputs.ToArray());
Expand Down Expand Up @@ -336,7 +281,7 @@ private static Pattern CreatePattern()
{
{ input, (Var)newInputs[0] },
};
var merger = new UnetMerger(multiVarMap);
var merger = new FusionMerger(multiVarMap);
var clonedRoot = merger.Clone(root, default);

var callFusion = new Call(new Fusion("UnetTimeEmb", $"{nameof(FuseUnetTimeEmb)}_{Count++}", ModuleKind, clonedRoot, newInputs.OfType<Var>().ToArray()), input);
Expand Down
51 changes: 3 additions & 48 deletions modules/Nncase.Modules.CPU/Passes/Rules/VAEFusion.cs
Original file line number Diff line number Diff line change
Expand Up @@ -24,51 +24,6 @@

namespace Nncase.Passes.Rules;

/// <summary>
/// VAE Merger for all.
/// </summary>
public sealed class VAEMerger : ExprCloner<Unit>
{
private readonly IReadOnlyDictionary<Expr, Var> _multiVarMap;

public VAEMerger(IReadOnlyDictionary<Expr, Var> multiVarMap)
{
_multiVarMap = multiVarMap;
}

protected override Expr VisitCall(Call expr, Unit context)
{
if (_multiVarMap.TryGetValue(expr, out var newVar))
{
return newVar;
}

return base.VisitCall(expr, context);
}

protected override Expr VisitLeafCall(Call expr, Unit context)
{
var target = Clone(expr.Target, context);
var arguments = CloneArray(expr.Arguments, context);
if (target is Binary)
{
arguments = arguments.Select(e => e switch { TensorConst { Value: Tensor { Shape.IsScalar: true } } tc => Const.FromTensor(Tensor.FromBytes(tc.ValueType.DType, tc.Value.BytesBuffer.ToArray(), new[] { 1 })), _ => e }).ToArray();
}

return expr.With(target: target, arguments: arguments);
}

protected override Expr VisitLeafVar(Var expr, Unit context)
{
if (_multiVarMap.TryGetValue(expr, out var newVar))
{
return newVar;
}

throw new InvalidOperationException();
}
}

/// <summary>
/// stable-disffusion VAE Decoder res-block.
/// </summary>
Expand Down Expand Up @@ -120,7 +75,7 @@ private static Pattern CreatePattern()
{
{ input, (Var)newInputs[0] },
};
var merger = new VAEMerger(multiVarMap);
var merger = new FusionMerger(multiVarMap);
var clonedRoot = merger.Clone(root, default);

var callFusion = new Call(new Fusion("VAEDecRes", $"{nameof(FuseVAEDecRes)}_{Count++}", ModuleKind, clonedRoot, newInputs.OfType<Var>().ToArray()), input);
Expand Down Expand Up @@ -163,7 +118,7 @@ private static Pattern CreatePattern()
{
{ input, (Var)newInputs[0] },
};
var merger = new VAEMerger(multiVarMap);
var merger = new FusionMerger(multiVarMap);
var clonedRoot = merger.Clone(root, default);

var callFusion = new Call(new Fusion("VAEDecHead", $"{nameof(FuseVAEDecHead)}_{Count++}", ModuleKind, clonedRoot, newInputs.OfType<Var>().ToArray()), input);
Expand Down Expand Up @@ -222,7 +177,7 @@ private static Pattern CreatePattern()
{
{ input, (Var)newInputs[0] },
};
var merger = new VAEMerger(multiVarMap);
var merger = new FusionMerger(multiVarMap);
var clonedRoot = merger.Clone(root, default);

var callFusion = new Call(new Fusion("VAEDecMHA", $"{nameof(FuseVAEDecMHA)}_{Count++}", ModuleKind, clonedRoot, newInputs.OfType<Var>().ToArray()), input);
Expand Down

0 comments on commit aafa155

Please sign in to comment.