[Mlir-commits] [mlir] [TOSA] TosaToLinalg: fix int64_t min/max lowering of clamp (PR #82641)
Matthias Gehre
llvmlistbot at llvm.org
Thu Feb 22 07:57:05 PST 2024
https://github.com/mgehre-amd created https://github.com/llvm/llvm-project/pull/82641
tosa.clamp takes `min`/`max` attributes as i64, so ensure that the lowering to linalg works for the whole range.
>From fb85db3b30aebfdad1351efbc6785cf20e9b5dcd Mon Sep 17 00:00:00 2001
From: Tiago Trevisan Jost <tiago.trevisanjost at amd.com>
Date: Mon, 12 Jun 2023 10:09:49 +0000
Subject: [PATCH] [TOSA] TosaToLinalg: fix int64_t min/max lowering of clamp
---
.../Conversion/TosaToLinalg/TosaToLinalg.cpp | 28 +++++++++----------
.../TosaToLinalg/tosa-to-linalg.mlir | 15 ++++++++++
2 files changed, 29 insertions(+), 14 deletions(-)
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 7eb32ebe3228fb..b706ac35c5ab15 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -382,25 +382,25 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
return clampFloatHelper(loc, args[0], min, max, rewriter);
}
- if (isa<tosa::ClampOp>(op) && isa<IntegerType>(elementTy)) {
- auto intTy = cast<IntegerType>(elementTy);
- int32_t min = static_cast<int32_t>(
- cast<IntegerAttr>(op->getAttr("min_int")).getValue().getSExtValue());
- int32_t max = static_cast<int32_t>(
- cast<IntegerAttr>(op->getAttr("max_int")).getValue().getSExtValue());
+ if (isa<tosa::ClampOp>(op) && elementTy.isa<IntegerType>()) {
+ auto intTy = elementTy.cast<IntegerType>();
+ int64_t min =
+ cast<IntegerAttr>(op->getAttr("min_int")).getValue().getSExtValue();
+ int64_t max =
+ cast<IntegerAttr>(op->getAttr("max_int")).getValue().getSExtValue();
if (intTy.isUnsignedInteger()) {
- min = std::max<int32_t>(min, 0);
- max = std::min<int32_t>(
+ min = std::max(min, (int64_t)0);
+ max = std::min(
max,
APInt::getMaxValue(intTy.getIntOrFloatBitWidth()).getSExtValue());
} else {
- min = std::max<int32_t>(
- min, APInt::getSignedMinValue(intTy.getIntOrFloatBitWidth())
- .getSExtValue());
- max = std::min<int32_t>(
- max, APInt::getSignedMaxValue(intTy.getIntOrFloatBitWidth())
- .getSExtValue());
+ min =
+ std::max(min, APInt::getSignedMinValue(intTy.getIntOrFloatBitWidth())
+ .getSExtValue());
+ max =
+ std::min(max, APInt::getSignedMaxValue(intTy.getIntOrFloatBitWidth())
+ .getSExtValue());
}
auto minVal = rewriter.create<arith::ConstantIntOp>(
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index febe74e8767465..1fa783f05f04ee 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -759,6 +759,21 @@ func.func @test_i8(%arg0: tensor<1xi8>) -> () {
// -----
+// CHECK-LABEL: @test_i64
+func.func @test_i64(%arg0: tensor<1xi64>) -> () {
+ // CHECK: linalg.generic
+ // CHECK: ^bb0(%[[ARG1:.+]]: i64,
+ // CHECK-DAG: %[[C127:.+]] = arith.constant -9223372036854775808
+ // CHECK-DAG: %[[C126:.+]] = arith.constant 9223372036854775807
+ // CHECK-DAG: %[[LOWER:.+]] = arith.maxsi %[[C127]], %[[ARG1]]
+ // CHECK-DAG: %[[CLAMPED:.+]] = arith.minsi %[[C126]], %[[LOWER]]
+ %0 = tosa.clamp %arg0 {min_int = -9223372036854775808 : i64, max_int = 9223372036854775807 : i64, min_fp = 0.0 : f32, max_fp = 0.0 : f32} : (tensor<1xi64>) -> tensor<1xi64>
+
+ return
+}
+
+// -----
+
// CHECK-LABEL: @test_clamp_f16
func.func @test_clamp_f16(%arg0: tensor<1xf16>) -> () {
// CHECK: linalg.generic
More information about the Mlir-commits
mailing list