[Mlir-commits] [mlir] TosaToTensor: Support reshape on tensors of unsigned integer (PR #91749)

Matthias Gehre llvmlistbot at llvm.org
Fri May 10 07:34:22 PDT 2024


https://github.com/mgehre-amd created https://github.com/llvm/llvm-project/pull/91749

On top of #91734

>From b06875a5c12f8ffc00b0caef26033c62b17909b5 Mon Sep 17 00:00:00 2001
From: Matthias Gehre <matthias.gehre at amd.com>
Date: Tue, 7 May 2024 16:27:53 +0200
Subject: [PATCH 1/2] TosaToTensor: Support reshape on tensors of unsigned
 integer

---
 .../Conversion/TosaToLinalg/TosaToLinalg.h    |  3 ++
 .../Conversion/TosaToTensor/TosaToTensor.h    |  4 ++-
 .../Conversion/TosaToLinalg/TosaToLinalg.cpp  | 34 +++++++++++++++++++
 .../Conversion/TosaToTensor/CMakeLists.txt    |  1 +
 .../Conversion/TosaToTensor/TosaToTensor.cpp  | 33 +++++++++++-------
 .../TosaToTensor/TosaToTensorPass.cpp         |  6 +++-
 .../TosaToTensor/tosa-to-tensor.mlir          | 14 ++++++++
 7 files changed, 80 insertions(+), 15 deletions(-)

diff --git a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
index 5fd77c8a0211a..d3024c7389b9c 100644
--- a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
+++ b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
@@ -18,6 +18,7 @@
 #include "mlir/Pass/Pass.h"
 
 namespace mlir {
+class TypeConverter;
 
 #define GEN_PASS_DECL_TOSATOLINALG
 #define GEN_PASS_DECL_TOSATOLINALGNAMED
@@ -52,6 +53,8 @@ void populateTosaToLinalgConversionPatterns(RewritePatternSet *patterns);
 void populateTosaToLinalgNamedConversionPatterns(
     RewritePatternSet *patterns, const TosaToLinalgNamedOptions &options);
 
+void populateTosaToLinalgTypeConversion(TypeConverter &converter);
+
 } // namespace tosa
 } // namespace mlir
 
diff --git a/mlir/include/mlir/Conversion/TosaToTensor/TosaToTensor.h b/mlir/include/mlir/Conversion/TosaToTensor/TosaToTensor.h
index 3953c83f3aa10..76a4b1b156336 100644
--- a/mlir/include/mlir/Conversion/TosaToTensor/TosaToTensor.h
+++ b/mlir/include/mlir/Conversion/TosaToTensor/TosaToTensor.h
@@ -16,6 +16,7 @@
 #include "mlir/Pass/Pass.h"
 
 namespace mlir {
+class TypeConverter;
 
 #define GEN_PASS_DECL_TOSATOTENSOR
 #include "mlir/Conversion/Passes.h.inc"
@@ -24,7 +25,8 @@ namespace tosa {
 
 std::unique_ptr<Pass> createTosaToTensor();
 
-void populateTosaToTensorConversionPatterns(RewritePatternSet *patterns);
+void populateTosaToTensorConversionPatterns(TypeConverter &converter,
+                                            RewritePatternSet *patterns);
 
 } // namespace tosa
 } // namespace mlir
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index e6ba6e6bc602d..dcb15012bda88 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -2617,3 +2617,37 @@ void mlir::tosa::populateTosaToLinalgConversionPatterns(
       TileConverter>(patterns->getContext());
   // clang-format on
 }
+
+void mlir::tosa::populateTosaToLinalgTypeConversion(TypeConverter &converter) {
+  converter.addConversion([&](Type type) -> std::optional<Type> {
+    if (type.isUnsignedInteger()) {
+      return IntegerType::get(type.getContext(), type.getIntOrFloatBitWidth(),
+                              IntegerType::SignednessSemantics::Signless);
+    }
+    return type;
+  });
+  converter.addConversion([&](TensorType type) -> std::optional<Type> {
+    auto converted = converter.convertType(type.getElementType());
+    if (!converted)
+      return {};
+    return type.clone(converted);
+  });
+  converter.addSourceMaterialization([&](OpBuilder &builder, Type resultType,
+                                         ValueRange inputs,
+                                         Location loc) -> std::optional<Value> {
+    if (inputs.size() != 1)
+      return std::nullopt;
+
+    return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
+        .getResult(0);
+  });
+  converter.addTargetMaterialization([&](OpBuilder &builder, Type resultType,
+                                         ValueRange inputs,
+                                         Location loc) -> std::optional<Value> {
+    if (inputs.size() != 1)
+      return std::nullopt;
+
+    return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
+        .getResult(0);
+  });
+}
diff --git a/mlir/lib/Conversion/TosaToTensor/CMakeLists.txt b/mlir/lib/Conversion/TosaToTensor/CMakeLists.txt
index 2870baa20757b..b1e7c9cba1a78 100644
--- a/mlir/lib/Conversion/TosaToTensor/CMakeLists.txt
+++ b/mlir/lib/Conversion/TosaToTensor/CMakeLists.txt
@@ -15,6 +15,7 @@ add_mlir_conversion_library(MLIRTosaToTensor
   MLIRIR
   MLIRPass
   MLIRTosaDialect
+  MLIRTosaToLinalg
   MLIRTosaTransforms
   MLIRSupport
   )
diff --git a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
index cd6da35582469..33f388faf6648 100644
--- a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
+++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
@@ -225,8 +225,17 @@ class ReshapeConverter : public OpConversionPattern<tosa::ReshapeOp> {
   matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const final {
     auto loc = reshape.getLoc();
-    auto resultType = reshape.getResult().getType();
-    auto input = reshape.getInput1();
+    auto resultType = cast_if_present<ShapedType>(
+        getTypeConverter()->convertType(reshape.getType()));
+    if (!resultType) {
+      return rewriter.notifyMatchFailure(reshape.getLoc(),
+                                         "could not convert result type");
+    }
+    auto input = dyn_cast<TypedValue<TensorType>>(adaptor.getInput1());
+    if (!input) {
+      return rewriter.notifyMatchFailure(reshape.getLoc(),
+                                         "expected input type to be tensor");
+    }
     auto newShape = reshape.getNewShape();
 
     // Infer all intermediate types
@@ -289,12 +298,13 @@ class SliceConverter : public OpConversionPattern<tosa::SliceOp> {
   }
 };
 
-class PadConverter : public OpRewritePattern<tosa::PadOp> {
+class PadConverter : public OpConversionPattern<tosa::PadOp> {
 public:
-  using OpRewritePattern<tosa::PadOp>::OpRewritePattern;
+  using OpConversionPattern::OpConversionPattern;
 
-  LogicalResult matchAndRewrite(tosa::PadOp padOp,
-                                PatternRewriter &rewriter) const final {
+  LogicalResult
+  matchAndRewrite(tosa::PadOp padOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const final {
     auto loc = padOp.getLoc();
     auto input = padOp.getInput1();
     auto padding = padOp.getPadding();
@@ -429,11 +439,8 @@ struct ConcatConverter : public OpConversionPattern<tosa::ConcatOp> {
 } // namespace
 
 void mlir::tosa::populateTosaToTensorConversionPatterns(
-    RewritePatternSet *patterns) {
-  patterns->add<
-    ConcatConverter,
-    PadConverter,
-    ReshapeConverter,
-    SliceConverter
-  >(patterns->getContext());
+    TypeConverter &converter, RewritePatternSet *patterns) {
+  patterns
+      ->add<ConcatConverter, PadConverter, ReshapeConverter, SliceConverter>(
+          converter, patterns->getContext());
 }
diff --git a/mlir/lib/Conversion/TosaToTensor/TosaToTensorPass.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensorPass.cpp
index 50dc55667fb94..9ae5edcce291e 100644
--- a/mlir/lib/Conversion/TosaToTensor/TosaToTensorPass.cpp
+++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensorPass.cpp
@@ -20,6 +20,7 @@
 #include "mlir/Pass/PassManager.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include <mlir/Conversion/TosaToLinalg/TosaToLinalg.h>
 
 namespace mlir {
 #define GEN_PASS_DEF_TOSATOTENSOR
@@ -42,7 +43,10 @@ struct TosaToTensor : public impl::TosaToTensorBase<TosaToTensor> {
     target.addLegalDialect<arith::ArithDialect>();
     target.addLegalDialect<tensor::TensorDialect>();
 
-    mlir::tosa::populateTosaToTensorConversionPatterns(&patterns);
+    TypeConverter converter;
+    mlir::tosa::populateTosaToLinalgTypeConversion(converter);
+
+    mlir::tosa::populateTosaToTensorConversionPatterns(converter, &patterns);
 
     if (failed(applyPartialConversion(getOperation(), target,
                                       std::move(patterns))))
diff --git a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
index b8c3d56f21f10..2eddde9a55660 100644
--- a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
+++ b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
@@ -405,6 +405,20 @@ func.func @test_reshape_6d_down_s2s_explicit(%arg0: tensor<1x2x3x5x7x11xf32>) ->
 
 // -----
 
+// CHECK-LABEL: @test_reshape_samerank_unsigned
+//  CHECK-SAME: (%[[ARG0:.*]]: tensor<3x2xui8>)
+func.func @test_reshape_samerank_unsigned(%arg0: tensor<3x2xui8>) -> tensor<2x3xui8> {
+  // CHECK-NEXT: %[[CAST1:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : tensor<3x2xui8> to tensor<3x2xi8>
+  // CHECK-NEXT: %[[RESHAPE1:.*]] = tensor.collapse_shape %[[CAST1]] {{\[}}[0, 1]] : tensor<3x2xi8> into tensor<6xi8>
+  // CHECK-NEXT: %[[RESHAPE2:.*]] = tensor.expand_shape %[[RESHAPE1]] {{\[}}[0, 1]] output_shape {{\[}}2, 3] : tensor<6xi8> into tensor<2x3xi8>
+  // CHECK-NEXT: %[[CAST2:.*]] = builtin.unrealized_conversion_cast %[[RESHAPE2]] : tensor<2x3xi8> to tensor<2x3xui8
+  %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 2, 3>} : (tensor<3x2xui8>) -> tensor<2x3xui8>
+  // CHECK-NEXT: return %[[CAST2]]
+  return %0 : tensor<2x3xui8>
+}
+
+// -----
+
 // CHECK-LABEL: func @slice
 func.func @slice(%arg0: tensor<6xf32>) ->() {
   // CHECK: [[SLICE:%.+]] = tensor.extract_slice %arg0[2] [1] [1]

>From 41f64ca52336ea30ffb7a3dc84f098460e5a4720 Mon Sep 17 00:00:00 2001
From: Matthias Gehre <matthias.gehre at amd.com>
Date: Thu, 4 Apr 2024 12:51:51 +0200
Subject: [PATCH 2/2] TosaToLinalg: Fix unsigned tosa.clamp

Plump the TypeConverter into PointwiseConverter,
and emit unsigned comparisons when the input type is unsigned.
---
 .../Conversion/TosaToLinalg/TosaToLinalg.h    |  3 +-
 .../mlir/Dialect/Tosa/Utils/ConversionUtils.h |  2 +-
 .../Conversion/TosaToLinalg/TosaToLinalg.cpp  | 83 ++++++++++++-------
 .../TosaToLinalg/TosaToLinalgNamed.cpp        |  3 +-
 .../TosaToLinalg/TosaToLinalgPass.cpp         |  5 +-
 .../Dialect/Tosa/Utils/ConversionUtils.cpp    |  6 +-
 .../TosaToLinalg/tosa-to-linalg-invalid.mlir  |  9 ++
 .../TosaToLinalg/tosa-to-linalg.mlir          |  9 +-
 8 files changed, 83 insertions(+), 37 deletions(-)

diff --git a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
index d3024c7389b9c..76c28d5e6dd40 100644
--- a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
+++ b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
@@ -47,7 +47,8 @@ void addTosaToLinalgPasses(
 void registerTosaToLinalgPipelines();
 
 /// Populates conversion passes from TOSA dialect to Linalg dialect.
-void populateTosaToLinalgConversionPatterns(RewritePatternSet *patterns);
+void populateTosaToLinalgConversionPatterns(TypeConverter &converter,
+                                            RewritePatternSet *patterns);
 
 /// Populates conversion passes from TOSA dialect to Linalg named operations.
 void populateTosaToLinalgNamedConversionPatterns(
diff --git a/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h b/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
index ca59b221d03eb..ceab7d9c628a5 100644
--- a/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
+++ b/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
@@ -37,7 +37,7 @@ Value clampFloatHelper(Location loc, Value arg, Value min, Value max,
 // Takes the parameters for a clamp and turns it into a series of ops for
 // integer inputs.
 Value clampIntHelper(Location loc, Value arg, Value min, Value max,
-                     OpBuilder &rewriter);
+                     OpBuilder &rewriter, bool isUnsigned);
 
 // Determines whether the integer value falls witin the range of integer type.
 bool validIntegerRange(IntegerType ty, int64_t value);
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index dcb15012bda88..2de0f06f8a506 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -46,10 +46,9 @@ createConstFromIntAttribute(Operation *op, const std::string &attrName,
       op->getLoc(), IntegerAttr::get(requiredAttrType, castedN));
 }
 
-static Value
-createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
-                                            ArrayRef<Type> resultTypes,
-                                            PatternRewriter &rewriter) {
+static Value createLinalgBodyCalculationForElementwiseOp(
+    Operation *op, ValueRange args, ArrayRef<Type> resultTypes,
+    ConversionPatternRewriter &rewriter) {
   Location loc = op->getLoc();
   auto elementTy =
       cast<ShapedType>(op->getOperand(0).getType()).getElementType();
@@ -186,7 +185,8 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
     Value max = rewriter.create<arith::ConstantIntOp>(
         loc, APInt::getSignedMaxValue(inputBitWidth).getSExtValue(),
         intermediateType);
-    auto clamp = clampIntHelper(loc, sub, min, max, rewriter);
+    auto clamp =
+        clampIntHelper(loc, sub, min, max, rewriter, /*isUnsigned=*/false);
 
     // Truncate to the final value.
     return rewriter.create<arith::TruncIOp>(loc, elementTy, clamp);
@@ -390,10 +390,15 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
         cast<IntegerAttr>(op->getAttr("max_int")).getValue().getSExtValue();
 
     if (intTy.isUnsignedInteger()) {
+      if (intTy.getIntOrFloatBitWidth() > 63) {
+        (void)rewriter.notifyMatchFailure(
+            op, "support for 64-bit or larger integers is not implemented");
+        return {};
+      }
       min = std::max(min, (int64_t)0);
-      max = std::min(
-          max,
-          APInt::getMaxValue(intTy.getIntOrFloatBitWidth()).getSExtValue());
+      max = std::min(max,
+                     (int64_t)APInt::getMaxValue(intTy.getIntOrFloatBitWidth())
+                         .getZExtValue());
     } else {
       min =
           std::max(min, APInt::getSignedMinValue(intTy.getIntOrFloatBitWidth())
@@ -407,7 +412,8 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
         loc, min, intTy.getIntOrFloatBitWidth());
     auto maxVal = rewriter.create<arith::ConstantIntOp>(
         loc, max, intTy.getIntOrFloatBitWidth());
-    return clampIntHelper(loc, args[0], minVal, maxVal, rewriter);
+    return clampIntHelper(loc, args[0], minVal, maxVal, rewriter,
+                          intTy.isUnsignedInteger());
   }
 
   // tosa::SigmoidOp
@@ -615,10 +621,9 @@ static Value expandRank(PatternRewriter &rewriter, Location loc, Value tensor,
 }
 
 static SmallVector<Value> expandInputRanks(PatternRewriter &rewriter,
-                                           Location loc, Operation *operation) {
-  auto rank =
-      cast<RankedTensorType>(operation->getResultTypes().front()).getRank();
-  return llvm::map_to_vector(operation->getOperands(), [&](Value operand) {
+                                           Location loc, ValueRange operands,
+                                           int64_t rank) {
+  return llvm::map_to_vector(operands, [&](Value operand) {
     return expandRank(rewriter, loc, operand, rank);
   });
 }
@@ -843,11 +848,16 @@ broadcastDynamicDimensions(PatternRewriter &rewriter, Location loc,
 }
 
 static LogicalResult
-emitElementwiseComputation(PatternRewriter &rewriter, Location loc,
+emitElementwiseComputation(ConversionPatternRewriter &rewriter, Location loc,
                            Operation *operation, ValueRange operands,
-                           ArrayRef<OpFoldResult> targetShape) {
+                           ArrayRef<OpFoldResult> targetShape,
+                           const TypeConverter &converter) {
   // Generate output tensor
-  auto resultType = cast<RankedTensorType>(operation->getResultTypes().front());
+  auto resultType = cast_or_null<RankedTensorType>(
+      converter.convertType(operation->getResultTypes().front()));
+  if (!resultType) {
+    return rewriter.notifyMatchFailure(operation, "failed to convert type");
+  }
   Value outputTensor = rewriter.create<tensor::EmptyOp>(
       loc, targetShape, resultType.getElementType());
 
@@ -894,8 +904,9 @@ emitElementwiseComputation(PatternRewriter &rewriter, Location loc,
 }
 
 static LogicalResult
-elementwiseMatchAndRewriteHelper(Operation *operation,
-                                 PatternRewriter &rewriter) {
+elementwiseMatchAndRewriteHelper(Operation *operation, ValueRange operands,
+                                 ConversionPatternRewriter &rewriter,
+                                 const TypeConverter &converter) {
 
   // Collect op properties
   assert(operation->getNumResults() == 1 && "elementwise op expects 1 result");
@@ -908,13 +919,15 @@ elementwiseMatchAndRewriteHelper(Operation *operation,
   // Lower operation
   IndexPool indexPool;
   auto loc = operation->getLoc();
-  auto expandedOperands = expandInputRanks(rewriter, loc, operation);
+  auto rank =
+      cast<RankedTensorType>(operation->getResultTypes().front()).getRank();
+  auto expandedOperands = expandInputRanks(rewriter, loc, operands, rank);
   auto [targetShape, masterOperands] =
       computeTargetShape(rewriter, loc, indexPool, expandedOperands);
   auto broadcastOperands = broadcastDynamicDimensions(
       rewriter, loc, indexPool, expandedOperands, targetShape, masterOperands);
   return emitElementwiseComputation(rewriter, loc, operation, broadcastOperands,
-                                    targetShape);
+                                    targetShape, converter);
 }
 
 // Returns the constant initial value for a given reduction operation. The
@@ -1100,13 +1113,16 @@ static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis,
 namespace {
 
 template <typename SrcOp>
-class PointwiseConverter : public OpRewritePattern<SrcOp> {
+class PointwiseConverter : public OpConversionPattern<SrcOp> {
 public:
-  using OpRewritePattern<SrcOp>::OpRewritePattern;
+  using OpConversionPattern<SrcOp>::OpConversionPattern;
+  using typename OpConversionPattern<SrcOp>::OpAdaptor;
 
-  LogicalResult matchAndRewrite(SrcOp op,
-                                PatternRewriter &rewriter) const final {
-    return elementwiseMatchAndRewriteHelper(op, rewriter);
+  LogicalResult
+  matchAndRewrite(SrcOp op, OpAdaptor operands,
+                  ConversionPatternRewriter &rewriter) const final {
+    return elementwiseMatchAndRewriteHelper(
+        op, operands.getOperands(), rewriter, *this->getTypeConverter());
   }
 };
 
@@ -1279,7 +1295,7 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
               loc, nestedBuilder.getI32IntegerAttr(intMax));
 
           value = clampIntHelper(nestedLoc, value, intMinVal, intMaxVal,
-                                 nestedBuilder);
+                                 nestedBuilder, /*isUnsigned=*/false);
 
           if (outIntType.getWidth() < 32) {
             value = nestedBuilder.create<arith::TruncIOp>(
@@ -1643,7 +1659,7 @@ class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
 
           auto offset = b.create<arith::SelectOp>(pred, one, zeroI32);
           val = b.create<arith::AddIOp>(val, offset);
-          val = clampIntHelper(loc, val, zeroI32, max, b);
+          val = clampIntHelper(loc, val, zeroI32, max, b, /*isUnsigned=*/false);
           return b.create<arith::IndexCastOp>(b.getIndexType(), val);
         };
 
@@ -1664,8 +1680,10 @@ class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
                                   Value max, ImplicitLocOpBuilder &b) {
           val0 = in;
           val1 = b.create<arith::AddIOp>(val0, oneVal);
-          val0 = clampIntHelper(loc, val0, zeroI32, max, b);
-          val1 = clampIntHelper(loc, val1, zeroI32, max, b);
+          val0 =
+              clampIntHelper(loc, val0, zeroI32, max, b, /*isUnsigned=*/false);
+          val1 =
+              clampIntHelper(loc, val1, zeroI32, max, b, /*isUnsigned=*/false);
           val0 = b.create<arith::IndexCastOp>(b.getIndexType(), val0);
           val1 = b.create<arith::IndexCastOp>(b.getIndexType(), val1);
         };
@@ -2552,7 +2570,7 @@ struct FFT2dConverter final : OpRewritePattern<FFT2dOp> {
 } // namespace
 
 void mlir::tosa::populateTosaToLinalgConversionPatterns(
-    RewritePatternSet *patterns) {
+    TypeConverter &converter, RewritePatternSet *patterns) {
 
   // We have multiple resize coverters to handle degenerate cases.
   patterns->add<GenericResizeConverter>(patterns->getContext(),
@@ -2599,7 +2617,10 @@ void mlir::tosa::populateTosaToLinalgConversionPatterns(
       PointwiseConverter<tosa::CeilOp>,
       PointwiseConverter<tosa::FloorOp>,
       PointwiseConverter<tosa::ClampOp>,
-      PointwiseConverter<tosa::SigmoidOp>,
+      PointwiseConverter<tosa::SigmoidOp>
+        >(converter, patterns->getContext());
+
+  patterns->add<
       IdentityNConverter<tosa::IdentityOp>,
       ReduceConverter<tosa::ReduceAllOp>,
       ReduceConverter<tosa::ReduceAnyOp>,
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index d8fb3abc0bef8..77c3d2e875791 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -1015,7 +1015,8 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
             auto max = rewriter.create<arith::ConstantIntOp>(
                 loc, APInt::getSignedMaxValue(outBitwidth).getSExtValue(),
                 accETy);
-            auto clamp = clampIntHelper(loc, scaled, min, max, rewriter);
+            auto clamp = clampIntHelper(loc, scaled, min, max, rewriter,
+                                        /*isUnsigned=*/false);
 
             poolVal = clamp;
             // Convert type.
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
index ad7f6cf84e5ed..08b0d91f3ed9d 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
@@ -63,8 +63,11 @@ struct TosaToLinalg : public impl::TosaToLinalgBase<TosaToLinalg> {
 
     target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
 
+    TypeConverter converter;
+    tosa::populateTosaToLinalgTypeConversion(converter);
+
     FunctionOpInterface func = getOperation();
-    mlir::tosa::populateTosaToLinalgConversionPatterns(&patterns);
+    mlir::tosa::populateTosaToLinalgConversionPatterns(converter, &patterns);
     if (failed(applyFullConversion(func, target, std::move(patterns))))
       signalPassFailure();
   }
diff --git a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
index 4fc97115064f3..f276924a8a9f6 100644
--- a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
+++ b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
@@ -38,7 +38,11 @@ Value mlir::tosa::clampFloatHelper(Location loc, Value arg, Value min,
 }
 
 Value mlir::tosa::clampIntHelper(Location loc, Value arg, Value min, Value max,
-                                 OpBuilder &rewriter) {
+                                 OpBuilder &rewriter, bool isUnsigned) {
+  if (isUnsigned) {
+    auto minOrArg = rewriter.create<arith::MaxUIOp>(loc, min, arg);
+    return rewriter.create<arith::MinUIOp>(loc, max, minOrArg);
+  }
   auto minOrArg = rewriter.create<arith::MaxSIOp>(loc, min, arg);
   return rewriter.create<arith::MinSIOp>(loc, max, minOrArg);
 }
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
index ad65410e635e9..f372e26096ab5 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
@@ -27,3 +27,12 @@ func.func @unranked_add(%arg0 : tensor<10x10xf32> , %arg1 : tensor<10x10xf32>, %
   %2 = tosa.reshape %0 {new_shape = array<i64: 10, 10>} : (tensor<*xf32>) -> tensor<10x10xf32>
   return %2 : tensor<10x10xf32>
 }
+
+// -----
+
+// CHECK-LABEL: @clamp_on_large_int
+func.func @clamp_on_large_int(%arg0: tensor<1xui64>) -> tensor<1xui64> {
+  // expected-error at +1 {{failed to legalize operation 'tosa.clamp'}}
+  %0 = tosa.clamp %arg0 {min_int = -1 : i64, max_int = 5 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xui64>) -> tensor<1xui64>
+  return %0 : tensor<1xui64>
+}
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 45b39f79a2a72..97142884617e2 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -606,7 +606,7 @@ func.func @test_simple_ui8(%arg0: tensor<1xui8>) -> () {
 // -----
 
 // CHECK-LABEL: @test_simple_i32
-func.func @test_simple_i32(%arg0: tensor<1xi32>) -> () {
+func.func @test_simple_i32(%arg0: tensor<1xi32>, %unsigned: tensor<1xui32>) -> () {
   // CHECK: linalg.generic
   // CHECK: arith.addi
   %0 = tosa.add %arg0, %arg0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
@@ -700,6 +700,13 @@ func.func @test_simple_i32(%arg0: tensor<1xi32>) -> () {
   // CHECK-DAG: arith.minsi
   %19 = tosa.clamp %0 {min_int = 1 : i64, max_int = 5 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xi32>) -> tensor<1xi32>
 
+  // CHECK: linalg.generic
+  // CHECK-DAG: %[[LB:.*]] = arith.constant 0 : i32
+  // CHECK-DAG: %[[UB:.*]] = arith.constant 5 : i32
+  // CHECK-DAG: arith.maxui %[[LB]],
+  // CHECK-DAG: arith.minui %[[UB]],
+  %u19 = tosa.clamp %unsigned {min_int = -1 : i64, max_int = 5 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xui32>) -> tensor<1xui32>
+
   // CHECK: linalg.generic
   // CHECK: arith.trunci
   %20 = tosa.cast %0 : (tensor<1xi32>) -> tensor<1xi16>



More information about the Mlir-commits mailing list