[Mlir-commits] [mlir] [TOSA] TosaToLinalg: fix int64_t min/max lowering of clamp (PR #82641)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Feb 22 07:57:39 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-linalg
Author: Matthias Gehre (mgehre-amd)
<details>
<summary>Changes</summary>
tosa.clamp takes `min`/`max` attributes as i64, so ensure that the lowering to linalg works for the whole range.
---
Full diff: https://github.com/llvm/llvm-project/pull/82641.diff
2 Files Affected:
- (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp (+14-14)
- (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir (+15)
``````````diff
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 7eb32ebe3228fb..b706ac35c5ab15 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -382,25 +382,25 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
return clampFloatHelper(loc, args[0], min, max, rewriter);
}
- if (isa<tosa::ClampOp>(op) && isa<IntegerType>(elementTy)) {
- auto intTy = cast<IntegerType>(elementTy);
- int32_t min = static_cast<int32_t>(
- cast<IntegerAttr>(op->getAttr("min_int")).getValue().getSExtValue());
- int32_t max = static_cast<int32_t>(
- cast<IntegerAttr>(op->getAttr("max_int")).getValue().getSExtValue());
+ if (isa<tosa::ClampOp>(op) && elementTy.isa<IntegerType>()) {
+ auto intTy = elementTy.cast<IntegerType>();
+ int64_t min =
+ cast<IntegerAttr>(op->getAttr("min_int")).getValue().getSExtValue();
+ int64_t max =
+ cast<IntegerAttr>(op->getAttr("max_int")).getValue().getSExtValue();
if (intTy.isUnsignedInteger()) {
- min = std::max<int32_t>(min, 0);
- max = std::min<int32_t>(
+ min = std::max(min, (int64_t)0);
+ max = std::min(
max,
APInt::getMaxValue(intTy.getIntOrFloatBitWidth()).getSExtValue());
} else {
- min = std::max<int32_t>(
- min, APInt::getSignedMinValue(intTy.getIntOrFloatBitWidth())
- .getSExtValue());
- max = std::min<int32_t>(
- max, APInt::getSignedMaxValue(intTy.getIntOrFloatBitWidth())
- .getSExtValue());
+ min =
+ std::max(min, APInt::getSignedMinValue(intTy.getIntOrFloatBitWidth())
+ .getSExtValue());
+ max =
+ std::min(max, APInt::getSignedMaxValue(intTy.getIntOrFloatBitWidth())
+ .getSExtValue());
}
auto minVal = rewriter.create<arith::ConstantIntOp>(
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index febe74e8767465..1fa783f05f04ee 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -759,6 +759,21 @@ func.func @test_i8(%arg0: tensor<1xi8>) -> () {
// -----
+// CHECK-LABEL: @test_i64
+func.func @test_i64(%arg0: tensor<1xi64>) -> () {
+ // CHECK: linalg.generic
+ // CHECK: ^bb0(%[[ARG1:.+]]: i64,
+ // CHECK-DAG: %[[C127:.+]] = arith.constant -9223372036854775808
+ // CHECK-DAG: %[[C126:.+]] = arith.constant 9223372036854775807
+ // CHECK-DAG: %[[LOWER:.+]] = arith.maxsi %[[C127]], %[[ARG1]]
+ // CHECK-DAG: %[[CLAMPED:.+]] = arith.minsi %[[C126]], %[[LOWER]]
+ %0 = tosa.clamp %arg0 {min_int = -9223372036854775808 : i64, max_int = 9223372036854775807 : i64, min_fp = 0.0 : f32, max_fp = 0.0 : f32} : (tensor<1xi64>) -> tensor<1xi64>
+
+ return
+}
+
+// -----
+
// CHECK-LABEL: @test_clamp_f16
func.func @test_clamp_f16(%arg0: tensor<1xf16>) -> () {
// CHECK: linalg.generic
``````````
</details>
https://github.com/llvm/llvm-project/pull/82641
More information about the Mlir-commits
mailing list