[Mlir-commits] [mlir] [mlir][emitc] Add EmitC lowering for arith.trunci, arith.extsi, arith.extui (PR #91491)
Corentin Ferry
llvmlistbot at llvm.org
Sun May 12 23:58:53 PDT 2024
https://github.com/cferry-AMD updated https://github.com/llvm/llvm-project/pull/91491
>From 32ab952d6f53f16132a423396caa1c118440d8c1 Mon Sep 17 00:00:00 2001
From: Corentin Ferry <corentin.ferry at amd.com>
Date: Wed, 8 May 2024 14:02:04 +0200
Subject: [PATCH 1/2] Add EmitC lowering for arith.{trunci,extsi,extui}
---
.../Conversion/ArithToEmitC/ArithToEmitC.cpp | 76 +++++++++++++++++++
.../arith-to-emitc-unsupported.mlir | 19 +++++
.../ArithToEmitC/arith-to-emitc.mlir | 39 ++++++++++
3 files changed, 134 insertions(+)
diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
index 1447b182ccfdb..6216e6ea89b9b 100644
--- a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
+++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
@@ -112,6 +112,78 @@ class CmpIOpConversion : public OpConversionPattern<arith::CmpIOp> {
}
};
+template <typename ArithOp, bool needsUnsigned>
+class CastConversion : public OpConversionPattern<ArithOp> {
+public:
+ using OpConversionPattern<ArithOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+
+ Type opReturnType = this->getTypeConverter()->convertType(op.getType());
+ if (!isa_and_nonnull<IntegerType>(opReturnType)) {
+ return rewriter.notifyMatchFailure(op, "expected integer result type");
+ }
+
+ if (adaptor.getOperands().size() != 1) {
+ return rewriter.notifyMatchFailure(
+ op, "CastConversion only supports unary ops");
+ }
+
+ Type operandType = adaptor.getIn().getType();
+ if (!isa_and_nonnull<IntegerType>(operandType)) {
+ return rewriter.notifyMatchFailure(op, "expected integer operand type");
+ }
+
+ bool isTruncation = operandType.getIntOrFloatBitWidth() >
+ opReturnType.getIntOrFloatBitWidth();
+ bool doUnsigned = needsUnsigned || isTruncation;
+
+ Type castType = opReturnType;
+ // For int conversions: if the op is a ui variant and the type wanted as
+ // return type isn't unsigned, we need to issue an unsigned type to do
+ // the conversion.
+ if (castType.isUnsignedInteger() != doUnsigned) {
+ castType = rewriter.getIntegerType(opReturnType.getIntOrFloatBitWidth(),
+ /*isSigned=*/!doUnsigned);
+ }
+
+ Value actualOp = adaptor.getIn();
+ // Fix the signedness of the operand if necessary
+ if (operandType.isUnsignedInteger() != doUnsigned) {
+ Type correctSignednessType =
+ rewriter.getIntegerType(operandType.getIntOrFloatBitWidth(),
+ /*isSigned=*/!doUnsigned);
+ actualOp = rewriter.template create<emitc::CastOp>(
+ op.getLoc(), correctSignednessType, actualOp);
+ }
+
+ auto result = rewriter.template create<emitc::CastOp>(op.getLoc(), castType,
+ actualOp);
+
+ // Fix the signedness of what this operation returns (for integers,
+ // the arith ops want signless results)
+ if (castType != opReturnType) {
+ result = rewriter.template create<emitc::CastOp>(op.getLoc(),
+ opReturnType, result);
+ }
+
+ rewriter.replaceOp(op, result);
+ return success();
+ }
+};
+
+template <typename ArithOp>
+class UnsignedCastConversion : public CastConversion<ArithOp, true> {
+ using CastConversion<ArithOp, true>::CastConversion;
+};
+
+template <typename ArithOp>
+class SignedCastConversion : public CastConversion<ArithOp, false> {
+ using CastConversion<ArithOp, false>::CastConversion;
+};
+
template <typename ArithOp, typename EmitCOp>
class ArithOpConversion final : public OpConversionPattern<ArithOp> {
public:
@@ -313,6 +385,10 @@ void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter,
IntegerOpConversion<arith::SubIOp, emitc::SubOp>,
CmpIOpConversion,
SelectOpConversion,
+ // Truncation is guaranteed for unsigned types.
+ UnsignedCastConversion<arith::TruncIOp>,
+ SignedCastConversion<arith::ExtSIOp>,
+ UnsignedCastConversion<arith::ExtUIOp>,
ItoFCastOpConversion<arith::SIToFPOp>,
ItoFCastOpConversion<arith::UIToFPOp>,
FtoICastOpConversion<arith::FPToSIOp>,
diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir
index 66dfa8fa3e157..551c3ba7a77ef 100644
--- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir
+++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir
@@ -63,3 +63,22 @@ func.func @arith_cast_fptoui_i1(%arg0: f32) -> i1 {
return %t: i1
}
+// -----
+
+func.func @index_cast(%arg0: i32) -> i32 {
+ // expected-error @+1 {{failed to legalize operation 'arith.index_cast'}}
+ %idx = arith.index_cast %arg0 : i32 to index
+ %int = arith.index_cast %idx : index to i32
+
+ return %int : i32
+}
+
+// -----
+
+func.func @index_castui(%arg0: i32) -> i32 {
+ // expected-error @+1 {{failed to legalize operation 'arith.index_castui'}}
+ %idx = arith.index_castui %arg0 : i32 to index
+ %int = arith.index_castui %idx : index to i32
+
+ return %int : i32
+}
diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
index 79fecd61494d0..80665bacd2a5c 100644
--- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
+++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
@@ -177,3 +177,42 @@ func.func @arith_int_to_float_cast_ops(%arg0: i8, %arg1: i64) {
return
}
+
+// -----
+
+func.func @trunci(%arg0: i32) -> i8 {
+ // CHECK-LABEL: trunci
+ // CHECK-SAME: (%[[Arg0:[^ ]*]]: i32)
+ // CHECK: %[[CastUI:.*]] = emitc.cast %[[Arg0]] : i32 to ui32
+ // CHECK: %[[Trunc:.*]] = emitc.cast %[[CastUI]] : ui32 to ui8
+ // CHECK: emitc.cast %[[Trunc]] : ui8 to i8
+ %truncd = arith.trunci %arg0 : i32 to i8
+
+ return %truncd : i8
+}
+
+// -----
+
+func.func @extsi(%arg0: i32) {
+ // CHECK-LABEL: extsi
+ // CHECK-SAME: ([[Arg0:[^ ]*]]: i32)
+ // CHECK: emitc.cast [[Arg0]] : i32 to i64
+
+ %extd = arith.extsi %arg0 : i32 to i64
+
+ return
+}
+
+// -----
+
+func.func @extui(%arg0: i32) {
+ // CHECK-LABEL: extui
+ // CHECK-SAME: (%[[Arg0:[^ ]*]]: i32)
+ // CHECK: %[[Conv0:.*]] = emitc.cast %[[Arg0]] : i32 to ui32
+ // CHECK: %[[Conv1:.*]] = emitc.cast %[[Conv0]] : ui32 to ui64
+ // CHECK: emitc.cast %[[Conv1]] : ui64 to i64
+
+ %extd = arith.extui %arg0 : i32 to i64
+
+ return
+}
>From 1259c29b364d732ed5c526f504ec0bc84ec21760 Mon Sep 17 00:00:00 2001
From: Corentin Ferry <corentin.ferry at amd.com>
Date: Mon, 13 May 2024 07:38:04 +0100
Subject: [PATCH 2/2] Review comments
---
.../Conversion/ArithToEmitC/ArithToEmitC.cpp | 17 +++++++----------
.../arith-to-emitc-unsupported.mlir | 4 ++--
.../Conversion/ArithToEmitC/arith-to-emitc.mlir | 14 ++++++--------
3 files changed, 15 insertions(+), 20 deletions(-)
diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
index 6216e6ea89b9b..60562d48726f5 100644
--- a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
+++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
@@ -112,7 +112,7 @@ class CmpIOpConversion : public OpConversionPattern<arith::CmpIOp> {
}
};
-template <typename ArithOp, bool needsUnsigned>
+template <typename ArithOp, bool castToUnsigned>
class CastConversion : public OpConversionPattern<ArithOp> {
public:
using OpConversionPattern<ArithOp>::OpConversionPattern;
@@ -122,9 +122,8 @@ class CastConversion : public OpConversionPattern<ArithOp> {
ConversionPatternRewriter &rewriter) const override {
Type opReturnType = this->getTypeConverter()->convertType(op.getType());
- if (!isa_and_nonnull<IntegerType>(opReturnType)) {
+ if (!isa_and_nonnull<IntegerType>(opReturnType))
return rewriter.notifyMatchFailure(op, "expected integer result type");
- }
if (adaptor.getOperands().size() != 1) {
return rewriter.notifyMatchFailure(
@@ -132,16 +131,15 @@ class CastConversion : public OpConversionPattern<ArithOp> {
}
Type operandType = adaptor.getIn().getType();
- if (!isa_and_nonnull<IntegerType>(operandType)) {
+ if (!isa_and_nonnull<IntegerType>(operandType))
return rewriter.notifyMatchFailure(op, "expected integer operand type");
- }
bool isTruncation = operandType.getIntOrFloatBitWidth() >
opReturnType.getIntOrFloatBitWidth();
- bool doUnsigned = needsUnsigned || isTruncation;
+ bool doUnsigned = castToUnsigned || isTruncation;
Type castType = opReturnType;
- // For int conversions: if the op is a ui variant and the type wanted as
+ // If the op is a ui variant and the type wanted as
// return type isn't unsigned, we need to issue an unsigned type to do
// the conversion.
if (castType.isUnsignedInteger() != doUnsigned) {
@@ -150,7 +148,7 @@ class CastConversion : public OpConversionPattern<ArithOp> {
}
Value actualOp = adaptor.getIn();
- // Fix the signedness of the operand if necessary
+ // Adapt the signedness of the operand if necessary
if (operandType.isUnsignedInteger() != doUnsigned) {
Type correctSignednessType =
rewriter.getIntegerType(operandType.getIntOrFloatBitWidth(),
@@ -162,8 +160,7 @@ class CastConversion : public OpConversionPattern<ArithOp> {
auto result = rewriter.template create<emitc::CastOp>(op.getLoc(), castType,
actualOp);
- // Fix the signedness of what this operation returns (for integers,
- // the arith ops want signless results)
+ // Cast to the expected output type
if (castType != opReturnType) {
result = rewriter.template create<emitc::CastOp>(op.getLoc(),
opReturnType, result);
diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir
index 551c3ba7a77ef..40a06fe9efe72 100644
--- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir
+++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir
@@ -65,7 +65,7 @@ func.func @arith_cast_fptoui_i1(%arg0: f32) -> i1 {
// -----
-func.func @index_cast(%arg0: i32) -> i32 {
+func.func @arith_index_cast(%arg0: i32) -> i32 {
// expected-error @+1 {{failed to legalize operation 'arith.index_cast'}}
%idx = arith.index_cast %arg0 : i32 to index
%int = arith.index_cast %idx : index to i32
@@ -75,7 +75,7 @@ func.func @index_cast(%arg0: i32) -> i32 {
// -----
-func.func @index_castui(%arg0: i32) -> i32 {
+func.func @arith_index_castui(%arg0: i32) -> i32 {
// expected-error @+1 {{failed to legalize operation 'arith.index_castui'}}
%idx = arith.index_castui %arg0 : i32 to index
%int = arith.index_castui %idx : index to i32
diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
index 80665bacd2a5c..274c12a1bae77 100644
--- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
+++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
@@ -180,8 +180,8 @@ func.func @arith_int_to_float_cast_ops(%arg0: i8, %arg1: i64) {
// -----
-func.func @trunci(%arg0: i32) -> i8 {
- // CHECK-LABEL: trunci
+func.func @arith_trunci(%arg0: i32) -> i8 {
+ // CHECK-LABEL: arith_trunci
// CHECK-SAME: (%[[Arg0:[^ ]*]]: i32)
// CHECK: %[[CastUI:.*]] = emitc.cast %[[Arg0]] : i32 to ui32
// CHECK: %[[Trunc:.*]] = emitc.cast %[[CastUI]] : ui32 to ui8
@@ -193,11 +193,10 @@ func.func @trunci(%arg0: i32) -> i8 {
// -----
-func.func @extsi(%arg0: i32) {
- // CHECK-LABEL: extsi
+func.func @arith_extsi(%arg0: i32) {
+ // CHECK-LABEL: arith_extsi
// CHECK-SAME: ([[Arg0:[^ ]*]]: i32)
// CHECK: emitc.cast [[Arg0]] : i32 to i64
-
%extd = arith.extsi %arg0 : i32 to i64
return
@@ -205,13 +204,12 @@ func.func @extsi(%arg0: i32) {
// -----
-func.func @extui(%arg0: i32) {
- // CHECK-LABEL: extui
+func.func @arith_extui(%arg0: i32) {
+ // CHECK-LABEL: arith_extui
// CHECK-SAME: (%[[Arg0:[^ ]*]]: i32)
// CHECK: %[[Conv0:.*]] = emitc.cast %[[Arg0]] : i32 to ui32
// CHECK: %[[Conv1:.*]] = emitc.cast %[[Conv0]] : ui32 to ui64
// CHECK: emitc.cast %[[Conv1]] : ui64 to i64
-
%extd = arith.extui %arg0 : i32 to i64
return
More information about the Mlir-commits
mailing list