[Mlir-commits] [mlir] e99e8ad - [mlir][arith] Add shli support to WIE
Jakub Kuderski
llvmlistbot at llvm.org
Wed Oct 5 12:21:22 PDT 2022
Author: Jakub Kuderski
Date: 2022-10-05T15:09:58-04:00
New Revision: e99e8ad24da902bd34902e6e83b6e71e255bf868
URL: https://github.com/llvm/llvm-project/commit/e99e8ad24da902bd34902e6e83b6e71e255bf868
DIFF: https://github.com/llvm/llvm-project/commit/e99e8ad24da902bd34902e6e83b6e71e255bf868.diff
LOG: [mlir][arith] Add shli support to WIE
Reviewed By: ThomasRaoux
Differential Revision: https://reviews.llvm.org/D135234
Added:
mlir/test/Integration/Dialect/Arith/CPU/test-wide-int-emulation-shli-i16.mlir
Modified:
mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp
mlir/test/Dialect/Arith/emulate-wide-int.mlir
mlir/test/Integration/Dialect/Arith/CPU/test-wide-int-emulation-compare-results-i16.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp
index ea132079924b6..c53abbc89da0e 100644
--- a/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp
@@ -485,6 +485,95 @@ struct ConvertExtUI final : OpConversionPattern<arith::ExtUIOp> {
}
};
+//===----------------------------------------------------------------------===//
+// ConvertShLI
+//===----------------------------------------------------------------------===//
+
+struct ConvertShLI final : OpConversionPattern<arith::ShLIOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(arith::ShLIOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = op->getLoc();
+
+ Type oldTy = op.getType();
+ auto newTy =
+ getTypeConverter()->convertType(oldTy).dyn_cast_or_null<VectorType>();
+ if (!newTy)
+ return rewriter.notifyMatchFailure(loc, "unsupported type");
+
+ Type newOperandTy = reduceInnermostDim(newTy);
+ // `oldBitWidth` == `2 * newBitWidth`
+ unsigned newBitWidth = newTy.getElementTypeBitWidth();
+
+ auto [lhsElem0, lhsElem1] =
+ extractLastDimHalves(rewriter, loc, adaptor.getLhs());
+ Value rhsElem0 = extractLastDimSlice(rewriter, loc, adaptor.getRhs(), 0);
+
+ // Assume that the shift amount is < 2 * newBitWidth. Calculate the low and
+ // high halves of the results separately:
+ // 1. low := LHS.low shli RHS
+ //
+ // 2. high := a or b or c, where:
+ // a) Bits from LHS.high, shifted by the RHS.
+ // b) Bits from LHS.low, shifted right. These come into play when
+ // RHS < newBitWidth, e.g.:
+ // [0000][llll] shli 3 --> [0lll][l000]
+ // ^
+ // |
+ // [llll] shrui (4 - 3)
+ // c) Bits from LHS.low, shifted left. These matter when
+ // RHS > newBitWidth, e.g.:
+ // [0000][llll] shli 7 --> [l000][0000]
+ // ^
+ // |
+ // [llll] shli (7 - 4)
+ //
+ // Because shifts by values >= newBitWidth are undefined, we ignore the high
+ // half of RHS, and introduce 'bounds checks' to account for
+ // RHS.low > newBitWidth.
+ //
+ // TODO: Explore possible optimizations.
+ Value zeroCst = createScalarOrSplatConstant(rewriter, loc, newOperandTy, 0);
+ Value elemBitWidth =
+ createScalarOrSplatConstant(rewriter, loc, newOperandTy, newBitWidth);
+
+ Value illegalElemShift = rewriter.create<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::uge, rhsElem0, elemBitWidth);
+
+ Value shiftedElem0 =
+ rewriter.create<arith::ShLIOp>(loc, lhsElem0, rhsElem0);
+ Value resElem0 = rewriter.create<arith::SelectOp>(loc, illegalElemShift,
+ zeroCst, shiftedElem0);
+
+ Value cappedShiftAmount = rewriter.create<arith::SelectOp>(
+ loc, illegalElemShift, elemBitWidth, rhsElem0);
+ Value rightShiftAmount =
+ rewriter.create<arith::SubIOp>(loc, elemBitWidth, cappedShiftAmount);
+ Value shiftedRight =
+ rewriter.create<arith::ShRUIOp>(loc, lhsElem0, rightShiftAmount);
+ Value overshotShiftAmount =
+ rewriter.create<arith::SubIOp>(loc, rhsElem0, elemBitWidth);
+ Value shiftedLeft =
+ rewriter.create<arith::ShLIOp>(loc, lhsElem0, overshotShiftAmount);
+
+ Value shiftedElem1 =
+ rewriter.create<arith::ShLIOp>(loc, lhsElem1, rhsElem0);
+ Value resElem1High = rewriter.create<arith::SelectOp>(
+ loc, illegalElemShift, zeroCst, shiftedElem1);
+ Value resElem1Low = rewriter.create<arith::SelectOp>(
+ loc, illegalElemShift, shiftedLeft, shiftedRight);
+ Value resElem1 =
+ rewriter.create<arith::OrIOp>(loc, resElem1Low, resElem1High);
+
+ Value resultVec =
+ constructResultVector(rewriter, loc, newTy, {resElem0, resElem1});
+ rewriter.replaceOp(op, resultVec);
+ return success();
+ }
+};
+
//===----------------------------------------------------------------------===//
// ConvertShRUI
//===----------------------------------------------------------------------===//
@@ -498,8 +587,13 @@ struct ConvertShRUI final : OpConversionPattern<arith::ShRUIOp> {
Location loc = op->getLoc();
Type oldTy = op.getType();
- auto newTy = getTypeConverter()->convertType(oldTy).cast<VectorType>();
+ auto newTy =
+ getTypeConverter()->convertType(oldTy).dyn_cast_or_null<VectorType>();
+ if (!newTy)
+ return rewriter.notifyMatchFailure(loc, "unsupported type");
+
Type newOperandTy = reduceInnermostDim(newTy);
+ // `oldBitWidth` == `2 * newBitWidth`
unsigned newBitWidth = newTy.getElementTypeBitWidth();
auto [lhsElem0, lhsElem1] =
@@ -727,7 +821,7 @@ void arith::populateWideIntEmulationPatterns(
// Misc ops.
ConvertConstant, ConvertVectorPrint,
// Binary ops.
- ConvertAddI, ConvertMulI, ConvertShRUI,
+ ConvertAddI, ConvertMulI, ConvertShLI, ConvertShRUI,
// Bitwise binary ops.
ConvertBitwiseBinary<arith::AndIOp>, ConvertBitwiseBinary<arith::OrIOp>,
ConvertBitwiseBinary<arith::XOrIOp>,
diff --git a/mlir/test/Dialect/Arith/emulate-wide-int.mlir b/mlir/test/Dialect/Arith/emulate-wide-int.mlir
index 59451f55d048f..eebf1d6902b92 100644
--- a/mlir/test/Dialect/Arith/emulate-wide-int.mlir
+++ b/mlir/test/Dialect/Arith/emulate-wide-int.mlir
@@ -278,6 +278,46 @@ func.func @muli_vector(%a : vector<3xi64>, %b : vector<3xi64>) -> vector<3xi64>
return %m : vector<3xi64>
}
+// CHECK-LABEL: func.func @shli_scalar
+// CHECK-SAME: ([[ARG0:%.+]]: vector<2xi32>, [[ARG1:%.+]]: vector<2xi32>) -> vector<2xi32>
+// CHECK-NEXT: [[LOW0:%.+]] = vector.extract [[ARG0]][0] : vector<2xi32>
+// CHECK-NEXT: [[HIGH0:%.+]] = vector.extract [[ARG0]][1] : vector<2xi32>
+// CHECK-NEXT: [[LOW1:%.+]] = vector.extract [[ARG1]][0] : vector<2xi32>
+// CHECK-NEXT: [[CST0:%.+]] = arith.constant 0 : i32
+// CHECK-NEXT: [[CST32:%.+]] = arith.constant 32 : i32
+// CHECK-NEXT: [[OOB:%.+]] = arith.cmpi uge, [[LOW1]], [[CST32]] : i32
+// CHECK-NEXT: [[SHLOW0:%.+]] = arith.shli [[LOW0]], [[LOW1]] : i32
+// CHECK-NEXT: [[RES0:%.+]] = arith.select [[OOB]], [[CST0]], [[SHLOW0]] : i32
+// CHECK-NEXT: [[SHAMT:%.+]] = arith.select [[OOB]], [[CST32]], [[LOW1]] : i32
+// CHECK-NEXT: [[RSHAMT:%.+]] = arith.subi [[CST32]], [[SHAMT]] : i32
+// CHECK-NEXT: [[SHRHIGH0:%.+]] = arith.shrui [[LOW0]], [[RSHAMT]] : i32
+// CHECK-NEXT: [[LSHAMT:%.+]] = arith.subi [[LOW1]], [[CST32]] : i32
+// CHECK-NEXT: [[SHLHIGH0:%.+]] = arith.shli [[LOW0]], [[LSHAMT]] : i32
+// CHECK-NEXT: [[SHLHIGH1:%.+]] = arith.shli [[HIGH0]], [[LOW1]] : i32
+// CHECK-NEXT: [[RES1HIGH:%.+]] = arith.select [[OOB]], [[CST0]], [[SHLHIGH1]] : i32
+// CHECK-NEXT: [[RES1LOW:%.+]] = arith.select [[OOB]], [[SHLHIGH0]], [[SHRHIGH0]] : i32
+// CHECK-NEXT: [[RES1:%.+]] = arith.ori [[RES1LOW]], [[RES1HIGH]] : i32
+// CHECK-NEXT: [[VZ:%.+]] = arith.constant dense<0> : vector<2xi32>
+// CHECK-NEXT: [[INS0:%.+]] = vector.insert [[RES0]], [[VZ]] [0] : i32 into vector<2xi32>
+// CHECK-NEXT: [[INS1:%.+]] = vector.insert [[RES1]], [[INS0]] [1] : i32 into vector<2xi32>
+// CHECK-NEXT: return [[INS1]] : vector<2xi32>
+func.func @shli_scalar(%a : i64, %b : i64) -> i64 {
+ %c = arith.shli %a, %b : i64
+ return %c : i64
+}
+
+// CHECK-LABEL: func.func @shli_vector
+// CHECK-SAME: ({{%.+}}: vector<3x2xi32>, {{%.+}}: vector<3x2xi32>) -> vector<3x2xi32>
+// CHECK: {{%.+}} = arith.shli {{%.+}}, {{%.+}} : vector<3x1xi32>
+// CHECK: {{%.+}} = arith.shrui {{%.+}}, {{%.+}} : vector<3x1xi32>
+// CHECK: {{%.+}} = arith.shli {{%.+}}, {{%.+}} : vector<3x1xi32>
+// CHECK: {{%.+}} = arith.shli {{%.+}}, {{%.+}} : vector<3x1xi32>
+// CHECK: return {{%.+}} : vector<3x2xi32>
+func.func @shli_vector(%a : vector<3xi64>, %b : vector<3xi64>) -> vector<3xi64> {
+ %m = arith.shli %a, %b : vector<3xi64>
+ return %m : vector<3xi64>
+}
+
// CHECK-LABEL: func.func @shrui_scalar
// CHECK-SAME: ([[ARG0:%.+]]: vector<2xi32>, [[ARG1:%.+]]: vector<2xi32>) -> vector<2xi32>
// CHECK-NEXT: [[LOW0:%.+]] = vector.extract [[ARG0]][0] : vector<2xi32>
@@ -326,6 +366,10 @@ func.func @shrui_scalar_cst_36(%a : i64) -> i64 {
// CHECK-LABEL: func.func @shrui_vector
// CHECK-SAME: ({{%.+}}: vector<3x2xi32>, {{%.+}}: vector<3x2xi32>) -> vector<3x2xi32>
+// CHECK: {{%.+}} = arith.shrui {{%.+}}, {{%.+}} : vector<3x1xi32>
+// CHECK: {{%.+}} = arith.shrui {{%.+}}, {{%.+}} : vector<3x1xi32>
+// CHECK: {{%.+}} = arith.shli {{%.+}}, {{%.+}} : vector<3x1xi32>
+// CHECK: {{%.+}} = arith.shrui {{%.+}}, {{%.+}} : vector<3x1xi32>
// CHECK: return {{%.+}} : vector<3x2xi32>
func.func @shrui_vector(%a : vector<3xi64>, %b : vector<3xi64>) -> vector<3xi64> {
%m = arith.shrui %a, %b : vector<3xi64>
diff --git a/mlir/test/Integration/Dialect/Arith/CPU/test-wide-int-emulation-compare-results-i16.mlir b/mlir/test/Integration/Dialect/Arith/CPU/test-wide-int-emulation-compare-results-i16.mlir
index 6ca279037d5f0..16e8634a05f34 100644
--- a/mlir/test/Integration/Dialect/Arith/CPU/test-wide-int-emulation-compare-results-i16.mlir
+++ b/mlir/test/Integration/Dialect/Arith/CPU/test-wide-int-emulation-compare-results-i16.mlir
@@ -156,6 +156,53 @@ func.func @test_muli() -> () {
return
}
+//===----------------------------------------------------------------------===//
+// Test arith.shli
+//===----------------------------------------------------------------------===//
+
+// Ops in this function will be emulated using i8 ops.
+func.func @emulate_shli(%lhs : i16, %rhs : i16) -> (i16) {
+ %res = arith.shli %lhs, %rhs : i16
+ return %res : i16
+}
+
+// Performs both wide and emulated `arith.shli`, and checks that the results
+// match.
+func.func @check_shli(%lhs : i16, %rhs : i16) -> () {
+ %wide = arith.shli %lhs, %rhs : i16
+ %emulated = func.call @emulate_shli(%lhs, %rhs) : (i16, i16) -> (i16)
+ func.call @check_results(%lhs, %rhs, %wide, %emulated) : (i16, i16, i16, i16) -> ()
+ return
+}
+
+// Checks that `arith.shli` is emulated properly by sampling the input space.
+// Checks all valid shift amounts for i16: 0 to 15.
+// In total, this test function checks 100 * 16 = 1.6k input pairs.
+func.func @test_shli() -> () {
+ %idx0 = arith.constant 0 : index
+ %idx1 = arith.constant 1 : index
+ %idx16 = arith.constant 16 : index
+ %idx100 = arith.constant 100 : index
+
+ %cst0 = arith.constant 0 : i16
+ %cst1 = arith.constant 1 : i16
+
+ scf.for %lhs_idx = %idx0 to %idx100 step %idx1 iter_args(%lhs = %cst0) -> (i16) {
+ %arg_lhs = func.call @xhash(%lhs) : (i16) -> (i16)
+
+ scf.for %rhs_idx = %idx0 to %idx16 step %idx1 iter_args(%rhs = %cst0) -> (i16) {
+ func.call @check_shli(%arg_lhs, %rhs) : (i16, i16) -> ()
+ %rhs_next = arith.addi %rhs, %cst1 : i16
+ scf.yield %rhs_next : i16
+ }
+
+ %lhs_next = arith.addi %lhs, %cst1 : i16
+ scf.yield %lhs_next : i16
+ }
+
+ return
+}
+
//===----------------------------------------------------------------------===//
// Test arith.shrui
//===----------------------------------------------------------------------===//
@@ -210,6 +257,7 @@ func.func @test_shrui() -> () {
func.func @entry() {
func.call @test_addi() : () -> ()
func.call @test_muli() : () -> ()
+ func.call @test_shli() : () -> ()
func.call @test_shrui() : () -> ()
return
}
diff --git a/mlir/test/Integration/Dialect/Arith/CPU/test-wide-int-emulation-shli-i16.mlir b/mlir/test/Integration/Dialect/Arith/CPU/test-wide-int-emulation-shli-i16.mlir
new file mode 100644
index 0000000000000..1e32d18740187
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Arith/CPU/test-wide-int-emulation-shli-i16.mlir
@@ -0,0 +1,73 @@
+// Check that the wide integer `arith.shli` emulation produces the same result as wide
+// `arith.shli`. Emulate i16 ops with i8 ops.
+
+// RUN: mlir-opt %s --convert-scf-to-cf --convert-cf-to-llvm --convert-vector-to-llvm \
+// RUN: --convert-func-to-llvm --convert-arith-to-llvm | \
+// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
+// RUN: --shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext | \
+// RUN: FileCheck %s --match-full-lines
+
+// RUN: mlir-opt %s --test-arith-emulate-wide-int="widest-int-supported=8" \
+// RUN: --convert-scf-to-cf --convert-cf-to-llvm --convert-vector-to-llvm \
+// RUN: --convert-func-to-llvm --convert-arith-to-llvm | \
+// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
+// RUN: --shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext | \
+// RUN: FileCheck %s --match-full-lines
+
+// Ops in this function *only* will be emulated using i8 types.
+func.func @emulate_shli(%lhs : i16, %rhs : i16) -> (i16) {
+ %res = arith.shli %lhs, %rhs : i16
+ return %res : i16
+}
+
+func.func @check_shli(%lhs : i16, %rhs : i16) -> () {
+ %res = func.call @emulate_shli(%lhs, %rhs) : (i16, i16) -> (i16)
+ vector.print %res : i16
+ return
+}
+
+func.func @entry() {
+ %cst0 = arith.constant 0 : i16
+ %cst1 = arith.constant 1 : i16
+ %cst2 = arith.constant 2 : i16
+ %cst7 = arith.constant 7 : i16
+ %cst8 = arith.constant 8 : i16
+ %cst9 = arith.constant 9 : i16
+ %cst15 = arith.constant 15 : i16
+
+ %cst_n1 = arith.constant -1 : i16
+
+ %cst1337 = arith.constant 1337 : i16
+
+ %cst_i16_min = arith.constant -32768 : i16
+
+ // CHECK: 0
+ // CHECK-NEXT: 0
+ // CHECK-NEXT: 1
+ // CHECK-NEXT: 2
+ // CHECK-NEXT: -2
+ // CHECK-NEXT: -32768
+ func.call @check_shli(%cst0, %cst0) : (i16, i16) -> ()
+ func.call @check_shli(%cst0, %cst1) : (i16, i16) -> ()
+ func.call @check_shli(%cst1, %cst0) : (i16, i16) -> ()
+ func.call @check_shli(%cst1, %cst1) : (i16, i16) -> ()
+ func.call @check_shli(%cst_n1, %cst1) : (i16, i16) -> ()
+ func.call @check_shli(%cst_n1, %cst15) : (i16, i16) -> ()
+
+ // CHECK-NEXT: 1337
+ // CHECK-NEXT: 5348
+ // CHECK-NEXT: -25472
+ // CHECK-NEXT: 14592
+ // CHECK-NEXT: 29184
+ // CHECK-NEXT: -32768
+ // CHECK-NEXT: 0
+ func.call @check_shli(%cst1337, %cst0) : (i16, i16) -> ()
+ func.call @check_shli(%cst1337, %cst2) : (i16, i16) -> ()
+ func.call @check_shli(%cst1337, %cst7) : (i16, i16) -> ()
+ func.call @check_shli(%cst1337, %cst8) : (i16, i16) -> ()
+ func.call @check_shli(%cst1337, %cst9) : (i16, i16) -> ()
+ func.call @check_shli(%cst1337, %cst15) : (i16, i16) -> ()
+ func.call @check_shli(%cst_i16_min, %cst1) : (i16, i16) -> ()
+
+ return
+}
More information about the Mlir-commits
mailing list