[Mlir-commits] [mlir] d86ef43 - [mlir][tosa] Update tosa.rescale for i48 input type
Rob Suderman
llvmlistbot at llvm.org
Fri Jun 4 16:38:14 PDT 2021
Author: Rob Suderman
Date: 2021-06-04T16:36:48-07:00
New Revision: d86ef4364fb50728a2b87ec67bd2714d759f72a4
URL: https://github.com/llvm/llvm-project/commit/d86ef4364fb50728a2b87ec67bd2714d759f72a4
DIFF: https://github.com/llvm/llvm-project/commit/d86ef4364fb50728a2b87ec67bd2714d759f72a4.diff
LOG: [mlir][tosa] Update tosa.rescale for i48 input type
i48 integers require slightly tweaked behavior, specifically supporting zero
point offsetting with slightly higher bitdepth. Updated results lowering
appropriately.
Reviewed By: NatashaKnk
Differential Revision: https://reviews.llvm.org/D102659
Added:
Modified:
mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td
mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp
mlir/test/Conversion/TosaToStandard/tosa-to-standard.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td
index f1dfb78d9a727..af98bd55d0b47 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td
@@ -37,14 +37,14 @@ def Tosa_ApplyScaleOp: Tosa_Op<"apply_scale", [NoSideEffect] # ElementwiseMappab
}];
let arguments = (ins
- Tosa_Int32Like:$value,
- Tosa_Int32Like:$multiplier,
+ Tosa_Int:$value,
+ Tosa_Int:$multiplier,
Tosa_Int8Like:$shift,
BoolAttr:$double_round
);
let results = (outs
- Tosa_Int32:$output
+ Tosa_Int:$output
);
}
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 808bb8d5a5d09..89a13750f99b0 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -1343,15 +1343,20 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
getNParallelLoopsAttrs(rank),
[&](OpBuilder &nestedBuilder, Location nestedLoc,
ValueRange blockArgs) {
+ Value value = blockArgs[0];
+
// For now we do all of our math in 64-bit. This is not optimal but
// should be correct for now, consider computing correct bit depth
// later.
+ int32_t inBitwidth =
+ value.getType().getIntOrFloatBitWidth() > 32 ? 48 : 32;
+
auto inputZp = createConstFromIntAttribute<int32_t>(
- op, "input_zp", nestedBuilder.getI32Type(), nestedBuilder);
+ op, "input_zp", nestedBuilder.getIntegerType(inBitwidth),
+ nestedBuilder);
auto outputZp = createConstFromIntAttribute<int32_t>(
op, "output_zp", nestedBuilder.getI32Type(), nestedBuilder);
- Value value = blockArgs[0];
Value multiplier = multiplierConstant ? multiplierConstant
: blockArgs[multiplierArg];
Value shift = shiftConstant ? shiftConstant : blockArgs[shiftArg];
diff --git a/mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp b/mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp
index 579e35ab1b2fb..699926a948618 100644
--- a/mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp
+++ b/mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp
@@ -64,11 +64,10 @@ class ApplyScaleOpConverter : public OpRewritePattern<tosa::ApplyScaleOp> {
Value multiplier32 = op.multiplier();
Value shift8 = op.shift();
bool doubleRound = op.double_round();
+ Type inType = op.value().getType();
Value one8 = rewriter.create<ConstantOp>(
loc, rewriter.getIntegerAttr(rewriter.getIntegerType(8), 1));
- Value one32 = rewriter.create<ConstantOp>(
- loc, rewriter.getIntegerAttr(rewriter.getI32Type(), 1));
Value one64 = rewriter.create<ConstantOp>(
loc, rewriter.getIntegerAttr(rewriter.getI64Type(), 1));
@@ -83,9 +82,6 @@ class ApplyScaleOpConverter : public OpRewritePattern<tosa::ApplyScaleOp> {
//
// Note that minimal bitwidth operators are used throughout the block.
- Value shift32 = rewriter.create<mlir::SignExtendIOp>(
- loc, rewriter.getI32Type(), shift8);
-
Value round64 = rewriter.create<mlir::ShiftLeftOp>(
loc, one64,
rewriter.create<SignExtendIOp>(loc, rewriter.getI64Type(),
@@ -93,8 +89,10 @@ class ApplyScaleOpConverter : public OpRewritePattern<tosa::ApplyScaleOp> {
// Double rounding is performing a round operation before the shift
if (doubleRound) {
- Value zero32 = rewriter.create<ConstantOp>(
- loc, rewriter.getZeroAttr(rewriter.getI32Type()));
+ Value one32 = rewriter.create<ConstantOp>(
+ loc, rewriter.getIntegerAttr(rewriter.getI32Type(), 1));
+ Value shift32 = rewriter.create<mlir::SignExtendIOp>(
+ loc, rewriter.getI32Type(), shift8);
Value thirty32 = rewriter.create<ConstantOp>(
loc, rewriter.getIntegerAttr(rewriter.getI32Type(), 30));
@@ -110,6 +108,8 @@ class ApplyScaleOpConverter : public OpRewritePattern<tosa::ApplyScaleOp> {
Value roundSub64 =
rewriter.create<mlir::SubIOp>(loc, round64, shiftThirty64);
+ Value zero32 =
+ rewriter.create<ConstantOp>(loc, rewriter.getZeroAttr(inType));
Value valueGreaterThanZero = rewriter.create<mlir::CmpIOp>(
loc, CmpIPredicate::sge, value32, zero32);
diff --git a/mlir/test/Conversion/TosaToStandard/tosa-to-standard.mlir b/mlir/test/Conversion/TosaToStandard/tosa-to-standard.mlir
index 2c80c31cf297b..9ffc854b8cd74 100644
--- a/mlir/test/Conversion/TosaToStandard/tosa-to-standard.mlir
+++ b/mlir/test/Conversion/TosaToStandard/tosa-to-standard.mlir
@@ -19,36 +19,70 @@ func @slice(%arg0: tensor<6xf32>) ->() {
// -----
-func @apply_scale_test(%arg0 : i32, %arg1 : i32, %arg2 : i8) -> (i32) {
- // CHECK: [[C1_8:%.+]] = constant 1 : i8
- // CHECK: [[C1_32:%.+]] = constant 1 : i32
- // CHECK: [[C1_64:%.+]] = constant 1 : i64
- // CHECK: [[SHIFT_MINUS_ONE_8:%.+]] = subi %arg2, [[C1_8]]
-
- // CHECK: [[SHIFT_32:%.+]] = sexti %arg2 : i8 to i32
- // CHECK: [[SHIFT_MINUS_ONE_64:%.+]] = sexti [[SHIFT_MINUS_ONE_8]] : i8 to i64
- // CHECK: [[SHIFTED_64:%.+]] = shift_left [[C1_64]], [[SHIFT_MINUS_ONE_64]]
-
- // CHECK: [[C0_32:%.+]] = constant 0 : i32
- // CHECK: [[C30_32:%.+]] = constant 30 : i32
- // CHECK: [[SECOND_BIAS:%.+]] = shift_left [[C1_32]], [[C30_32]]
- // CHECK: [[SECOND_BIAS_64:%.+]] = sexti [[SECOND_BIAS]] : i32 to i64
- // CHECK: [[POSITIVE_ROUND:%.+]] = addi [[SHIFTED_64]], [[SECOND_BIAS_64]]
- // CHECK: [[NEGATIVE_ROUND:%.+]] = subi [[SHIFTED_64]], [[SECOND_BIAS_64]]
- // CHECK: [[VALUE_NEGATIVE:%.+]] = cmpi sge, %arg0, [[C0_32]] : i32
- // CHECK: [[DOUBLE_ROUNDED:%.+]] = select [[VALUE_NEGATIVE]], [[POSITIVE_ROUND]], [[NEGATIVE_ROUND]] : i64
- // CHECK: [[C32_32:%.+]] = constant 32 : i32
- // CHECK: [[IS_32BIT_SHIFT:%.+]] = cmpi sge, [[SHIFT_32]], [[C32_32]]
- // CHECK: [[ROUND:%.+]] = select [[IS_32BIT_SHIFT]], [[DOUBLE_ROUNDED]], [[SHIFTED_64]]
-
- // CHECK: [[VAL_64:%.+]] = sexti %arg0 : i32 to i64
- // CHECK: [[MULTIPLY_64:%.+]] = sexti %arg1 : i32 to i64
- // CHECK: [[SHIFT_64:%.+]] = sexti %arg2 : i8 to i64
- // CHECK: [[SCALED:%.+]] = muli [[VAL_64]], [[MULTIPLY_64]]
- // CHECK: [[BIASED:%.+]] = addi [[SCALED]], [[ROUND]]
- // CHECK: [[DOWNSHIFTED:%.+]] = shift_right_signed [[BIASED]], [[SHIFT_64]]
+// CHECK-LABEL: @apply_scale_test_i32
+func @apply_scale_test_i32(%arg0 : i32, %arg1 : i32, %arg2 : i8) -> (i32) {
+ // CHECK-DAG: [[C1_8:%.+]] = constant 1 : i8
+ // CHECK-DAG: [[C1_32:%.+]] = constant 1 : i32
+ // CHECK-DAG: [[C1_64:%.+]] = constant 1 : i64
+ // CHECK-DAG: [[SHIFT_MINUS_ONE_8:%.+]] = subi %arg2, [[C1_8]]
+
+ // CHECK-DAG: [[SHIFT_32:%.+]] = sexti %arg2 : i8 to i32
+ // CHECK-DAG: [[SHIFT_MINUS_ONE_64:%.+]] = sexti [[SHIFT_MINUS_ONE_8]] : i8 to i64
+ // CHECK-DAG: [[SHIFTED_64:%.+]] = shift_left [[C1_64]], [[SHIFT_MINUS_ONE_64]]
+
+ // CHECK-DAG: [[C0_32:%.+]] = constant 0 : i32
+ // CHECK-DAG: [[C30_32:%.+]] = constant 30 : i32
+ // CHECK-DAG: [[SECOND_BIAS:%.+]] = shift_left [[C1_32]], [[C30_32]]
+ // CHECK-DAG: [[SECOND_BIAS_64:%.+]] = sexti [[SECOND_BIAS]] : i32 to i64
+ // CHECK-DAG: [[POSITIVE_ROUND:%.+]] = addi [[SHIFTED_64]], [[SECOND_BIAS_64]]
+ // CHECK-DAG: [[NEGATIVE_ROUND:%.+]] = subi [[SHIFTED_64]], [[SECOND_BIAS_64]]
+ // CHECK-DAG: [[VALUE_NEGATIVE:%.+]] = cmpi sge, %arg0, [[C0_32]] : i32
+ // CHECK-DAG: [[DOUBLE_ROUNDED:%.+]] = select [[VALUE_NEGATIVE]], [[POSITIVE_ROUND]], [[NEGATIVE_ROUND]] : i64
+ // CHECK-DAG: [[C32_32:%.+]] = constant 32 : i32
+ // CHECK-DAG: [[IS_32BIT_SHIFT:%.+]] = cmpi sge, [[SHIFT_32]], [[C32_32]]
+ // CHECK-DAG: [[ROUND:%.+]] = select [[IS_32BIT_SHIFT]], [[DOUBLE_ROUNDED]], [[SHIFTED_64]]
+
+ // CHECK-DAG: [[VAL_64:%.+]] = sexti %arg0 : i32 to i64
+ // CHECK-DAG: [[MULTIPLY_64:%.+]] = sexti %arg1 : i32 to i64
+ // CHECK-DAG: [[SHIFT_64:%.+]] = sexti %arg2 : i8 to i64
+ // CHECK-DAG: [[SCALED:%.+]] = muli [[VAL_64]], [[MULTIPLY_64]]
+ // CHECK-DAG: [[BIASED:%.+]] = addi [[SCALED]], [[ROUND]]
+ // CHECK-DAG: [[DOWNSHIFTED:%.+]] = shift_right_signed [[BIASED]], [[SHIFT_64]]
// CHECK: [[TRUNCATED:%.+]] = trunci [[DOWNSHIFTED]]
%0 = "tosa.apply_scale"(%arg0, %arg1, %arg2) {double_round = true} : (i32, i32, i8) -> i32
return %0 : i32
}
+
+// -----
+
+// CHECK-LABEL: @apply_scale_test_i48
+func @apply_scale_test_i48(%arg0 : i48, %arg1 : i32, %arg2 : i8) -> (i32) {
+ // CHECK-DAG: [[C1_8:%.+]] = constant 1 : i8
+ // CHECK-DAG: [[C1_32:%.+]] = constant 1 : i32
+ // CHECK-DAG: [[C1_64:%.+]] = constant 1 : i64
+ // CHECK-DAG: [[C30_32:%.+]] = constant 30 : i32
+ // CHECK-DAG: [[C0_32:%.+]] = constant 0 : i48
+ // CHECK-DAG: [[C32_32:%.+]] = constant 32 : i32
+ // CHECK-DAG: [[SHIFT_MINUS_ONE_8:%.+]] = subi %arg2, [[C1_8]]
+ // CHECK-DAG: [[SHIFT_32:%.+]] = sexti %arg2 : i8 to i32
+ // CHECK-DAG: [[SHIFT_MINUS_ONE_64:%.+]] = sexti [[SHIFT_MINUS_ONE_8]] : i8 to i64
+ // CHECK-DAG: [[SHIFTED_64:%.+]] = shift_left [[C1_64]], [[SHIFT_MINUS_ONE_64]]
+ // CHECK-DAG: [[SECOND_BIAS:%.+]] = shift_left [[C1_32]], [[C30_32]]
+ // CHECK-DAG: [[SECOND_BIAS_64:%.+]] = sexti [[SECOND_BIAS]] : i32 to i64
+ // CHECK-DAG: [[POSITIVE_ROUND:%.+]] = addi [[SHIFTED_64]], [[SECOND_BIAS_64]]
+ // CHECK-DAG: [[NEGATIVE_ROUND:%.+]] = subi [[SHIFTED_64]], [[SECOND_BIAS_64]]
+ // CHECK-DAG: [[VALUE_NEGATIVE:%.+]] = cmpi sge, %arg0, [[C0_32]] : i48
+ // CHECK-DAG: [[DOUBLE_ROUNDED:%.+]] = select [[VALUE_NEGATIVE]], [[POSITIVE_ROUND]], [[NEGATIVE_ROUND]] : i64
+ // CHECK-DAG: [[IS_32BIT_SHIFT:%.+]] = cmpi sge, [[SHIFT_32]], [[C32_32]]
+ // CHECK-DAG: [[ROUND:%.+]] = select [[IS_32BIT_SHIFT]], [[DOUBLE_ROUNDED]], [[SHIFTED_64]]
+ // CHECK-DAG: [[VAL_64:%.+]] = sexti %arg0 : i48 to i64
+ // CHECK-DAG: [[MULTIPLY_64:%.+]] = sexti %arg1 : i32 to i64
+ // CHECK-DAG: [[SHIFT_64:%.+]] = sexti %arg2 : i8 to i64
+ // CHECK-DAG: [[SCALED:%.+]] = muli [[VAL_64]], [[MULTIPLY_64]]
+ // CHECK-DAG: [[BIASED:%.+]] = addi [[SCALED]], [[ROUND]]
+ // CHECK-DAG: [[DOWNSHIFTED:%.+]] = shift_right_signed [[BIASED]], [[SHIFT_64]]
+ // CHECK: [[TRUNCATED:%.+]] = trunci [[DOWNSHIFTED]]
+ %0 = "tosa.apply_scale"(%arg0, %arg1, %arg2) {double_round = true} : (i48, i32, i8) -> i32
+ return %0 : i32
+}
More information about the Mlir-commits
mailing list