[Mlir-commits] [mlir] 7630379 - [mlir][emitc] Add EmitC lowering for arith.trunci, arith.extsi, arith.extui
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed May 22 07:33:42 PDT 2024
Author: Corentin Ferry
Date: 2024-05-22T16:33:37+02:00
New Revision: 7630379156ec08c9d7b1ea3c03c09e7dc89ef4ee
URL: https://github.com/llvm/llvm-project/commit/7630379156ec08c9d7b1ea3c03c09e7dc89ef4ee
DIFF: https://github.com/llvm/llvm-project/commit/7630379156ec08c9d7b1ea3c03c09e7dc89ef4ee.diff
LOG: [mlir][emitc] Add EmitC lowering for arith.trunci, arith.extsi, arith.extui
This commit adds conversion to EmitC for arith dialect casts between integer types (trunc, extsi, extui), excluding indexes for now.
Added:
Modified:
mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir
mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
index 1447b182ccfdb..0be3d76f556de 100644
--- a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
+++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
@@ -15,6 +15,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/EmitC/IR/EmitC.h"
+#include "mlir/Tools/PDLL/AST/Types.h"
#include "mlir/Transforms/DialectConversion.h"
using namespace mlir;
@@ -112,6 +113,93 @@ class CmpIOpConversion : public OpConversionPattern<arith::CmpIOp> {
}
};
+template <typename ArithOp, bool castToUnsigned>
+class CastConversion : public OpConversionPattern<ArithOp> {
+public:
+ using OpConversionPattern<ArithOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+
+ Type opReturnType = this->getTypeConverter()->convertType(op.getType());
+ if (!isa_and_nonnull<IntegerType>(opReturnType))
+ return rewriter.notifyMatchFailure(op, "expected integer result type");
+
+ if (adaptor.getOperands().size() != 1) {
+ return rewriter.notifyMatchFailure(
+ op, "CastConversion only supports unary ops");
+ }
+
+ Type operandType = adaptor.getIn().getType();
+ if (!isa_and_nonnull<IntegerType>(operandType))
+ return rewriter.notifyMatchFailure(op, "expected integer operand type");
+
+ // Signed (sign-extending) casts from i1 are not supported.
+ if (operandType.isInteger(1) && !castToUnsigned)
+ return rewriter.notifyMatchFailure(op,
+ "operation not supported on i1 type");
+
+ // to-i1 conversions: arith semantics want truncation, whereas (bool)(v) is
+ // equivalent to (v != 0). Implementing as (bool)(v & 0x01) gives
+ // truncation.
+ if (opReturnType.isInteger(1)) {
+ auto constOne = rewriter.create<emitc::ConstantOp>(
+ op.getLoc(), operandType, rewriter.getIntegerAttr(operandType, 1));
+ auto oneAndOperand = rewriter.create<emitc::BitwiseAndOp>(
+ op.getLoc(), operandType, adaptor.getIn(), constOne);
+ rewriter.replaceOpWithNewOp<emitc::CastOp>(op, opReturnType,
+ oneAndOperand);
+ return success();
+ }
+
+ bool isTruncation = operandType.getIntOrFloatBitWidth() >
+ opReturnType.getIntOrFloatBitWidth();
+ bool doUnsigned = castToUnsigned || isTruncation;
+
+ Type castType = opReturnType;
+ // If the op is a ui variant and the type wanted as
+ // return type isn't unsigned, we need to issue an unsigned type to do
+ // the conversion.
+ if (castType.isUnsignedInteger() != doUnsigned) {
+ castType = rewriter.getIntegerType(opReturnType.getIntOrFloatBitWidth(),
+ /*isSigned=*/!doUnsigned);
+ }
+
+ Value actualOp = adaptor.getIn();
+ // Adapt the signedness of the operand if necessary
+ if (operandType.isUnsignedInteger() != doUnsigned) {
+ Type correctSignednessType =
+ rewriter.getIntegerType(operandType.getIntOrFloatBitWidth(),
+ /*isSigned=*/!doUnsigned);
+ actualOp = rewriter.template create<emitc::CastOp>(
+ op.getLoc(), correctSignednessType, actualOp);
+ }
+
+ auto result = rewriter.template create<emitc::CastOp>(op.getLoc(), castType,
+ actualOp);
+
+ // Cast to the expected output type
+ if (castType != opReturnType) {
+ result = rewriter.template create<emitc::CastOp>(op.getLoc(),
+ opReturnType, result);
+ }
+
+ rewriter.replaceOp(op, result);
+ return success();
+ }
+};
+
+template <typename ArithOp>
+class UnsignedCastConversion : public CastConversion<ArithOp, true> {
+ using CastConversion<ArithOp, true>::CastConversion;
+};
+
+template <typename ArithOp>
+class SignedCastConversion : public CastConversion<ArithOp, false> {
+ using CastConversion<ArithOp, false>::CastConversion;
+};
+
template <typename ArithOp, typename EmitCOp>
class ArithOpConversion final : public OpConversionPattern<ArithOp> {
public:
@@ -313,6 +401,10 @@ void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter,
IntegerOpConversion<arith::SubIOp, emitc::SubOp>,
CmpIOpConversion,
SelectOpConversion,
+ // Truncation is guaranteed for unsigned types.
+ UnsignedCastConversion<arith::TruncIOp>,
+ SignedCastConversion<arith::ExtSIOp>,
+ UnsignedCastConversion<arith::ExtUIOp>,
ItoFCastOpConversion<arith::SIToFPOp>,
ItoFCastOpConversion<arith::UIToFPOp>,
FtoICastOpConversion<arith::FPToSIOp>,
diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir
index 66dfa8fa3e157..97e4593f97b90 100644
--- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir
+++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir
@@ -63,3 +63,10 @@ func.func @arith_cast_fptoui_i1(%arg0: f32) -> i1 {
return %t: i1
}
+// -----
+
+func.func @arith_extsi_i1_to_i32(%arg0: i1) {
+ // expected-error @+1 {{failed to legalize operation 'arith.extsi'}}
+ %idx = arith.extsi %arg0 : i1 to i32
+ return
+}
diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
index 79fecd61494d0..b453b69a214e8 100644
--- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
+++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
@@ -177,3 +177,66 @@ func.func @arith_int_to_float_cast_ops(%arg0: i8, %arg1: i64) {
return
}
+
+// -----
+
+func.func @arith_trunci(%arg0: i32) -> i8 {
+ // CHECK-LABEL: arith_trunci
+ // CHECK-SAME: (%[[Arg0:[^ ]*]]: i32)
+ // CHECK: %[[CastUI:.*]] = emitc.cast %[[Arg0]] : i32 to ui32
+ // CHECK: %[[Trunc:.*]] = emitc.cast %[[CastUI]] : ui32 to ui8
+ // CHECK: emitc.cast %[[Trunc]] : ui8 to i8
+ %truncd = arith.trunci %arg0 : i32 to i8
+
+ return %truncd : i8
+}
+
+// -----
+
+func.func @arith_trunci_to_i1(%arg0: i32) -> i1 {
+ // CHECK-LABEL: arith_trunci_to_i1
+ // CHECK-SAME: (%[[Arg0:[^ ]*]]: i32)
+ // CHECK: %[[Const:.*]] = "emitc.constant"
+ // CHECK-SAME: value = 1
+ // CHECK: %[[And:.*]] = emitc.bitwise_and %[[Arg0]], %[[Const]] : (i32, i32) -> i32
+ // CHECK: emitc.cast %[[And]] : i32 to i1
+ %truncd = arith.trunci %arg0 : i32 to i1
+
+ return %truncd : i1
+}
+
+// -----
+
+func.func @arith_extsi(%arg0: i32) {
+ // CHECK-LABEL: arith_extsi
+ // CHECK-SAME: ([[Arg0:[^ ]*]]: i32)
+ // CHECK: emitc.cast [[Arg0]] : i32 to i64
+ %extd = arith.extsi %arg0 : i32 to i64
+
+ return
+}
+
+// -----
+
+func.func @arith_extui(%arg0: i32) {
+ // CHECK-LABEL: arith_extui
+ // CHECK-SAME: (%[[Arg0:[^ ]*]]: i32)
+ // CHECK: %[[Conv0:.*]] = emitc.cast %[[Arg0]] : i32 to ui32
+ // CHECK: %[[Conv1:.*]] = emitc.cast %[[Conv0]] : ui32 to ui64
+ // CHECK: emitc.cast %[[Conv1]] : ui64 to i64
+ %extd = arith.extui %arg0 : i32 to i64
+
+ return
+}
+
+// -----
+
+func.func @arith_extui_i1_to_i32(%arg0: i1) {
+ // CHECK-LABEL: arith_extui_i1_to_i32
+ // CHECK-SAME: (%[[Arg0:[^ ]*]]: i1)
+ // CHECK: %[[Conv0:.*]] = emitc.cast %[[Arg0]] : i1 to ui1
+ // CHECK: %[[Conv1:.*]] = emitc.cast %[[Conv0]] : ui1 to ui32
+ // CHECK: emitc.cast %[[Conv1]] : ui32 to i32
+ %idx = arith.extui %arg0 : i1 to i32
+ return
+}
More information about the Mlir-commits
mailing list