[Mlir-commits] [mlir] 76c9712 - [mlir][tosa] Fix clamp to restrict only within valid bitwidth range
Rob Suderman
llvmlistbot at llvm.org
Wed Aug 18 12:20:18 PDT 2021
Author: Robert Suderman
Date: 2021-08-18T12:14:01-07:00
New Revision: 76c9712196906a1c5c1598e196b6abed139f090e
URL: https://github.com/llvm/llvm-project/commit/76c9712196906a1c5c1598e196b6abed139f090e
DIFF: https://github.com/llvm/llvm-project/commit/76c9712196906a1c5c1598e196b6abed139f090e.diff
LOG: [mlir][tosa] Fix clamp to restrict only within valid bitwidth range
Its possible for the clamp to have invalid min/max values on its range. To fix
this we validate the range of the min/max and clamp to a valid range.
Reviewed By: NatashaKnk
Differential Revision: https://reviews.llvm.org/D108256
Added:
Modified:
mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index ca64e5de39f0e..f600c27d992f0 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -428,12 +428,32 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
}
if (isa<tosa::ClampOp>(op) && elementTy.isa<IntegerType>()) {
- auto min = createConstFromIntAttribute<int32_t>(op, "min_int", elementTy,
- rewriter);
- auto max = createConstFromIntAttribute<int32_t>(op, "max_int", elementTy,
- rewriter);
- return clampHelper<mlir::CmpIOp>(loc, args[0], min, max, CmpIPredicate::slt,
- rewriter);
+ auto intTy = elementTy.cast<IntegerType>();
+ int32_t min = static_cast<int32_t>(
+ op->getAttr("min_int").cast<IntegerAttr>().getValue().getSExtValue());
+ int32_t max = static_cast<int32_t>(
+ op->getAttr("max_int").cast<IntegerAttr>().getValue().getSExtValue());
+
+ if (intTy.isUnsignedInteger()) {
+ min = std::max<int32_t>(min, 0);
+ max = std::min<int32_t>(
+ max,
+ APInt::getMaxValue(intTy.getIntOrFloatBitWidth()).getSExtValue());
+ } else {
+ min = std::max<int32_t>(
+ min, APInt::getSignedMinValue(intTy.getIntOrFloatBitWidth())
+ .getSExtValue());
+ max = std::min<int32_t>(
+ max, APInt::getSignedMaxValue(intTy.getIntOrFloatBitWidth())
+ .getSExtValue());
+ }
+
+ auto minVal =
+ rewriter.create<ConstantIntOp>(loc, min, intTy.getIntOrFloatBitWidth());
+ auto maxVal =
+ rewriter.create<ConstantIntOp>(loc, max, intTy.getIntOrFloatBitWidth());
+ return clampHelper<mlir::CmpIOp>(loc, args[0], minVal, maxVal,
+ CmpIPredicate::slt, rewriter);
}
// tosa::ReluNOp
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 88906b763d390..99c33d9b8eba7 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -404,6 +404,31 @@ func @test_simple_i32(%arg0: tensor<1xi32>) -> () {
// -----
+// CHECK-LABEL: @test_i8
+func @test_i8(%arg0: tensor<1xi8>) -> () {
+ // CHECK: linalg.generic
+ // CHECK-DAG: %[[C127:.+]] = constant -127
+ // CHECK-DAG: %[[C126:.+]] = constant 126
+ // CHECK-DAG: %[[CMP1:.+]] = cmpi slt, %arg1, %[[C127]]
+ // CHECK-DAG: %[[SEL1:.+]] = select %[[CMP1]], %[[C127]]
+ // CHECK-DAG: %[[CMP2:.+]] = cmpi slt, %[[C126]], %arg1
+ // CHECK: %[[SEL2:.+]] = select %[[CMP2]], %[[C126]], %[[SEL1]]
+ %0 = "tosa.clamp"(%arg0) {min_int = -127 : i64, max_int = 126 : i64, min_fp = 0.0 : f32, max_fp = 0.0 : f32} : (tensor<1xi8>) -> tensor<1xi8>
+
+ // CHECK: linalg.generic
+ // CHECK-DAG: %[[C128:.+]] = constant -128
+ // CHECK-DAG: %[[C127:.+]] = constant 127
+ // CHECK-DAG: %[[CMP1:.+]] = cmpi slt, %arg1, %[[C128]]
+ // CHECK-DAG: %[[SEL1:.+]] = select %[[CMP1]], %[[C128]]
+ // CHECK-DAG: %[[CMP2:.+]] = cmpi slt, %[[C127]], %arg1
+ // CHECK: %[[SEL2:.+]] = select %[[CMP2]], %[[C127]], %[[SEL1]]
+ %1 = "tosa.clamp"(%arg0) {min_int = -130 : i64, max_int = 130 : i64, min_fp = 0.0 : f32, max_fp = 0.0 : f32} : (tensor<1xi8>) -> tensor<1xi8>
+
+ return
+}
+
+// -----
+
// CHECK-LABEL: @test_bool
func @test_bool(%arg0: tensor<1xi1>, %arg1: tensor<1xi1>) -> () {
// CHECK: linalg.generic
More information about the Mlir-commits
mailing list