[Mlir-commits] [mlir] c1e9883 - [TOSA] TosaToLinalg: fix int64_t min/max lowering of clamp (#82641)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Feb 22 12:16:37 PST 2024


Author: Matthias Gehre
Date: 2024-02-22T21:16:33+01:00
New Revision: c1e9883a813db76c1b108ad715895928bb93f4c2

URL: https://github.com/llvm/llvm-project/commit/c1e9883a813db76c1b108ad715895928bb93f4c2
DIFF: https://github.com/llvm/llvm-project/commit/c1e9883a813db76c1b108ad715895928bb93f4c2.diff

LOG: [TOSA] TosaToLinalg: fix int64_t min/max lowering of clamp (#82641)

tosa.clamp takes `min`/`max` attributes as i64, so ensure that the
lowering to linalg works for the whole range.

Co-authored-by: Tiago Trevisan Jost <tiago.trevisanjost at amd.com>

Added: 
    

Modified: 
    mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
    mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 7eb32ebe3228fb..7c477f2e1412be 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -384,23 +384,23 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
 
   if (isa<tosa::ClampOp>(op) && isa<IntegerType>(elementTy)) {
     auto intTy = cast<IntegerType>(elementTy);
-    int32_t min = static_cast<int32_t>(
-        cast<IntegerAttr>(op->getAttr("min_int")).getValue().getSExtValue());
-    int32_t max = static_cast<int32_t>(
-        cast<IntegerAttr>(op->getAttr("max_int")).getValue().getSExtValue());
+    int64_t min =
+        cast<IntegerAttr>(op->getAttr("min_int")).getValue().getSExtValue();
+    int64_t max =
+        cast<IntegerAttr>(op->getAttr("max_int")).getValue().getSExtValue();
 
     if (intTy.isUnsignedInteger()) {
-      min = std::max<int32_t>(min, 0);
-      max = std::min<int32_t>(
+      min = std::max(min, (int64_t)0);
+      max = std::min(
           max,
           APInt::getMaxValue(intTy.getIntOrFloatBitWidth()).getSExtValue());
     } else {
-      min = std::max<int32_t>(
-          min, APInt::getSignedMinValue(intTy.getIntOrFloatBitWidth())
-                   .getSExtValue());
-      max = std::min<int32_t>(
-          max, APInt::getSignedMaxValue(intTy.getIntOrFloatBitWidth())
-                   .getSExtValue());
+      min =
+          std::max(min, APInt::getSignedMinValue(intTy.getIntOrFloatBitWidth())
+                            .getSExtValue());
+      max =
+          std::min(max, APInt::getSignedMaxValue(intTy.getIntOrFloatBitWidth())
+                            .getSExtValue());
     }
 
     auto minVal = rewriter.create<arith::ConstantIntOp>(

diff  --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index febe74e8767465..1fa783f05f04ee 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -759,6 +759,21 @@ func.func @test_i8(%arg0: tensor<1xi8>) -> () {
 
 // -----
 
+// CHECK-LABEL: @test_i64
+func.func @test_i64(%arg0: tensor<1xi64>) -> () {
+  // CHECK: linalg.generic
+  // CHECK: ^bb0(%[[ARG1:.+]]: i64,
+  // CHECK-DAG: %[[C127:.+]] = arith.constant -9223372036854775808
+  // CHECK-DAG: %[[C126:.+]] = arith.constant 9223372036854775807
+  // CHECK-DAG: %[[LOWER:.+]] = arith.maxsi %[[C127]], %[[ARG1]]
+  // CHECK-DAG: %[[CLAMPED:.+]] = arith.minsi %[[C126]], %[[LOWER]]
+  %0 = tosa.clamp %arg0 {min_int = -9223372036854775808 : i64, max_int = 9223372036854775807 : i64, min_fp = 0.0 : f32, max_fp = 0.0 : f32} : (tensor<1xi64>) -> tensor<1xi64>
+
+  return
+}
+
+// -----
+
 // CHECK-LABEL: @test_clamp_f16
 func.func @test_clamp_f16(%arg0: tensor<1xf16>) -> () {
   // CHECK: linalg.generic


        


More information about the Mlir-commits mailing list