Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implementing fake_quantize_per_channel_affine #959

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ __pycache__/
# OpenCover UI analysis results
OpenCover/

# Azure Stream Analytics local run output
# Azure Stream Analytics local run output
ASALocalRun/

# MSBuild Binary and Structured Log
Expand All @@ -271,4 +271,5 @@ packages/
/src/Native/out/build/x64-Debug
*.code-workspace
/.idea
/.vscode
/test/TorchSharpTest/exportsd.py
15 changes: 15 additions & 0 deletions src/Native/LibTorchSharp/THSTensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2200,3 +2200,18 @@ Tensor THSTensor_unflatten_names(Tensor tensor, const char** names, const int64_

return nullptr;
}

Tensor THSTensor_fake_quantize_per_channel_affine(Tensor tensor, Tensor scale, Tensor zero_point, int64_t axis, int64_t quant_min, int64_t quant_max)
{
at::Tensor res;
CATCH(res = at::fake_quantize_per_channel_affine(*tensor, *scale, *zero_point, axis, quant_min, quant_max);)
return ResultTensor(res);
}

Tensor THSTensor_fake_quantize_per_channel_affine_cachemask(Tensor tensor, Tensor scale, Tensor zero_point, int64_t axis, int64_t quant_min, int64_t quant_max, Tensor* mask)
{
std::tuple<at::Tensor, at::Tensor> res;
CATCH(res = at::fake_quantize_per_channel_affine_cachemask(*tensor, *scale, *zero_point, axis, quant_min, quant_max);)
*mask = ResultTensor(std::get<1>(res));
return ResultTensor(std::get<0>(res));
}
6 changes: 6 additions & 0 deletions src/Native/LibTorchSharp/THSTensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -1712,3 +1712,9 @@ EXPORT_API(Tensor) THSTensor_kaiser_window(const int64_t len, bool periodic, dou

EXPORT_API(Tensor) THSTensor_stft(const Tensor x, int64_t n_fft, int64_t hop_length, int64_t win_length, const Tensor window, bool normalized, int64_t onesided, bool return_complex);
EXPORT_API(Tensor) THSTensor_istft(const Tensor x, int64_t n_fft, int64_t hop_length, int64_t win_length, const Tensor window, bool center, bool normalized, int64_t onesided, int64_t length, bool return_complex);


// Pointwise Ops

EXPORT_API(Tensor) THSTensor_fake_quantize_per_channel_affine(const Tensor tensor, const Tensor scale, const Tensor zero_point, int64_t axis, int64_t quant_min, int64_t quant_max);
EXPORT_API(Tensor) THSTensor_fake_quantize_per_channel_affine_cachemask(const Tensor tensor, const Tensor scale, const Tensor zero_point, int64_t axis, int64_t quant_min, int64_t quant_max, Tensor* mask);
7 changes: 7 additions & 0 deletions src/TorchSharp/PInvoke/LibTorchSharp.THSTensor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2071,9 +2071,16 @@ internal static extern IntPtr THSTensor_upsample_nearest3d(IntPtr input,

[DllImport("LibTorchSharp")]
internal static extern IntPtr THSTensor_searchsorted_t(IntPtr sorted_sequence, IntPtr values, bool out_int32, bool right, IntPtr sorter);

[DllImport("LibTorchSharp")]
internal static extern IntPtr THSTensor_searchsorted_s(IntPtr sorted_sequence, IntPtr values, bool out_int32, bool right, IntPtr sorter);

[DllImport("LibTorchSharp")]
internal static extern IntPtr THSTensor_fake_quantize_per_channel_affine(IntPtr tensor, IntPtr scale, IntPtr zero_point, long axis, long quant_min, long quant_max);

[DllImport("LibTorchSharp")]
internal static extern IntPtr THSTensor_fake_quantize_per_channel_affine_cachemask(IntPtr tensor, IntPtr scale, IntPtr zero_point, long axis, long quant_min, long quant_max, out IntPtr mask);

[DllImport("LibTorchSharp")]
internal static extern IntPtr THSTensor_histogram_t(IntPtr input, IntPtr bins, IntPtr weight, bool density, out IntPtr r_bin_edges);
[DllImport("LibTorchSharp")]
Expand Down
50 changes: 50 additions & 0 deletions src/TorchSharp/Tensor/Tensor.PointwiseOps.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information.
#nullable enable
using System;
using static TorchSharp.PInvoke.LibTorchSharp;

namespace TorchSharp
{
public static partial class torch
{
public partial class Tensor
{
// https://pytorch.org/docs/stable/generated/torch.fake_quantize_per_channel_affine
/// <summary>
/// Returns a new tensor with the data in this fake quantized per channel using
/// <paramref name="scale"/>, <paramref name="zero_point"/>, <paramref name="quant_min"/> and <paramref name="quant_max"/>,
/// across the channel specified by <paramref name="axis"/>.
/// </summary>
/// <param name="scale">quantization scale, per channel (float32)</param>
/// <param name="zero_point">quantization zero_point, per channel (torch.int32, torch.half, or torch.float32)</param>
/// <param name="axis">channel axis</param>
/// <param name="quant_min">lower bound of the quantized domain</param>
/// <param name="quant_max">upper bound of the quantized domain</param>
/// <returns>A newly fake_quantized per channel torch.float32 tensor</returns>
public Tensor fake_quantize_per_channel_affine(Tensor scale, Tensor zero_point, long axis, long quant_min, long quant_max)
{
var res = THSTensor_fake_quantize_per_channel_affine(
Handle, scale.Handle, zero_point.handle,
axis, quant_min, quant_max);

if (res == IntPtr.Zero)
CheckForErrors();

return new Tensor(res);
}

// see: aten/src/ATen/native/quantized/FakeQuantPerChannelAffine.cpp
internal (Tensor res, Tensor mask) fake_quantize_per_channel_affine_cachemask(Tensor scale, Tensor zero_point, long axis, long quant_min, long quant_max)
{
var res = THSTensor_fake_quantize_per_channel_affine_cachemask(
Handle, scale.Handle, zero_point.handle,
axis, quant_min, quant_max, out IntPtr mask);

if (res == IntPtr.Zero || mask == IntPtr.Zero)
CheckForErrors();

return (new Tensor(res), new Tensor(mask));
}
}
}
}
4 changes: 2 additions & 2 deletions src/TorchSharp/Tensor/torch.PointwiseOps.cs
Original file line number Diff line number Diff line change
Expand Up @@ -741,9 +741,9 @@ public static Tensor addcmul_(Tensor input, Tensor tensor1, Tensor tensor2, Scal
/// <param name="quant_min">lower bound of the quantized domain</param>
/// <param name="quant_max">upper bound of the quantized domain</param>
/// <returns>A newly fake_quantized per channel torch.float32 tensor</returns>
[Pure, Obsolete("not implemented", true)]
[Pure]
public static Tensor fake_quantize_per_channel_affine(Tensor input, Tensor scale, Tensor zero_point, int axis, long quant_min, long quant_max)
=> throw new NotImplementedException();
=> input.fake_quantize_per_channel_affine(scale, zero_point, axis, quant_min, quant_max);

// https://pytorch.org/docs/stable/generated/torch.fake_quantize_per_tensor_affine
/// <summary>
Expand Down
16 changes: 7 additions & 9 deletions src/TorchVision/AdjustGamma.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information.
#nullable enable
using System;
using static TorchSharp.torch;

Expand All @@ -19,16 +20,16 @@ internal AdjustGamma(double gamma, double gain = 1.0)
public Tensor call(Tensor img)
{
var dtype = img.dtype;
if (!torch.is_floating_point(img))
img = transforms.ConvertImageDtype(torch.float32).call(img);
if (!is_floating_point(img))
img = transforms.ConvertImageDtype(float32).call(img);

img = (gain * img.pow(gamma)).clamp(0, 1);

return transforms.ConvertImageDtype(dtype).call(img); ;
return transforms.ConvertImageDtype(dtype).call(img);
}

private double gamma;
private double gain;
private readonly double gamma;
private readonly double gain;
}

public static partial class transforms
Expand All @@ -44,10 +45,7 @@ public static partial class transforms
/// </param>
/// <param name="gain">The constant multiplier in the gamma correction equation.</param>
/// <returns></returns>
static public ITransform AdjustGamma(double gamma, double gain = 1.0)
{
return new AdjustGamma(gamma);
}
public static ITransform AdjustGamma(double gamma, double gain = 1.0) => new AdjustGamma(gamma, gain);
}
}
}
17 changes: 17 additions & 0 deletions test/TorchSharpTest/TestTorchTensor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7724,6 +7724,7 @@ public void TestCartesianProd()
}

[Fact]
[TestOf(nameof(torch.combinations))]
public void TestCombinations()
{
var t = torch.arange(5);
Expand All @@ -7735,6 +7736,7 @@ public void TestCombinations()
}

[Fact]
[TestOf(nameof(torch.cdist))]
public void TestCDist()
{
var a = torch.randn(3, 2);
Expand All @@ -7746,6 +7748,7 @@ public void TestCDist()
}

[Fact]
[TestOf(nameof(torch.rot90))]
public void TestRot90()
{
var a = torch.arange(8).view(2, 2, 2);
Expand All @@ -7756,6 +7759,7 @@ public void TestRot90()
}

[Fact]
[TestOf(nameof(torch.diag_embed))]
public void TestDiagembed()
{
var a = torch.randn(2, 3);
Expand All @@ -7768,6 +7772,7 @@ public void TestDiagembed()
}

[Fact]
[TestOf(nameof(torch.searchsorted))]
public void TestSearchSorted()
{
var ss = torch.from_array(new long[] { 1, 3, 5, 7, 9, 2, 4, 6, 8, 10 }).reshape(2, -1);
Expand All @@ -7782,6 +7787,7 @@ public void TestSearchSorted()
}

[Fact]
[TestOf(nameof(torch.histogram))]
public void TestHistogram()
{
// https://pytorch.org/docs/stable/generated/torch.histogram.html
Expand All @@ -7804,6 +7810,17 @@ public void TestHistogram()
Assert.True(bin_edges.allclose(torch.tensor(new double[] { 0, 1, 2, 3 }), 0.001));
}

[Fact]
[TestOf(nameof(torch.fake_quantize_per_channel_affine))]
public void TestFakeQuantizePerChannelAffine()
{
var x = torch.rand(2, 2, 2);
var scales = (torch.randn(2) + 1d) * 0.05d;
var zero_points = torch.zeros(2).to(torch.int32);
var result = torch.fake_quantize_per_channel_affine(x, scales, zero_points, axis: 0, quant_min: 0, quant_max: 255);
Assert.True(true);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Huh? Is this just because the test is really just looking for an exception in the call?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Took this code from the official documentation. It's the minimal amount of testing.

Should test the result values, but would need to get rid of the random initializer.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Either use a Generator instance, or (I do this frequently) just create a tensor from the random values in the PyTorch doc and then the correct results should be right there in the documentation already...

}

[Fact]
public void TestHistogramOptimBinNums()
{
Expand Down