[Mlir-commits] [mlir] aa1c533 - [mlir][tosa] Expand tosa.apply_scale lowering for vectors

Rob Suderman llvmlistbot at llvm.org
Wed Jan 12 14:10:04 PST 2022


Author: Rob Suderman
Date: 2022-01-12T14:07:52-08:00
New Revision: aa1c533a4e4422b8fe3c65499f2d3a4e8c75949f

URL: https://github.com/llvm/llvm-project/commit/aa1c533a4e4422b8fe3c65499f2d3a4e8c75949f
DIFF: https://github.com/llvm/llvm-project/commit/aa1c533a4e4422b8fe3c65499f2d3a4e8c75949f.diff

LOG: [mlir][tosa] Expand tosa.apply_scale lowering for vectors

Apply scale may encounter scalar, tensor, or vector operations. Expand the
lowering so that it can lower arbitrary of container types.

Reviewed By: NatashaKnk

Differential Revision: https://reviews.llvm.org/D117080

Added: 
    

Modified: 
    mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp
    mlir/test/Conversion/TosaToStandard/tosa-to-standard.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp b/mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp
index 8496d227c339d..a150c2c2d8ae7 100644
--- a/mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp
+++ b/mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp
@@ -52,6 +52,23 @@ class SliceOpConverter : public OpRewritePattern<tosa::SliceOp> {
   }
 };
 
+Type matchContainerType(Type element, Type container) {
+  if (auto shapedTy = container.dyn_cast<ShapedType>())
+    return shapedTy.clone(element);
+
+  return element;
+}
+
+Attribute getConstantAttr(Type type, int64_t value, PatternRewriter &rewriter) {
+  if (auto shapedTy = type.dyn_cast<ShapedType>()) {
+    Type eTy = shapedTy.getElementType();
+    APInt valueInt(eTy.getIntOrFloatBitWidth(), value);
+    return DenseIntElementsAttr::get(shapedTy, valueInt);
+  }
+
+  return rewriter.getIntegerAttr(type, value);
+}
+
 // This converts the TOSA ApplyScale operator to a set of StandardOps ops,
 // using 64-bit operations to perform the necessary multiply, bias, and shift.
 // Multiple types are used to use minimal bit width operations.
@@ -65,13 +82,19 @@ class ApplyScaleOpConverter : public OpRewritePattern<tosa::ApplyScaleOp> {
     Value value32 = op.value();
     Value multiplier32 = op.multiplier();
     Value shift8 = op.shift();
+
     bool doubleRound = op.double_round();
     Type inType = op.value().getType();
+    Type resultTy = op.getType();
+
+    Type i8Ty = matchContainerType(rewriter.getIntegerType(8), resultTy);
+    Type i32Ty = matchContainerType(rewriter.getI32Type(), resultTy);
+    Type i64Ty = matchContainerType(rewriter.getI64Type(), resultTy);
 
     Value one8 = rewriter.create<arith::ConstantOp>(
-        loc, rewriter.getIntegerAttr(rewriter.getIntegerType(8), 1));
+        loc, getConstantAttr(i8Ty, 1, rewriter));
     Value one64 = rewriter.create<arith::ConstantOp>(
-        loc, rewriter.getIntegerAttr(rewriter.getI64Type(), 1));
+        loc, getConstantAttr(i64Ty, 1, rewriter));
 
     Value shiftSubOne8 = rewriter.create<arith::SubIOp>(loc, shift8, one8);
 
@@ -85,23 +108,20 @@ class ApplyScaleOpConverter : public OpRewritePattern<tosa::ApplyScaleOp> {
     // Note that minimal bitwidth operators are used throughout the block.
 
     Value round64 = rewriter.create<arith::ShLIOp>(
-        loc, one64,
-        rewriter.create<arith::ExtSIOp>(loc, rewriter.getI64Type(),
-                                        shiftSubOne8));
+        loc, one64, rewriter.create<arith::ExtSIOp>(loc, i64Ty, shiftSubOne8));
 
     // Double rounding is performing a round operation before the shift
     if (doubleRound) {
       Value one32 = rewriter.create<arith::ConstantOp>(
-          loc, rewriter.getIntegerAttr(rewriter.getI32Type(), 1));
-      Value shift32 =
-          rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), shift8);
+          loc, getConstantAttr(i32Ty, 1, rewriter));
+      Value shift32 = rewriter.create<arith::ExtSIOp>(loc, i32Ty, shift8);
       Value thirty32 = rewriter.create<arith::ConstantOp>(
-          loc, rewriter.getIntegerAttr(rewriter.getI32Type(), 30));
+          loc, getConstantAttr(i32Ty, 30, rewriter));
 
       Value shiftThirty32 =
           rewriter.create<arith::ShLIOp>(loc, one32, thirty32);
-      Value shiftThirty64 = rewriter.create<arith::ExtSIOp>(
-          loc, rewriter.getI64Type(), shiftThirty32);
+      Value shiftThirty64 =
+          rewriter.create<arith::ExtSIOp>(loc, i64Ty, shiftThirty32);
 
       // Round value needs to with be added or subtracted depending on the sign
       // of the input value.
@@ -120,7 +140,7 @@ class ApplyScaleOpConverter : public OpRewritePattern<tosa::ApplyScaleOp> {
 
       // We only perform double rounding if the shift value is greater than 32.
       Value thirtyTwo32 = rewriter.create<arith::ConstantOp>(
-          loc, rewriter.getIntegerAttr(rewriter.getI32Type(), 32));
+          loc, getConstantAttr(i32Ty, 32, rewriter));
       Value shiftGreaterThanThirtyTwo = rewriter.create<arith::CmpIOp>(
           loc, arith::CmpIPredicate::sge, shift32, thirtyTwo32);
       round64 = rewriter.create<mlir::SelectOp>(loc, shiftGreaterThanThirtyTwo,
@@ -133,20 +153,17 @@ class ApplyScaleOpConverter : public OpRewritePattern<tosa::ApplyScaleOp> {
     //
     // Note that multiply and shift need to be perform in i64 to preserve bits.
 
-    Value value64 =
-        rewriter.create<arith::ExtSIOp>(loc, rewriter.getI64Type(), value32);
-    Value multiplier64 = rewriter.create<arith::ExtSIOp>(
-        loc, rewriter.getI64Type(), multiplier32);
-    Value shift64 =
-        rewriter.create<arith::ExtSIOp>(loc, rewriter.getI64Type(), shift8);
+    Value value64 = rewriter.create<arith::ExtSIOp>(loc, i64Ty, value32);
+    Value multiplier64 =
+        rewriter.create<arith::ExtSIOp>(loc, i64Ty, multiplier32);
+    Value shift64 = rewriter.create<arith::ExtSIOp>(loc, i64Ty, shift8);
 
     // Multiply as a pair of i64 values to guarantee the end value fits.
     Value result64 = rewriter.create<arith::MulIOp>(loc, value64, multiplier64);
     result64 = rewriter.create<arith::AddIOp>(loc, result64, round64);
     result64 = rewriter.create<arith::ShRSIOp>(loc, result64, shift64);
 
-    Value result32 =
-        rewriter.create<arith::TruncIOp>(loc, rewriter.getI32Type(), result64);
+    Value result32 = rewriter.create<arith::TruncIOp>(loc, resultTy, result64);
 
     rewriter.replaceOp(op, result32);
     return success();

diff  --git a/mlir/test/Conversion/TosaToStandard/tosa-to-standard.mlir b/mlir/test/Conversion/TosaToStandard/tosa-to-standard.mlir
index 284ba471569b6..b346f43c37d9c 100644
--- a/mlir/test/Conversion/TosaToStandard/tosa-to-standard.mlir
+++ b/mlir/test/Conversion/TosaToStandard/tosa-to-standard.mlir
@@ -56,6 +56,43 @@ func @apply_scale_test_i32(%arg0 : i32, %arg1 : i32, %arg2 : i8) -> (i32) {
 
 // -----
 
+// CHECK-LABEL: @apply_scale_test_vector
+func @apply_scale_test_vector(%arg0 : vector<4xi32>, %arg1 : vector<4xi32>, %arg2 : vector<4xi8>) -> (vector<4xi32>) {
+  // CHECK-DAG: [[C1_8:%.+]] = arith.constant dense<1> : vector<4xi8>
+  // CHECK-DAG: [[C1_32:%.+]] = arith.constant dense<1> : vector<4xi32>
+  // CHECK-DAG: [[C1_64:%.+]] = arith.constant dense<1> : vector<4xi64>
+  // CHECK-DAG: [[SHIFT_MINUS_ONE_8:%.+]] = arith.subi %arg2, [[C1_8]]
+
+  // CHECK-DAG: [[SHIFT_32:%.+]] = arith.extsi %arg2 : vector<4xi8> to vector<4xi32>
+  // CHECK-DAG: [[SHIFT_MINUS_ONE_64:%.+]] = arith.extsi [[SHIFT_MINUS_ONE_8]] : vector<4xi8> to vector<4xi64>
+  // CHECK-DAG: [[SHIFTED_64:%.+]] = arith.shli [[C1_64]], [[SHIFT_MINUS_ONE_64]]
+
+  // CHECK-DAG: [[C0_32:%.+]] = arith.constant dense<0> : vector<4xi32>
+  // CHECK-DAG: [[C30_32:%.+]] = arith.constant dense<30> : vector<4xi32>
+  // CHECK-DAG: [[SECOND_BIAS:%.+]] = arith.shli [[C1_32]], [[C30_32]]
+  // CHECK-DAG: [[SECOND_BIAS_64:%.+]] = arith.extsi [[SECOND_BIAS]] : vector<4xi32> to vector<4xi64>
+  // CHECK-DAG: [[POSITIVE_ROUND:%.+]] = arith.addi [[SHIFTED_64]], [[SECOND_BIAS_64]]
+  // CHECK-DAG: [[NEGATIVE_ROUND:%.+]] = arith.subi [[SHIFTED_64]], [[SECOND_BIAS_64]]
+  // CHECK-DAG: [[VALUE_NEGATIVE:%.+]] = arith.cmpi sge, %arg0, [[C0_32]] : vector<4xi32>
+  // CHECK-DAG: [[DOUBLE_ROUNDED:%.+]] = select [[VALUE_NEGATIVE]], [[POSITIVE_ROUND]], [[NEGATIVE_ROUND]] : vector<4xi1>, vector<4xi64>
+  // CHECK-DAG: [[C32_32:%.+]] = arith.constant dense<32> : vector<4xi32>
+  // CHECK-DAG: [[IS_32BIT_SHIFT:%.+]] = arith.cmpi sge, [[SHIFT_32]], [[C32_32]]
+  // CHECK-DAG: [[ROUND:%.+]] = select [[IS_32BIT_SHIFT]], [[DOUBLE_ROUNDED]], [[SHIFTED_64]]
+
+  // CHECK-DAG: [[VAL_64:%.+]] = arith.extsi %arg0 : vector<4xi32> to vector<4xi64>
+  // CHECK-DAG: [[MULTIPLY_64:%.+]] = arith.extsi %arg1 : vector<4xi32> to vector<4xi64>
+  // CHECK-DAG: [[SHIFT_64:%.+]] = arith.extsi %arg2 : vector<4xi8> to vector<4xi64>
+  // CHECK-DAG: [[SCALED:%.+]] = arith.muli [[VAL_64]], [[MULTIPLY_64]]
+  // CHECK-DAG: [[BIASED:%.+]] = arith.addi [[SCALED]], [[ROUND]]
+  // CHECK-DAG: [[DOWNSHIFTED:%.+]] = arith.shrsi [[BIASED]], [[SHIFT_64]]
+  // CHECK: [[TRUNCATED:%.+]] = arith.trunci [[DOWNSHIFTED]]
+
+  %0 = "tosa.apply_scale"(%arg0, %arg1, %arg2) {double_round = true} : (vector<4xi32>, vector<4xi32>, vector<4xi8>) -> vector<4xi32>
+  return %0 : vector<4xi32>
+}
+
+// -----
+
 // CHECK-LABEL: @apply_scale_test_i48
 func @apply_scale_test_i48(%arg0 : i48, %arg1 : i32, %arg2 : i8) -> (i32) {
   // CHECK-DAG: [[C1_8:%.+]] = arith.constant 1 : i8


        


More information about the Mlir-commits mailing list