[Mlir-commits] [mlir] [MLIR] TosaToLinalgNamed: Lower unsigned tosa.max_pool2d (PR #123290)

Matthias Gehre llvmlistbot at llvm.org
Thu Jan 16 23:23:17 PST 2025


https://github.com/mgehre-amd created https://github.com/llvm/llvm-project/pull/123290

This PR allows to lower **unsigned** `tosa.max_pool2d` to linalg.
```
// CHECK-LABEL: @max_pool_ui8
func.func @max_pool_ui8(%arg0: tensor<1x6x34x62xui8>) -> tensor<1x4x32x62xui8> {
  // CHECK: builtin.unrealized_conversion_cast {{.*}} : tensor<1x6x34x62xui8> to tensor<1x6x34x62xi8>
  // CHECK: arith.constant 0
  // CHECK: linalg.pooling_nhwc_max_unsigned {{.*}} : (tensor<1x4x32x62xi8>) -> tensor<1x4x32x62xi8>
  // CHECK: builtin.unrealized_conversion_cast {{.*}} : tensor<1x4x32x62xi8> to tensor<1x4x32x62xui8>
  %0 = tosa.max_pool2d %arg0 {pad = array<i64: 0, 0, 0, 0>, kernel = array<i64: 3, 3>, stride = array<i64: 1, 1>} : (tensor<1x6x34x62xui8>) -> tensor<1x4x32x62xui8>
  return %0 : tensor<1x4x32x62xui8>
}
```
It does this by
- converting the MaxPool2dConverter from OpRewriterPattern to OpConversion Pattern
- adjusting the padding value to the the minimum unsigned value when the max_pool is unsigned
- lowering to `linalg.pooling_nhwc_max_unsigned` (which uses `arith.maxui`) when the max_pool is unsigned

>From 68aded8a408b62c1f7d4ddc26a59fa066caf07da Mon Sep 17 00:00:00 2001
From: Matthias Gehre <matthias.gehre at amd.com>
Date: Fri, 17 Jan 2025 08:04:36 +0100
Subject: [PATCH] [MLIR] TosaToLinalgName: Lower unsigned tosa.max_pool2d

This PR allows to lower unsigned `tosa.max_pool2d` to linalg.
```
// CHECK-LABEL: @max_pool_ui8
func.func @max_pool_ui8(%arg0: tensor<1x6x34x62xui8>) -> tensor<1x4x32x62xui8> {
  // CHECK: builtin.unrealized_conversion_cast {{.*}} : tensor<1x6x34x62xui8> to tensor<1x6x34x62xi8>
  // CHECK: arith.constant 0
  // CHECK: linalg.pooling_nhwc_max_unsigned {{.*}} : (tensor<1x4x32x62xi8>) -> tensor<1x4x32x62xi8>
  // CHECK: builtin.unrealized_conversion_cast {{.*}} : tensor<1x4x32x62xi8> to tensor<1x4x32x62xui8>
  %0 = tosa.max_pool2d %arg0 {pad = array<i64: 0, 0, 0, 0>, kernel = array<i64: 3, 3>, stride = array<i64: 1, 1>} : (tensor<1x6x34x62xui8>) -> tensor<1x4x32x62xui8>
  return %0 : tensor<1x4x32x62xui8>
}
```
It does this by
- converting the MaxPool2dConverter from OpRewriterPattern to OpConversion Pattern
- adjusting the padding value to the the minimum unsigned value when the max_pool is unsigned
- lowering to `linalg.pooling_nhwc_max_unsigned` (which uses `arith.maxui`) when the max_pool is unsigned
---
 .../Conversion/TosaToLinalg/TosaToLinalg.h    |  3 +-
 .../TosaToLinalg/TosaToLinalgNamed.cpp        | 52 +++++++++++++------
 .../TosaToLinalg/TosaToLinalgNamedPass.cpp    |  6 ++-
 .../TosaToLinalg/tosa-to-linalg-named.mlir    | 13 +++++
 4 files changed, 56 insertions(+), 18 deletions(-)

diff --git a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
index 1822016fc88fe6..a1eb22eba69877 100644
--- a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
+++ b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
@@ -52,7 +52,8 @@ void populateTosaToLinalgConversionPatterns(const TypeConverter &converter,
 
 /// Populates conversion passes from TOSA dialect to Linalg named operations.
 void populateTosaToLinalgNamedConversionPatterns(
-    RewritePatternSet *patterns, const TosaToLinalgNamedOptions &options);
+    const TypeConverter &converter, RewritePatternSet *patterns,
+    const TosaToLinalgNamedOptions &options);
 
 } // namespace tosa
 } // namespace mlir
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index d537aef5791031..b7af37d293ac1c 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -695,17 +695,18 @@ class FullyConnectedConverter
   }
 };
 
-class MaxPool2dConverter : public OpRewritePattern<tosa::MaxPool2dOp> {
+class MaxPool2dConverter : public OpConversionPattern<tosa::MaxPool2dOp> {
 public:
-  using OpRewritePattern<tosa::MaxPool2dOp>::OpRewritePattern;
+  using OpConversionPattern::OpConversionPattern;
 
   // Compute the dynamic output sizes of the maxpool operation.
   static SmallVector<Value>
-  computeDynamicOutputSizes(tosa::MaxPool2dOp op, PatternRewriter &rewriter) {
+  computeDynamicOutputSizes(tosa::MaxPool2dOp op, OpAdaptor adaptor,
+                            ConversionPatternRewriter &rewriter) {
     TensorType resultTy = op.getType();
     Location loc = op.getLoc();
 
-    TypedValue<TensorType> input = op.getInput();
+    Value input = adaptor.getInput();
     ArrayRef<int64_t> kernel = op.getKernel();
     ArrayRef<int64_t> pad = op.getPad();
     ArrayRef<int64_t> stride = op.getStride();
@@ -744,16 +745,22 @@ class MaxPool2dConverter : public OpRewritePattern<tosa::MaxPool2dOp> {
     return dynamicDims;
   }
 
-  LogicalResult matchAndRewrite(tosa::MaxPool2dOp op,
-                                PatternRewriter &rewriter) const final {
+  LogicalResult
+  matchAndRewrite(tosa::MaxPool2dOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const final {
     Location loc = op.getLoc();
-    TypedValue<TensorType> input = op.getInput();
-    ShapedType inputTy = input.getType();
+    Value input = adaptor.getInput();
+    ShapedType inputTy = cast<ShapedType>(input.getType());
 
-    ShapedType resultTy = op.getType();
+    bool isUnsigned = op.getType().getElementType().isUnsignedInteger();
+    ShapedType resultTy =
+        cast<ShapedType>(getTypeConverter()->convertType(op.getType()));
+    if (!resultTy)
+      return rewriter.notifyMatchFailure(op, "failed to convert type");
     Type resultETy = inputTy.getElementType();
 
-    SmallVector<Value> dynamicDims = computeDynamicOutputSizes(op, rewriter);
+    SmallVector<Value> dynamicDims =
+        computeDynamicOutputSizes(op, adaptor, rewriter);
 
     // Determine what the initial value needs to be for the max pool op.
     TypedAttr initialAttr;
@@ -762,7 +769,10 @@ class MaxPool2dConverter : public OpRewritePattern<tosa::MaxPool2dOp> {
           resultETy, APFloat::getLargest(
                          cast<FloatType>(resultETy).getFloatSemantics(), true));
 
-    if (isa<IntegerType>(resultETy))
+    else if (isUnsigned)
+      initialAttr = rewriter.getIntegerAttr(
+          resultETy, APInt::getZero(resultETy.getIntOrFloatBitWidth()));
+    else if (isa<IntegerType>(resultETy))
       initialAttr = rewriter.getIntegerAttr(
           resultETy,
           APInt::getSignedMinValue(resultETy.getIntOrFloatBitWidth()));
@@ -798,9 +808,15 @@ class MaxPool2dConverter : public OpRewritePattern<tosa::MaxPool2dOp> {
     Value fakeWindowDims =
         rewriter.create<tensor::EmptyOp>(loc, kernel, resultETy);
 
-    rewriter.replaceOpWithNewOp<linalg::PoolingNhwcMaxOp>(
-        op, ArrayRef<Type>{resultTy}, ValueRange{paddedInput, fakeWindowDims},
-        filledEmptyTensor, strideAttr, dilationAttr);
+    if (isUnsigned) {
+      rewriter.replaceOpWithNewOp<linalg::PoolingNhwcMaxUnsignedOp>(
+          op, ArrayRef<Type>{resultTy}, ValueRange{paddedInput, fakeWindowDims},
+          filledEmptyTensor, strideAttr, dilationAttr);
+    } else {
+      rewriter.replaceOpWithNewOp<linalg::PoolingNhwcMaxOp>(
+          op, ArrayRef<Type>{resultTy}, ValueRange{paddedInput, fakeWindowDims},
+          filledEmptyTensor, strideAttr, dilationAttr);
+    }
     return success();
   }
 };
@@ -1070,7 +1086,8 @@ class TransposeConverter : public OpRewritePattern<tosa::TransposeOp> {
 } // namespace
 
 void mlir::tosa::populateTosaToLinalgNamedConversionPatterns(
-    RewritePatternSet *patterns, const TosaToLinalgNamedOptions &options) {
+    const TypeConverter &converter, RewritePatternSet *patterns,
+    const TosaToLinalgNamedOptions &options) {
   if (options.preferConv2DKernelLayoutHWCF) {
     patterns->add<ConvConverter<tosa::Conv2DOp, linalg::Conv2DNhwcHwcfOp,
                                 linalg::Conv2DNhwcHwcfQOp>>(
@@ -1085,10 +1102,13 @@ void mlir::tosa::populateTosaToLinalgNamedConversionPatterns(
       ConvConverter<tosa::Conv3DOp, linalg::Conv3DNdhwcDhwcfOp, linalg::Conv3DNdhwcDhwcfQOp>,
       DepthwiseConvConverter,
       MatMulConverter,
-      MaxPool2dConverter,
       AvgPool2dConverter,
       FullyConnectedConverter,
       TransposeConverter
   >(patterns->getContext());
+
+  patterns->add<
+      MaxPool2dConverter
+    >(converter, patterns->getContext());
   // clang-format on
 }
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp
index 096969391e51b9..7d943b3779fb02 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp
@@ -47,6 +47,9 @@ struct TosaToLinalgNamed
   }
 
   void runOnOperation() override {
+    TypeConverter converter;
+    tosa::populateTosaTypeConversion(converter);
+
     RewritePatternSet patterns(&getContext());
     ConversionTarget target(getContext());
     target.addLegalDialect<linalg::LinalgDialect, tosa::TosaDialect,
@@ -67,7 +70,8 @@ struct TosaToLinalgNamed
     FunctionOpInterface func = getOperation();
     TosaToLinalgNamedOptions options;
     options.preferConv2DKernelLayoutHWCF = preferConv2DKernelLayoutHWCF;
-    tosa::populateTosaToLinalgNamedConversionPatterns(&patterns, options);
+    tosa::populateTosaToLinalgNamedConversionPatterns(converter, &patterns,
+                                                      options);
     if (failed(applyFullConversion(func, target, std::move(patterns))))
       signalPassFailure();
   }
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
index 453a8610e7169a..5eeaebb384e408 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
@@ -200,6 +200,19 @@ func.func @max_pool_i8(%arg0: tensor<1x6x34x62xi8>) -> () {
   return
 }
 
+// CHECK-LABEL: @max_pool_ui8
+func.func @max_pool_ui8(%arg0: tensor<1x6x34x62xui8>) -> tensor<1x4x32x62xui8> {
+  // CHECK: builtin.unrealized_conversion_cast {{.*}} : tensor<1x6x34x62xui8> to tensor<1x6x34x62xi8>
+  // CHECK: arith.constant 0
+  // CHECK: linalg.pooling_nhwc_max_unsigned
+  // CHECK-SAME: ins({{.*}} : tensor<1x6x34x62xi8>, tensor<3x3xi8>)
+  // CHECK-SAME: outs({{.*}} : tensor<1x4x32x62xi8>)
+  // CHECK-SAME: -> tensor<1x4x32x62xi8>
+  // CHECK: builtin.unrealized_conversion_cast {{.*}} : tensor<1x4x32x62xi8> to tensor<1x4x32x62xui8>
+  %0 = tosa.max_pool2d %arg0 {pad = array<i64: 0, 0, 0, 0>, kernel = array<i64: 3, 3>, stride = array<i64: 1, 1>} : (tensor<1x6x34x62xui8>) -> tensor<1x4x32x62xui8>
+  return %0 : tensor<1x4x32x62xui8>
+}
+
 // CHECK-LABEL: @max_pool_i16
 func.func @max_pool_i16(%arg0: tensor<1x6x34x62xi16>) -> () {
   // CHECK: arith.constant -32768



More information about the Mlir-commits mailing list