[Mlir-commits] [mlir] 6433333 - [mlir][arith] Add `arith.shrsi` support to WIE
Jakub Kuderski
llvmlistbot at llvm.org
Mon Nov 14 17:52:51 PST 2022
Author: Jakub Kuderski
Date: 2022-11-14T20:52:10-05:00
New Revision: 64333332db6f1640ea43339d4d72da7379b41b41
URL: https://github.com/llvm/llvm-project/commit/64333332db6f1640ea43339d4d72da7379b41b41
DIFF: https://github.com/llvm/llvm-project/commit/64333332db6f1640ea43339d4d72da7379b41b41.diff
LOG: [mlir][arith] Add `arith.shrsi` support to WIE
This includes LIT tests over the generated ops and runtime tests.
Reviewed By: antiagainst
Differential Revision: https://reviews.llvm.org/D137965
Added:
mlir/test/Integration/Dialect/Arith/CPU/test-wide-int-emulation-shrsi-i16.mlir
Modified:
mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp
mlir/test/Dialect/Arith/emulate-wide-int.mlir
mlir/test/Integration/Dialect/Arith/CPU/test-wide-int-emulation-compare-results-i16.mlir
mlir/test/Integration/Dialect/Arith/CPU/test-wide-int-emulation-shrui-i16.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp
index 6a725927eefc5..4a5e6d09fdc2e 100644
--- a/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp
@@ -781,6 +781,69 @@ struct ConvertShRUI final : OpConversionPattern<arith::ShRUIOp> {
}
};
+//===----------------------------------------------------------------------===//
+// ConvertShRSI
+//===----------------------------------------------------------------------===//
+
+struct ConvertShRSI final : OpConversionPattern<arith::ShRSIOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(arith::ShRSIOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = op->getLoc();
+
+ Type oldTy = op.getType();
+ auto newTy =
+ getTypeConverter()->convertType(oldTy).dyn_cast_or_null<VectorType>();
+ if (!newTy)
+ return rewriter.notifyMatchFailure(
+ loc, llvm::formatv("unsupported type: {0}", op.getType()));
+
+ Value lhsElem1 = extractLastDimSlice(rewriter, loc, adaptor.getLhs(), 1);
+ Value rhsElem0 = extractLastDimSlice(rewriter, loc, adaptor.getRhs(), 0);
+
+ Type narrowTy = rhsElem0.getType();
+ int64_t origBitwidth = newTy.getElementTypeBitWidth() * 2;
+
+ // Rewrite this as an bitwise or of `arith.shrui` and sign extension bits.
+ // Perform as many ops over the narrow integer type as possible and let the
+ // other emulation patterns convert the rest.
+ Value elemZero =
+ createScalarOrSplatConstant(rewriter, loc, narrowTy, 0);
+ Value signBit = rewriter.create<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::slt, lhsElem1, elemZero);
+ signBit = dropTrailingX1Dim(rewriter, loc, signBit);
+
+ // Create a bit pattern of either all ones or all zeros. Then shift it left
+ // to calculate the sign extension bits created by shifting the original
+ // sign bit right.
+ Value allSign = rewriter.create<arith::ExtSIOp>(loc, oldTy, signBit);
+ Value maxShift =
+ createScalarOrSplatConstant(rewriter, loc, narrowTy, origBitwidth);
+ Value numNonSignExtBits =
+ rewriter.create<arith::SubIOp>(loc, maxShift, rhsElem0);
+ numNonSignExtBits = dropTrailingX1Dim(rewriter, loc, numNonSignExtBits);
+ numNonSignExtBits =
+ rewriter.create<arith::ExtUIOp>(loc, oldTy, numNonSignExtBits);
+ Value signBits =
+ rewriter.create<arith::ShLIOp>(loc, allSign, numNonSignExtBits);
+
+ // Use original arguments to create the right shift.
+ Value shrui = rewriter.create<arith::ShRUIOp>(loc, op.getLhs(), op.getRhs());
+ Value shrsi = rewriter.create<arith::OrIOp>(loc, shrui, signBits);
+
+ // Handle shifting by zero. This is necessary when the `signBits` shift is
+ // invalid.
+ Value isNoop = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
+ rhsElem0, elemZero);
+ isNoop = dropTrailingX1Dim(rewriter, loc, isNoop);
+ rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isNoop, op.getLhs(), shrsi);
+
+ return success();
+ }
+};
+
//===----------------------------------------------------------------------===//
// ConvertTruncI
//===----------------------------------------------------------------------===//
@@ -799,7 +862,8 @@ struct ConvertTruncI final : OpConversionPattern<arith::TruncIOp> {
loc, llvm::formatv("unsupported truncation result type: {0}",
op.getType()));
- // Discard the high half of the input. Truncate the low half, if necessary.
+ // Discard the high half of the input. Truncate the low half, if
+ // necessary.
Value extracted = extractLastDimSlice(rewriter, loc, adaptor.getIn(), 0);
extracted = dropTrailingX1Dim(rewriter, loc, extracted);
Value truncated =
@@ -940,7 +1004,7 @@ void arith::populateArithWideIntEmulationPatterns(
// Misc ops.
ConvertConstant, ConvertCmpI, ConvertSelect, ConvertVectorPrint,
// Binary ops.
- ConvertAddI, ConvertMulI, ConvertShLI, ConvertShRUI,
+ ConvertAddI, ConvertMulI, ConvertShLI, ConvertShRSI, ConvertShRUI,
// Bitwise binary ops.
ConvertBitwiseBinary<arith::AndIOp>, ConvertBitwiseBinary<arith::OrIOp>,
ConvertBitwiseBinary<arith::XOrIOp>,
diff --git a/mlir/test/Dialect/Arith/emulate-wide-int.mlir b/mlir/test/Dialect/Arith/emulate-wide-int.mlir
index 3356542b30404..83cf04a268b17 100644
--- a/mlir/test/Dialect/Arith/emulate-wide-int.mlir
+++ b/mlir/test/Dialect/Arith/emulate-wide-int.mlir
@@ -587,6 +587,45 @@ func.func @shrui_vector(%a : vector<3xi64>, %b : vector<3xi64>) -> vector<3xi64>
return %m : vector<3xi64>
}
+// CHECK-LABEL: func.func @shrsi_scalar
+// CHECK-SAME: ([[ARG0:%.+]]: vector<2xi32>, [[ARG1:%.+]]: vector<2xi32>) -> vector<2xi32>
+// CHECK-NEXT: [[HIGH0:%.+]] = vector.extract [[ARG0]][1] : vector<2xi32>
+// CHECK-NEXT: [[LOW1:%.+]] = vector.extract [[ARG1]][0] : vector<2xi32>
+// CHECK-NEXT: [[CST0:%.+]] = arith.constant 0 : i32
+// CHECK-NEXT: [[NEG:%.+]] = arith.cmpi slt, [[HIGH0]], [[CST0]] : i32
+// CHECK-NEXT: [[NEGEXT:%.+]] = arith.extsi [[NEG]] : i1 to i32
+// CHECK: [[CST64:%.+]] = arith.constant 64 : i32
+// CHECK-NEXT: [[SIGNBITS:%.+]] = arith.subi [[CST64]], [[LOW1]] : i32
+// CHECK: arith.shli
+// CHECK: arith.shrui
+// CHECK: arith.shli
+// CHECK: arith.shli
+// CHECK: arith.shrui
+// CHECK: arith.shrui
+// CHECK: arith.shli
+// CHECK: arith.shrui
+// CHECK: return {{%.+}} : vector<2xi32>
+func.func @shrsi_scalar(%a : i64, %b : i64) -> i64 {
+ %c = arith.shrsi %a, %b : i64
+ return %c : i64
+}
+
+// CHECK-LABEL: func.func @shrsi_vector
+// CHECK-SAME: ({{%.+}}: vector<3x2xi32>, {{%.+}}: vector<3x2xi32>) -> vector<3x2xi32>
+// CHECK: arith.shli {{%.+}}, {{%.+}} : vector<3x1xi32>
+// CHECK: arith.shrui {{%.+}}, {{%.+}} : vector<3x1xi32>
+// CHECK: arith.shli {{%.+}}, {{%.+}} : vector<3x1xi32>
+// CHECK: arith.shli {{%.+}}, {{%.+}} : vector<3x1xi32>
+// CHECK: arith.shrui {{%.+}}, {{%.+}} : vector<3x1xi32>
+// CHECK: arith.shrui {{%.+}}, {{%.+}} : vector<3x1xi32>
+// CHECK: arith.shli {{%.+}}, {{%.+}} : vector<3x1xi32>
+// CHECK: arith.shrui {{%.+}}, {{%.+}} : vector<3x1xi32>
+// CHECK: return {{%.+}} : vector<3x2xi32>
+func.func @shrsi_vector(%a : vector<3xi64>, %b : vector<3xi64>) -> vector<3xi64> {
+ %m = arith.shrsi %a, %b : vector<3xi64>
+ return %m : vector<3xi64>
+}
+
// CHECK-LABEL: func @andi_scalar_a_b
// CHECK-SAME: ([[ARG0:%.+]]: vector<2xi32>, [[ARG1:%.+]]: vector<2xi32>) -> vector<2xi32>
// CHECK-NEXT: [[LOW0:%.+]] = vector.extract [[ARG0]][0] : vector<2xi32>
diff --git a/mlir/test/Integration/Dialect/Arith/CPU/test-wide-int-emulation-compare-results-i16.mlir b/mlir/test/Integration/Dialect/Arith/CPU/test-wide-int-emulation-compare-results-i16.mlir
index 16e8634a05f34..ee8037c1167d9 100644
--- a/mlir/test/Integration/Dialect/Arith/CPU/test-wide-int-emulation-compare-results-i16.mlir
+++ b/mlir/test/Integration/Dialect/Arith/CPU/test-wide-int-emulation-compare-results-i16.mlir
@@ -203,6 +203,53 @@ func.func @test_shli() -> () {
return
}
+//===----------------------------------------------------------------------===//
+// Test arith.shrsi
+//===----------------------------------------------------------------------===//
+
+// Ops in this function will be emulated using i8 ops.
+func.func @emulate_shrsi(%lhs : i16, %rhs : i16) -> (i16) {
+ %res = arith.shrsi %lhs, %rhs : i16
+ return %res : i16
+}
+
+// Performs both wide and emulated `arith.shrsi`, and checks that the results
+// match.
+func.func @check_shrsi(%lhs : i16, %rhs : i16) -> () {
+ %wide = arith.shrsi %lhs, %rhs : i16
+ %emulated = func.call @emulate_shrsi(%lhs, %rhs) : (i16, i16) -> (i16)
+ func.call @check_results(%lhs, %rhs, %wide, %emulated) : (i16, i16, i16, i16) -> ()
+ return
+}
+
+// Checks that `arith.shrus` is emulated properly by sampling the input space.
+// Checks all valid shift amounts for i16: 0 to 15.
+// In total, this test function checks 100 * 16 = 1.6k input pairs.
+func.func @test_shrsi() -> () {
+ %idx0 = arith.constant 0 : index
+ %idx1 = arith.constant 1 : index
+ %idx16 = arith.constant 16 : index
+ %idx100 = arith.constant 100 : index
+
+ %cst0 = arith.constant 0 : i16
+ %cst1 = arith.constant 1 : i16
+
+ scf.for %lhs_idx = %idx0 to %idx100 step %idx1 iter_args(%lhs = %cst0) -> (i16) {
+ %arg_lhs = func.call @xhash(%lhs) : (i16) -> (i16)
+
+ scf.for %rhs_idx = %idx0 to %idx16 step %idx1 iter_args(%rhs = %cst0) -> (i16) {
+ func.call @check_shrsi(%arg_lhs, %rhs) : (i16, i16) -> ()
+ %rhs_next = arith.addi %rhs, %cst1 : i16
+ scf.yield %rhs_next : i16
+ }
+
+ %lhs_next = arith.addi %lhs, %cst1 : i16
+ scf.yield %lhs_next : i16
+ }
+
+ return
+}
+
//===----------------------------------------------------------------------===//
// Test arith.shrui
//===----------------------------------------------------------------------===//
@@ -258,6 +305,7 @@ func.func @entry() {
func.call @test_addi() : () -> ()
func.call @test_muli() : () -> ()
func.call @test_shli() : () -> ()
+ func.call @test_shrsi() : () -> ()
func.call @test_shrui() : () -> ()
return
}
diff --git a/mlir/test/Integration/Dialect/Arith/CPU/test-wide-int-emulation-shrsi-i16.mlir b/mlir/test/Integration/Dialect/Arith/CPU/test-wide-int-emulation-shrsi-i16.mlir
new file mode 100644
index 0000000000000..3be3792ff6777
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Arith/CPU/test-wide-int-emulation-shrsi-i16.mlir
@@ -0,0 +1,100 @@
+// Check that the wide integer `arith.shrsi` emulation produces the same result as wide
+// `arith.shrsi`. Emulate i16 ops with i8 ops.
+
+// RUN: mlir-opt %s --convert-scf-to-cf --convert-cf-to-llvm --convert-vector-to-llvm \
+// RUN: --convert-func-to-llvm --convert-arith-to-llvm | \
+// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
+// RUN: --shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext | \
+// RUN: FileCheck %s --match-full-lines
+
+// RUN: mlir-opt %s --test-arith-emulate-wide-int="widest-int-supported=8" \
+// RUN: --convert-scf-to-cf --convert-cf-to-llvm --convert-vector-to-llvm \
+// RUN: --convert-func-to-llvm --convert-arith-to-llvm | \
+// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
+// RUN: --shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext | \
+// RUN: FileCheck %s --match-full-lines
+
+// Ops in this function *only* will be emulated using i8 types.
+func.func @emulate_shrsi(%lhs : i16, %rhs : i16) -> (i16) {
+ %res = arith.shrsi %lhs, %rhs : i16
+ return %res : i16
+}
+
+func.func @check_shrsi(%lhs : i16, %rhs : i16) -> () {
+ %res = func.call @emulate_shrsi(%lhs, %rhs) : (i16, i16) -> (i16)
+ vector.print %res : i16
+ return
+}
+
+func.func @entry() {
+ %cst0 = arith.constant 0 : i16
+ %cst1 = arith.constant 1 : i16
+ %cst2 = arith.constant 2 : i16
+ %cst7 = arith.constant 7 : i16
+ %cst8 = arith.constant 8 : i16
+ %cst9 = arith.constant 9 : i16
+ %cst15 = arith.constant 15 : i16
+
+ %cst_n1 = arith.constant -1 : i16
+
+ %cst1337 = arith.constant 1337 : i16
+ %cst_n1337 = arith.constant -1337 : i16
+
+ %cst_i16_min = arith.constant -32768 : i16
+
+ // CHECK: -32768
+ // CHECK-NEXT: -16384
+ // CHECK-NEXT: -8192
+ // CHECK-NEXT: -256
+ // CHECK-NEXT: -128
+ // CHECK-NEXT: -64
+ // CHECK-NEXT: -1
+ func.call @check_shrsi(%cst_i16_min, %cst0) : (i16, i16) -> ()
+ func.call @check_shrsi(%cst_i16_min, %cst1) : (i16, i16) -> ()
+ func.call @check_shrsi(%cst_i16_min, %cst2) : (i16, i16) -> ()
+ func.call @check_shrsi(%cst_i16_min, %cst7) : (i16, i16) -> ()
+ func.call @check_shrsi(%cst_i16_min, %cst8) : (i16, i16) -> ()
+ func.call @check_shrsi(%cst_i16_min, %cst9) : (i16, i16) -> ()
+ func.call @check_shrsi(%cst_i16_min, %cst15) : (i16, i16) -> ()
+
+ // CHECK-NEXT: 0
+ // CHECK-NEXT: 0
+ // CHECK-NEXT: 0
+ // CHECK-NEXT: 1
+ // CHECK-NEXT: -1
+ // CHECK-NEXT: -1
+ func.call @check_shrsi(%cst0, %cst0) : (i16, i16) -> ()
+ func.call @check_shrsi(%cst0, %cst1) : (i16, i16) -> ()
+ func.call @check_shrsi(%cst1, %cst1) : (i16, i16) -> ()
+ func.call @check_shrsi(%cst1, %cst0) : (i16, i16) -> ()
+ func.call @check_shrsi(%cst_n1, %cst1) : (i16, i16) -> ()
+ func.call @check_shrsi(%cst_n1, %cst15) : (i16, i16) -> ()
+
+ // CHECK-NEXT: 1337
+ // CHECK-NEXT: 334
+ // CHECK-NEXT: 10
+ // CHECK-NEXT: 5
+ // CHECK-NEXT: 2
+ // CHECK-NEXT: 0
+ func.call @check_shrsi(%cst1337, %cst0) : (i16, i16) -> ()
+ func.call @check_shrsi(%cst1337, %cst2) : (i16, i16) -> ()
+ func.call @check_shrsi(%cst1337, %cst7) : (i16, i16) -> ()
+ func.call @check_shrsi(%cst1337, %cst8) : (i16, i16) -> ()
+ func.call @check_shrsi(%cst1337, %cst9) : (i16, i16) -> ()
+ func.call @check_shrsi(%cst1337, %cst15) : (i16, i16) -> ()
+
+ // CHECK-NEXT: -1337
+ // CHECK-NEXT: -335
+ // CHECK-NEXT: -11
+ // CHECK-NEXT: -6
+ // CHECK-NEXT: -3
+ // CHECK-NEXT: -1
+ func.call @check_shrsi(%cst_n1337, %cst0) : (i16, i16) -> ()
+ func.call @check_shrsi(%cst_n1337, %cst2) : (i16, i16) -> ()
+ func.call @check_shrsi(%cst_n1337, %cst7) : (i16, i16) -> ()
+ func.call @check_shrsi(%cst_n1337, %cst8) : (i16, i16) -> ()
+ func.call @check_shrsi(%cst_n1337, %cst9) : (i16, i16) -> ()
+ func.call @check_shrsi(%cst_n1337, %cst15) : (i16, i16) -> ()
+
+ return
+}
diff --git a/mlir/test/Integration/Dialect/Arith/CPU/test-wide-int-emulation-shrui-i16.mlir b/mlir/test/Integration/Dialect/Arith/CPU/test-wide-int-emulation-shrui-i16.mlir
index 2aa75f32bd00c..e22e28f2555a3 100644
--- a/mlir/test/Integration/Dialect/Arith/CPU/test-wide-int-emulation-shrui-i16.mlir
+++ b/mlir/test/Integration/Dialect/Arith/CPU/test-wide-int-emulation-shrui-i16.mlir
@@ -38,6 +38,7 @@ func.func @entry() {
%cst_n1 = arith.constant -1 : i16
%cst1337 = arith.constant 1337 : i16
+ %cst_n1337 = arith.constant -1337 : i16
%cst_i16_min = arith.constant -32768 : i16
@@ -67,6 +68,19 @@ func.func @entry() {
func.call @check_shrui(%cst1337, %cst9) : (i16, i16) -> ()
func.call @check_shrui(%cst1337, %cst15) : (i16, i16) -> ()
+ // CHECK-NEXT: -1337
+ // CHECK-NEXT: 16049
+ // CHECK-NEXT: 501
+ // CHECK-NEXT: 250
+ // CHECK-NEXT: 125
+ // CHECK-NEXT: 1
+ func.call @check_shrui(%cst_n1337, %cst0) : (i16, i16) -> ()
+ func.call @check_shrui(%cst_n1337, %cst2) : (i16, i16) -> ()
+ func.call @check_shrui(%cst_n1337, %cst7) : (i16, i16) -> ()
+ func.call @check_shrui(%cst_n1337, %cst8) : (i16, i16) -> ()
+ func.call @check_shrui(%cst_n1337, %cst9) : (i16, i16) -> ()
+ func.call @check_shrui(%cst_n1337, %cst15) : (i16, i16) -> ()
+
// CHECK-NEXT: 16384
// CHECK-NEXT: 8192
// CHECK-NEXT: 256
More information about the Mlir-commits
mailing list