[Mlir-commits] [mlir] 40126e6 - [mlir][arith] Add andi, ori, and xori support to WIE
Jakub Kuderski
llvmlistbot at llvm.org
Wed Oct 5 11:35:54 PDT 2022
Author: Jakub Kuderski
Date: 2022-10-05T14:34:42-04:00
New Revision: 40126e66b623273bd3e85308c3edfa602b286b35
URL: https://github.com/llvm/llvm-project/commit/40126e66b623273bd3e85308c3edfa602b286b35
DIFF: https://github.com/llvm/llvm-project/commit/40126e66b623273bd3e85308c3edfa602b286b35.diff
LOG: [mlir][arith] Add andi, ori, and xori support to WIE
Reviewed By: ThomasRaoux
Differential Revision: https://reviews.llvm.org/D135204
Added:
Modified:
mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp
mlir/test/Dialect/Arith/emulate-wide-int.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp
index 3bec82c5cd03f..ea132079924b6 100644
--- a/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp
@@ -287,6 +287,40 @@ struct ConvertAddI final : OpConversionPattern<arith::AddIOp> {
}
};
+//===----------------------------------------------------------------------===//
+// ConvertBitwiseBinary
+//===----------------------------------------------------------------------===//
+
+/// Conversion pattern template for bitwise binary ops, e.g., `arith.andi`.
+template <typename BinaryOp>
+struct ConvertBitwiseBinary final : OpConversionPattern<BinaryOp> {
+ using OpConversionPattern<BinaryOp>::OpConversionPattern;
+ using OpAdaptor = typename OpConversionPattern<BinaryOp>::OpAdaptor;
+
+ LogicalResult
+ matchAndRewrite(BinaryOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = op->getLoc();
+ auto newTy = this->getTypeConverter()
+ ->convertType(op.getType())
+ .template dyn_cast_or_null<VectorType>();
+ if (!newTy)
+ return rewriter.notifyMatchFailure(loc, "unsupported type");
+
+ auto [lhsElem0, lhsElem1] =
+ extractLastDimHalves(rewriter, loc, adaptor.getLhs());
+ auto [rhsElem0, rhsElem1] =
+ extractLastDimHalves(rewriter, loc, adaptor.getRhs());
+
+ Value resElem0 = rewriter.create<BinaryOp>(loc, lhsElem0, rhsElem0);
+ Value resElem1 = rewriter.create<BinaryOp>(loc, lhsElem1, rhsElem1);
+ Value resultVec =
+ constructResultVector(rewriter, loc, newTy, {resElem0, resElem1});
+ rewriter.replaceOp(op, resultVec);
+ return success();
+ }
+};
+
//===----------------------------------------------------------------------===//
// ConvertMulI
//===----------------------------------------------------------------------===//
@@ -694,6 +728,9 @@ void arith::populateWideIntEmulationPatterns(
ConvertConstant, ConvertVectorPrint,
// Binary ops.
ConvertAddI, ConvertMulI, ConvertShRUI,
+ // Bitwise binary ops.
+ ConvertBitwiseBinary<arith::AndIOp>, ConvertBitwiseBinary<arith::OrIOp>,
+ ConvertBitwiseBinary<arith::XOrIOp>,
// Extension and truncation ops.
ConvertExtSI, ConvertExtUI, ConvertTruncI>(typeConverter,
patterns.getContext());
diff --git a/mlir/test/Dialect/Arith/emulate-wide-int.mlir b/mlir/test/Dialect/Arith/emulate-wide-int.mlir
index 940b125d66192..59451f55d048f 100644
--- a/mlir/test/Dialect/Arith/emulate-wide-int.mlir
+++ b/mlir/test/Dialect/Arith/emulate-wide-int.mlir
@@ -331,3 +331,81 @@ func.func @shrui_vector(%a : vector<3xi64>, %b : vector<3xi64>) -> vector<3xi64>
%m = arith.shrui %a, %b : vector<3xi64>
return %m : vector<3xi64>
}
+
+// CHECK-LABEL: func @andi_scalar_a_b
+// CHECK-SAME: ([[ARG0:%.+]]: vector<2xi32>, [[ARG1:%.+]]: vector<2xi32>) -> vector<2xi32>
+// CHECK-NEXT: [[LOW0:%.+]] = vector.extract [[ARG0]][0] : vector<2xi32>
+// CHECK-NEXT: [[HIGH0:%.+]] = vector.extract [[ARG0]][1] : vector<2xi32>
+// CHECK-NEXT: [[LOW1:%.+]] = vector.extract [[ARG1]][0] : vector<2xi32>
+// CHECK-NEXT: [[HIGH1:%.+]] = vector.extract [[ARG1]][1] : vector<2xi32>
+// CHECK-NEXT: [[RES0:%.+]] = arith.andi [[LOW0]], [[LOW1]] : i32
+// CHECK-NEXT: [[RES1:%.+]] = arith.andi [[HIGH0]], [[HIGH1]] : i32
+// CHECK: [[INS0:%.+]] = vector.insert [[RES0]], {{%.+}} [0] : i32 into vector<2xi32>
+// CHECK-NEXT: [[INS1:%.+]] = vector.insert [[RES1]], [[INS0]] [1] : i32 into vector<2xi32>
+// CHECK-NEXT: return [[INS1]] : vector<2xi32>
+func.func @andi_scalar_a_b(%a : i64, %b : i64) -> i64 {
+ %x = arith.andi %a, %b : i64
+ return %x : i64
+}
+
+// CHECK-LABEL: func @andi_vector_a_b
+// CHECK-SAME: ([[ARG0:%.+]]: vector<3x2xi32>, [[ARG1:%.+]]: vector<3x2xi32>) -> vector<3x2xi32>
+// CHECK: {{%.+}} = arith.andi {{%.+}}, {{%.+}} : vector<3x1xi32>
+// CHECK-NEXT: {{%.+}} = arith.andi {{%.+}}, {{%.+}} : vector<3x1xi32>
+// CHECK: return {{.+}} : vector<3x2xi32>
+func.func @andi_vector_a_b(%a : vector<3xi64>, %b : vector<3xi64>) -> vector<3xi64> {
+ %x = arith.andi %a, %b : vector<3xi64>
+ return %x : vector<3xi64>
+}
+
+// CHECK-LABEL: func @ori_scalar_a_b
+// CHECK-SAME: ([[ARG0:%.+]]: vector<2xi32>, [[ARG1:%.+]]: vector<2xi32>) -> vector<2xi32>
+// CHECK-NEXT: [[LOW0:%.+]] = vector.extract [[ARG0]][0] : vector<2xi32>
+// CHECK-NEXT: [[HIGH0:%.+]] = vector.extract [[ARG0]][1] : vector<2xi32>
+// CHECK-NEXT: [[LOW1:%.+]] = vector.extract [[ARG1]][0] : vector<2xi32>
+// CHECK-NEXT: [[HIGH1:%.+]] = vector.extract [[ARG1]][1] : vector<2xi32>
+// CHECK-NEXT: [[RES0:%.+]] = arith.ori [[LOW0]], [[LOW1]] : i32
+// CHECK-NEXT: [[RES1:%.+]] = arith.ori [[HIGH0]], [[HIGH1]] : i32
+// CHECK: [[INS0:%.+]] = vector.insert [[RES0]], {{%.+}} [0] : i32 into vector<2xi32>
+// CHECK-NEXT: [[INS1:%.+]] = vector.insert [[RES1]], [[INS0]] [1] : i32 into vector<2xi32>
+// CHECK-NEXT: return [[INS1]] : vector<2xi32>
+func.func @ori_scalar_a_b(%a : i64, %b : i64) -> i64 {
+ %x = arith.ori %a, %b : i64
+ return %x : i64
+}
+
+// CHECK-LABEL: func @ori_vector_a_b
+// CHECK-SAME: ([[ARG0:%.+]]: vector<3x2xi32>, [[ARG1:%.+]]: vector<3x2xi32>) -> vector<3x2xi32>
+// CHECK: {{%.+}} = arith.ori {{%.+}}, {{%.+}} : vector<3x1xi32>
+// CHECK-NEXT: {{%.+}} = arith.ori {{%.+}}, {{%.+}} : vector<3x1xi32>
+// CHECK: return {{.+}} : vector<3x2xi32>
+func.func @ori_vector_a_b(%a : vector<3xi64>, %b : vector<3xi64>) -> vector<3xi64> {
+ %x = arith.ori %a, %b : vector<3xi64>
+ return %x : vector<3xi64>
+}
+
+// CHECK-LABEL: func @xori_scalar_a_b
+// CHECK-SAME: ([[ARG0:%.+]]: vector<2xi32>, [[ARG1:%.+]]: vector<2xi32>) -> vector<2xi32>
+// CHECK-NEXT: [[LOW0:%.+]] = vector.extract [[ARG0]][0] : vector<2xi32>
+// CHECK-NEXT: [[HIGH0:%.+]] = vector.extract [[ARG0]][1] : vector<2xi32>
+// CHECK-NEXT: [[LOW1:%.+]] = vector.extract [[ARG1]][0] : vector<2xi32>
+// CHECK-NEXT: [[HIGH1:%.+]] = vector.extract [[ARG1]][1] : vector<2xi32>
+// CHECK-NEXT: [[RES0:%.+]] = arith.xori [[LOW0]], [[LOW1]] : i32
+// CHECK-NEXT: [[RES1:%.+]] = arith.xori [[HIGH0]], [[HIGH1]] : i32
+// CHECK: [[INS0:%.+]] = vector.insert [[RES0]], {{%.+}} [0] : i32 into vector<2xi32>
+// CHECK-NEXT: [[INS1:%.+]] = vector.insert [[RES1]], [[INS0]] [1] : i32 into vector<2xi32>
+// CHECK-NEXT: return [[INS1]] : vector<2xi32>
+func.func @xori_scalar_a_b(%a : i64, %b : i64) -> i64 {
+ %x = arith.xori %a, %b : i64
+ return %x : i64
+}
+
+// CHECK-LABEL: func @xori_vector_a_b
+// CHECK-SAME: ([[ARG0:%.+]]: vector<3x2xi32>, [[ARG1:%.+]]: vector<3x2xi32>) -> vector<3x2xi32>
+// CHECK: {{%.+}} = arith.xori {{%.+}}, {{%.+}} : vector<3x1xi32>
+// CHECK-NEXT: {{%.+}} = arith.xori {{%.+}}, {{%.+}} : vector<3x1xi32>
+// CHECK: return {{.+}} : vector<3x2xi32>
+func.func @xori_vector_a_b(%a : vector<3xi64>, %b : vector<3xi64>) -> vector<3xi64> {
+ %x = arith.xori %a, %b : vector<3xi64>
+ return %x : vector<3xi64>
+}
More information about the Mlir-commits
mailing list