[Mlir-commits] [mlir] 6e8e91a - [MLIR][TOSA] Fix converting tosa.clamp and tosa.relu to linalg

Krzysztof Drewniak llvmlistbot at llvm.org
Mon Jul 11 10:18:53 PDT 2022


Author: jungpark-mlir
Date: 2022-07-11T17:18:47Z
New Revision: 6e8e91a7b63c51f487ddfbe2d6b2372ea1310faf

URL: https://github.com/llvm/llvm-project/commit/6e8e91a7b63c51f487ddfbe2d6b2372ea1310faf
DIFF: https://github.com/llvm/llvm-project/commit/6e8e91a7b63c51f487ddfbe2d6b2372ea1310faf.diff

LOG: [MLIR][TOSA] Fix converting tosa.clamp and tosa.relu to linalg

Tosa to Linalg conversion crashes when input tensor is a float type other than fp32.
Because tosa.clamp and tosa.reluN have fp32 min/max attribute which is converted as arith.constant with the attribute type.
This commit fixes the crash by correctly setting the float constant type from the input tensor.

Reviewed By: eric-k256

Differential Revision: https://reviews.llvm.org/D128630

Added: 
    

Modified: 
    mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
    mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 5711ffe003f47..d14d5d7d136ac 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -369,10 +369,17 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
 
   // tosa::ClampOp
   if (isa<tosa::ClampOp>(op) && elementTy.isa<FloatType>()) {
-    auto min = rewriter.create<arith::ConstantOp>(loc, elementTy,
-                                                  op->getAttr("min_fp"));
-    auto max = rewriter.create<arith::ConstantOp>(loc, elementTy,
-                                                  op->getAttr("max_fp"));
+    bool losesInfo = false;
+    APFloat min_apf = op->getAttr("min_fp").cast<FloatAttr>().getValue();
+    APFloat max_apf = op->getAttr("max_fp").cast<FloatAttr>().getValue();
+    min_apf.convert(elementTy.cast<FloatType>().getFloatSemantics(),
+                    APFloat::rmNearestTiesToEven, &losesInfo);
+    max_apf.convert(elementTy.cast<FloatType>().getFloatSemantics(),
+                    APFloat::rmNearestTiesToEven, &losesInfo);
+    auto min = rewriter.create<arith::ConstantOp>(
+        loc, elementTy, rewriter.getFloatAttr(elementTy, min_apf));
+    auto max = rewriter.create<arith::ConstantOp>(
+        loc, elementTy, rewriter.getFloatAttr(elementTy, max_apf));
     return clampHelper<arith::CmpFOp>(loc, args[0], min, max,
                                       arith::CmpFPredicate::OLT, rewriter);
   }
@@ -410,8 +417,12 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
   if (isa<tosa::ReluNOp>(op) && elementTy.isa<FloatType>()) {
     auto zero =
         rewriter.create<arith::ConstantOp>(loc, FloatAttr::get(elementTy, 0));
-    auto n = rewriter.create<arith::ConstantOp>(loc, elementTy,
-                                                op->getAttr("max_fp"));
+    bool losesInfo = false;
+    APFloat max_apf = op->getAttr("max_fp").cast<FloatAttr>().getValue();
+    max_apf.convert(elementTy.cast<FloatType>().getFloatSemantics(),
+                    APFloat::rmNearestTiesToEven, &losesInfo);
+    auto n = rewriter.create<arith::ConstantOp>(
+        loc, elementTy, rewriter.getFloatAttr(elementTy, max_apf));
     return clampHelper<arith::CmpFOp>(loc, args[0], zero, n,
                                       arith::CmpFPredicate::OLT, rewriter);
   }

diff  --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 357a77c5ab780..193c49c4ee87b 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -473,6 +473,22 @@ func.func @test_i8(%arg0: tensor<1xi8>) -> () {
 
 // -----
 
+// CHECK-LABEL: @test_clamp_f16
+func.func @test_clamp_f16(%arg0: tensor<1xf16>) -> () {
+  // CHECK: linalg.generic
+  // CHECK-DAG: %[[C0:.+]] = arith.constant 0.0
+  // CHECK-DAG: %[[C6:.+]] = arith.constant 6.0
+  // CHECK-DAG: %[[CMP1:.+]] = arith.cmpf olt, %arg1, %[[C0]]
+  // CHECK-DAG: %[[SEL1:.+]] = arith.select %[[CMP1]], %[[C0]]
+  // CHECK-DAG: %[[CMP2:.+]] = arith.cmpf olt, %[[C6]], %arg1
+  // CHECK: %[[SEL2:.+]] = arith.select %[[CMP2]], %[[C6]], %[[SEL1]]
+  %0 = "tosa.clamp"(%arg0) {min_int = 0 : i64, max_int = 0 : i64, min_fp = 0.0 : f32, max_fp = 6.0 : f32} : (tensor<1xf16>) -> tensor<1xf16>
+
+  return
+}
+
+// -----
+
 // CHECK-LABEL: @test_bool
 func.func @test_bool(%arg0: tensor<1xi1>, %arg1: tensor<1xi1>) -> () {
   // CHECK: linalg.generic


        


More information about the Mlir-commits mailing list