[Mlir-commits] [mlir] [MLIR] TosaToLinalgNamed: Lower unsigned tosa.max_pool2d (PR #123290)
Matthias Gehre
llvmlistbot at llvm.org
Thu Jan 16 23:23:17 PST 2025
https://github.com/mgehre-amd created https://github.com/llvm/llvm-project/pull/123290
This PR allows to lower **unsigned** `tosa.max_pool2d` to linalg.
```
// CHECK-LABEL: @max_pool_ui8
func.func @max_pool_ui8(%arg0: tensor<1x6x34x62xui8>) -> tensor<1x4x32x62xui8> {
// CHECK: builtin.unrealized_conversion_cast {{.*}} : tensor<1x6x34x62xui8> to tensor<1x6x34x62xi8>
// CHECK: arith.constant 0
// CHECK: linalg.pooling_nhwc_max_unsigned {{.*}} : (tensor<1x4x32x62xi8>) -> tensor<1x4x32x62xi8>
// CHECK: builtin.unrealized_conversion_cast {{.*}} : tensor<1x4x32x62xi8> to tensor<1x4x32x62xui8>
%0 = tosa.max_pool2d %arg0 {pad = array<i64: 0, 0, 0, 0>, kernel = array<i64: 3, 3>, stride = array<i64: 1, 1>} : (tensor<1x6x34x62xui8>) -> tensor<1x4x32x62xui8>
return %0 : tensor<1x4x32x62xui8>
}
```
It does this by
- converting the MaxPool2dConverter from OpRewriterPattern to OpConversion Pattern
- adjusting the padding value to the the minimum unsigned value when the max_pool is unsigned
- lowering to `linalg.pooling_nhwc_max_unsigned` (which uses `arith.maxui`) when the max_pool is unsigned
>From 68aded8a408b62c1f7d4ddc26a59fa066caf07da Mon Sep 17 00:00:00 2001
From: Matthias Gehre <matthias.gehre at amd.com>
Date: Fri, 17 Jan 2025 08:04:36 +0100
Subject: [PATCH] [MLIR] TosaToLinalgName: Lower unsigned tosa.max_pool2d
This PR allows to lower unsigned `tosa.max_pool2d` to linalg.
```
// CHECK-LABEL: @max_pool_ui8
func.func @max_pool_ui8(%arg0: tensor<1x6x34x62xui8>) -> tensor<1x4x32x62xui8> {
// CHECK: builtin.unrealized_conversion_cast {{.*}} : tensor<1x6x34x62xui8> to tensor<1x6x34x62xi8>
// CHECK: arith.constant 0
// CHECK: linalg.pooling_nhwc_max_unsigned {{.*}} : (tensor<1x4x32x62xi8>) -> tensor<1x4x32x62xi8>
// CHECK: builtin.unrealized_conversion_cast {{.*}} : tensor<1x4x32x62xi8> to tensor<1x4x32x62xui8>
%0 = tosa.max_pool2d %arg0 {pad = array<i64: 0, 0, 0, 0>, kernel = array<i64: 3, 3>, stride = array<i64: 1, 1>} : (tensor<1x6x34x62xui8>) -> tensor<1x4x32x62xui8>
return %0 : tensor<1x4x32x62xui8>
}
```
It does this by
- converting the MaxPool2dConverter from OpRewriterPattern to OpConversion Pattern
- adjusting the padding value to the the minimum unsigned value when the max_pool is unsigned
- lowering to `linalg.pooling_nhwc_max_unsigned` (which uses `arith.maxui`) when the max_pool is unsigned
---
.../Conversion/TosaToLinalg/TosaToLinalg.h | 3 +-
.../TosaToLinalg/TosaToLinalgNamed.cpp | 52 +++++++++++++------
.../TosaToLinalg/TosaToLinalgNamedPass.cpp | 6 ++-
.../TosaToLinalg/tosa-to-linalg-named.mlir | 13 +++++
4 files changed, 56 insertions(+), 18 deletions(-)
diff --git a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
index 1822016fc88fe6..a1eb22eba69877 100644
--- a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
+++ b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
@@ -52,7 +52,8 @@ void populateTosaToLinalgConversionPatterns(const TypeConverter &converter,
/// Populates conversion passes from TOSA dialect to Linalg named operations.
void populateTosaToLinalgNamedConversionPatterns(
- RewritePatternSet *patterns, const TosaToLinalgNamedOptions &options);
+ const TypeConverter &converter, RewritePatternSet *patterns,
+ const TosaToLinalgNamedOptions &options);
} // namespace tosa
} // namespace mlir
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index d537aef5791031..b7af37d293ac1c 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -695,17 +695,18 @@ class FullyConnectedConverter
}
};
-class MaxPool2dConverter : public OpRewritePattern<tosa::MaxPool2dOp> {
+class MaxPool2dConverter : public OpConversionPattern<tosa::MaxPool2dOp> {
public:
- using OpRewritePattern<tosa::MaxPool2dOp>::OpRewritePattern;
+ using OpConversionPattern::OpConversionPattern;
// Compute the dynamic output sizes of the maxpool operation.
static SmallVector<Value>
- computeDynamicOutputSizes(tosa::MaxPool2dOp op, PatternRewriter &rewriter) {
+ computeDynamicOutputSizes(tosa::MaxPool2dOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) {
TensorType resultTy = op.getType();
Location loc = op.getLoc();
- TypedValue<TensorType> input = op.getInput();
+ Value input = adaptor.getInput();
ArrayRef<int64_t> kernel = op.getKernel();
ArrayRef<int64_t> pad = op.getPad();
ArrayRef<int64_t> stride = op.getStride();
@@ -744,16 +745,22 @@ class MaxPool2dConverter : public OpRewritePattern<tosa::MaxPool2dOp> {
return dynamicDims;
}
- LogicalResult matchAndRewrite(tosa::MaxPool2dOp op,
- PatternRewriter &rewriter) const final {
+ LogicalResult
+ matchAndRewrite(tosa::MaxPool2dOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const final {
Location loc = op.getLoc();
- TypedValue<TensorType> input = op.getInput();
- ShapedType inputTy = input.getType();
+ Value input = adaptor.getInput();
+ ShapedType inputTy = cast<ShapedType>(input.getType());
- ShapedType resultTy = op.getType();
+ bool isUnsigned = op.getType().getElementType().isUnsignedInteger();
+ ShapedType resultTy =
+ cast<ShapedType>(getTypeConverter()->convertType(op.getType()));
+ if (!resultTy)
+ return rewriter.notifyMatchFailure(op, "failed to convert type");
Type resultETy = inputTy.getElementType();
- SmallVector<Value> dynamicDims = computeDynamicOutputSizes(op, rewriter);
+ SmallVector<Value> dynamicDims =
+ computeDynamicOutputSizes(op, adaptor, rewriter);
// Determine what the initial value needs to be for the max pool op.
TypedAttr initialAttr;
@@ -762,7 +769,10 @@ class MaxPool2dConverter : public OpRewritePattern<tosa::MaxPool2dOp> {
resultETy, APFloat::getLargest(
cast<FloatType>(resultETy).getFloatSemantics(), true));
- if (isa<IntegerType>(resultETy))
+ else if (isUnsigned)
+ initialAttr = rewriter.getIntegerAttr(
+ resultETy, APInt::getZero(resultETy.getIntOrFloatBitWidth()));
+ else if (isa<IntegerType>(resultETy))
initialAttr = rewriter.getIntegerAttr(
resultETy,
APInt::getSignedMinValue(resultETy.getIntOrFloatBitWidth()));
@@ -798,9 +808,15 @@ class MaxPool2dConverter : public OpRewritePattern<tosa::MaxPool2dOp> {
Value fakeWindowDims =
rewriter.create<tensor::EmptyOp>(loc, kernel, resultETy);
- rewriter.replaceOpWithNewOp<linalg::PoolingNhwcMaxOp>(
- op, ArrayRef<Type>{resultTy}, ValueRange{paddedInput, fakeWindowDims},
- filledEmptyTensor, strideAttr, dilationAttr);
+ if (isUnsigned) {
+ rewriter.replaceOpWithNewOp<linalg::PoolingNhwcMaxUnsignedOp>(
+ op, ArrayRef<Type>{resultTy}, ValueRange{paddedInput, fakeWindowDims},
+ filledEmptyTensor, strideAttr, dilationAttr);
+ } else {
+ rewriter.replaceOpWithNewOp<linalg::PoolingNhwcMaxOp>(
+ op, ArrayRef<Type>{resultTy}, ValueRange{paddedInput, fakeWindowDims},
+ filledEmptyTensor, strideAttr, dilationAttr);
+ }
return success();
}
};
@@ -1070,7 +1086,8 @@ class TransposeConverter : public OpRewritePattern<tosa::TransposeOp> {
} // namespace
void mlir::tosa::populateTosaToLinalgNamedConversionPatterns(
- RewritePatternSet *patterns, const TosaToLinalgNamedOptions &options) {
+ const TypeConverter &converter, RewritePatternSet *patterns,
+ const TosaToLinalgNamedOptions &options) {
if (options.preferConv2DKernelLayoutHWCF) {
patterns->add<ConvConverter<tosa::Conv2DOp, linalg::Conv2DNhwcHwcfOp,
linalg::Conv2DNhwcHwcfQOp>>(
@@ -1085,10 +1102,13 @@ void mlir::tosa::populateTosaToLinalgNamedConversionPatterns(
ConvConverter<tosa::Conv3DOp, linalg::Conv3DNdhwcDhwcfOp, linalg::Conv3DNdhwcDhwcfQOp>,
DepthwiseConvConverter,
MatMulConverter,
- MaxPool2dConverter,
AvgPool2dConverter,
FullyConnectedConverter,
TransposeConverter
>(patterns->getContext());
+
+ patterns->add<
+ MaxPool2dConverter
+ >(converter, patterns->getContext());
// clang-format on
}
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp
index 096969391e51b9..7d943b3779fb02 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp
@@ -47,6 +47,9 @@ struct TosaToLinalgNamed
}
void runOnOperation() override {
+ TypeConverter converter;
+ tosa::populateTosaTypeConversion(converter);
+
RewritePatternSet patterns(&getContext());
ConversionTarget target(getContext());
target.addLegalDialect<linalg::LinalgDialect, tosa::TosaDialect,
@@ -67,7 +70,8 @@ struct TosaToLinalgNamed
FunctionOpInterface func = getOperation();
TosaToLinalgNamedOptions options;
options.preferConv2DKernelLayoutHWCF = preferConv2DKernelLayoutHWCF;
- tosa::populateTosaToLinalgNamedConversionPatterns(&patterns, options);
+ tosa::populateTosaToLinalgNamedConversionPatterns(converter, &patterns,
+ options);
if (failed(applyFullConversion(func, target, std::move(patterns))))
signalPassFailure();
}
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
index 453a8610e7169a..5eeaebb384e408 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
@@ -200,6 +200,19 @@ func.func @max_pool_i8(%arg0: tensor<1x6x34x62xi8>) -> () {
return
}
+// CHECK-LABEL: @max_pool_ui8
+func.func @max_pool_ui8(%arg0: tensor<1x6x34x62xui8>) -> tensor<1x4x32x62xui8> {
+ // CHECK: builtin.unrealized_conversion_cast {{.*}} : tensor<1x6x34x62xui8> to tensor<1x6x34x62xi8>
+ // CHECK: arith.constant 0
+ // CHECK: linalg.pooling_nhwc_max_unsigned
+ // CHECK-SAME: ins({{.*}} : tensor<1x6x34x62xi8>, tensor<3x3xi8>)
+ // CHECK-SAME: outs({{.*}} : tensor<1x4x32x62xi8>)
+ // CHECK-SAME: -> tensor<1x4x32x62xi8>
+ // CHECK: builtin.unrealized_conversion_cast {{.*}} : tensor<1x4x32x62xi8> to tensor<1x4x32x62xui8>
+ %0 = tosa.max_pool2d %arg0 {pad = array<i64: 0, 0, 0, 0>, kernel = array<i64: 3, 3>, stride = array<i64: 1, 1>} : (tensor<1x6x34x62xui8>) -> tensor<1x4x32x62xui8>
+ return %0 : tensor<1x4x32x62xui8>
+}
+
// CHECK-LABEL: @max_pool_i16
func.func @max_pool_i16(%arg0: tensor<1x6x34x62xi16>) -> () {
// CHECK: arith.constant -32768
More information about the Mlir-commits
mailing list