[Mlir-commits] [mlir] [mlir][emitc] Lower arith.andi, arith.ori, arith.xori to EmitC (PR #93666)

Corentin Ferry llvmlistbot at llvm.org
Fri May 31 01:23:07 PDT 2024


https://github.com/cferry-AMD updated https://github.com/llvm/llvm-project/pull/93666

>From 1ede9edd82d7fbf27ccd252aff458b09fed07ca8 Mon Sep 17 00:00:00 2001
From: Corentin Ferry <corentin.ferry at amd.com>
Date: Wed, 29 May 2024 10:46:20 +0100
Subject: [PATCH 1/2] Lower arith.andi, ori, xori

---
 .../Conversion/ArithToEmitC/ArithToEmitC.cpp  | 66 +++++++++++++++++++
 .../ArithToEmitC/arith-to-emitc.mlir          | 38 +++++++++++
 2 files changed, 104 insertions(+)

diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
index 388794ec122d2..b40ff00390882 100644
--- a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
+++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
@@ -40,6 +40,29 @@ class ArithConstantOpConversionPattern
   }
 };
 
+/// Check if the signedness of type \p ty matches the expected
+/// signedness, and issue a type with the correct signedness if
+/// necessary.
+Type adaptIntegralTypeSignedness(Type ty, bool needsUnsigned) {
+  if (isa<IntegerType>(ty)) {
+    // Turns signless integers into signed integers.
+    if (ty.isUnsignedInteger() != needsUnsigned) {
+      auto signedness = needsUnsigned
+                            ? IntegerType::SignednessSemantics::Unsigned
+                            : IntegerType::SignednessSemantics::Signed;
+      return IntegerType::get(ty.getContext(), ty.getIntOrFloatBitWidth(),
+                              signedness);
+    }
+  }
+  return ty;
+}
+
+/// Insert a cast operation to type \p ty if \p val
+/// does not have this type.
+Value adaptValueType(Value val, ConversionPatternRewriter &rewriter, Type ty) {
+  return rewriter.createOrFold<emitc::CastOp>(val.getLoc(), ty, val);
+}
+
 class CmpIOpConversion : public OpConversionPattern<arith::CmpIOp> {
 public:
   using OpConversionPattern::OpConversionPattern;
@@ -265,6 +288,46 @@ class IntegerOpConversion final : public OpConversionPattern<ArithOp> {
   }
 };
 
+template <typename ArithOp, typename EmitCOp>
+class BitwiseOpConversion : public OpConversionPattern<ArithOp> {
+public:
+  using OpConversionPattern<ArithOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+
+    Type type = this->getTypeConverter()->convertType(op.getType());
+    if (!isa_and_nonnull<IntegerType>(type)) {
+      return rewriter.notifyMatchFailure(
+          op,
+          "expected integer type, vector/tensor support not yet implemented");
+    }
+
+    // Bitwise ops can be performed directly on booleans
+    if (type.isInteger(1)) {
+      rewriter.replaceOpWithNewOp<EmitCOp>(op, type, adaptor.getLhs(),
+                                           adaptor.getRhs());
+      return success();
+    }
+
+    // Bitwise ops are defined by the C standard on unsigned operands.
+    Type arithmeticType =
+        adaptIntegralTypeSignedness(type, /*needsUnsigned=*/true);
+
+    Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType);
+    Value rhs = adaptValueType(adaptor.getRhs(), rewriter, arithmeticType);
+
+    Value arithmeticResult = rewriter.template create<EmitCOp>(
+        op.getLoc(), arithmeticType, lhs, rhs);
+
+    Value result = adaptValueType(arithmeticResult, rewriter, type);
+
+    rewriter.replaceOp(op, result);
+    return success();
+  }
+};
+
 class SelectOpConversion : public OpConversionPattern<arith::SelectOp> {
 public:
   using OpConversionPattern<arith::SelectOp>::OpConversionPattern;
@@ -401,6 +464,9 @@ void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter,
     IntegerOpConversion<arith::AddIOp, emitc::AddOp>,
     IntegerOpConversion<arith::MulIOp, emitc::MulOp>,
     IntegerOpConversion<arith::SubIOp, emitc::SubOp>,
+    BitwiseOpConversion<arith::AndIOp, emitc::BitwiseAndOp>,
+    BitwiseOpConversion<arith::OrIOp, emitc::BitwiseOrOp>,
+    BitwiseOpConversion<arith::XOrIOp, emitc::BitwiseXorOp>,
     CmpIOpConversion,
     SelectOpConversion,
     // Truncation is guaranteed for unsigned types.
diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
index dac3fd99b607c..e34d93c20bb70 100644
--- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
+++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
@@ -88,6 +88,44 @@ func.func @arith_index(%arg0: index, %arg1: index) {
 
 // -----
 
+// CHECK-LABEL: arith_bitwise
+// CHECK-SAME: %[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32
+func.func @arith_bitwise(%arg0: i32, %arg1: i32) {
+  // CHECK: %[[C1:[^ ]*]] = emitc.cast %[[ARG0]] : i32 to ui32
+  // CHECK: %[[C2:[^ ]*]] = emitc.cast %[[ARG1]] : i32 to ui32
+  // CHECK: %[[AND:[^ ]*]] = emitc.bitwise_and %[[C1]], %[[C2]] : (ui32, ui32) -> ui32
+  // CHECK: %[[C3:[^ ]*]] = emitc.cast %[[AND]] : ui32 to i32
+  %5 = arith.andi %arg0, %arg1 : i32
+  // CHECK: %[[C1:[^ ]*]] = emitc.cast %[[ARG0]] : i32 to ui32
+  // CHECK: %[[C2:[^ ]*]] = emitc.cast %[[ARG1]] : i32 to ui32
+  // CHECK: %[[OR:[^ ]*]] = emitc.bitwise_or %[[C1]], %[[C2]] : (ui32, ui32) -> ui32
+  // CHECK: %[[C3:[^ ]*]] = emitc.cast %[[OR]] : ui32 to i32
+  %6 = arith.ori %arg0, %arg1 : i32
+  // CHECK: %[[C1:[^ ]*]] = emitc.cast %[[ARG0]] : i32 to ui32
+  // CHECK: %[[C2:[^ ]*]] = emitc.cast %[[ARG1]] : i32 to ui32
+  // CHECK: %[[XOR:[^ ]*]] = emitc.bitwise_xor %[[C1]], %[[C2]] : (ui32, ui32) -> ui32
+  // CHECK: %[[C3:[^ ]*]] = emitc.cast %[[XOR]] : ui32 to i32
+  %7 = arith.xori %arg0, %arg1 : i32
+
+  return
+}
+
+// -----
+
+// CHECK-LABEL: arith_bitwise_bool
+func.func @arith_bitwise_bool(%arg0: i1, %arg1: i1) {
+  // CHECK: %[[AND:[^ ]*]] = emitc.bitwise_and %arg0, %arg1 : (i1, i1) -> i1
+  %5 = arith.andi %arg0, %arg1 : i1
+  // CHECK: %[[OR:[^ ]*]] = emitc.bitwise_or %arg0, %arg1 : (i1, i1) -> i1
+  %6 = arith.ori %arg0, %arg1 : i1
+  // CHECK: %[[xor:[^ ]*]] = emitc.bitwise_xor %arg0, %arg1 : (i1, i1) -> i1
+  %7 = arith.xori %arg0, %arg1 : i1
+  
+  return
+}
+
+// -----
+
 // CHECK-LABEL: arith_signed_integer_div_rem
 func.func @arith_signed_integer_div_rem(%arg0: i32, %arg1: i32) {
   // CHECK: emitc.div %arg0, %arg1 : (i32, i32) -> i32

>From 844c9337a0a51887fa626614f68a81d1224fb676 Mon Sep 17 00:00:00 2001
From: Corentin Ferry <corentin.ferry at amd.com>
Date: Fri, 31 May 2024 09:13:58 +0100
Subject: [PATCH 2/2] Address review comments

---
 .../Conversion/ArithToEmitC/ArithToEmitC.cpp  |  8 ++------
 .../ArithToEmitC/arith-to-emitc.mlir          | 19 ++++++++++---------
 2 files changed, 12 insertions(+), 15 deletions(-)

diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
index b40ff00390882..9b1e47147861c 100644
--- a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
+++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
@@ -40,12 +40,9 @@ class ArithConstantOpConversionPattern
   }
 };
 
-/// Check if the signedness of type \p ty matches the expected
-/// signedness, and issue a type with the correct signedness if
-/// necessary.
+/// Get the signed or unsigned type corresponding to \p ty.
 Type adaptIntegralTypeSignedness(Type ty, bool needsUnsigned) {
   if (isa<IntegerType>(ty)) {
-    // Turns signless integers into signed integers.
     if (ty.isUnsignedInteger() != needsUnsigned) {
       auto signedness = needsUnsigned
                             ? IntegerType::SignednessSemantics::Unsigned
@@ -57,8 +54,7 @@ Type adaptIntegralTypeSignedness(Type ty, bool needsUnsigned) {
   return ty;
 }
 
-/// Insert a cast operation to type \p ty if \p val
-/// does not have this type.
+/// Insert a cast operation to type \p ty if \p val does not have this type.
 Value adaptValueType(Value val, ConversionPatternRewriter &rewriter, Type ty) {
   return rewriter.createOrFold<emitc::CastOp>(val.getLoc(), ty, val);
 }
diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
index e34d93c20bb70..5b1a1860b0f91 100644
--- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
+++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
@@ -95,17 +95,17 @@ func.func @arith_bitwise(%arg0: i32, %arg1: i32) {
   // CHECK: %[[C2:[^ ]*]] = emitc.cast %[[ARG1]] : i32 to ui32
   // CHECK: %[[AND:[^ ]*]] = emitc.bitwise_and %[[C1]], %[[C2]] : (ui32, ui32) -> ui32
   // CHECK: %[[C3:[^ ]*]] = emitc.cast %[[AND]] : ui32 to i32
-  %5 = arith.andi %arg0, %arg1 : i32
+  %0 = arith.andi %arg0, %arg1 : i32
   // CHECK: %[[C1:[^ ]*]] = emitc.cast %[[ARG0]] : i32 to ui32
   // CHECK: %[[C2:[^ ]*]] = emitc.cast %[[ARG1]] : i32 to ui32
   // CHECK: %[[OR:[^ ]*]] = emitc.bitwise_or %[[C1]], %[[C2]] : (ui32, ui32) -> ui32
   // CHECK: %[[C3:[^ ]*]] = emitc.cast %[[OR]] : ui32 to i32
-  %6 = arith.ori %arg0, %arg1 : i32
+  %1 = arith.ori %arg0, %arg1 : i32
   // CHECK: %[[C1:[^ ]*]] = emitc.cast %[[ARG0]] : i32 to ui32
   // CHECK: %[[C2:[^ ]*]] = emitc.cast %[[ARG1]] : i32 to ui32
   // CHECK: %[[XOR:[^ ]*]] = emitc.bitwise_xor %[[C1]], %[[C2]] : (ui32, ui32) -> ui32
   // CHECK: %[[C3:[^ ]*]] = emitc.cast %[[XOR]] : ui32 to i32
-  %7 = arith.xori %arg0, %arg1 : i32
+  %2 = arith.xori %arg0, %arg1 : i32
 
   return
 }
@@ -113,13 +113,14 @@ func.func @arith_bitwise(%arg0: i32, %arg1: i32) {
 // -----
 
 // CHECK-LABEL: arith_bitwise_bool
+// CHECK-SAME: %[[ARG0:.*]]: i1, %[[ARG1:.*]]: i1
 func.func @arith_bitwise_bool(%arg0: i1, %arg1: i1) {
-  // CHECK: %[[AND:[^ ]*]] = emitc.bitwise_and %arg0, %arg1 : (i1, i1) -> i1
-  %5 = arith.andi %arg0, %arg1 : i1
-  // CHECK: %[[OR:[^ ]*]] = emitc.bitwise_or %arg0, %arg1 : (i1, i1) -> i1
-  %6 = arith.ori %arg0, %arg1 : i1
-  // CHECK: %[[xor:[^ ]*]] = emitc.bitwise_xor %arg0, %arg1 : (i1, i1) -> i1
-  %7 = arith.xori %arg0, %arg1 : i1
+  // CHECK: %[[AND:[^ ]*]] = emitc.bitwise_and %[[ARG0]], %[[ARG1]] : (i1, i1) -> i1
+  %0 = arith.andi %arg0, %arg1 : i1
+  // CHECK: %[[OR:[^ ]*]] = emitc.bitwise_or %[[ARG0]], %[[ARG1]] : (i1, i1) -> i1
+  %1 = arith.ori %arg0, %arg1 : i1
+  // CHECK: %[[xor:[^ ]*]] = emitc.bitwise_xor %[[ARG0]], %[[ARG1]] : (i1, i1) -> i1
+  %2 = arith.xori %arg0, %arg1 : i1
   
   return
 }



More information about the Mlir-commits mailing list