[Mlir-commits] [mlir] c1e9883 - [TOSA] TosaToLinalg: fix int64_t min/max lowering of clamp (#82641)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Feb 22 12:16:37 PST 2024
Author: Matthias Gehre
Date: 2024-02-22T21:16:33+01:00
New Revision: c1e9883a813db76c1b108ad715895928bb93f4c2
URL: https://github.com/llvm/llvm-project/commit/c1e9883a813db76c1b108ad715895928bb93f4c2
DIFF: https://github.com/llvm/llvm-project/commit/c1e9883a813db76c1b108ad715895928bb93f4c2.diff
LOG: [TOSA] TosaToLinalg: fix int64_t min/max lowering of clamp (#82641)
tosa.clamp takes `min`/`max` attributes as i64, so ensure that the
lowering to linalg works for the whole range.
Co-authored-by: Tiago Trevisan Jost <tiago.trevisanjost at amd.com>
Added:
Modified:
mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 7eb32ebe3228fb..7c477f2e1412be 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -384,23 +384,23 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
if (isa<tosa::ClampOp>(op) && isa<IntegerType>(elementTy)) {
auto intTy = cast<IntegerType>(elementTy);
- int32_t min = static_cast<int32_t>(
- cast<IntegerAttr>(op->getAttr("min_int")).getValue().getSExtValue());
- int32_t max = static_cast<int32_t>(
- cast<IntegerAttr>(op->getAttr("max_int")).getValue().getSExtValue());
+ int64_t min =
+ cast<IntegerAttr>(op->getAttr("min_int")).getValue().getSExtValue();
+ int64_t max =
+ cast<IntegerAttr>(op->getAttr("max_int")).getValue().getSExtValue();
if (intTy.isUnsignedInteger()) {
- min = std::max<int32_t>(min, 0);
- max = std::min<int32_t>(
+ min = std::max(min, (int64_t)0);
+ max = std::min(
max,
APInt::getMaxValue(intTy.getIntOrFloatBitWidth()).getSExtValue());
} else {
- min = std::max<int32_t>(
- min, APInt::getSignedMinValue(intTy.getIntOrFloatBitWidth())
- .getSExtValue());
- max = std::min<int32_t>(
- max, APInt::getSignedMaxValue(intTy.getIntOrFloatBitWidth())
- .getSExtValue());
+ min =
+ std::max(min, APInt::getSignedMinValue(intTy.getIntOrFloatBitWidth())
+ .getSExtValue());
+ max =
+ std::min(max, APInt::getSignedMaxValue(intTy.getIntOrFloatBitWidth())
+ .getSExtValue());
}
auto minVal = rewriter.create<arith::ConstantIntOp>(
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index febe74e8767465..1fa783f05f04ee 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -759,6 +759,21 @@ func.func @test_i8(%arg0: tensor<1xi8>) -> () {
// -----
+// CHECK-LABEL: @test_i64
+func.func @test_i64(%arg0: tensor<1xi64>) -> () {
+ // CHECK: linalg.generic
+ // CHECK: ^bb0(%[[ARG1:.+]]: i64,
+ // CHECK-DAG: %[[C127:.+]] = arith.constant -9223372036854775808
+ // CHECK-DAG: %[[C126:.+]] = arith.constant 9223372036854775807
+ // CHECK-DAG: %[[LOWER:.+]] = arith.maxsi %[[C127]], %[[ARG1]]
+ // CHECK-DAG: %[[CLAMPED:.+]] = arith.minsi %[[C126]], %[[LOWER]]
+ %0 = tosa.clamp %arg0 {min_int = -9223372036854775808 : i64, max_int = 9223372036854775807 : i64, min_fp = 0.0 : f32, max_fp = 0.0 : f32} : (tensor<1xi64>) -> tensor<1xi64>
+
+ return
+}
+
+// -----
+
// CHECK-LABEL: @test_clamp_f16
func.func @test_clamp_f16(%arg0: tensor<1xf16>) -> () {
// CHECK: linalg.generic
More information about the Mlir-commits
mailing list