[Mlir-commits] [mlir] d86ef43 - [mlir][tosa] Update tosa.rescale for i48 input type

Rob Suderman llvmlistbot at llvm.org
Fri Jun 4 16:38:14 PDT 2021


Author: Rob Suderman
Date: 2021-06-04T16:36:48-07:00
New Revision: d86ef4364fb50728a2b87ec67bd2714d759f72a4

URL: https://github.com/llvm/llvm-project/commit/d86ef4364fb50728a2b87ec67bd2714d759f72a4
DIFF: https://github.com/llvm/llvm-project/commit/d86ef4364fb50728a2b87ec67bd2714d759f72a4.diff

LOG: [mlir][tosa] Update tosa.rescale for i48 input type

i48 integers require slightly tweaked behavior, specifically supporting zero
point offsetting with slightly higher bitdepth. Updated results lowering
appropriately.

Reviewed By: NatashaKnk

Differential Revision: https://reviews.llvm.org/D102659

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td
    mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
    mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp
    mlir/test/Conversion/TosaToStandard/tosa-to-standard.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td
index f1dfb78d9a727..af98bd55d0b47 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td
@@ -37,14 +37,14 @@ def Tosa_ApplyScaleOp: Tosa_Op<"apply_scale", [NoSideEffect] # ElementwiseMappab
   }];
 
   let arguments = (ins
-    Tosa_Int32Like:$value,
-    Tosa_Int32Like:$multiplier,
+    Tosa_Int:$value,
+    Tosa_Int:$multiplier,
     Tosa_Int8Like:$shift,
     BoolAttr:$double_round
   );
 
   let results = (outs
-    Tosa_Int32:$output
+    Tosa_Int:$output
   );
 }
 

diff  --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 808bb8d5a5d09..89a13750f99b0 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -1343,15 +1343,20 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
         getNParallelLoopsAttrs(rank),
         [&](OpBuilder &nestedBuilder, Location nestedLoc,
             ValueRange blockArgs) {
+          Value value = blockArgs[0];
+
           // For now we do all of our math in 64-bit. This is not optimal but
           // should be correct for now, consider computing correct bit depth
           // later.
+          int32_t inBitwidth =
+              value.getType().getIntOrFloatBitWidth() > 32 ? 48 : 32;
+
           auto inputZp = createConstFromIntAttribute<int32_t>(
-              op, "input_zp", nestedBuilder.getI32Type(), nestedBuilder);
+              op, "input_zp", nestedBuilder.getIntegerType(inBitwidth),
+              nestedBuilder);
           auto outputZp = createConstFromIntAttribute<int32_t>(
               op, "output_zp", nestedBuilder.getI32Type(), nestedBuilder);
 
-          Value value = blockArgs[0];
           Value multiplier = multiplierConstant ? multiplierConstant
                                                 : blockArgs[multiplierArg];
           Value shift = shiftConstant ? shiftConstant : blockArgs[shiftArg];

diff  --git a/mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp b/mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp
index 579e35ab1b2fb..699926a948618 100644
--- a/mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp
+++ b/mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp
@@ -64,11 +64,10 @@ class ApplyScaleOpConverter : public OpRewritePattern<tosa::ApplyScaleOp> {
     Value multiplier32 = op.multiplier();
     Value shift8 = op.shift();
     bool doubleRound = op.double_round();
+    Type inType = op.value().getType();
 
     Value one8 = rewriter.create<ConstantOp>(
         loc, rewriter.getIntegerAttr(rewriter.getIntegerType(8), 1));
-    Value one32 = rewriter.create<ConstantOp>(
-        loc, rewriter.getIntegerAttr(rewriter.getI32Type(), 1));
     Value one64 = rewriter.create<ConstantOp>(
         loc, rewriter.getIntegerAttr(rewriter.getI64Type(), 1));
 
@@ -83,9 +82,6 @@ class ApplyScaleOpConverter : public OpRewritePattern<tosa::ApplyScaleOp> {
     //
     // Note that minimal bitwidth operators are used throughout the block.
 
-    Value shift32 = rewriter.create<mlir::SignExtendIOp>(
-        loc, rewriter.getI32Type(), shift8);
-
     Value round64 = rewriter.create<mlir::ShiftLeftOp>(
         loc, one64,
         rewriter.create<SignExtendIOp>(loc, rewriter.getI64Type(),
@@ -93,8 +89,10 @@ class ApplyScaleOpConverter : public OpRewritePattern<tosa::ApplyScaleOp> {
 
     // Double rounding is performing a round operation before the shift
     if (doubleRound) {
-      Value zero32 = rewriter.create<ConstantOp>(
-          loc, rewriter.getZeroAttr(rewriter.getI32Type()));
+      Value one32 = rewriter.create<ConstantOp>(
+          loc, rewriter.getIntegerAttr(rewriter.getI32Type(), 1));
+      Value shift32 = rewriter.create<mlir::SignExtendIOp>(
+          loc, rewriter.getI32Type(), shift8);
       Value thirty32 = rewriter.create<ConstantOp>(
           loc, rewriter.getIntegerAttr(rewriter.getI32Type(), 30));
 
@@ -110,6 +108,8 @@ class ApplyScaleOpConverter : public OpRewritePattern<tosa::ApplyScaleOp> {
       Value roundSub64 =
           rewriter.create<mlir::SubIOp>(loc, round64, shiftThirty64);
 
+      Value zero32 =
+          rewriter.create<ConstantOp>(loc, rewriter.getZeroAttr(inType));
       Value valueGreaterThanZero = rewriter.create<mlir::CmpIOp>(
           loc, CmpIPredicate::sge, value32, zero32);
 

diff  --git a/mlir/test/Conversion/TosaToStandard/tosa-to-standard.mlir b/mlir/test/Conversion/TosaToStandard/tosa-to-standard.mlir
index 2c80c31cf297b..9ffc854b8cd74 100644
--- a/mlir/test/Conversion/TosaToStandard/tosa-to-standard.mlir
+++ b/mlir/test/Conversion/TosaToStandard/tosa-to-standard.mlir
@@ -19,36 +19,70 @@ func @slice(%arg0: tensor<6xf32>) ->() {
 
 // -----
 
-func @apply_scale_test(%arg0 : i32, %arg1 : i32, %arg2 : i8) -> (i32) {
-  // CHECK: [[C1_8:%.+]] = constant 1 : i8
-  // CHECK: [[C1_32:%.+]] = constant 1 : i32
-  // CHECK: [[C1_64:%.+]] = constant 1 : i64
-  // CHECK: [[SHIFT_MINUS_ONE_8:%.+]] = subi %arg2, [[C1_8]]
-
-  // CHECK: [[SHIFT_32:%.+]] = sexti %arg2 : i8 to i32
-  // CHECK: [[SHIFT_MINUS_ONE_64:%.+]] = sexti [[SHIFT_MINUS_ONE_8]] : i8 to i64
-  // CHECK: [[SHIFTED_64:%.+]] = shift_left [[C1_64]], [[SHIFT_MINUS_ONE_64]]
-
-  // CHECK: [[C0_32:%.+]] = constant 0 : i32
-  // CHECK: [[C30_32:%.+]] = constant 30 : i32
-  // CHECK: [[SECOND_BIAS:%.+]] = shift_left [[C1_32]], [[C30_32]]
-  // CHECK: [[SECOND_BIAS_64:%.+]] = sexti [[SECOND_BIAS]] : i32 to i64
-  // CHECK: [[POSITIVE_ROUND:%.+]] = addi [[SHIFTED_64]], [[SECOND_BIAS_64]]
-  // CHECK: [[NEGATIVE_ROUND:%.+]] = subi [[SHIFTED_64]], [[SECOND_BIAS_64]]
-  // CHECK: [[VALUE_NEGATIVE:%.+]] = cmpi sge, %arg0, [[C0_32]] : i32
-  // CHECK: [[DOUBLE_ROUNDED:%.+]] = select [[VALUE_NEGATIVE]], [[POSITIVE_ROUND]], [[NEGATIVE_ROUND]] : i64
-  // CHECK: [[C32_32:%.+]] = constant 32 : i32
-  // CHECK: [[IS_32BIT_SHIFT:%.+]] = cmpi sge, [[SHIFT_32]], [[C32_32]]
-  // CHECK: [[ROUND:%.+]] = select [[IS_32BIT_SHIFT]], [[DOUBLE_ROUNDED]], [[SHIFTED_64]]
-
-  // CHECK: [[VAL_64:%.+]] = sexti %arg0 : i32 to i64
-  // CHECK: [[MULTIPLY_64:%.+]] = sexti %arg1 : i32 to i64
-  // CHECK: [[SHIFT_64:%.+]] = sexti %arg2 : i8 to i64
-  // CHECK: [[SCALED:%.+]] = muli [[VAL_64]], [[MULTIPLY_64]]
-  // CHECK: [[BIASED:%.+]] = addi [[SCALED]], [[ROUND]]
-  // CHECK: [[DOWNSHIFTED:%.+]] = shift_right_signed [[BIASED]], [[SHIFT_64]]
+// CHECK-LABEL: @apply_scale_test_i32
+func @apply_scale_test_i32(%arg0 : i32, %arg1 : i32, %arg2 : i8) -> (i32) {
+  // CHECK-DAG: [[C1_8:%.+]] = constant 1 : i8
+  // CHECK-DAG: [[C1_32:%.+]] = constant 1 : i32
+  // CHECK-DAG: [[C1_64:%.+]] = constant 1 : i64
+  // CHECK-DAG: [[SHIFT_MINUS_ONE_8:%.+]] = subi %arg2, [[C1_8]]
+
+  // CHECK-DAG: [[SHIFT_32:%.+]] = sexti %arg2 : i8 to i32
+  // CHECK-DAG: [[SHIFT_MINUS_ONE_64:%.+]] = sexti [[SHIFT_MINUS_ONE_8]] : i8 to i64
+  // CHECK-DAG: [[SHIFTED_64:%.+]] = shift_left [[C1_64]], [[SHIFT_MINUS_ONE_64]]
+
+  // CHECK-DAG: [[C0_32:%.+]] = constant 0 : i32
+  // CHECK-DAG: [[C30_32:%.+]] = constant 30 : i32
+  // CHECK-DAG: [[SECOND_BIAS:%.+]] = shift_left [[C1_32]], [[C30_32]]
+  // CHECK-DAG: [[SECOND_BIAS_64:%.+]] = sexti [[SECOND_BIAS]] : i32 to i64
+  // CHECK-DAG: [[POSITIVE_ROUND:%.+]] = addi [[SHIFTED_64]], [[SECOND_BIAS_64]]
+  // CHECK-DAG: [[NEGATIVE_ROUND:%.+]] = subi [[SHIFTED_64]], [[SECOND_BIAS_64]]
+  // CHECK-DAG: [[VALUE_NEGATIVE:%.+]] = cmpi sge, %arg0, [[C0_32]] : i32
+  // CHECK-DAG: [[DOUBLE_ROUNDED:%.+]] = select [[VALUE_NEGATIVE]], [[POSITIVE_ROUND]], [[NEGATIVE_ROUND]] : i64
+  // CHECK-DAG: [[C32_32:%.+]] = constant 32 : i32
+  // CHECK-DAG: [[IS_32BIT_SHIFT:%.+]] = cmpi sge, [[SHIFT_32]], [[C32_32]]
+  // CHECK-DAG: [[ROUND:%.+]] = select [[IS_32BIT_SHIFT]], [[DOUBLE_ROUNDED]], [[SHIFTED_64]]
+
+  // CHECK-DAG: [[VAL_64:%.+]] = sexti %arg0 : i32 to i64
+  // CHECK-DAG: [[MULTIPLY_64:%.+]] = sexti %arg1 : i32 to i64
+  // CHECK-DAG: [[SHIFT_64:%.+]] = sexti %arg2 : i8 to i64
+  // CHECK-DAG: [[SCALED:%.+]] = muli [[VAL_64]], [[MULTIPLY_64]]
+  // CHECK-DAG: [[BIASED:%.+]] = addi [[SCALED]], [[ROUND]]
+  // CHECK-DAG: [[DOWNSHIFTED:%.+]] = shift_right_signed [[BIASED]], [[SHIFT_64]]
   // CHECK: [[TRUNCATED:%.+]] = trunci [[DOWNSHIFTED]]
 
   %0 = "tosa.apply_scale"(%arg0, %arg1, %arg2) {double_round = true} : (i32, i32, i8) -> i32
   return %0 : i32
 }
+
+// -----
+
+// CHECK-LABEL: @apply_scale_test_i48
+func @apply_scale_test_i48(%arg0 : i48, %arg1 : i32, %arg2 : i8) -> (i32) {
+  // CHECK-DAG: [[C1_8:%.+]] = constant 1 : i8
+  // CHECK-DAG: [[C1_32:%.+]] = constant 1 : i32
+  // CHECK-DAG: [[C1_64:%.+]] = constant 1 : i64
+  // CHECK-DAG: [[C30_32:%.+]] = constant 30 : i32
+  // CHECK-DAG: [[C0_32:%.+]] = constant 0 : i48
+  // CHECK-DAG: [[C32_32:%.+]] = constant 32 : i32
+  // CHECK-DAG: [[SHIFT_MINUS_ONE_8:%.+]] = subi %arg2, [[C1_8]]
+  // CHECK-DAG: [[SHIFT_32:%.+]] = sexti %arg2 : i8 to i32
+  // CHECK-DAG: [[SHIFT_MINUS_ONE_64:%.+]] = sexti [[SHIFT_MINUS_ONE_8]] : i8 to i64
+  // CHECK-DAG: [[SHIFTED_64:%.+]] = shift_left [[C1_64]], [[SHIFT_MINUS_ONE_64]]
+  // CHECK-DAG: [[SECOND_BIAS:%.+]] = shift_left [[C1_32]], [[C30_32]]
+  // CHECK-DAG: [[SECOND_BIAS_64:%.+]] = sexti [[SECOND_BIAS]] : i32 to i64
+  // CHECK-DAG: [[POSITIVE_ROUND:%.+]] = addi [[SHIFTED_64]], [[SECOND_BIAS_64]]
+  // CHECK-DAG: [[NEGATIVE_ROUND:%.+]] = subi [[SHIFTED_64]], [[SECOND_BIAS_64]]
+  // CHECK-DAG: [[VALUE_NEGATIVE:%.+]] = cmpi sge, %arg0, [[C0_32]] : i48
+  // CHECK-DAG: [[DOUBLE_ROUNDED:%.+]] = select [[VALUE_NEGATIVE]], [[POSITIVE_ROUND]], [[NEGATIVE_ROUND]] : i64
+  // CHECK-DAG: [[IS_32BIT_SHIFT:%.+]] = cmpi sge, [[SHIFT_32]], [[C32_32]]
+  // CHECK-DAG: [[ROUND:%.+]] = select [[IS_32BIT_SHIFT]], [[DOUBLE_ROUNDED]], [[SHIFTED_64]]
+  // CHECK-DAG: [[VAL_64:%.+]] = sexti %arg0 : i48 to i64
+  // CHECK-DAG: [[MULTIPLY_64:%.+]] = sexti %arg1 : i32 to i64
+  // CHECK-DAG: [[SHIFT_64:%.+]] = sexti %arg2 : i8 to i64
+  // CHECK-DAG: [[SCALED:%.+]] = muli [[VAL_64]], [[MULTIPLY_64]]
+  // CHECK-DAG: [[BIASED:%.+]] = addi [[SCALED]], [[ROUND]]
+  // CHECK-DAG: [[DOWNSHIFTED:%.+]] = shift_right_signed [[BIASED]], [[SHIFT_64]]
+  // CHECK: [[TRUNCATED:%.+]] = trunci [[DOWNSHIFTED]]
+  %0 = "tosa.apply_scale"(%arg0, %arg1, %arg2) {double_round = true} : (i48, i32, i8) -> i32
+  return %0 : i32
+}


        


More information about the Mlir-commits mailing list