[Mlir-commits] [mlir] 40126e6 - [mlir][arith] Add andi, ori, and xori support to WIE

Jakub Kuderski llvmlistbot at llvm.org
Wed Oct 5 11:35:54 PDT 2022


Author: Jakub Kuderski
Date: 2022-10-05T14:34:42-04:00
New Revision: 40126e66b623273bd3e85308c3edfa602b286b35

URL: https://github.com/llvm/llvm-project/commit/40126e66b623273bd3e85308c3edfa602b286b35
DIFF: https://github.com/llvm/llvm-project/commit/40126e66b623273bd3e85308c3edfa602b286b35.diff

LOG: [mlir][arith] Add andi, ori, and xori support to WIE

Reviewed By: ThomasRaoux

Differential Revision: https://reviews.llvm.org/D135204

Added: 
    

Modified: 
    mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp
    mlir/test/Dialect/Arith/emulate-wide-int.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp
index 3bec82c5cd03f..ea132079924b6 100644
--- a/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp
@@ -287,6 +287,40 @@ struct ConvertAddI final : OpConversionPattern<arith::AddIOp> {
   }
 };
 
+//===----------------------------------------------------------------------===//
+// ConvertBitwiseBinary
+//===----------------------------------------------------------------------===//
+
+/// Conversion pattern template for bitwise binary ops, e.g., `arith.andi`.
+template <typename BinaryOp>
+struct ConvertBitwiseBinary final : OpConversionPattern<BinaryOp> {
+  using OpConversionPattern<BinaryOp>::OpConversionPattern;
+  using OpAdaptor = typename OpConversionPattern<BinaryOp>::OpAdaptor;
+
+  LogicalResult
+  matchAndRewrite(BinaryOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = op->getLoc();
+    auto newTy = this->getTypeConverter()
+                     ->convertType(op.getType())
+                     .template dyn_cast_or_null<VectorType>();
+    if (!newTy)
+      return rewriter.notifyMatchFailure(loc, "unsupported type");
+
+    auto [lhsElem0, lhsElem1] =
+        extractLastDimHalves(rewriter, loc, adaptor.getLhs());
+    auto [rhsElem0, rhsElem1] =
+        extractLastDimHalves(rewriter, loc, adaptor.getRhs());
+
+    Value resElem0 = rewriter.create<BinaryOp>(loc, lhsElem0, rhsElem0);
+    Value resElem1 = rewriter.create<BinaryOp>(loc, lhsElem1, rhsElem1);
+    Value resultVec =
+        constructResultVector(rewriter, loc, newTy, {resElem0, resElem1});
+    rewriter.replaceOp(op, resultVec);
+    return success();
+  }
+};
+
 //===----------------------------------------------------------------------===//
 // ConvertMulI
 //===----------------------------------------------------------------------===//
@@ -694,6 +728,9 @@ void arith::populateWideIntEmulationPatterns(
       ConvertConstant, ConvertVectorPrint,
       // Binary ops.
       ConvertAddI, ConvertMulI, ConvertShRUI,
+      // Bitwise binary ops.
+      ConvertBitwiseBinary<arith::AndIOp>, ConvertBitwiseBinary<arith::OrIOp>,
+      ConvertBitwiseBinary<arith::XOrIOp>,
       // Extension and truncation ops.
       ConvertExtSI, ConvertExtUI, ConvertTruncI>(typeConverter,
                                                  patterns.getContext());

diff  --git a/mlir/test/Dialect/Arith/emulate-wide-int.mlir b/mlir/test/Dialect/Arith/emulate-wide-int.mlir
index 940b125d66192..59451f55d048f 100644
--- a/mlir/test/Dialect/Arith/emulate-wide-int.mlir
+++ b/mlir/test/Dialect/Arith/emulate-wide-int.mlir
@@ -331,3 +331,81 @@ func.func @shrui_vector(%a : vector<3xi64>, %b : vector<3xi64>) -> vector<3xi64>
     %m = arith.shrui %a, %b : vector<3xi64>
     return %m : vector<3xi64>
 }
+
+// CHECK-LABEL: func @andi_scalar_a_b
+// CHECK-SAME:    ([[ARG0:%.+]]: vector<2xi32>, [[ARG1:%.+]]: vector<2xi32>) -> vector<2xi32>
+// CHECK-NEXT:    [[LOW0:%.+]]   = vector.extract [[ARG0]][0] : vector<2xi32>
+// CHECK-NEXT:    [[HIGH0:%.+]]  = vector.extract [[ARG0]][1] : vector<2xi32>
+// CHECK-NEXT:    [[LOW1:%.+]]   = vector.extract [[ARG1]][0] : vector<2xi32>
+// CHECK-NEXT:    [[HIGH1:%.+]]  = vector.extract [[ARG1]][1] : vector<2xi32>
+// CHECK-NEXT:    [[RES0:%.+]]   = arith.andi [[LOW0]], [[LOW1]] : i32
+// CHECK-NEXT:    [[RES1:%.+]]   = arith.andi [[HIGH0]], [[HIGH1]] : i32
+// CHECK:         [[INS0:%.+]]   = vector.insert [[RES0]], {{%.+}} [0] : i32 into vector<2xi32>
+// CHECK-NEXT:    [[INS1:%.+]]   = vector.insert [[RES1]], [[INS0]] [1] : i32 into vector<2xi32>
+// CHECK-NEXT:    return [[INS1]] : vector<2xi32>
+func.func @andi_scalar_a_b(%a : i64, %b : i64) -> i64 {
+    %x = arith.andi %a, %b : i64
+    return %x : i64
+}
+
+// CHECK-LABEL: func @andi_vector_a_b
+// CHECK-SAME:   ([[ARG0:%.+]]: vector<3x2xi32>, [[ARG1:%.+]]: vector<3x2xi32>) -> vector<3x2xi32>
+// CHECK:         {{%.+}} = arith.andi {{%.+}}, {{%.+}} : vector<3x1xi32>
+// CHECK-NEXT:    {{%.+}} = arith.andi {{%.+}}, {{%.+}} : vector<3x1xi32>
+// CHECK:         return {{.+}} : vector<3x2xi32>
+func.func @andi_vector_a_b(%a : vector<3xi64>, %b : vector<3xi64>) -> vector<3xi64> {
+    %x = arith.andi %a, %b : vector<3xi64>
+    return %x : vector<3xi64>
+}
+
+// CHECK-LABEL: func @ori_scalar_a_b
+// CHECK-SAME:    ([[ARG0:%.+]]: vector<2xi32>, [[ARG1:%.+]]: vector<2xi32>) -> vector<2xi32>
+// CHECK-NEXT:    [[LOW0:%.+]]   = vector.extract [[ARG0]][0] : vector<2xi32>
+// CHECK-NEXT:    [[HIGH0:%.+]]  = vector.extract [[ARG0]][1] : vector<2xi32>
+// CHECK-NEXT:    [[LOW1:%.+]]   = vector.extract [[ARG1]][0] : vector<2xi32>
+// CHECK-NEXT:    [[HIGH1:%.+]]  = vector.extract [[ARG1]][1] : vector<2xi32>
+// CHECK-NEXT:    [[RES0:%.+]]   = arith.ori [[LOW0]], [[LOW1]] : i32
+// CHECK-NEXT:    [[RES1:%.+]]   = arith.ori [[HIGH0]], [[HIGH1]] : i32
+// CHECK:         [[INS0:%.+]]   = vector.insert [[RES0]], {{%.+}} [0] : i32 into vector<2xi32>
+// CHECK-NEXT:    [[INS1:%.+]]   = vector.insert [[RES1]], [[INS0]] [1] : i32 into vector<2xi32>
+// CHECK-NEXT:    return [[INS1]] : vector<2xi32>
+func.func @ori_scalar_a_b(%a : i64, %b : i64) -> i64 {
+    %x = arith.ori %a, %b : i64
+    return %x : i64
+}
+
+// CHECK-LABEL: func @ori_vector_a_b
+// CHECK-SAME:   ([[ARG0:%.+]]: vector<3x2xi32>, [[ARG1:%.+]]: vector<3x2xi32>) -> vector<3x2xi32>
+// CHECK:         {{%.+}} = arith.ori {{%.+}}, {{%.+}} : vector<3x1xi32>
+// CHECK-NEXT:    {{%.+}} = arith.ori {{%.+}}, {{%.+}} : vector<3x1xi32>
+// CHECK:         return {{.+}} : vector<3x2xi32>
+func.func @ori_vector_a_b(%a : vector<3xi64>, %b : vector<3xi64>) -> vector<3xi64> {
+    %x = arith.ori %a, %b : vector<3xi64>
+    return %x : vector<3xi64>
+}
+
+// CHECK-LABEL: func @xori_scalar_a_b
+// CHECK-SAME:    ([[ARG0:%.+]]: vector<2xi32>, [[ARG1:%.+]]: vector<2xi32>) -> vector<2xi32>
+// CHECK-NEXT:    [[LOW0:%.+]]   = vector.extract [[ARG0]][0] : vector<2xi32>
+// CHECK-NEXT:    [[HIGH0:%.+]]  = vector.extract [[ARG0]][1] : vector<2xi32>
+// CHECK-NEXT:    [[LOW1:%.+]]   = vector.extract [[ARG1]][0] : vector<2xi32>
+// CHECK-NEXT:    [[HIGH1:%.+]]  = vector.extract [[ARG1]][1] : vector<2xi32>
+// CHECK-NEXT:    [[RES0:%.+]]   = arith.xori [[LOW0]], [[LOW1]] : i32
+// CHECK-NEXT:    [[RES1:%.+]]   = arith.xori [[HIGH0]], [[HIGH1]] : i32
+// CHECK:         [[INS0:%.+]]   = vector.insert [[RES0]], {{%.+}} [0] : i32 into vector<2xi32>
+// CHECK-NEXT:    [[INS1:%.+]]   = vector.insert [[RES1]], [[INS0]] [1] : i32 into vector<2xi32>
+// CHECK-NEXT:    return [[INS1]] : vector<2xi32>
+func.func @xori_scalar_a_b(%a : i64, %b : i64) -> i64 {
+    %x = arith.xori %a, %b : i64
+    return %x : i64
+}
+
+// CHECK-LABEL: func @xori_vector_a_b
+// CHECK-SAME:   ([[ARG0:%.+]]: vector<3x2xi32>, [[ARG1:%.+]]: vector<3x2xi32>) -> vector<3x2xi32>
+// CHECK:         {{%.+}} = arith.xori {{%.+}}, {{%.+}} : vector<3x1xi32>
+// CHECK-NEXT:    {{%.+}} = arith.xori {{%.+}}, {{%.+}} : vector<3x1xi32>
+// CHECK:         return {{.+}} : vector<3x2xi32>
+func.func @xori_vector_a_b(%a : vector<3xi64>, %b : vector<3xi64>) -> vector<3xi64> {
+    %x = arith.xori %a, %b : vector<3xi64>
+    return %x : vector<3xi64>
+}


        


More information about the Mlir-commits mailing list