[Mlir-commits] [mlir] [mlir][emitc] Add EmitC lowering for arith.trunci, arith.extsi, arith.extui (PR #91491)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed May 8 08:49:45 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Corentin Ferry (cferry-AMD)

<details>
<summary>Changes</summary>

These operations can be lowered to EmitC provided the sign-extension and truncation behavior is respected.

Per [C++ Reference](https://en.cppreference.com/w/cpp/language/implicit_conversion#Numeric_conversions): when casting to a narrower integer, truncation is guaranteed if unsigned casts are performed, or C++20 is used regardless of the sign. This implementation sticks to unsigned for trunci, so C++20 is not necessary.

This implementation is a bit more generic than needed by these three operations to accomodate `index_cast` and `index_castui` at a later point (specific `emitc.size_t` and `emitc.ssize_t` types are being discussed).

---
Full diff: https://github.com/llvm/llvm-project/pull/91491.diff


3 Files Affected:

- (modified) mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp (+76) 
- (modified) mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir (+19) 
- (modified) mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir (+39) 


``````````diff
diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
index 1447b182ccfdb..6216e6ea89b9b 100644
--- a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
+++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
@@ -112,6 +112,78 @@ class CmpIOpConversion : public OpConversionPattern<arith::CmpIOp> {
   }
 };
 
+template <typename ArithOp, bool needsUnsigned>
+class CastConversion : public OpConversionPattern<ArithOp> {
+public:
+  using OpConversionPattern<ArithOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+
+    Type opReturnType = this->getTypeConverter()->convertType(op.getType());
+    if (!isa_and_nonnull<IntegerType>(opReturnType)) {
+      return rewriter.notifyMatchFailure(op, "expected integer result type");
+    }
+
+    if (adaptor.getOperands().size() != 1) {
+      return rewriter.notifyMatchFailure(
+          op, "CastConversion only supports unary ops");
+    }
+
+    Type operandType = adaptor.getIn().getType();
+    if (!isa_and_nonnull<IntegerType>(operandType)) {
+      return rewriter.notifyMatchFailure(op, "expected integer operand type");
+    }
+
+    bool isTruncation = operandType.getIntOrFloatBitWidth() >
+                        opReturnType.getIntOrFloatBitWidth();
+    bool doUnsigned = needsUnsigned || isTruncation;
+
+    Type castType = opReturnType;
+    // For int conversions: if the op is a ui variant and the type wanted as
+    // return type isn't unsigned, we need to issue an unsigned type to do
+    // the conversion.
+    if (castType.isUnsignedInteger() != doUnsigned) {
+      castType = rewriter.getIntegerType(opReturnType.getIntOrFloatBitWidth(),
+                                         /*isSigned=*/!doUnsigned);
+    }
+
+    Value actualOp = adaptor.getIn();
+    // Fix the signedness of the operand if necessary
+    if (operandType.isUnsignedInteger() != doUnsigned) {
+      Type correctSignednessType =
+          rewriter.getIntegerType(operandType.getIntOrFloatBitWidth(),
+                                  /*isSigned=*/!doUnsigned);
+      actualOp = rewriter.template create<emitc::CastOp>(
+          op.getLoc(), correctSignednessType, actualOp);
+    }
+
+    auto result = rewriter.template create<emitc::CastOp>(op.getLoc(), castType,
+                                                          actualOp);
+
+    // Fix the signedness of what this operation returns (for integers,
+    // the arith ops want signless results)
+    if (castType != opReturnType) {
+      result = rewriter.template create<emitc::CastOp>(op.getLoc(),
+                                                       opReturnType, result);
+    }
+
+    rewriter.replaceOp(op, result);
+    return success();
+  }
+};
+
+template <typename ArithOp>
+class UnsignedCastConversion : public CastConversion<ArithOp, true> {
+  using CastConversion<ArithOp, true>::CastConversion;
+};
+
+template <typename ArithOp>
+class SignedCastConversion : public CastConversion<ArithOp, false> {
+  using CastConversion<ArithOp, false>::CastConversion;
+};
+
 template <typename ArithOp, typename EmitCOp>
 class ArithOpConversion final : public OpConversionPattern<ArithOp> {
 public:
@@ -313,6 +385,10 @@ void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter,
     IntegerOpConversion<arith::SubIOp, emitc::SubOp>,
     CmpIOpConversion,
     SelectOpConversion,
+    // Truncation is guaranteed for unsigned types.
+    UnsignedCastConversion<arith::TruncIOp>,
+    SignedCastConversion<arith::ExtSIOp>,
+    UnsignedCastConversion<arith::ExtUIOp>,
     ItoFCastOpConversion<arith::SIToFPOp>,
     ItoFCastOpConversion<arith::UIToFPOp>,
     FtoICastOpConversion<arith::FPToSIOp>,
diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir
index 66dfa8fa3e157..551c3ba7a77ef 100644
--- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir
+++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir
@@ -63,3 +63,22 @@ func.func @arith_cast_fptoui_i1(%arg0: f32) -> i1 {
   return %t: i1
 }
 
+// -----
+
+func.func @index_cast(%arg0: i32) -> i32 {
+  // expected-error @+1 {{failed to legalize operation 'arith.index_cast'}}
+  %idx = arith.index_cast %arg0 : i32 to index
+  %int = arith.index_cast %idx : index to i32
+
+  return %int : i32
+}
+
+// -----
+
+func.func @index_castui(%arg0: i32) -> i32 {
+  // expected-error @+1 {{failed to legalize operation 'arith.index_castui'}}
+  %idx = arith.index_castui %arg0 : i32 to index
+  %int = arith.index_castui %idx : index to i32
+
+  return %int : i32
+}
diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
index 79fecd61494d0..80665bacd2a5c 100644
--- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
+++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
@@ -177,3 +177,42 @@ func.func @arith_int_to_float_cast_ops(%arg0: i8, %arg1: i64) {
 
   return
 }
+
+// -----
+
+func.func @trunci(%arg0: i32) -> i8 {
+  // CHECK-LABEL: trunci
+  // CHECK-SAME: (%[[Arg0:[^ ]*]]: i32)
+  // CHECK: %[[CastUI:.*]] = emitc.cast %[[Arg0]] : i32 to ui32
+  // CHECK: %[[Trunc:.*]] = emitc.cast %[[CastUI]] : ui32 to ui8
+  // CHECK: emitc.cast %[[Trunc]] : ui8 to i8
+  %truncd = arith.trunci %arg0 : i32 to i8
+
+  return %truncd : i8
+}
+
+// -----
+
+func.func @extsi(%arg0: i32) {
+  // CHECK-LABEL: extsi
+  // CHECK-SAME: ([[Arg0:[^ ]*]]: i32)
+  // CHECK: emitc.cast [[Arg0]] : i32 to i64
+
+  %extd = arith.extsi %arg0 : i32 to i64
+
+  return
+}
+
+// -----
+
+func.func @extui(%arg0: i32) {
+  // CHECK-LABEL: extui
+  // CHECK-SAME: (%[[Arg0:[^ ]*]]: i32)
+  // CHECK: %[[Conv0:.*]] = emitc.cast %[[Arg0]] : i32 to ui32
+  // CHECK: %[[Conv1:.*]] = emitc.cast %[[Conv0]] : ui32 to ui64
+  // CHECK: emitc.cast %[[Conv1]] : ui64 to i64
+
+  %extd = arith.extui %arg0 : i32 to i64
+
+  return
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/91491


More information about the Mlir-commits mailing list