[Mlir-commits] [mlir] [mlir][emitc] Add EmitC lowering for arith.trunci, arith.extsi, arith.extui (PR #91491)

Corentin Ferry llvmlistbot at llvm.org
Wed May 22 06:48:41 PDT 2024


https://github.com/cferry-AMD updated https://github.com/llvm/llvm-project/pull/91491

>From 32ab952d6f53f16132a423396caa1c118440d8c1 Mon Sep 17 00:00:00 2001
From: Corentin Ferry <corentin.ferry at amd.com>
Date: Wed, 8 May 2024 14:02:04 +0200
Subject: [PATCH 1/4] Add EmitC lowering for arith.{trunci,extsi,extui}

---
 .../Conversion/ArithToEmitC/ArithToEmitC.cpp  | 76 +++++++++++++++++++
 .../arith-to-emitc-unsupported.mlir           | 19 +++++
 .../ArithToEmitC/arith-to-emitc.mlir          | 39 ++++++++++
 3 files changed, 134 insertions(+)

diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
index 1447b182ccfdb..6216e6ea89b9b 100644
--- a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
+++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
@@ -112,6 +112,78 @@ class CmpIOpConversion : public OpConversionPattern<arith::CmpIOp> {
   }
 };
 
+template <typename ArithOp, bool needsUnsigned>
+class CastConversion : public OpConversionPattern<ArithOp> {
+public:
+  using OpConversionPattern<ArithOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+
+    Type opReturnType = this->getTypeConverter()->convertType(op.getType());
+    if (!isa_and_nonnull<IntegerType>(opReturnType)) {
+      return rewriter.notifyMatchFailure(op, "expected integer result type");
+    }
+
+    if (adaptor.getOperands().size() != 1) {
+      return rewriter.notifyMatchFailure(
+          op, "CastConversion only supports unary ops");
+    }
+
+    Type operandType = adaptor.getIn().getType();
+    if (!isa_and_nonnull<IntegerType>(operandType)) {
+      return rewriter.notifyMatchFailure(op, "expected integer operand type");
+    }
+
+    bool isTruncation = operandType.getIntOrFloatBitWidth() >
+                        opReturnType.getIntOrFloatBitWidth();
+    bool doUnsigned = needsUnsigned || isTruncation;
+
+    Type castType = opReturnType;
+    // For int conversions: if the op is a ui variant and the type wanted as
+    // return type isn't unsigned, we need to issue an unsigned type to do
+    // the conversion.
+    if (castType.isUnsignedInteger() != doUnsigned) {
+      castType = rewriter.getIntegerType(opReturnType.getIntOrFloatBitWidth(),
+                                         /*isSigned=*/!doUnsigned);
+    }
+
+    Value actualOp = adaptor.getIn();
+    // Fix the signedness of the operand if necessary
+    if (operandType.isUnsignedInteger() != doUnsigned) {
+      Type correctSignednessType =
+          rewriter.getIntegerType(operandType.getIntOrFloatBitWidth(),
+                                  /*isSigned=*/!doUnsigned);
+      actualOp = rewriter.template create<emitc::CastOp>(
+          op.getLoc(), correctSignednessType, actualOp);
+    }
+
+    auto result = rewriter.template create<emitc::CastOp>(op.getLoc(), castType,
+                                                          actualOp);
+
+    // Fix the signedness of what this operation returns (for integers,
+    // the arith ops want signless results)
+    if (castType != opReturnType) {
+      result = rewriter.template create<emitc::CastOp>(op.getLoc(),
+                                                       opReturnType, result);
+    }
+
+    rewriter.replaceOp(op, result);
+    return success();
+  }
+};
+
+template <typename ArithOp>
+class UnsignedCastConversion : public CastConversion<ArithOp, true> {
+  using CastConversion<ArithOp, true>::CastConversion;
+};
+
+template <typename ArithOp>
+class SignedCastConversion : public CastConversion<ArithOp, false> {
+  using CastConversion<ArithOp, false>::CastConversion;
+};
+
 template <typename ArithOp, typename EmitCOp>
 class ArithOpConversion final : public OpConversionPattern<ArithOp> {
 public:
@@ -313,6 +385,10 @@ void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter,
     IntegerOpConversion<arith::SubIOp, emitc::SubOp>,
     CmpIOpConversion,
     SelectOpConversion,
+    // Truncation is guaranteed for unsigned types.
+    UnsignedCastConversion<arith::TruncIOp>,
+    SignedCastConversion<arith::ExtSIOp>,
+    UnsignedCastConversion<arith::ExtUIOp>,
     ItoFCastOpConversion<arith::SIToFPOp>,
     ItoFCastOpConversion<arith::UIToFPOp>,
     FtoICastOpConversion<arith::FPToSIOp>,
diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir
index 66dfa8fa3e157..551c3ba7a77ef 100644
--- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir
+++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir
@@ -63,3 +63,22 @@ func.func @arith_cast_fptoui_i1(%arg0: f32) -> i1 {
   return %t: i1
 }
 
+// -----
+
+func.func @index_cast(%arg0: i32) -> i32 {
+  // expected-error @+1 {{failed to legalize operation 'arith.index_cast'}}
+  %idx = arith.index_cast %arg0 : i32 to index
+  %int = arith.index_cast %idx : index to i32
+
+  return %int : i32
+}
+
+// -----
+
+func.func @index_castui(%arg0: i32) -> i32 {
+  // expected-error @+1 {{failed to legalize operation 'arith.index_castui'}}
+  %idx = arith.index_castui %arg0 : i32 to index
+  %int = arith.index_castui %idx : index to i32
+
+  return %int : i32
+}
diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
index 79fecd61494d0..80665bacd2a5c 100644
--- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
+++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
@@ -177,3 +177,42 @@ func.func @arith_int_to_float_cast_ops(%arg0: i8, %arg1: i64) {
 
   return
 }
+
+// -----
+
+func.func @trunci(%arg0: i32) -> i8 {
+  // CHECK-LABEL: trunci
+  // CHECK-SAME: (%[[Arg0:[^ ]*]]: i32)
+  // CHECK: %[[CastUI:.*]] = emitc.cast %[[Arg0]] : i32 to ui32
+  // CHECK: %[[Trunc:.*]] = emitc.cast %[[CastUI]] : ui32 to ui8
+  // CHECK: emitc.cast %[[Trunc]] : ui8 to i8
+  %truncd = arith.trunci %arg0 : i32 to i8
+
+  return %truncd : i8
+}
+
+// -----
+
+func.func @extsi(%arg0: i32) {
+  // CHECK-LABEL: extsi
+  // CHECK-SAME: ([[Arg0:[^ ]*]]: i32)
+  // CHECK: emitc.cast [[Arg0]] : i32 to i64
+
+  %extd = arith.extsi %arg0 : i32 to i64
+
+  return
+}
+
+// -----
+
+func.func @extui(%arg0: i32) {
+  // CHECK-LABEL: extui
+  // CHECK-SAME: (%[[Arg0:[^ ]*]]: i32)
+  // CHECK: %[[Conv0:.*]] = emitc.cast %[[Arg0]] : i32 to ui32
+  // CHECK: %[[Conv1:.*]] = emitc.cast %[[Conv0]] : ui32 to ui64
+  // CHECK: emitc.cast %[[Conv1]] : ui64 to i64
+
+  %extd = arith.extui %arg0 : i32 to i64
+
+  return
+}

>From 1259c29b364d732ed5c526f504ec0bc84ec21760 Mon Sep 17 00:00:00 2001
From: Corentin Ferry <corentin.ferry at amd.com>
Date: Mon, 13 May 2024 07:38:04 +0100
Subject: [PATCH 2/4] Review comments

---
 .../Conversion/ArithToEmitC/ArithToEmitC.cpp    | 17 +++++++----------
 .../arith-to-emitc-unsupported.mlir             |  4 ++--
 .../Conversion/ArithToEmitC/arith-to-emitc.mlir | 14 ++++++--------
 3 files changed, 15 insertions(+), 20 deletions(-)

diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
index 6216e6ea89b9b..60562d48726f5 100644
--- a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
+++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
@@ -112,7 +112,7 @@ class CmpIOpConversion : public OpConversionPattern<arith::CmpIOp> {
   }
 };
 
-template <typename ArithOp, bool needsUnsigned>
+template <typename ArithOp, bool castToUnsigned>
 class CastConversion : public OpConversionPattern<ArithOp> {
 public:
   using OpConversionPattern<ArithOp>::OpConversionPattern;
@@ -122,9 +122,8 @@ class CastConversion : public OpConversionPattern<ArithOp> {
                   ConversionPatternRewriter &rewriter) const override {
 
     Type opReturnType = this->getTypeConverter()->convertType(op.getType());
-    if (!isa_and_nonnull<IntegerType>(opReturnType)) {
+    if (!isa_and_nonnull<IntegerType>(opReturnType))
       return rewriter.notifyMatchFailure(op, "expected integer result type");
-    }
 
     if (adaptor.getOperands().size() != 1) {
       return rewriter.notifyMatchFailure(
@@ -132,16 +131,15 @@ class CastConversion : public OpConversionPattern<ArithOp> {
     }
 
     Type operandType = adaptor.getIn().getType();
-    if (!isa_and_nonnull<IntegerType>(operandType)) {
+    if (!isa_and_nonnull<IntegerType>(operandType))
       return rewriter.notifyMatchFailure(op, "expected integer operand type");
-    }
 
     bool isTruncation = operandType.getIntOrFloatBitWidth() >
                         opReturnType.getIntOrFloatBitWidth();
-    bool doUnsigned = needsUnsigned || isTruncation;
+    bool doUnsigned = castToUnsigned || isTruncation;
 
     Type castType = opReturnType;
-    // For int conversions: if the op is a ui variant and the type wanted as
+    // If the op is a ui variant and the type wanted as
     // return type isn't unsigned, we need to issue an unsigned type to do
     // the conversion.
     if (castType.isUnsignedInteger() != doUnsigned) {
@@ -150,7 +148,7 @@ class CastConversion : public OpConversionPattern<ArithOp> {
     }
 
     Value actualOp = adaptor.getIn();
-    // Fix the signedness of the operand if necessary
+    // Adapt the signedness of the operand if necessary
     if (operandType.isUnsignedInteger() != doUnsigned) {
       Type correctSignednessType =
           rewriter.getIntegerType(operandType.getIntOrFloatBitWidth(),
@@ -162,8 +160,7 @@ class CastConversion : public OpConversionPattern<ArithOp> {
     auto result = rewriter.template create<emitc::CastOp>(op.getLoc(), castType,
                                                           actualOp);
 
-    // Fix the signedness of what this operation returns (for integers,
-    // the arith ops want signless results)
+    // Cast to the expected output type
     if (castType != opReturnType) {
       result = rewriter.template create<emitc::CastOp>(op.getLoc(),
                                                        opReturnType, result);
diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir
index 551c3ba7a77ef..40a06fe9efe72 100644
--- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir
+++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir
@@ -65,7 +65,7 @@ func.func @arith_cast_fptoui_i1(%arg0: f32) -> i1 {
 
 // -----
 
-func.func @index_cast(%arg0: i32) -> i32 {
+func.func @arith_index_cast(%arg0: i32) -> i32 {
   // expected-error @+1 {{failed to legalize operation 'arith.index_cast'}}
   %idx = arith.index_cast %arg0 : i32 to index
   %int = arith.index_cast %idx : index to i32
@@ -75,7 +75,7 @@ func.func @index_cast(%arg0: i32) -> i32 {
 
 // -----
 
-func.func @index_castui(%arg0: i32) -> i32 {
+func.func @arith_index_castui(%arg0: i32) -> i32 {
   // expected-error @+1 {{failed to legalize operation 'arith.index_castui'}}
   %idx = arith.index_castui %arg0 : i32 to index
   %int = arith.index_castui %idx : index to i32
diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
index 80665bacd2a5c..274c12a1bae77 100644
--- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
+++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
@@ -180,8 +180,8 @@ func.func @arith_int_to_float_cast_ops(%arg0: i8, %arg1: i64) {
 
 // -----
 
-func.func @trunci(%arg0: i32) -> i8 {
-  // CHECK-LABEL: trunci
+func.func @arith_trunci(%arg0: i32) -> i8 {
+  // CHECK-LABEL: arith_trunci
   // CHECK-SAME: (%[[Arg0:[^ ]*]]: i32)
   // CHECK: %[[CastUI:.*]] = emitc.cast %[[Arg0]] : i32 to ui32
   // CHECK: %[[Trunc:.*]] = emitc.cast %[[CastUI]] : ui32 to ui8
@@ -193,11 +193,10 @@ func.func @trunci(%arg0: i32) -> i8 {
 
 // -----
 
-func.func @extsi(%arg0: i32) {
-  // CHECK-LABEL: extsi
+func.func @arith_extsi(%arg0: i32) {
+  // CHECK-LABEL: arith_extsi
   // CHECK-SAME: ([[Arg0:[^ ]*]]: i32)
   // CHECK: emitc.cast [[Arg0]] : i32 to i64
-
   %extd = arith.extsi %arg0 : i32 to i64
 
   return
@@ -205,13 +204,12 @@ func.func @extsi(%arg0: i32) {
 
 // -----
 
-func.func @extui(%arg0: i32) {
-  // CHECK-LABEL: extui
+func.func @arith_extui(%arg0: i32) {
+  // CHECK-LABEL: arith_extui
   // CHECK-SAME: (%[[Arg0:[^ ]*]]: i32)
   // CHECK: %[[Conv0:.*]] = emitc.cast %[[Arg0]] : i32 to ui32
   // CHECK: %[[Conv1:.*]] = emitc.cast %[[Conv0]] : ui32 to ui64
   // CHECK: emitc.cast %[[Conv1]] : ui64 to i64
-
   %extd = arith.extui %arg0 : i32 to i64
 
   return

>From b413a9c6986cb80d7c66500f1f72797103ea8ccb Mon Sep 17 00:00:00 2001
From: Corentin Ferry <corentin.ferry at amd.com>
Date: Wed, 22 May 2024 13:58:16 +0100
Subject: [PATCH 3/4] Review comments, handle i1

---
 .../Conversion/ArithToEmitC/ArithToEmitC.cpp  | 18 +++++++++++++
 .../arith-to-emitc-unsupported.mlir           | 20 +++-----------
 .../ArithToEmitC/arith-to-emitc.mlir          | 26 +++++++++++++++++++
 3 files changed, 48 insertions(+), 16 deletions(-)

diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
index 60562d48726f5..496c197ce983f 100644
--- a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
+++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
@@ -15,6 +15,7 @@
 
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/EmitC/IR/EmitC.h"
+#include "mlir/Tools/PDLL/AST/Types.h"
 #include "mlir/Transforms/DialectConversion.h"
 
 using namespace mlir;
@@ -134,6 +135,23 @@ class CastConversion : public OpConversionPattern<ArithOp> {
     if (!isa_and_nonnull<IntegerType>(operandType))
       return rewriter.notifyMatchFailure(op, "expected integer operand type");
 
+    // Signed (sign-extending) casts from i1 are not supported.
+    if(operandType.isInteger(1) && !castToUnsigned)
+      return rewriter.notifyMatchFailure(op, "operation not supported on i1 type");
+
+    // to-i1 conversions: arith semantics want truncation, whereas (bool)(v) is
+    // equivalent to (v != 0). Implementing as (bool)(v & 0x01) gives
+    // truncation.
+    if (opReturnType.isInteger(1)) {
+      auto constOne = rewriter.create<emitc::ConstantOp>(
+          op.getLoc(), operandType, rewriter.getIntegerAttr(operandType, 1));
+      auto oneAndOperand = rewriter.create<emitc::BitwiseAndOp>(
+          op.getLoc(), operandType, adaptor.getIn(), constOne);
+      rewriter.replaceOpWithNewOp<emitc::CastOp>(op, opReturnType,
+                                                 oneAndOperand);
+      return success();
+    }
+
     bool isTruncation = operandType.getIntOrFloatBitWidth() >
                         opReturnType.getIntOrFloatBitWidth();
     bool doUnsigned = castToUnsigned || isTruncation;
diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir
index 40a06fe9efe72..97e4593f97b90 100644
--- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir
+++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir
@@ -65,20 +65,8 @@ func.func @arith_cast_fptoui_i1(%arg0: f32) -> i1 {
 
 // -----
 
-func.func @arith_index_cast(%arg0: i32) -> i32 {
-  // expected-error @+1 {{failed to legalize operation 'arith.index_cast'}}
-  %idx = arith.index_cast %arg0 : i32 to index
-  %int = arith.index_cast %idx : index to i32
-
-  return %int : i32
-}
-
-// -----
-
-func.func @arith_index_castui(%arg0: i32) -> i32 {
-  // expected-error @+1 {{failed to legalize operation 'arith.index_castui'}}
-  %idx = arith.index_castui %arg0 : i32 to index
-  %int = arith.index_castui %idx : index to i32
-
-  return %int : i32
+func.func @arith_extsi_i1_to_i32(%arg0: i1) {
+  // expected-error @+1 {{failed to legalize operation 'arith.extsi'}}
+  %idx = arith.extsi %arg0 : i1 to i32
+  return
 }
diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
index 274c12a1bae77..bedaee8c3be11 100644
--- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
+++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
@@ -193,6 +193,20 @@ func.func @arith_trunci(%arg0: i32) -> i8 {
 
 // -----
 
+func.func @arith_trunci_to_i1(%arg0: i32) -> i1 {
+  // CHECK-LABEL: arith_trunci_to_i1
+  // CHECK-SAME: (%[[Arg0:[^ ]*]]: i32)
+  // CHECK: %[[Const:.*]] = "emitc.constant"
+  // CHECK-SAME: value = 1
+  // CHECK: %[[And:.*]] = emitc.bitwise_and %[[Arg0]], %[[Const]] : (i32, i32) -> i32
+  // CHECK: emitc.cast %[[And]] : i32 to i1
+  %truncd = arith.trunci %arg0 : i32 to i1
+
+  return %truncd : i1
+}
+
+// -----
+
 func.func @arith_extsi(%arg0: i32) {
   // CHECK-LABEL: arith_extsi
   // CHECK-SAME: ([[Arg0:[^ ]*]]: i32)
@@ -214,3 +228,15 @@ func.func @arith_extui(%arg0: i32) {
 
   return
 }
+
+// -----
+
+func.func @arith_extui_i1_to_i32(%arg0: i1) {
+  // CHECK-LABEL: arith_extui_i1_to_i32
+  // CHECK-SAME: (%[[Arg0:[^ ]*]]: i1)
+  // CHECK: %[[Conv0:.*]] = emitc.cast %[[Arg0]] : i1 to ui1
+  // CHECK: %[[Conv1:.*]] = emitc.cast %[[Conv0]] : ui1 to ui32
+  // CHECK: emitc.cast %[[Conv1]] : ui32 to i32
+  %idx = arith.extui %arg0 : i1 to i32
+  return
+}
\ No newline at end of file

>From ccb8b03f839493f147ed40699665e7539e28a387 Mon Sep 17 00:00:00 2001
From: Corentin Ferry <corentin.ferry at amd.com>
Date: Wed, 22 May 2024 14:48:26 +0100
Subject: [PATCH 4/4] Formatting issues

---
 mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp     | 5 +++--
 mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir | 2 +-
 2 files changed, 4 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
index 496c197ce983f..0be3d76f556de 100644
--- a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
+++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
@@ -136,8 +136,9 @@ class CastConversion : public OpConversionPattern<ArithOp> {
       return rewriter.notifyMatchFailure(op, "expected integer operand type");
 
     // Signed (sign-extending) casts from i1 are not supported.
-    if(operandType.isInteger(1) && !castToUnsigned)
-      return rewriter.notifyMatchFailure(op, "operation not supported on i1 type");
+    if (operandType.isInteger(1) && !castToUnsigned)
+      return rewriter.notifyMatchFailure(op,
+                                         "operation not supported on i1 type");
 
     // to-i1 conversions: arith semantics want truncation, whereas (bool)(v) is
     // equivalent to (v != 0). Implementing as (bool)(v & 0x01) gives
diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
index bedaee8c3be11..b453b69a214e8 100644
--- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
+++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
@@ -239,4 +239,4 @@ func.func @arith_extui_i1_to_i32(%arg0: i1) {
   // CHECK: emitc.cast %[[Conv1]] : ui32 to i32
   %idx = arith.extui %arg0 : i1 to i32
   return
-}
\ No newline at end of file
+}



More information about the Mlir-commits mailing list