[Mlir-commits] [mlir] [TOSA] TosaToLinalg: fix int64_t min/max lowering of clamp (PR #82641)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Feb 22 07:57:39 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-linalg

Author: Matthias Gehre (mgehre-amd)

<details>
<summary>Changes</summary>

tosa.clamp takes `min`/`max` attributes as i64, so ensure that the lowering to linalg works for the whole range.

---
Full diff: https://github.com/llvm/llvm-project/pull/82641.diff


2 Files Affected:

- (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp (+14-14) 
- (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir (+15) 


``````````diff
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 7eb32ebe3228fb..b706ac35c5ab15 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -382,25 +382,25 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
     return clampFloatHelper(loc, args[0], min, max, rewriter);
   }
 
-  if (isa<tosa::ClampOp>(op) && isa<IntegerType>(elementTy)) {
-    auto intTy = cast<IntegerType>(elementTy);
-    int32_t min = static_cast<int32_t>(
-        cast<IntegerAttr>(op->getAttr("min_int")).getValue().getSExtValue());
-    int32_t max = static_cast<int32_t>(
-        cast<IntegerAttr>(op->getAttr("max_int")).getValue().getSExtValue());
+  if (isa<tosa::ClampOp>(op) && elementTy.isa<IntegerType>()) {
+    auto intTy = elementTy.cast<IntegerType>();
+    int64_t min =
+        cast<IntegerAttr>(op->getAttr("min_int")).getValue().getSExtValue();
+    int64_t max =
+        cast<IntegerAttr>(op->getAttr("max_int")).getValue().getSExtValue();
 
     if (intTy.isUnsignedInteger()) {
-      min = std::max<int32_t>(min, 0);
-      max = std::min<int32_t>(
+      min = std::max(min, (int64_t)0);
+      max = std::min(
           max,
           APInt::getMaxValue(intTy.getIntOrFloatBitWidth()).getSExtValue());
     } else {
-      min = std::max<int32_t>(
-          min, APInt::getSignedMinValue(intTy.getIntOrFloatBitWidth())
-                   .getSExtValue());
-      max = std::min<int32_t>(
-          max, APInt::getSignedMaxValue(intTy.getIntOrFloatBitWidth())
-                   .getSExtValue());
+      min =
+          std::max(min, APInt::getSignedMinValue(intTy.getIntOrFloatBitWidth())
+                            .getSExtValue());
+      max =
+          std::min(max, APInt::getSignedMaxValue(intTy.getIntOrFloatBitWidth())
+                            .getSExtValue());
     }
 
     auto minVal = rewriter.create<arith::ConstantIntOp>(
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index febe74e8767465..1fa783f05f04ee 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -759,6 +759,21 @@ func.func @test_i8(%arg0: tensor<1xi8>) -> () {
 
 // -----
 
+// CHECK-LABEL: @test_i64
+func.func @test_i64(%arg0: tensor<1xi64>) -> () {
+  // CHECK: linalg.generic
+  // CHECK: ^bb0(%[[ARG1:.+]]: i64,
+  // CHECK-DAG: %[[C127:.+]] = arith.constant -9223372036854775808
+  // CHECK-DAG: %[[C126:.+]] = arith.constant 9223372036854775807
+  // CHECK-DAG: %[[LOWER:.+]] = arith.maxsi %[[C127]], %[[ARG1]]
+  // CHECK-DAG: %[[CLAMPED:.+]] = arith.minsi %[[C126]], %[[LOWER]]
+  %0 = tosa.clamp %arg0 {min_int = -9223372036854775808 : i64, max_int = 9223372036854775807 : i64, min_fp = 0.0 : f32, max_fp = 0.0 : f32} : (tensor<1xi64>) -> tensor<1xi64>
+
+  return
+}
+
+// -----
+
 // CHECK-LABEL: @test_clamp_f16
 func.func @test_clamp_f16(%arg0: tensor<1xf16>) -> () {
   // CHECK: linalg.generic

``````````

</details>


https://github.com/llvm/llvm-project/pull/82641


More information about the Mlir-commits mailing list