[Mlir-commits] [mlir] d7c44a5 - [mlir][tosa] Fix tosa.mul to use tosa.apply_scale

Rob Suderman llvmlistbot at llvm.org
Mon Mar 22 11:05:38 PDT 2021


Author: Rob Suderman
Date: 2021-03-22T11:01:35-07:00
New Revision: d7c44a5c7870f4866f2e0e82c3297ffb7a800013

URL: https://github.com/llvm/llvm-project/commit/d7c44a5c7870f4866f2e0e82c3297ffb7a800013
DIFF: https://github.com/llvm/llvm-project/commit/d7c44a5c7870f4866f2e0e82c3297ffb7a800013.diff

LOG: [mlir][tosa] Fix tosa.mul to use tosa.apply_scale

Multiply-shift requires wider compute types or CPU specific code to avoid
premature truncation, apply_shift fixes this issue

Also, Tosa's mul op supports different input / output types. Added path that
sign-extends input values to int-32 values before multiplying.

Differential Revision: https://reviews.llvm.org/D99011

Added: 
    

Modified: 
    mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
    mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 698fb5a35cd3c..d6cc45c4ee600 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -115,12 +115,39 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
   }
 
   if (isa<tosa::MulOp>(op) && elementTy.isa<IntegerType>()) {
-    auto mul =
-        rewriter.create<mlir::MulIOp>(loc, resultTypes, args[0], args[1]);
-    auto constant =
-        rewriter.create<mlir::ConstantOp>(loc, elementTy, op->getAttr("shift"));
-    return rewriter.create<mlir::SignedShiftRightOp>(loc, resultTypes, mul,
-                                                     constant);
+    Value a = args[0];
+    Value b = args[1];
+    auto shift =
+        op->getAttr("shift").cast<IntegerAttr>().getValue().getSExtValue();
+    if (shift > 0) {
+      auto shiftConst =
+          rewriter.create<ConstantIntOp>(loc, shift, /*bitwidth=*/8);
+      if (!a.getType().isInteger(32))
+        a = rewriter.create<SignExtendIOp>(loc, rewriter.getI32Type(), a);
+
+      if (!b.getType().isInteger(32))
+        b = rewriter.create<SignExtendIOp>(loc, rewriter.getI32Type(), b);
+
+      auto result = rewriter.create<tosa::ApplyScaleOp>(
+          loc, rewriter.getI32Type(), a, b, shiftConst,
+          rewriter.getBoolAttr(false));
+
+      if (elementTy.isInteger(32))
+        return result;
+
+      return rewriter.create<TruncateIOp>(loc, elementTy, result);
+    }
+
+    int aWidth = a.getType().getIntOrFloatBitWidth();
+    int bWidth = b.getType().getIntOrFloatBitWidth();
+    int cWidth = resultTypes[0].getIntOrFloatBitWidth();
+
+    if (aWidth < cWidth)
+      a = rewriter.create<SignExtendIOp>(loc, resultTypes[0], a);
+    if (bWidth < cWidth)
+      b = rewriter.create<SignExtendIOp>(loc, resultTypes[0], b);
+
+    return rewriter.create<mlir::MulIOp>(loc, resultTypes, a, b);
   }
 
   // tosa::NegateOp

diff  --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index c41770b105ba0..33b82bc9e0fb3 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -214,6 +214,19 @@ func @test_simple_f16(%arg0: tensor<1xf16>) -> () {
 
 // -----
 
+// CHECK-LABEL: @test_simple_i16
+func @test_simple_i16(%arg0: tensor<1xi16>) -> () {
+  // CHECK: linalg.generic
+  // CHECK: sext
+  // CHECK: sext
+  // CHECK: muli
+  %0 = "tosa.mul"(%arg0, %arg0) {shift = 0 : i32} : (tensor<1xi16>, tensor<1xi16>) -> tensor<1xi32>
+
+  return
+}
+
+// -----
+
 // CHECK-LABEL: @test_simple_i32
 func @test_simple_i32(%arg0: tensor<1xi32>) -> () {
   // CHECK: linalg.generic
@@ -228,82 +241,87 @@ func @test_simple_i32(%arg0: tensor<1xi32>) -> () {
   // CHECK: muli
   %2 = "tosa.mul"(%arg0, %arg0) {shift = 0 : i32} : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
 
+  // CHECK: linalg.generic
+  // CHECK: constant 2
+  // CHECK: apply_scale
+  %3 = "tosa.mul"(%arg0, %arg0) {shift = 2 : i32} : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
+
   // CHECK: linalg.generic
   // CHECK: muli
-  %3 = "tosa.negate"(%arg0) : (tensor<1xi32>) -> tensor<1xi32>
+  %4 = "tosa.negate"(%arg0) : (tensor<1xi32>) -> tensor<1xi32>
 
   // CHECK: linalg.generic
   // CHECK: and
-  %4 = "tosa.bitwise_and"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
+  %5 = "tosa.bitwise_and"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
 
   // CHECK: linalg.generic
   // CHECK: or
-  %5 = "tosa.bitwise_or"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
+  %6 = "tosa.bitwise_or"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
 
   // CHECK: linalg.generic
   // CHECK: xor
-  %6 = "tosa.bitwise_xor"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
+  %7 = "tosa.bitwise_xor"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
 
   // CHECK: linalg.generic
   // CHECK: shift_left
-  %7 = "tosa.logical_left_shift"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
+  %8 = "tosa.logical_left_shift"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
 
   // CHECK: linalg.generic
   // CHECK: shift_right_unsigned
-  %8 = "tosa.logical_right_shift"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
+  %9 = "tosa.logical_right_shift"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
 
   // CHECK: linalg.generic
   // CHECK: cmpi
-  %9 = "tosa.greater"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
+  %10 = "tosa.greater"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
 
   // CHECK: linalg.generic
   // CHECK: cmpi
-  %10 = "tosa.greater_equal"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
+  %11 = "tosa.greater_equal"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
 
   // CHECK: linalg.generic
   // CHECK: select
-  %11 = "tosa.select"(%9, %0, %1) : (tensor<1xi1>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
+  %12 = "tosa.select"(%10, %0, %1) : (tensor<1xi1>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
 
   // CHECK: linalg.generic
   // CHECK: cmpi
   // CHECK: select
-  %12 = "tosa.maximum"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
+  %13 = "tosa.maximum"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
 
   // CHECK: linalg.generic
   // CHECK: cmpi
   // CHECK: select
-  %13 = "tosa.minimum"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
+  %14 = "tosa.minimum"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
 
   // CHECK: linalg.generic
   // CHECK: cmpi
   // CHECK: select
-  %14 = "tosa.clamp"(%0) {min_int = 1 : i64, max_int = 5 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xi32>) -> tensor<1xi32>
+  %15 = "tosa.clamp"(%0) {min_int = 1 : i64, max_int = 5 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xi32>) -> tensor<1xi32>
 
   // CHECK: linalg.generic
   // CHECK: cmpi
   // CHECK: select
-  %15 = "tosa.reluN"(%0) {max_int = 5 : i64, max_fp = 5.0 : f32} : (tensor<1xi32>) -> tensor<1xi32>
+  %16 = "tosa.reluN"(%0) {max_int = 5 : i64, max_fp = 5.0 : f32} : (tensor<1xi32>) -> tensor<1xi32>
 
   // CHECK: linalg.generic
   // CHECK: trunci
-  %16 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi16>
+  %17 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi16>
 
   // CHECK: linalg.generic
   // CHECK: yield
-  %17 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi32>
+  %18 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi32>
 
   // CHECK: linalg.generic
   // CHECK: sexti
-  %18 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi64>
+  %19 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi64>
 
   // CHECK: linalg.generic
   // CHECK: constant 0
   // CHECK: cmpi
-  %19 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi1>
+  %20 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi1>
 
   // CHECK: linalg.generic
   // CHECK: sitofp
-  %20 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xf32>
+  %21 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xf32>
 
   return
 }


        


More information about the Mlir-commits mailing list