[Mlir-commits] [mlir] 76c9712 - [mlir][tosa] Fix clamp to restrict only within valid bitwidth range

Rob Suderman llvmlistbot at llvm.org
Wed Aug 18 12:20:18 PDT 2021


Author: Robert Suderman
Date: 2021-08-18T12:14:01-07:00
New Revision: 76c9712196906a1c5c1598e196b6abed139f090e

URL: https://github.com/llvm/llvm-project/commit/76c9712196906a1c5c1598e196b6abed139f090e
DIFF: https://github.com/llvm/llvm-project/commit/76c9712196906a1c5c1598e196b6abed139f090e.diff

LOG: [mlir][tosa] Fix clamp to restrict only within valid bitwidth range

Its possible for the clamp to have invalid min/max values on its range. To fix
this we validate the range of the min/max and clamp to a valid range.

Reviewed By: NatashaKnk

Differential Revision: https://reviews.llvm.org/D108256

Added: 
    

Modified: 
    mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
    mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index ca64e5de39f0e..f600c27d992f0 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -428,12 +428,32 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
   }
 
   if (isa<tosa::ClampOp>(op) && elementTy.isa<IntegerType>()) {
-    auto min = createConstFromIntAttribute<int32_t>(op, "min_int", elementTy,
-                                                    rewriter);
-    auto max = createConstFromIntAttribute<int32_t>(op, "max_int", elementTy,
-                                                    rewriter);
-    return clampHelper<mlir::CmpIOp>(loc, args[0], min, max, CmpIPredicate::slt,
-                                     rewriter);
+    auto intTy = elementTy.cast<IntegerType>();
+    int32_t min = static_cast<int32_t>(
+        op->getAttr("min_int").cast<IntegerAttr>().getValue().getSExtValue());
+    int32_t max = static_cast<int32_t>(
+        op->getAttr("max_int").cast<IntegerAttr>().getValue().getSExtValue());
+
+    if (intTy.isUnsignedInteger()) {
+      min = std::max<int32_t>(min, 0);
+      max = std::min<int32_t>(
+          max,
+          APInt::getMaxValue(intTy.getIntOrFloatBitWidth()).getSExtValue());
+    } else {
+      min = std::max<int32_t>(
+          min, APInt::getSignedMinValue(intTy.getIntOrFloatBitWidth())
+                   .getSExtValue());
+      max = std::min<int32_t>(
+          max, APInt::getSignedMaxValue(intTy.getIntOrFloatBitWidth())
+                   .getSExtValue());
+    }
+
+    auto minVal =
+        rewriter.create<ConstantIntOp>(loc, min, intTy.getIntOrFloatBitWidth());
+    auto maxVal =
+        rewriter.create<ConstantIntOp>(loc, max, intTy.getIntOrFloatBitWidth());
+    return clampHelper<mlir::CmpIOp>(loc, args[0], minVal, maxVal,
+                                     CmpIPredicate::slt, rewriter);
   }
 
   // tosa::ReluNOp

diff  --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 88906b763d390..99c33d9b8eba7 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -404,6 +404,31 @@ func @test_simple_i32(%arg0: tensor<1xi32>) -> () {
 
 // -----
 
+// CHECK-LABEL: @test_i8
+func @test_i8(%arg0: tensor<1xi8>) -> () {
+  // CHECK: linalg.generic
+  // CHECK-DAG: %[[C127:.+]] = constant -127
+  // CHECK-DAG: %[[C126:.+]] = constant 126
+  // CHECK-DAG: %[[CMP1:.+]] = cmpi slt, %arg1, %[[C127]]
+  // CHECK-DAG: %[[SEL1:.+]] = select %[[CMP1]], %[[C127]]
+  // CHECK-DAG: %[[CMP2:.+]] = cmpi slt, %[[C126]], %arg1
+  // CHECK: %[[SEL2:.+]] = select %[[CMP2]], %[[C126]], %[[SEL1]]
+  %0 = "tosa.clamp"(%arg0) {min_int = -127 : i64, max_int = 126 : i64, min_fp = 0.0 : f32, max_fp = 0.0 : f32} : (tensor<1xi8>) -> tensor<1xi8>
+
+  // CHECK: linalg.generic
+  // CHECK-DAG: %[[C128:.+]] = constant -128
+  // CHECK-DAG: %[[C127:.+]] = constant 127
+  // CHECK-DAG: %[[CMP1:.+]] = cmpi slt, %arg1, %[[C128]]
+  // CHECK-DAG: %[[SEL1:.+]] = select %[[CMP1]], %[[C128]]
+  // CHECK-DAG: %[[CMP2:.+]] = cmpi slt, %[[C127]], %arg1
+  // CHECK: %[[SEL2:.+]] = select %[[CMP2]], %[[C127]], %[[SEL1]]
+  %1 = "tosa.clamp"(%arg0) {min_int = -130 : i64, max_int = 130 : i64, min_fp = 0.0 : f32, max_fp = 0.0 : f32} : (tensor<1xi8>) -> tensor<1xi8>
+
+  return
+}
+
+// -----
+
 // CHECK-LABEL: @test_bool
 func @test_bool(%arg0: tensor<1xi1>, %arg1: tensor<1xi1>) -> () {
   // CHECK: linalg.generic


        


More information about the Mlir-commits mailing list