[Mlir-commits] [mlir] d7c44a5 - [mlir][tosa] Fix tosa.mul to use tosa.apply_scale
Rob Suderman
llvmlistbot at llvm.org
Mon Mar 22 11:05:38 PDT 2021
Author: Rob Suderman
Date: 2021-03-22T11:01:35-07:00
New Revision: d7c44a5c7870f4866f2e0e82c3297ffb7a800013
URL: https://github.com/llvm/llvm-project/commit/d7c44a5c7870f4866f2e0e82c3297ffb7a800013
DIFF: https://github.com/llvm/llvm-project/commit/d7c44a5c7870f4866f2e0e82c3297ffb7a800013.diff
LOG: [mlir][tosa] Fix tosa.mul to use tosa.apply_scale
Multiply-shift requires wider compute types or CPU specific code to avoid
premature truncation, apply_shift fixes this issue
Also, Tosa's mul op supports different input / output types. Added path that
sign-extends input values to int-32 values before multiplying.
Differential Revision: https://reviews.llvm.org/D99011
Added:
Modified:
mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 698fb5a35cd3c..d6cc45c4ee600 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -115,12 +115,39 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
}
if (isa<tosa::MulOp>(op) && elementTy.isa<IntegerType>()) {
- auto mul =
- rewriter.create<mlir::MulIOp>(loc, resultTypes, args[0], args[1]);
- auto constant =
- rewriter.create<mlir::ConstantOp>(loc, elementTy, op->getAttr("shift"));
- return rewriter.create<mlir::SignedShiftRightOp>(loc, resultTypes, mul,
- constant);
+ Value a = args[0];
+ Value b = args[1];
+ auto shift =
+ op->getAttr("shift").cast<IntegerAttr>().getValue().getSExtValue();
+ if (shift > 0) {
+ auto shiftConst =
+ rewriter.create<ConstantIntOp>(loc, shift, /*bitwidth=*/8);
+ if (!a.getType().isInteger(32))
+ a = rewriter.create<SignExtendIOp>(loc, rewriter.getI32Type(), a);
+
+ if (!b.getType().isInteger(32))
+ b = rewriter.create<SignExtendIOp>(loc, rewriter.getI32Type(), b);
+
+ auto result = rewriter.create<tosa::ApplyScaleOp>(
+ loc, rewriter.getI32Type(), a, b, shiftConst,
+ rewriter.getBoolAttr(false));
+
+ if (elementTy.isInteger(32))
+ return result;
+
+ return rewriter.create<TruncateIOp>(loc, elementTy, result);
+ }
+
+ int aWidth = a.getType().getIntOrFloatBitWidth();
+ int bWidth = b.getType().getIntOrFloatBitWidth();
+ int cWidth = resultTypes[0].getIntOrFloatBitWidth();
+
+ if (aWidth < cWidth)
+ a = rewriter.create<SignExtendIOp>(loc, resultTypes[0], a);
+ if (bWidth < cWidth)
+ b = rewriter.create<SignExtendIOp>(loc, resultTypes[0], b);
+
+ return rewriter.create<mlir::MulIOp>(loc, resultTypes, a, b);
}
// tosa::NegateOp
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index c41770b105ba0..33b82bc9e0fb3 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -214,6 +214,19 @@ func @test_simple_f16(%arg0: tensor<1xf16>) -> () {
// -----
+// CHECK-LABEL: @test_simple_i16
+func @test_simple_i16(%arg0: tensor<1xi16>) -> () {
+ // CHECK: linalg.generic
+ // CHECK: sext
+ // CHECK: sext
+ // CHECK: muli
+ %0 = "tosa.mul"(%arg0, %arg0) {shift = 0 : i32} : (tensor<1xi16>, tensor<1xi16>) -> tensor<1xi32>
+
+ return
+}
+
+// -----
+
// CHECK-LABEL: @test_simple_i32
func @test_simple_i32(%arg0: tensor<1xi32>) -> () {
// CHECK: linalg.generic
@@ -228,82 +241,87 @@ func @test_simple_i32(%arg0: tensor<1xi32>) -> () {
// CHECK: muli
%2 = "tosa.mul"(%arg0, %arg0) {shift = 0 : i32} : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
+ // CHECK: linalg.generic
+ // CHECK: constant 2
+ // CHECK: apply_scale
+ %3 = "tosa.mul"(%arg0, %arg0) {shift = 2 : i32} : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
+
// CHECK: linalg.generic
// CHECK: muli
- %3 = "tosa.negate"(%arg0) : (tensor<1xi32>) -> tensor<1xi32>
+ %4 = "tosa.negate"(%arg0) : (tensor<1xi32>) -> tensor<1xi32>
// CHECK: linalg.generic
// CHECK: and
- %4 = "tosa.bitwise_and"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
+ %5 = "tosa.bitwise_and"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
// CHECK: linalg.generic
// CHECK: or
- %5 = "tosa.bitwise_or"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
+ %6 = "tosa.bitwise_or"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
// CHECK: linalg.generic
// CHECK: xor
- %6 = "tosa.bitwise_xor"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
+ %7 = "tosa.bitwise_xor"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
// CHECK: linalg.generic
// CHECK: shift_left
- %7 = "tosa.logical_left_shift"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
+ %8 = "tosa.logical_left_shift"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
// CHECK: linalg.generic
// CHECK: shift_right_unsigned
- %8 = "tosa.logical_right_shift"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
+ %9 = "tosa.logical_right_shift"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
// CHECK: linalg.generic
// CHECK: cmpi
- %9 = "tosa.greater"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
+ %10 = "tosa.greater"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
// CHECK: linalg.generic
// CHECK: cmpi
- %10 = "tosa.greater_equal"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
+ %11 = "tosa.greater_equal"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
// CHECK: linalg.generic
// CHECK: select
- %11 = "tosa.select"(%9, %0, %1) : (tensor<1xi1>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
+ %12 = "tosa.select"(%10, %0, %1) : (tensor<1xi1>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
// CHECK: linalg.generic
// CHECK: cmpi
// CHECK: select
- %12 = "tosa.maximum"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
+ %13 = "tosa.maximum"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
// CHECK: linalg.generic
// CHECK: cmpi
// CHECK: select
- %13 = "tosa.minimum"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
+ %14 = "tosa.minimum"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
// CHECK: linalg.generic
// CHECK: cmpi
// CHECK: select
- %14 = "tosa.clamp"(%0) {min_int = 1 : i64, max_int = 5 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xi32>) -> tensor<1xi32>
+ %15 = "tosa.clamp"(%0) {min_int = 1 : i64, max_int = 5 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xi32>) -> tensor<1xi32>
// CHECK: linalg.generic
// CHECK: cmpi
// CHECK: select
- %15 = "tosa.reluN"(%0) {max_int = 5 : i64, max_fp = 5.0 : f32} : (tensor<1xi32>) -> tensor<1xi32>
+ %16 = "tosa.reluN"(%0) {max_int = 5 : i64, max_fp = 5.0 : f32} : (tensor<1xi32>) -> tensor<1xi32>
// CHECK: linalg.generic
// CHECK: trunci
- %16 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi16>
+ %17 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi16>
// CHECK: linalg.generic
// CHECK: yield
- %17 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi32>
+ %18 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi32>
// CHECK: linalg.generic
// CHECK: sexti
- %18 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi64>
+ %19 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi64>
// CHECK: linalg.generic
// CHECK: constant 0
// CHECK: cmpi
- %19 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi1>
+ %20 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi1>
// CHECK: linalg.generic
// CHECK: sitofp
- %20 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xf32>
+ %21 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xf32>
return
}
More information about the Mlir-commits
mailing list