[Mlir-commits] [mlir] [mlir][tosa] Change ClampOp's min/max attributes (PR #125197)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Jan 31 02:45:26 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-tosa
Author: Hsiangkai Wang (Hsiangkai)
<details>
<summary>Changes</summary>
This changes Tosa ClampOp attributes to min_val and max_val which are either integer attributes or float attributes, and adds verify checks that these attribute element types must match element types of input and output
---
Patch is 56.57 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/125197.diff
11 Files Affected:
- (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td (+2-4)
- (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td (+8)
- (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp (+4-4)
- (modified) mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp (+94-35)
- (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+25-11)
- (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir (+6-35)
- (modified) mlir/test/Dialect/Tosa/canonicalize.mlir (+28-37)
- (modified) mlir/test/Dialect/Tosa/invalid.mlir (+2-2)
- (modified) mlir/test/Dialect/Tosa/ops.mlir (+6-6)
- (modified) mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir (+2-2)
- (modified) mlir/test/Dialect/Tosa/tosa-reduce-transposes.mlir (+22-22)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 9e3e41d288e4ac..41acff74321fdc 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -387,10 +387,8 @@ def Tosa_ClampOp : Tosa_ElementwiseUnaryOp<"clamp"> {
let arguments = (ins
Tosa_Tensor:$input,
- I64Attr:$min_int,
- I64Attr:$max_int,
- Tosa_FloatAttr:$min_fp,
- Tosa_FloatAttr:$max_fp,
+ Tosa_IntOrFloatAttr:$min_val,
+ Tosa_IntOrFloatAttr:$max_val,
DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
);
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
index 5693acf3a01db4..3795d51e5afce3 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
@@ -202,6 +202,14 @@ def Tosa_FloatAttr : Attr<CPred<"::llvm::isa<::mlir::FloatAttr>($_self)">,
let returnType = [{ ::mlir::APFloat }];
}
+def Tosa_IntegerAttr : Attr<CPred<"::llvm::isa<::mlir::IntegerAttr>($_self)">,
+ "arbitrary integer attribute"> {
+ let storageType = [{ ::mlir::IntegerAttr }];
+ let returnType = [{ ::llvm::APInt }];
+}
+
+def Tosa_IntOrFloatAttr : AnyAttrOf<[Tosa_IntegerAttr, Tosa_FloatAttr]>;
+
//===----------------------------------------------------------------------===//
// Iterable attributes.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index b0eb2d6cbc30b6..49cb87a8786f95 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -385,8 +385,8 @@ static Value createLinalgBodyCalculationForElementwiseOp(
// tosa::ClampOp
if (isa<tosa::ClampOp>(op) && isa<FloatType>(elementTy)) {
bool losesInfo = false;
- APFloat minApf = cast<FloatAttr>(op->getAttr("min_fp")).getValue();
- APFloat maxApf = cast<FloatAttr>(op->getAttr("max_fp")).getValue();
+ APFloat minApf = cast<FloatAttr>(op->getAttr("min_val")).getValue();
+ APFloat maxApf = cast<FloatAttr>(op->getAttr("max_val")).getValue();
minApf.convert(cast<FloatType>(elementTy).getFloatSemantics(),
APFloat::rmNearestTiesToEven, &losesInfo);
maxApf.convert(cast<FloatType>(elementTy).getFloatSemantics(),
@@ -401,9 +401,9 @@ static Value createLinalgBodyCalculationForElementwiseOp(
if (isa<tosa::ClampOp>(op) && isa<IntegerType>(elementTy)) {
auto intTy = cast<IntegerType>(elementTy);
int64_t min =
- cast<IntegerAttr>(op->getAttr("min_int")).getValue().getSExtValue();
+ cast<IntegerAttr>(op->getAttr("min_val")).getValue().getSExtValue();
int64_t max =
- cast<IntegerAttr>(op->getAttr("max_int")).getValue().getSExtValue();
+ cast<IntegerAttr>(op->getAttr("max_val")).getValue().getSExtValue();
int64_t minRepresentable = std::numeric_limits<int64_t>::min();
int64_t maxRepresentable = std::numeric_limits<int64_t>::max();
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 98871268e313b6..71369d81fbe908 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -287,10 +287,12 @@ struct ClampIsNoOp : public OpRewritePattern<tosa::ClampOp> {
if (isa<FloatType>(inputElementType)) {
// Unlike integer types, floating point types can represent infinity.
- auto minClamp = op.getMinFp();
- auto maxClamp = op.getMaxFp();
- bool isMin = minClamp.isInfinity() && minClamp.isNegative();
- bool isMax = maxClamp.isInfinity() && !maxClamp.isNegative();
+ auto minClamp =
+ llvm::cast<mlir::FloatAttr>(op.getMinValAttr()).getValue();
+ auto maxClamp =
+ llvm::cast<mlir::FloatAttr>(op.getMaxValAttr()).getValue();
+ bool isMin = minClamp.isNegInfinity();
+ bool isMax = maxClamp.isInfinity();
if (isMin && isMax) {
rewriter.replaceOp(op, input);
@@ -300,8 +302,10 @@ struct ClampIsNoOp : public OpRewritePattern<tosa::ClampOp> {
}
if (inputElementType.isUnsignedInteger()) {
- int64_t minClamp = op.getMinInt();
- int64_t maxClamp = op.getMaxInt();
+ int64_t minClamp =
+ llvm::cast<mlir::IntegerAttr>(op.getMinValAttr()).getUInt();
+ int64_t maxClamp =
+ llvm::cast<mlir::IntegerAttr>(op.getMaxValAttr()).getUInt();
int64_t intMin =
APInt::getMinValue(inputElementType.getIntOrFloatBitWidth())
@@ -318,8 +322,10 @@ struct ClampIsNoOp : public OpRewritePattern<tosa::ClampOp> {
}
if (llvm::isa<IntegerType>(inputElementType)) {
- int64_t minClamp = op.getMinInt();
- int64_t maxClamp = op.getMaxInt();
+ int64_t minClamp =
+ llvm::cast<mlir::IntegerAttr>(op.getMinValAttr()).getInt();
+ int64_t maxClamp =
+ llvm::cast<mlir::IntegerAttr>(op.getMaxValAttr()).getInt();
int64_t intMin =
APInt::getSignedMinValue(inputElementType.getIntOrFloatBitWidth())
@@ -374,9 +380,10 @@ struct ClampClampOptimization : public OpRewritePattern<tosa::ClampOp> {
LogicalResult matchAndRewrite(tosa::ClampOp op,
PatternRewriter &rewriter) const override {
+ Value input = op.getInput();
+
// Check the input to the CLAMP op is itself a CLAMP.
- auto clampOp =
- dyn_cast_if_present<tosa::ClampOp>(op.getInput().getDefiningOp());
+ auto clampOp = dyn_cast_if_present<tosa::ClampOp>(input.getDefiningOp());
if (!clampOp)
return failure();
@@ -386,34 +393,86 @@ struct ClampClampOptimization : public OpRewritePattern<tosa::ClampOp> {
if (opNanMode == "IGNORE" && clampNanMode == "PROPAGATE")
return failure();
- // Check we have intersecting ranges.
- const auto opMinInt = op.getMinInt();
- const auto opMaxInt = op.getMaxInt();
- const auto clampOpMinInt = clampOp.getMinInt();
- const auto clampOpMaxInt = clampOp.getMaxInt();
- ClampRange<std::int64_t> opRangeIntRange(opMinInt, opMaxInt);
- ClampRange<std::int64_t> clampRangeIntRange(clampOpMinInt, clampOpMaxInt);
- if (!opRangeIntRange.intersects(clampRangeIntRange))
- return failure();
+ auto maxValAttr = op.getMaxValAttr();
+ auto minValAttr = op.getMinValAttr();
+ auto clampOpMaxValAttr = clampOp.getMaxValAttr();
+ auto clampOpMinValAttr = clampOp.getMinValAttr();
- const auto opMinFloat = op.getMinFp();
- const auto opMaxFloat = op.getMaxFp();
- const auto clampOpMinFloat = clampOp.getMinFp();
- const auto clampOpMaxFloat = clampOp.getMaxFp();
- ClampRange<APFloat> opRangeFloatRange(opMinFloat, opMaxFloat);
- ClampRange<APFloat> clampRangeFloatRange(clampOpMinFloat, clampOpMaxFloat);
- if (!opRangeFloatRange.intersects(clampRangeFloatRange))
- return failure();
+ auto inputEType = llvm::cast<ShapedType>(input.getType()).getElementType();
+ if (auto quantType =
+ llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputEType)) {
+ inputEType = quantType.getStorageType();
+ }
+
+ Attribute newMinValAttr, newMaxValAttr;
+ if (mlir::isa<FloatType>(inputEType)) {
+ auto floatMaxValAttr = cast<mlir::FloatAttr>(maxValAttr);
+ auto floatMinValAttr = cast<mlir::FloatAttr>(minValAttr);
+ auto clampOpFloatMaxValAttr = cast<mlir::FloatAttr>(clampOpMaxValAttr);
+ auto clampOpFloatMinValAttr = cast<mlir::FloatAttr>(clampOpMinValAttr);
+
+ // Check we have intersecting ranges.
+ const auto opMinFloat = floatMinValAttr.getValue();
+ const auto opMaxFloat = floatMaxValAttr.getValue();
+ const auto clampOpMinFloat = clampOpFloatMinValAttr.getValue();
+ const auto clampOpMaxFloat = clampOpFloatMaxValAttr.getValue();
+ ClampRange<APFloat> opRangeFloatRange(opMinFloat, opMaxFloat);
+ ClampRange<APFloat> clampRangeFloatRange(clampOpMinFloat,
+ clampOpMaxFloat);
+ if (!opRangeFloatRange.intersects(clampRangeFloatRange))
+ return failure();
+
+ // Run the transformation.
+ auto newMinVal = std::max(opMinFloat, clampOpMinFloat);
+ auto newMaxVal = std::min(opMaxFloat, clampOpMaxFloat);
+ newMinValAttr = rewriter.getFloatAttr(inputEType, newMinVal);
+ newMaxValAttr = rewriter.getFloatAttr(inputEType, newMaxVal);
+ } else {
+ assert(mlir::isa<IntegerType>(inputEType));
+ auto intMaxValAttr = cast<mlir::IntegerAttr>(maxValAttr);
+ auto intMinValAttr = cast<mlir::IntegerAttr>(minValAttr);
+ auto clampOpIntMaxValAttr = cast<mlir::IntegerAttr>(clampOpMaxValAttr);
+ auto clampOpIntMinValAttr = cast<mlir::IntegerAttr>(clampOpMinValAttr);
+
+ if (inputEType.isUnsignedInteger()) {
+ // Check we have intersecting ranges.
+ const auto opMinInt = intMinValAttr.getUInt();
+ const auto opMaxInt = intMaxValAttr.getUInt();
+ const auto clampOpMinInt = clampOpIntMinValAttr.getUInt();
+ const auto clampOpMaxInt = clampOpIntMaxValAttr.getUInt();
+ ClampRange<std::uint64_t> opRangeIntRange(opMinInt, opMaxInt);
+ ClampRange<std::uint64_t> clampRangeIntRange(clampOpMinInt,
+ clampOpMaxInt);
+ if (!opRangeIntRange.intersects(clampRangeIntRange))
+ return failure();
+
+ // Run the transformation.
+ auto newMinVal = std::max(opMinInt, clampOpMinInt);
+ auto newMaxVal = std::min(opMaxInt, clampOpMaxInt);
+ newMinValAttr = rewriter.getIntegerAttr(inputEType, newMinVal);
+ newMaxValAttr = rewriter.getIntegerAttr(inputEType, newMaxVal);
+ } else {
+ // Check we have intersecting ranges.
+ const auto opMinInt = intMinValAttr.getInt();
+ const auto opMaxInt = intMaxValAttr.getInt();
+ const auto clampOpMinInt = clampOpIntMinValAttr.getInt();
+ const auto clampOpMaxInt = clampOpIntMaxValAttr.getInt();
+ ClampRange<std::int64_t> opRangeIntRange(opMinInt, opMaxInt);
+ ClampRange<std::int64_t> clampRangeIntRange(clampOpMinInt,
+ clampOpMaxInt);
+ if (!opRangeIntRange.intersects(clampRangeIntRange))
+ return failure();
+
+ // Run the transformation.
+ auto newMinVal = std::max(opMinInt, clampOpMinInt);
+ auto newMaxVal = std::min(opMaxInt, clampOpMaxInt);
+ newMinValAttr = rewriter.getIntegerAttr(inputEType, newMinVal);
+ newMaxValAttr = rewriter.getIntegerAttr(inputEType, newMaxVal);
+ }
+ }
- // Run the transformation.
- const auto minFp = std::max(opMinFloat, clampOpMinFloat).convertToFloat();
- const auto maxFp = std::min(opMaxFloat, clampOpMaxFloat).convertToFloat();
- const auto minInt = std::max(opMinInt, clampOpMinInt);
- const auto maxInt = std::min(opMaxInt, clampOpMaxInt);
rewriter.replaceOpWithNewOp<tosa::ClampOp>(
- op, op.getType(), clampOp.getInput(),
- rewriter.getI64IntegerAttr(minInt), rewriter.getI64IntegerAttr(maxInt),
- rewriter.getF32FloatAttr(minFp), rewriter.getF32FloatAttr(maxFp),
+ op, op.getType(), clampOp.getInput(), newMinValAttr, newMaxValAttr,
rewriter.getStringAttr((opNanMode != clampNanMode) ? "IGNORE"
: opNanMode));
return success();
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index c0b419b6f473c8..23c1d45f7c1057 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -381,26 +381,40 @@ LogicalResult tosa::ClampOp::verify() {
llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputETy)) {
inputETy = quantType.getStorageType();
}
- mlir::Type maxFpType = getMaxFpAttr().getType();
- mlir::Type minFpType = getMinFpAttr().getType();
mlir::Type outputETy =
llvm::cast<ShapedType>(getOutput().getType()).getElementType();
if (auto quantType =
llvm::dyn_cast<mlir::quant::UniformQuantizedType>(outputETy)) {
outputETy = quantType.getStorageType();
}
- unsigned dataTypeBitWidth = inputETy.getIntOrFloatBitWidth();
-
if (inputETy != outputETy)
return emitOpError("input/output element types are incompatible.");
- // If input datatype is float, check that the two min/max_fp attributes
- // share the same type and that their type is either the same of the input's
- // datatype, or a float type whose bitwidth > input datatype bitwidth.
- if (!inputETy.isInteger(dataTypeBitWidth)) {
- if (((maxFpType != minFpType) ||
- (maxFpType != inputETy && maxFpType.getIntOrFloatBitWidth() <=
- inputETy.getIntOrFloatBitWidth())))
+ auto maxValAttr = getMaxValAttr();
+ auto minValAttr = getMinValAttr();
+
+ unsigned dataTypeBitWidth = inputETy.getIntOrFloatBitWidth();
+
+ if (inputETy.isInteger(dataTypeBitWidth)) {
+ // if input datatype is integer, check that the min_val/max_val attributes
+ // are integer attributes, and that their type is the same as the input's
+ // datatype
+ auto intMaxValAttr = mlir::dyn_cast<mlir::IntegerAttr>(maxValAttr);
+ auto intMinValAttr = mlir::dyn_cast<mlir::IntegerAttr>(minValAttr);
+ if (!intMaxValAttr || !intMinValAttr ||
+ (intMaxValAttr.getType() != intMinValAttr.getType()) ||
+ (intMaxValAttr.getType() != inputETy))
+ return emitOpError("min/max attributes types are incompatible with "
+ "input/output element types.");
+ } else {
+ // otherwise, input datatype is float, check that the min_val/max_val
+ // attributes share the same type and that their type is the same as the
+ // input's datatype
+ auto floatMaxValAttr = mlir::dyn_cast<mlir::FloatAttr>(maxValAttr);
+ auto floatMinValAttr = mlir::dyn_cast<mlir::FloatAttr>(minValAttr);
+ if (!floatMaxValAttr || !floatMinValAttr ||
+ (floatMaxValAttr.getType() != floatMinValAttr.getType()) ||
+ (floatMaxValAttr.getType() != inputETy))
return emitOpError("min/max attributes types are incompatible with "
"input/output element types.");
}
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index f9bdcefa35317a..9ba08b427b1ae5 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -529,7 +529,7 @@ func.func @test_simple_f32(%arg0: tensor<1xf32>) -> () {
// CHECK: linalg.generic
// CHECK: arith.minimumf
// CHECK: arith.maximumf
- %18 = tosa.clamp %0 {min_int = 1 : i64, max_int = 5 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xf32>) -> tensor<1xf32>
+ %18 = tosa.clamp %0 {min_val = 1.0 : f32, max_val = 5.0 : f32} : (tensor<1xf32>) -> tensor<1xf32>
// CHECK: linalg.generic
// CHECK: arith.negf
@@ -729,35 +729,14 @@ func.func @test_simple_i32(%arg0: tensor<1xi32>, %unsigned: tensor<1xui32>, %uns
// CHECK: linalg.generic
// CHECK-DAG: arith.maxsi
// CHECK-DAG: arith.minsi
- %19 = tosa.clamp %0 {min_int = 1 : i64, max_int = 5 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xi32>) -> tensor<1xi32>
+ %19 = tosa.clamp %0 {min_val = 1 : i32, max_val = 5 : i32} : (tensor<1xi32>) -> tensor<1xi32>
// CHECK: linalg.generic
// CHECK-DAG: %[[LB:.*]] = arith.constant 4 : i32
// CHECK-DAG: %[[UB:.*]] = arith.constant 32 : i32
// CHECK-DAG: arith.maxui %[[LB]],
// CHECK-DAG: arith.minui %[[UB]],
- %u0 = tosa.clamp %unsigned {min_int = 4 : i64, max_int = 32 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xui32>) -> tensor<1xui32>
-
- // CHECK: linalg.generic
- // CHECK-DAG: %[[LB:.*]] = arith.constant -1 : i32
- // CHECK-DAG: %[[UB:.*]] = arith.constant -1 : i32
- // CHECK-DAG: arith.maxui %[[LB]],
- // CHECK-DAG: arith.minui %[[UB]],
- %u1 = tosa.clamp %unsigned {min_int = 9223372036854775807 : i64, max_int = 9223372036854775807 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xui32>) -> tensor<1xui32>
-
- // CHECK: linalg.generic
- // CHECK-DAG: %[[LB:.*]] = arith.constant 0 : i32
- // CHECK-DAG: %[[UB:.*]] = arith.constant 0 : i32
- // CHECK-DAG: arith.maxui %[[LB]],
- // CHECK-DAG: arith.minui %[[UB]],
- %u2 = tosa.clamp %unsigned {min_int = -3 : i64, max_int = -2 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xui32>) -> tensor<1xui32>
-
- // CHECK: linalg.generic
- // CHECK-DAG: %[[LB:.*]] = arith.constant 0 : i64
- // CHECK-DAG: %[[UB:.*]] = arith.constant 9223372036854775807 : i64
- // CHECK-DAG: arith.maxui %[[LB]],
- // CHECK-DAG: arith.minui %[[UB]],
- %u3 = tosa.clamp %unsigned64 {min_int = -3 : i64, max_int = 9223372036854775807 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xui64>) -> tensor<1xui64>
+ %u0 = tosa.clamp %unsigned {min_val = 4 : ui32, max_val = 32 : ui32} : (tensor<1xui32>) -> tensor<1xui32>
// CHECK: linalg.generic
// CHECK: arith.trunci
@@ -807,15 +786,7 @@ func.func @test_i8(%arg0: tensor<1xi8>) -> () {
// CHECK-DAG: %[[C126:.+]] = arith.constant 126
// CHECK-DAG: %[[LOWER:.+]] = arith.maxsi %[[C127]], %[[ARG1]]
// CHECK-DAG: %[[CLAMPED:.+]] = arith.minsi %[[C126]], %[[LOWER]]
- %0 = tosa.clamp %arg0 {min_int = -127 : i64, max_int = 126 : i64, min_fp = 0.0 : f32, max_fp = 0.0 : f32} : (tensor<1xi8>) -> tensor<1xi8>
-
- // CHECK: linalg.generic
- // CHECK: ^bb0(%[[ARG1:.+]]: i8,
- // CHECK-DAG: %[[C128:.+]] = arith.constant -128
- // CHECK-DAG: %[[C127:.+]] = arith.constant 127
- // CHECK-DAG: %[[LOWER:.+]] = arith.maxsi %[[C128]], %[[ARG1]]
- // CHECK-DAG: %[[CLAMPED:.+]] = arith.minsi %[[C127]], %[[LOWER]]
- %1 = tosa.clamp %arg0 {min_int = -130 : i64, max_int = 130 : i64, min_fp = 0.0 : f32, max_fp = 0.0 : f32} : (tensor<1xi8>) -> tensor<1xi8>
+ %0 = tosa.clamp %arg0 {min_val = -127 : i8, max_val = 126 : i8} : (tensor<1xi8>) -> tensor<1xi8>
return
}
@@ -830,7 +801,7 @@ func.func @test_i64(%arg0: tensor<1xi64>) -> () {
// CHECK-DAG: %[[C126:.+]] = arith.constant 9223372036854775807
// CHECK-DAG: %[[LOWER:.+]] = arith.maxsi %[[C127]], %[[ARG1]]
// CHECK-DAG: %[[CLAMPED:.+]] = arith.minsi %[[C126]], %[[LOWER]]
- %0 = tosa.clamp %arg0 {min_int = -9223372036854775808 : i64, max_int = 9223372036854775807 : i64, min_fp = 0.0 : f32, max_fp = 0.0 : f32} : (tensor<1xi64>) -> tensor<1xi64>
+ %0 = tosa.clamp %arg0 {min_val = -9223372036854775808 : i64, max_val = 9223372036854775807 : i64} : (tensor<1xi64>) -> tensor<1xi64>
return
}
@@ -845,7 +816,7 @@ func.func @test_clamp_f16(%arg0: tensor<1xf16>) -> () {
// CHECK-DAG: %[[C6:.+]] = arith.constant 6.0
// CHECK-DAG: %[[MIN:.+]] = arith.minimumf %[[ARG1]], %[[C6]]
// CHECK-DAG: %[[MAX:.+]] = arith.maximumf %[[MIN]], %[[C0]]
- %0 = tosa.clamp %arg0 {min_int = 0 : i64, max_int = 0 : i64, min_fp = 0.0 : f32, max_fp = 6.0 : f32} : (tensor<1xf16>) -> tensor<1xf16>
+ %0 = tosa.clamp %arg0 {min_val = 0.0 : f16, max_val = 6.0 : f16} : (tensor<1xf16>) -> tensor<1xf16>
return
}
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 71a7e2826a63cc..c104ac10f64b92 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -52,25 +52,16 @@ func.func @cast_nofold(%arg0: tensor<?x1xf32>) -> tensor<?x1xi32> {
// CHECK-LABEL: @clamp_i32_not_noop
func.func @clamp_i32_not_noop(%arg0: tensor<4xi32>) -> tensor<4xi32> {
// CHECK: tosa.clamp
- %0 = tosa.clamp %arg0 {min_int = 1 : i64, max_int = 4 : i64, min_fp = 1.0 : f32, max_fp = 4.0 : f32} : (tensor<4xi32>) -> tensor<4xi32>
+ %0 = tosa.clamp %arg0 {min_val = 1 : i32, max_val = 4 : i32} : (tensor<4xi32>) -> tensor<4xi32>
return %0 : tensor<4xi32>
}
// -----
-// CHECK-LABEL: @clamp_f16_not_noop
-func.func @clamp_f16_not_noop(%arg0: tensor<4xf16>) -> tensor<4xf16> {
- // CHECK: tosa.clamp
- %0 = tosa.clamp %arg0 {min_int = -128 : i64, max_int = 127 : i64, min_fp = -3.40282347E+38 : f32, max_fp = 3.40282347E+38 : f32} : (tensor<4xf16>) -> tensor<4xf16>
- return %0 : tensor<4xf16>
-}
-
-...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/125197
More information about the Mlir-commits
mailing list