[Mlir-commits] [mlir] aa1c533 - [mlir][tosa] Expand tosa.apply_scale lowering for vectors
Rob Suderman
llvmlistbot at llvm.org
Wed Jan 12 14:10:04 PST 2022
Author: Rob Suderman
Date: 2022-01-12T14:07:52-08:00
New Revision: aa1c533a4e4422b8fe3c65499f2d3a4e8c75949f
URL: https://github.com/llvm/llvm-project/commit/aa1c533a4e4422b8fe3c65499f2d3a4e8c75949f
DIFF: https://github.com/llvm/llvm-project/commit/aa1c533a4e4422b8fe3c65499f2d3a4e8c75949f.diff
LOG: [mlir][tosa] Expand tosa.apply_scale lowering for vectors
Apply scale may encounter scalar, tensor, or vector operations. Expand the
lowering so that it can lower arbitrary of container types.
Reviewed By: NatashaKnk
Differential Revision: https://reviews.llvm.org/D117080
Added:
Modified:
mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp
mlir/test/Conversion/TosaToStandard/tosa-to-standard.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp b/mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp
index 8496d227c339d..a150c2c2d8ae7 100644
--- a/mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp
+++ b/mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp
@@ -52,6 +52,23 @@ class SliceOpConverter : public OpRewritePattern<tosa::SliceOp> {
}
};
+Type matchContainerType(Type element, Type container) {
+ if (auto shapedTy = container.dyn_cast<ShapedType>())
+ return shapedTy.clone(element);
+
+ return element;
+}
+
+Attribute getConstantAttr(Type type, int64_t value, PatternRewriter &rewriter) {
+ if (auto shapedTy = type.dyn_cast<ShapedType>()) {
+ Type eTy = shapedTy.getElementType();
+ APInt valueInt(eTy.getIntOrFloatBitWidth(), value);
+ return DenseIntElementsAttr::get(shapedTy, valueInt);
+ }
+
+ return rewriter.getIntegerAttr(type, value);
+}
+
// This converts the TOSA ApplyScale operator to a set of StandardOps ops,
// using 64-bit operations to perform the necessary multiply, bias, and shift.
// Multiple types are used to use minimal bit width operations.
@@ -65,13 +82,19 @@ class ApplyScaleOpConverter : public OpRewritePattern<tosa::ApplyScaleOp> {
Value value32 = op.value();
Value multiplier32 = op.multiplier();
Value shift8 = op.shift();
+
bool doubleRound = op.double_round();
Type inType = op.value().getType();
+ Type resultTy = op.getType();
+
+ Type i8Ty = matchContainerType(rewriter.getIntegerType(8), resultTy);
+ Type i32Ty = matchContainerType(rewriter.getI32Type(), resultTy);
+ Type i64Ty = matchContainerType(rewriter.getI64Type(), resultTy);
Value one8 = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getIntegerAttr(rewriter.getIntegerType(8), 1));
+ loc, getConstantAttr(i8Ty, 1, rewriter));
Value one64 = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getIntegerAttr(rewriter.getI64Type(), 1));
+ loc, getConstantAttr(i64Ty, 1, rewriter));
Value shiftSubOne8 = rewriter.create<arith::SubIOp>(loc, shift8, one8);
@@ -85,23 +108,20 @@ class ApplyScaleOpConverter : public OpRewritePattern<tosa::ApplyScaleOp> {
// Note that minimal bitwidth operators are used throughout the block.
Value round64 = rewriter.create<arith::ShLIOp>(
- loc, one64,
- rewriter.create<arith::ExtSIOp>(loc, rewriter.getI64Type(),
- shiftSubOne8));
+ loc, one64, rewriter.create<arith::ExtSIOp>(loc, i64Ty, shiftSubOne8));
// Double rounding is performing a round operation before the shift
if (doubleRound) {
Value one32 = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getIntegerAttr(rewriter.getI32Type(), 1));
- Value shift32 =
- rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), shift8);
+ loc, getConstantAttr(i32Ty, 1, rewriter));
+ Value shift32 = rewriter.create<arith::ExtSIOp>(loc, i32Ty, shift8);
Value thirty32 = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getIntegerAttr(rewriter.getI32Type(), 30));
+ loc, getConstantAttr(i32Ty, 30, rewriter));
Value shiftThirty32 =
rewriter.create<arith::ShLIOp>(loc, one32, thirty32);
- Value shiftThirty64 = rewriter.create<arith::ExtSIOp>(
- loc, rewriter.getI64Type(), shiftThirty32);
+ Value shiftThirty64 =
+ rewriter.create<arith::ExtSIOp>(loc, i64Ty, shiftThirty32);
// Round value needs to with be added or subtracted depending on the sign
// of the input value.
@@ -120,7 +140,7 @@ class ApplyScaleOpConverter : public OpRewritePattern<tosa::ApplyScaleOp> {
// We only perform double rounding if the shift value is greater than 32.
Value thirtyTwo32 = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getIntegerAttr(rewriter.getI32Type(), 32));
+ loc, getConstantAttr(i32Ty, 32, rewriter));
Value shiftGreaterThanThirtyTwo = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::sge, shift32, thirtyTwo32);
round64 = rewriter.create<mlir::SelectOp>(loc, shiftGreaterThanThirtyTwo,
@@ -133,20 +153,17 @@ class ApplyScaleOpConverter : public OpRewritePattern<tosa::ApplyScaleOp> {
//
// Note that multiply and shift need to be perform in i64 to preserve bits.
- Value value64 =
- rewriter.create<arith::ExtSIOp>(loc, rewriter.getI64Type(), value32);
- Value multiplier64 = rewriter.create<arith::ExtSIOp>(
- loc, rewriter.getI64Type(), multiplier32);
- Value shift64 =
- rewriter.create<arith::ExtSIOp>(loc, rewriter.getI64Type(), shift8);
+ Value value64 = rewriter.create<arith::ExtSIOp>(loc, i64Ty, value32);
+ Value multiplier64 =
+ rewriter.create<arith::ExtSIOp>(loc, i64Ty, multiplier32);
+ Value shift64 = rewriter.create<arith::ExtSIOp>(loc, i64Ty, shift8);
// Multiply as a pair of i64 values to guarantee the end value fits.
Value result64 = rewriter.create<arith::MulIOp>(loc, value64, multiplier64);
result64 = rewriter.create<arith::AddIOp>(loc, result64, round64);
result64 = rewriter.create<arith::ShRSIOp>(loc, result64, shift64);
- Value result32 =
- rewriter.create<arith::TruncIOp>(loc, rewriter.getI32Type(), result64);
+ Value result32 = rewriter.create<arith::TruncIOp>(loc, resultTy, result64);
rewriter.replaceOp(op, result32);
return success();
diff --git a/mlir/test/Conversion/TosaToStandard/tosa-to-standard.mlir b/mlir/test/Conversion/TosaToStandard/tosa-to-standard.mlir
index 284ba471569b6..b346f43c37d9c 100644
--- a/mlir/test/Conversion/TosaToStandard/tosa-to-standard.mlir
+++ b/mlir/test/Conversion/TosaToStandard/tosa-to-standard.mlir
@@ -56,6 +56,43 @@ func @apply_scale_test_i32(%arg0 : i32, %arg1 : i32, %arg2 : i8) -> (i32) {
// -----
+// CHECK-LABEL: @apply_scale_test_vector
+func @apply_scale_test_vector(%arg0 : vector<4xi32>, %arg1 : vector<4xi32>, %arg2 : vector<4xi8>) -> (vector<4xi32>) {
+ // CHECK-DAG: [[C1_8:%.+]] = arith.constant dense<1> : vector<4xi8>
+ // CHECK-DAG: [[C1_32:%.+]] = arith.constant dense<1> : vector<4xi32>
+ // CHECK-DAG: [[C1_64:%.+]] = arith.constant dense<1> : vector<4xi64>
+ // CHECK-DAG: [[SHIFT_MINUS_ONE_8:%.+]] = arith.subi %arg2, [[C1_8]]
+
+ // CHECK-DAG: [[SHIFT_32:%.+]] = arith.extsi %arg2 : vector<4xi8> to vector<4xi32>
+ // CHECK-DAG: [[SHIFT_MINUS_ONE_64:%.+]] = arith.extsi [[SHIFT_MINUS_ONE_8]] : vector<4xi8> to vector<4xi64>
+ // CHECK-DAG: [[SHIFTED_64:%.+]] = arith.shli [[C1_64]], [[SHIFT_MINUS_ONE_64]]
+
+ // CHECK-DAG: [[C0_32:%.+]] = arith.constant dense<0> : vector<4xi32>
+ // CHECK-DAG: [[C30_32:%.+]] = arith.constant dense<30> : vector<4xi32>
+ // CHECK-DAG: [[SECOND_BIAS:%.+]] = arith.shli [[C1_32]], [[C30_32]]
+ // CHECK-DAG: [[SECOND_BIAS_64:%.+]] = arith.extsi [[SECOND_BIAS]] : vector<4xi32> to vector<4xi64>
+ // CHECK-DAG: [[POSITIVE_ROUND:%.+]] = arith.addi [[SHIFTED_64]], [[SECOND_BIAS_64]]
+ // CHECK-DAG: [[NEGATIVE_ROUND:%.+]] = arith.subi [[SHIFTED_64]], [[SECOND_BIAS_64]]
+ // CHECK-DAG: [[VALUE_NEGATIVE:%.+]] = arith.cmpi sge, %arg0, [[C0_32]] : vector<4xi32>
+ // CHECK-DAG: [[DOUBLE_ROUNDED:%.+]] = select [[VALUE_NEGATIVE]], [[POSITIVE_ROUND]], [[NEGATIVE_ROUND]] : vector<4xi1>, vector<4xi64>
+ // CHECK-DAG: [[C32_32:%.+]] = arith.constant dense<32> : vector<4xi32>
+ // CHECK-DAG: [[IS_32BIT_SHIFT:%.+]] = arith.cmpi sge, [[SHIFT_32]], [[C32_32]]
+ // CHECK-DAG: [[ROUND:%.+]] = select [[IS_32BIT_SHIFT]], [[DOUBLE_ROUNDED]], [[SHIFTED_64]]
+
+ // CHECK-DAG: [[VAL_64:%.+]] = arith.extsi %arg0 : vector<4xi32> to vector<4xi64>
+ // CHECK-DAG: [[MULTIPLY_64:%.+]] = arith.extsi %arg1 : vector<4xi32> to vector<4xi64>
+ // CHECK-DAG: [[SHIFT_64:%.+]] = arith.extsi %arg2 : vector<4xi8> to vector<4xi64>
+ // CHECK-DAG: [[SCALED:%.+]] = arith.muli [[VAL_64]], [[MULTIPLY_64]]
+ // CHECK-DAG: [[BIASED:%.+]] = arith.addi [[SCALED]], [[ROUND]]
+ // CHECK-DAG: [[DOWNSHIFTED:%.+]] = arith.shrsi [[BIASED]], [[SHIFT_64]]
+ // CHECK: [[TRUNCATED:%.+]] = arith.trunci [[DOWNSHIFTED]]
+
+ %0 = "tosa.apply_scale"(%arg0, %arg1, %arg2) {double_round = true} : (vector<4xi32>, vector<4xi32>, vector<4xi8>) -> vector<4xi32>
+ return %0 : vector<4xi32>
+}
+
+// -----
+
// CHECK-LABEL: @apply_scale_test_i48
func @apply_scale_test_i48(%arg0 : i48, %arg1 : i32, %arg2 : i8) -> (i32) {
// CHECK-DAG: [[C1_8:%.+]] = arith.constant 1 : i8
More information about the Mlir-commits
mailing list