[Mlir-commits] [mlir] [mlir][tosa] Change ClampOp's min/max attributes (PR #125197)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jan 31 02:45:26 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-tosa

Author: Hsiangkai Wang (Hsiangkai)

<details>
<summary>Changes</summary>

This changes Tosa ClampOp attributes to min_val and max_val which are either integer attributes or float attributes, and adds verify checks that these attribute element types must match element types of input and output

---

Patch is 56.57 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/125197.diff


11 Files Affected:

- (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td (+2-4) 
- (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td (+8) 
- (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp (+4-4) 
- (modified) mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp (+94-35) 
- (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+25-11) 
- (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir (+6-35) 
- (modified) mlir/test/Dialect/Tosa/canonicalize.mlir (+28-37) 
- (modified) mlir/test/Dialect/Tosa/invalid.mlir (+2-2) 
- (modified) mlir/test/Dialect/Tosa/ops.mlir (+6-6) 
- (modified) mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir (+2-2) 
- (modified) mlir/test/Dialect/Tosa/tosa-reduce-transposes.mlir (+22-22) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 9e3e41d288e4ac..41acff74321fdc 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -387,10 +387,8 @@ def Tosa_ClampOp : Tosa_ElementwiseUnaryOp<"clamp"> {
 
   let arguments = (ins
     Tosa_Tensor:$input,
-    I64Attr:$min_int,
-    I64Attr:$max_int,
-    Tosa_FloatAttr:$min_fp,
-    Tosa_FloatAttr:$max_fp,
+    Tosa_IntOrFloatAttr:$min_val,
+    Tosa_IntOrFloatAttr:$max_val,
     DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
   );
 
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
index 5693acf3a01db4..3795d51e5afce3 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
@@ -202,6 +202,14 @@ def Tosa_FloatAttr : Attr<CPred<"::llvm::isa<::mlir::FloatAttr>($_self)">,
   let returnType = [{ ::mlir::APFloat }];
 }
 
+def Tosa_IntegerAttr : Attr<CPred<"::llvm::isa<::mlir::IntegerAttr>($_self)">,
+                          "arbitrary integer attribute"> {
+  let storageType = [{ ::mlir::IntegerAttr }];
+  let returnType = [{ ::llvm::APInt }];
+}
+
+def Tosa_IntOrFloatAttr : AnyAttrOf<[Tosa_IntegerAttr, Tosa_FloatAttr]>;
+
 //===----------------------------------------------------------------------===//
 // Iterable attributes.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index b0eb2d6cbc30b6..49cb87a8786f95 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -385,8 +385,8 @@ static Value createLinalgBodyCalculationForElementwiseOp(
   // tosa::ClampOp
   if (isa<tosa::ClampOp>(op) && isa<FloatType>(elementTy)) {
     bool losesInfo = false;
-    APFloat minApf = cast<FloatAttr>(op->getAttr("min_fp")).getValue();
-    APFloat maxApf = cast<FloatAttr>(op->getAttr("max_fp")).getValue();
+    APFloat minApf = cast<FloatAttr>(op->getAttr("min_val")).getValue();
+    APFloat maxApf = cast<FloatAttr>(op->getAttr("max_val")).getValue();
     minApf.convert(cast<FloatType>(elementTy).getFloatSemantics(),
                    APFloat::rmNearestTiesToEven, &losesInfo);
     maxApf.convert(cast<FloatType>(elementTy).getFloatSemantics(),
@@ -401,9 +401,9 @@ static Value createLinalgBodyCalculationForElementwiseOp(
   if (isa<tosa::ClampOp>(op) && isa<IntegerType>(elementTy)) {
     auto intTy = cast<IntegerType>(elementTy);
     int64_t min =
-        cast<IntegerAttr>(op->getAttr("min_int")).getValue().getSExtValue();
+        cast<IntegerAttr>(op->getAttr("min_val")).getValue().getSExtValue();
     int64_t max =
-        cast<IntegerAttr>(op->getAttr("max_int")).getValue().getSExtValue();
+        cast<IntegerAttr>(op->getAttr("max_val")).getValue().getSExtValue();
 
     int64_t minRepresentable = std::numeric_limits<int64_t>::min();
     int64_t maxRepresentable = std::numeric_limits<int64_t>::max();
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 98871268e313b6..71369d81fbe908 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -287,10 +287,12 @@ struct ClampIsNoOp : public OpRewritePattern<tosa::ClampOp> {
 
     if (isa<FloatType>(inputElementType)) {
       // Unlike integer types, floating point types can represent infinity.
-      auto minClamp = op.getMinFp();
-      auto maxClamp = op.getMaxFp();
-      bool isMin = minClamp.isInfinity() && minClamp.isNegative();
-      bool isMax = maxClamp.isInfinity() && !maxClamp.isNegative();
+      auto minClamp =
+          llvm::cast<mlir::FloatAttr>(op.getMinValAttr()).getValue();
+      auto maxClamp =
+          llvm::cast<mlir::FloatAttr>(op.getMaxValAttr()).getValue();
+      bool isMin = minClamp.isNegInfinity();
+      bool isMax = maxClamp.isInfinity();
 
       if (isMin && isMax) {
         rewriter.replaceOp(op, input);
@@ -300,8 +302,10 @@ struct ClampIsNoOp : public OpRewritePattern<tosa::ClampOp> {
     }
 
     if (inputElementType.isUnsignedInteger()) {
-      int64_t minClamp = op.getMinInt();
-      int64_t maxClamp = op.getMaxInt();
+      int64_t minClamp =
+          llvm::cast<mlir::IntegerAttr>(op.getMinValAttr()).getUInt();
+      int64_t maxClamp =
+          llvm::cast<mlir::IntegerAttr>(op.getMaxValAttr()).getUInt();
 
       int64_t intMin =
           APInt::getMinValue(inputElementType.getIntOrFloatBitWidth())
@@ -318,8 +322,10 @@ struct ClampIsNoOp : public OpRewritePattern<tosa::ClampOp> {
     }
 
     if (llvm::isa<IntegerType>(inputElementType)) {
-      int64_t minClamp = op.getMinInt();
-      int64_t maxClamp = op.getMaxInt();
+      int64_t minClamp =
+          llvm::cast<mlir::IntegerAttr>(op.getMinValAttr()).getInt();
+      int64_t maxClamp =
+          llvm::cast<mlir::IntegerAttr>(op.getMaxValAttr()).getInt();
 
       int64_t intMin =
           APInt::getSignedMinValue(inputElementType.getIntOrFloatBitWidth())
@@ -374,9 +380,10 @@ struct ClampClampOptimization : public OpRewritePattern<tosa::ClampOp> {
 
   LogicalResult matchAndRewrite(tosa::ClampOp op,
                                 PatternRewriter &rewriter) const override {
+    Value input = op.getInput();
+
     // Check the input to the CLAMP op is itself a CLAMP.
-    auto clampOp =
-        dyn_cast_if_present<tosa::ClampOp>(op.getInput().getDefiningOp());
+    auto clampOp = dyn_cast_if_present<tosa::ClampOp>(input.getDefiningOp());
     if (!clampOp)
       return failure();
 
@@ -386,34 +393,86 @@ struct ClampClampOptimization : public OpRewritePattern<tosa::ClampOp> {
     if (opNanMode == "IGNORE" && clampNanMode == "PROPAGATE")
       return failure();
 
-    // Check we have intersecting ranges.
-    const auto opMinInt = op.getMinInt();
-    const auto opMaxInt = op.getMaxInt();
-    const auto clampOpMinInt = clampOp.getMinInt();
-    const auto clampOpMaxInt = clampOp.getMaxInt();
-    ClampRange<std::int64_t> opRangeIntRange(opMinInt, opMaxInt);
-    ClampRange<std::int64_t> clampRangeIntRange(clampOpMinInt, clampOpMaxInt);
-    if (!opRangeIntRange.intersects(clampRangeIntRange))
-      return failure();
+    auto maxValAttr = op.getMaxValAttr();
+    auto minValAttr = op.getMinValAttr();
+    auto clampOpMaxValAttr = clampOp.getMaxValAttr();
+    auto clampOpMinValAttr = clampOp.getMinValAttr();
 
-    const auto opMinFloat = op.getMinFp();
-    const auto opMaxFloat = op.getMaxFp();
-    const auto clampOpMinFloat = clampOp.getMinFp();
-    const auto clampOpMaxFloat = clampOp.getMaxFp();
-    ClampRange<APFloat> opRangeFloatRange(opMinFloat, opMaxFloat);
-    ClampRange<APFloat> clampRangeFloatRange(clampOpMinFloat, clampOpMaxFloat);
-    if (!opRangeFloatRange.intersects(clampRangeFloatRange))
-      return failure();
+    auto inputEType = llvm::cast<ShapedType>(input.getType()).getElementType();
+    if (auto quantType =
+            llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputEType)) {
+      inputEType = quantType.getStorageType();
+    }
+
+    Attribute newMinValAttr, newMaxValAttr;
+    if (mlir::isa<FloatType>(inputEType)) {
+      auto floatMaxValAttr = cast<mlir::FloatAttr>(maxValAttr);
+      auto floatMinValAttr = cast<mlir::FloatAttr>(minValAttr);
+      auto clampOpFloatMaxValAttr = cast<mlir::FloatAttr>(clampOpMaxValAttr);
+      auto clampOpFloatMinValAttr = cast<mlir::FloatAttr>(clampOpMinValAttr);
+
+      // Check we have intersecting ranges.
+      const auto opMinFloat = floatMinValAttr.getValue();
+      const auto opMaxFloat = floatMaxValAttr.getValue();
+      const auto clampOpMinFloat = clampOpFloatMinValAttr.getValue();
+      const auto clampOpMaxFloat = clampOpFloatMaxValAttr.getValue();
+      ClampRange<APFloat> opRangeFloatRange(opMinFloat, opMaxFloat);
+      ClampRange<APFloat> clampRangeFloatRange(clampOpMinFloat,
+                                               clampOpMaxFloat);
+      if (!opRangeFloatRange.intersects(clampRangeFloatRange))
+        return failure();
+
+      // Run the transformation.
+      auto newMinVal = std::max(opMinFloat, clampOpMinFloat);
+      auto newMaxVal = std::min(opMaxFloat, clampOpMaxFloat);
+      newMinValAttr = rewriter.getFloatAttr(inputEType, newMinVal);
+      newMaxValAttr = rewriter.getFloatAttr(inputEType, newMaxVal);
+    } else {
+      assert(mlir::isa<IntegerType>(inputEType));
+      auto intMaxValAttr = cast<mlir::IntegerAttr>(maxValAttr);
+      auto intMinValAttr = cast<mlir::IntegerAttr>(minValAttr);
+      auto clampOpIntMaxValAttr = cast<mlir::IntegerAttr>(clampOpMaxValAttr);
+      auto clampOpIntMinValAttr = cast<mlir::IntegerAttr>(clampOpMinValAttr);
+
+      if (inputEType.isUnsignedInteger()) {
+        // Check we have intersecting ranges.
+        const auto opMinInt = intMinValAttr.getUInt();
+        const auto opMaxInt = intMaxValAttr.getUInt();
+        const auto clampOpMinInt = clampOpIntMinValAttr.getUInt();
+        const auto clampOpMaxInt = clampOpIntMaxValAttr.getUInt();
+        ClampRange<std::uint64_t> opRangeIntRange(opMinInt, opMaxInt);
+        ClampRange<std::uint64_t> clampRangeIntRange(clampOpMinInt,
+                                                     clampOpMaxInt);
+        if (!opRangeIntRange.intersects(clampRangeIntRange))
+          return failure();
+
+        // Run the transformation.
+        auto newMinVal = std::max(opMinInt, clampOpMinInt);
+        auto newMaxVal = std::min(opMaxInt, clampOpMaxInt);
+        newMinValAttr = rewriter.getIntegerAttr(inputEType, newMinVal);
+        newMaxValAttr = rewriter.getIntegerAttr(inputEType, newMaxVal);
+      } else {
+        // Check we have intersecting ranges.
+        const auto opMinInt = intMinValAttr.getInt();
+        const auto opMaxInt = intMaxValAttr.getInt();
+        const auto clampOpMinInt = clampOpIntMinValAttr.getInt();
+        const auto clampOpMaxInt = clampOpIntMaxValAttr.getInt();
+        ClampRange<std::int64_t> opRangeIntRange(opMinInt, opMaxInt);
+        ClampRange<std::int64_t> clampRangeIntRange(clampOpMinInt,
+                                                    clampOpMaxInt);
+        if (!opRangeIntRange.intersects(clampRangeIntRange))
+          return failure();
+
+        // Run the transformation.
+        auto newMinVal = std::max(opMinInt, clampOpMinInt);
+        auto newMaxVal = std::min(opMaxInt, clampOpMaxInt);
+        newMinValAttr = rewriter.getIntegerAttr(inputEType, newMinVal);
+        newMaxValAttr = rewriter.getIntegerAttr(inputEType, newMaxVal);
+      }
+    }
 
-    // Run the transformation.
-    const auto minFp = std::max(opMinFloat, clampOpMinFloat).convertToFloat();
-    const auto maxFp = std::min(opMaxFloat, clampOpMaxFloat).convertToFloat();
-    const auto minInt = std::max(opMinInt, clampOpMinInt);
-    const auto maxInt = std::min(opMaxInt, clampOpMaxInt);
     rewriter.replaceOpWithNewOp<tosa::ClampOp>(
-        op, op.getType(), clampOp.getInput(),
-        rewriter.getI64IntegerAttr(minInt), rewriter.getI64IntegerAttr(maxInt),
-        rewriter.getF32FloatAttr(minFp), rewriter.getF32FloatAttr(maxFp),
+        op, op.getType(), clampOp.getInput(), newMinValAttr, newMaxValAttr,
         rewriter.getStringAttr((opNanMode != clampNanMode) ? "IGNORE"
                                                            : opNanMode));
     return success();
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index c0b419b6f473c8..23c1d45f7c1057 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -381,26 +381,40 @@ LogicalResult tosa::ClampOp::verify() {
           llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputETy)) {
     inputETy = quantType.getStorageType();
   }
-  mlir::Type maxFpType = getMaxFpAttr().getType();
-  mlir::Type minFpType = getMinFpAttr().getType();
   mlir::Type outputETy =
       llvm::cast<ShapedType>(getOutput().getType()).getElementType();
   if (auto quantType =
           llvm::dyn_cast<mlir::quant::UniformQuantizedType>(outputETy)) {
     outputETy = quantType.getStorageType();
   }
-  unsigned dataTypeBitWidth = inputETy.getIntOrFloatBitWidth();
-
   if (inputETy != outputETy)
     return emitOpError("input/output element types are incompatible.");
 
-  // If input datatype is float, check that the two min/max_fp attributes
-  // share the same type and that their type is either the same of the input's
-  // datatype, or a float type whose bitwidth > input datatype bitwidth.
-  if (!inputETy.isInteger(dataTypeBitWidth)) {
-    if (((maxFpType != minFpType) ||
-         (maxFpType != inputETy && maxFpType.getIntOrFloatBitWidth() <=
-                                       inputETy.getIntOrFloatBitWidth())))
+  auto maxValAttr = getMaxValAttr();
+  auto minValAttr = getMinValAttr();
+
+  unsigned dataTypeBitWidth = inputETy.getIntOrFloatBitWidth();
+
+  if (inputETy.isInteger(dataTypeBitWidth)) {
+    // if input datatype is integer, check that the min_val/max_val attributes
+    // are integer attributes, and that their type is the same as the input's
+    // datatype
+    auto intMaxValAttr = mlir::dyn_cast<mlir::IntegerAttr>(maxValAttr);
+    auto intMinValAttr = mlir::dyn_cast<mlir::IntegerAttr>(minValAttr);
+    if (!intMaxValAttr || !intMinValAttr ||
+        (intMaxValAttr.getType() != intMinValAttr.getType()) ||
+        (intMaxValAttr.getType() != inputETy))
+      return emitOpError("min/max attributes types are incompatible with "
+                         "input/output element types.");
+  } else {
+    // otherwise, input datatype is float, check that the min_val/max_val
+    // attributes share the same type and that their type is the same as the
+    // input's datatype
+    auto floatMaxValAttr = mlir::dyn_cast<mlir::FloatAttr>(maxValAttr);
+    auto floatMinValAttr = mlir::dyn_cast<mlir::FloatAttr>(minValAttr);
+    if (!floatMaxValAttr || !floatMinValAttr ||
+        (floatMaxValAttr.getType() != floatMinValAttr.getType()) ||
+        (floatMaxValAttr.getType() != inputETy))
       return emitOpError("min/max attributes types are incompatible with "
                          "input/output element types.");
   }
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index f9bdcefa35317a..9ba08b427b1ae5 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -529,7 +529,7 @@ func.func @test_simple_f32(%arg0: tensor<1xf32>) -> () {
   // CHECK: linalg.generic
   // CHECK: arith.minimumf
   // CHECK: arith.maximumf
-  %18 = tosa.clamp %0 {min_int = 1 : i64, max_int = 5 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xf32>) -> tensor<1xf32>
+  %18 = tosa.clamp %0 {min_val = 1.0 : f32, max_val = 5.0 : f32} : (tensor<1xf32>) -> tensor<1xf32>
 
   // CHECK: linalg.generic
   // CHECK: arith.negf
@@ -729,35 +729,14 @@ func.func @test_simple_i32(%arg0: tensor<1xi32>, %unsigned: tensor<1xui32>, %uns
   // CHECK: linalg.generic
   // CHECK-DAG: arith.maxsi
   // CHECK-DAG: arith.minsi
-  %19 = tosa.clamp %0 {min_int = 1 : i64, max_int = 5 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xi32>) -> tensor<1xi32>
+  %19 = tosa.clamp %0 {min_val = 1 : i32, max_val = 5 : i32} : (tensor<1xi32>) -> tensor<1xi32>
 
   // CHECK: linalg.generic
   // CHECK-DAG: %[[LB:.*]] = arith.constant 4 : i32
   // CHECK-DAG: %[[UB:.*]] = arith.constant 32 : i32
   // CHECK-DAG: arith.maxui %[[LB]],
   // CHECK-DAG: arith.minui %[[UB]],
-  %u0 = tosa.clamp %unsigned {min_int = 4 : i64, max_int = 32 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xui32>) -> tensor<1xui32>
-
-  // CHECK: linalg.generic
-  // CHECK-DAG: %[[LB:.*]] = arith.constant -1 : i32
-  // CHECK-DAG: %[[UB:.*]] = arith.constant -1 : i32
-  // CHECK-DAG: arith.maxui %[[LB]],
-  // CHECK-DAG: arith.minui %[[UB]],
-  %u1 = tosa.clamp %unsigned {min_int = 9223372036854775807 : i64, max_int = 9223372036854775807 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xui32>) -> tensor<1xui32>
-
-  // CHECK: linalg.generic
-  // CHECK-DAG: %[[LB:.*]] = arith.constant 0 : i32
-  // CHECK-DAG: %[[UB:.*]] = arith.constant 0 : i32
-  // CHECK-DAG: arith.maxui %[[LB]],
-  // CHECK-DAG: arith.minui %[[UB]],
-  %u2 = tosa.clamp %unsigned {min_int = -3 : i64, max_int = -2 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xui32>) -> tensor<1xui32>
-
-  // CHECK: linalg.generic
-  // CHECK-DAG: %[[LB:.*]] = arith.constant 0 : i64
-  // CHECK-DAG: %[[UB:.*]] = arith.constant 9223372036854775807 : i64
-  // CHECK-DAG: arith.maxui %[[LB]],
-  // CHECK-DAG: arith.minui %[[UB]],
-  %u3 = tosa.clamp %unsigned64 {min_int = -3 : i64, max_int = 9223372036854775807 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xui64>) -> tensor<1xui64>
+  %u0 = tosa.clamp %unsigned {min_val = 4 : ui32, max_val = 32 : ui32} : (tensor<1xui32>) -> tensor<1xui32>
 
   // CHECK: linalg.generic
   // CHECK: arith.trunci
@@ -807,15 +786,7 @@ func.func @test_i8(%arg0: tensor<1xi8>) -> () {
   // CHECK-DAG: %[[C126:.+]] = arith.constant 126
   // CHECK-DAG: %[[LOWER:.+]] = arith.maxsi %[[C127]], %[[ARG1]]
   // CHECK-DAG: %[[CLAMPED:.+]] = arith.minsi %[[C126]], %[[LOWER]]
-  %0 = tosa.clamp %arg0 {min_int = -127 : i64, max_int = 126 : i64, min_fp = 0.0 : f32, max_fp = 0.0 : f32} : (tensor<1xi8>) -> tensor<1xi8>
-
-  // CHECK: linalg.generic
-  // CHECK: ^bb0(%[[ARG1:.+]]: i8,
-  // CHECK-DAG: %[[C128:.+]] = arith.constant -128
-  // CHECK-DAG: %[[C127:.+]] = arith.constant 127
-  // CHECK-DAG: %[[LOWER:.+]] = arith.maxsi %[[C128]], %[[ARG1]]
-  // CHECK-DAG: %[[CLAMPED:.+]] = arith.minsi %[[C127]], %[[LOWER]]
-  %1 = tosa.clamp %arg0 {min_int = -130 : i64, max_int = 130 : i64, min_fp = 0.0 : f32, max_fp = 0.0 : f32} : (tensor<1xi8>) -> tensor<1xi8>
+  %0 = tosa.clamp %arg0 {min_val = -127 : i8, max_val = 126 : i8} : (tensor<1xi8>) -> tensor<1xi8>
 
   return
 }
@@ -830,7 +801,7 @@ func.func @test_i64(%arg0: tensor<1xi64>) -> () {
   // CHECK-DAG: %[[C126:.+]] = arith.constant 9223372036854775807
   // CHECK-DAG: %[[LOWER:.+]] = arith.maxsi %[[C127]], %[[ARG1]]
   // CHECK-DAG: %[[CLAMPED:.+]] = arith.minsi %[[C126]], %[[LOWER]]
-  %0 = tosa.clamp %arg0 {min_int = -9223372036854775808 : i64, max_int = 9223372036854775807 : i64, min_fp = 0.0 : f32, max_fp = 0.0 : f32} : (tensor<1xi64>) -> tensor<1xi64>
+  %0 = tosa.clamp %arg0 {min_val = -9223372036854775808 : i64, max_val = 9223372036854775807 : i64} : (tensor<1xi64>) -> tensor<1xi64>
 
   return
 }
@@ -845,7 +816,7 @@ func.func @test_clamp_f16(%arg0: tensor<1xf16>) -> () {
   // CHECK-DAG: %[[C6:.+]] = arith.constant 6.0
   // CHECK-DAG: %[[MIN:.+]] = arith.minimumf %[[ARG1]], %[[C6]]
   // CHECK-DAG: %[[MAX:.+]] = arith.maximumf %[[MIN]], %[[C0]]
-  %0 = tosa.clamp %arg0 {min_int = 0 : i64, max_int = 0 : i64, min_fp = 0.0 : f32, max_fp = 6.0 : f32} : (tensor<1xf16>) -> tensor<1xf16>
+  %0 = tosa.clamp %arg0 {min_val = 0.0 : f16, max_val = 6.0 : f16} : (tensor<1xf16>) -> tensor<1xf16>
 
   return
 }
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 71a7e2826a63cc..c104ac10f64b92 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -52,25 +52,16 @@ func.func @cast_nofold(%arg0: tensor<?x1xf32>) -> tensor<?x1xi32> {
 // CHECK-LABEL: @clamp_i32_not_noop
 func.func @clamp_i32_not_noop(%arg0: tensor<4xi32>) -> tensor<4xi32> {
   // CHECK: tosa.clamp
-  %0 = tosa.clamp %arg0 {min_int = 1 : i64, max_int = 4 : i64, min_fp = 1.0 : f32, max_fp = 4.0 : f32} : (tensor<4xi32>) -> tensor<4xi32>
+  %0 = tosa.clamp %arg0 {min_val = 1 : i32, max_val = 4 : i32} : (tensor<4xi32>) -> tensor<4xi32>
   return %0 : tensor<4xi32>
 }
 
 // -----
 
-// CHECK-LABEL: @clamp_f16_not_noop
-func.func @clamp_f16_not_noop(%arg0: tensor<4xf16>) -> tensor<4xf16> {
-  // CHECK: tosa.clamp
-  %0 = tosa.clamp %arg0 {min_int = -128 : i64, max_int = 127 : i64, min_fp = -3.40282347E+38 : f32, max_fp = 3.40282347E+38 : f32} : (tensor<4xf16>) -> tensor<4xf16>
-  return %0 : tensor<4xf16>
-}
-
-...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/125197


More information about the Mlir-commits mailing list