[Mlir-commits] [mlir] [mlir][emitc] Arith to EmitC: handle floating-point<->integer conversions (PR #87614)
Corentin Ferry
llvmlistbot at llvm.org
Mon Apr 15 06:49:50 PDT 2024
https://github.com/cferry-AMD updated https://github.com/llvm/llvm-project/pull/87614
>From 133a8ba87f3e3fa25e60542714334df935dba585 Mon Sep 17 00:00:00 2001
From: Corentin Ferry <corentin.ferry at amd.com>
Date: Wed, 27 Mar 2024 08:58:15 +0000
Subject: [PATCH 1/2] [mlir][emitc] Arith to EmitC: handle FP<->Integer
conversions
---
.../Conversion/ArithToEmitC/ArithToEmitC.h | 3 +-
mlir/include/mlir/Conversion/Passes.td | 17 ++++
.../Conversion/ArithToEmitC/ArithToEmitC.cpp | 85 ++++++++++++++++++-
.../ArithToEmitC/ArithToEmitCPass.cpp | 4 +-
.../arith-to-emitc-cast-truncate.mlir | 20 +++++
.../arith-to-emitc-cast-unsupported.mlir | 48 +++++++++++
.../arith-to-emitc-unsupported.mlir | 7 ++
.../ArithToEmitC/arith-to-emitc.mlir | 15 ++++
8 files changed, 194 insertions(+), 5 deletions(-)
create mode 100644 mlir/test/Conversion/ArithToEmitC/arith-to-emitc-cast-truncate.mlir
create mode 100644 mlir/test/Conversion/ArithToEmitC/arith-to-emitc-cast-unsupported.mlir
create mode 100644 mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir
diff --git a/mlir/include/mlir/Conversion/ArithToEmitC/ArithToEmitC.h b/mlir/include/mlir/Conversion/ArithToEmitC/ArithToEmitC.h
index 9cb43689d1ce64..32d039e9c89185 100644
--- a/mlir/include/mlir/Conversion/ArithToEmitC/ArithToEmitC.h
+++ b/mlir/include/mlir/Conversion/ArithToEmitC/ArithToEmitC.h
@@ -14,7 +14,8 @@ class RewritePatternSet;
class TypeConverter;
void populateArithToEmitCPatterns(TypeConverter &typeConverter,
- RewritePatternSet &patterns);
+ RewritePatternSet &patterns,
+ bool optionFloatToIntTruncates);
} // namespace mlir
#endif // MLIR_CONVERSION_ARITHTOEMITC_ARITHTOEMITC_H
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index d094ee3b36ab95..029cbd7aec2819 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -139,7 +139,24 @@ def ArithToAMDGPUConversionPass : Pass<"convert-arith-to-amdgpu"> {
def ConvertArithToEmitC : Pass<"convert-arith-to-emitc"> {
let summary = "Convert Arith dialect to EmitC dialect";
+ let description = [{
+ This pass converts `arith` dialect operations to `emitc`.
+
+ The semantics of floating-point to integer conversions `arith.fptosi`,
+ `arith.fptoui` require rounding towards zero. Typical C++ implementations
+ use this behavior for float-to-integer casts, but that is not mandated by
+ C++ and there are implementation-defined means to change the default behavior.
+
+ If casts can be guaranteed to use round-to-zero, use the
+ `float-to-int-truncates` flag to allow conversion of `arith.fptosi` and
+ `arith.fptoui` operations.
+ }];
let dependentDialects = ["emitc::EmitCDialect"];
+ let options = [
+ Option<"floatToIntTruncates", "float-to-int-truncates", "bool",
+ /*default=*/"false",
+ "Whether the behavior of float-to-int cast in emitc is truncation">,
+ ];
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
index db493c1294ba2d..311978ea6c40e0 100644
--- a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
+++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
@@ -128,6 +128,78 @@ class SelectOpConversion : public OpConversionPattern<arith::SelectOp> {
}
};
+// Floating-point to integer conversions.
+template <typename CastOp>
+class FtoICastOpConversion : public OpConversionPattern<CastOp> {
+private:
+ bool floatToIntTruncates;
+
+public:
+ FtoICastOpConversion(const TypeConverter &typeConverter, MLIRContext *context,
+ bool optionFloatToIntTruncates)
+ : OpConversionPattern<CastOp>(typeConverter, context),
+ floatToIntTruncates(optionFloatToIntTruncates) {}
+
+ LogicalResult
+ matchAndRewrite(CastOp castOp, typename CastOp::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+
+ Type operandType = adaptor.getIn().getType();
+ if (!emitc::isSupportedFloatType(operandType))
+ return rewriter.notifyMatchFailure(castOp,
+ "unsupported cast source type");
+
+ if (!floatToIntTruncates)
+ return rewriter.notifyMatchFailure(
+ castOp, "conversion currently requires EmitC casts to use truncation "
+ "as rounding mode");
+
+ Type dstType = this->getTypeConverter()->convertType(castOp.getType());
+ if (!dstType)
+ return rewriter.notifyMatchFailure(castOp, "type conversion failed");
+
+ if (!emitc::isSupportedIntegerType(dstType))
+ return rewriter.notifyMatchFailure(castOp,
+ "unsupported cast destination type");
+
+ rewriter.replaceOpWithNewOp<emitc::CastOp>(castOp, dstType,
+ adaptor.getOperands());
+
+ return success();
+ }
+};
+
+// Integer to floating-point conversions.
+template <typename CastOp>
+class ItoFCastOpConversion : public OpConversionPattern<CastOp> {
+public:
+ ItoFCastOpConversion(const TypeConverter &typeConverter, MLIRContext *context)
+ : OpConversionPattern<CastOp>(typeConverter, context) {}
+
+ LogicalResult
+ matchAndRewrite(CastOp castOp, typename CastOp::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+
+ Type operandType = adaptor.getIn().getType();
+ if (!emitc::isSupportedIntegerType(operandType))
+ return rewriter.notifyMatchFailure(castOp,
+ "unsupported cast source type");
+
+ Type dstType = this->getTypeConverter()->convertType(castOp.getType());
+ if (!dstType)
+ return rewriter.notifyMatchFailure(castOp, "type conversion failed");
+
+ if (!emitc::isSupportedFloatType(dstType))
+ return rewriter.notifyMatchFailure(castOp,
+ "unsupported cast destination type");
+
+ rewriter.replaceOpWithNewOp<emitc::CastOp>(castOp, dstType,
+ adaptor.getOperands());
+
+ return success();
+ }
+};
+
} // namespace
//===----------------------------------------------------------------------===//
@@ -135,7 +207,8 @@ class SelectOpConversion : public OpConversionPattern<arith::SelectOp> {
//===----------------------------------------------------------------------===//
void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter,
- RewritePatternSet &patterns) {
+ RewritePatternSet &patterns,
+ bool optionFloatToIntTruncates) {
MLIRContext *ctx = patterns.getContext();
// clang-format off
@@ -148,7 +221,13 @@ void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter,
IntegerOpConversion<arith::AddIOp, emitc::AddOp>,
IntegerOpConversion<arith::MulIOp, emitc::MulOp>,
IntegerOpConversion<arith::SubIOp, emitc::SubOp>,
- SelectOpConversion
- >(typeConverter, ctx);
+ SelectOpConversion,
+ ItoFCastOpConversion<arith::SIToFPOp>,
+ ItoFCastOpConversion<arith::UIToFPOp>
+ >(typeConverter, ctx)
+ .add<
+ FtoICastOpConversion<arith::FPToSIOp>,
+ FtoICastOpConversion<arith::FPToUIOp>
+ >(typeConverter, ctx, optionFloatToIntTruncates);
// clang-format on
}
diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitCPass.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitCPass.cpp
index 45a088ed144f17..546bbfe2082eff 100644
--- a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitCPass.cpp
+++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitCPass.cpp
@@ -29,6 +29,8 @@ using namespace mlir;
namespace {
struct ConvertArithToEmitC
: public impl::ConvertArithToEmitCBase<ConvertArithToEmitC> {
+ using Base::Base;
+
void runOnOperation() override;
};
} // namespace
@@ -44,7 +46,7 @@ void ConvertArithToEmitC::runOnOperation() {
TypeConverter typeConverter;
typeConverter.addConversion([](Type type) { return type; });
- populateArithToEmitCPatterns(typeConverter, patterns);
+ populateArithToEmitCPatterns(typeConverter, patterns, floatToIntTruncates);
if (failed(
applyPartialConversion(getOperation(), target, std::move(patterns))))
diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-cast-truncate.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-cast-truncate.mlir
new file mode 100644
index 00000000000000..f45b6306b0292b
--- /dev/null
+++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-cast-truncate.mlir
@@ -0,0 +1,20 @@
+// RUN: mlir-opt -split-input-file --pass-pipeline="builtin.module(convert-arith-to-emitc{float-to-int-truncates})" %s | FileCheck %s
+
+func.func @arith_float_to_int_cast_ops(%arg0: f32, %arg1: f64) {
+ // CHECK: emitc.cast %arg0 : f32 to i32
+ %0 = arith.fptosi %arg0 : f32 to i32
+
+ // CHECK: emitc.cast %arg1 : f64 to i32
+ %1 = arith.fptosi %arg1 : f64 to i32
+
+ // CHECK: emitc.cast %arg0 : f32 to i16
+ %2 = arith.fptosi %arg0 : f32 to i16
+
+ // CHECK: emitc.cast %arg1 : f64 to i16
+ %3 = arith.fptosi %arg1 : f64 to i16
+
+ // CHECK: emitc.cast %arg0 : f32 to i32
+ %4 = arith.fptoui %arg0 : f32 to i32
+
+ return
+}
diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-cast-unsupported.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-cast-unsupported.mlir
new file mode 100644
index 00000000000000..34fc9f3dffc0c8
--- /dev/null
+++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-cast-unsupported.mlir
@@ -0,0 +1,48 @@
+// RUN: mlir-opt -split-input-file --pass-pipeline="builtin.module(convert-arith-to-emitc{float-to-int-truncates})" -verify-diagnostics %s
+
+func.func @arith_cast_tensor(%arg0: tensor<5xf32>) -> tensor<5xi32> {
+ // expected-error @+1 {{failed to legalize operation 'arith.fptosi'}}
+ %t = arith.fptosi %arg0 : tensor<5xf32> to tensor<5xi32>
+ return %t: tensor<5xi32>
+}
+
+// -----
+
+func.func @arith_cast_vector(%arg0: vector<5xf32>) -> vector<5xi32> {
+ // expected-error @+1 {{failed to legalize operation 'arith.fptosi'}}
+ %t = arith.fptosi %arg0 : vector<5xf32> to vector<5xi32>
+ return %t: vector<5xi32>
+}
+
+// -----
+
+func.func @arith_cast_bf16(%arg0: bf16) -> i32 {
+ // expected-error @+1 {{failed to legalize operation 'arith.fptosi'}}
+ %t = arith.fptosi %arg0 : bf16 to i32
+ return %t: i32
+}
+
+// -----
+
+func.func @arith_cast_f16(%arg0: f16) -> i32 {
+ // expected-error @+1 {{failed to legalize operation 'arith.fptosi'}}
+ %t = arith.fptosi %arg0 : f16 to i32
+ return %t: i32
+}
+
+
+// -----
+
+func.func @arith_cast_to_bf16(%arg0: i32) -> bf16 {
+ // expected-error @+1 {{failed to legalize operation 'arith.sitofp'}}
+ %t = arith.sitofp %arg0 : i32 to bf16
+ return %t: bf16
+}
+
+// -----
+
+func.func @arith_cast_to_f16(%arg0: i32) -> f16 {
+ // expected-error @+1 {{failed to legalize operation 'arith.sitofp'}}
+ %t = arith.sitofp %arg0 : i32 to f16
+ return %t: f16
+}
diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir
new file mode 100644
index 00000000000000..bbec664100564b
--- /dev/null
+++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir
@@ -0,0 +1,7 @@
+// RUN: mlir-opt -split-input-file -convert-arith-to-emitc -verify-diagnostics %s
+
+func.func @arith_cast_f32(%arg0: f32) -> i32 {
+ // expected-error @+1 {{failed to legalize operation 'arith.fptosi'}}
+ %t = arith.fptosi %arg0 : f32 to i32
+ return %t: i32
+}
diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
index 76ba518577ab8e..406aa254ecfee1 100644
--- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
+++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
@@ -93,3 +93,18 @@ func.func @arith_select(%arg0: i1, %arg1: tensor<8xi32>, %arg2: tensor<8xi32>) -
%0 = arith.select %arg0, %arg1, %arg2 : i1, tensor<8xi32>
return
}
+
+// -----
+
+func.func @arith_int_to_float_cast_ops(%arg0: i8, %arg1: i64) {
+ // CHECK: emitc.cast %arg0 : i8 to f32
+ %0 = arith.sitofp %arg0 : i8 to f32
+
+ // CHECK: emitc.cast %arg1 : i64 to f32
+ %1 = arith.sitofp %arg1 : i64 to f32
+
+ // CHECK: emitc.cast %arg0 : i8 to f32
+ %2 = arith.uitofp %arg0 : i8 to f32
+
+ return
+}
>From db3765b7252ef606f73e1ccac52ed101f4961741 Mon Sep 17 00:00:00 2001
From: Corentin Ferry <corentin.ferry at amd.com>
Date: Mon, 15 Apr 2024 09:28:03 +0200
Subject: [PATCH 2/2] Merge pull request #160 from Xilinx/corentin.fix_itofp
[FXML-4281] Fix signedness behavior of unsigned integer <-> floating-point conversions
---
.../Conversion/ArithToEmitC/ArithToEmitC.cpp | 36 ++++++++++++++++---
.../arith-to-emitc-cast-truncate.mlir | 3 +-
.../ArithToEmitC/arith-to-emitc.mlir | 3 +-
3 files changed, 35 insertions(+), 7 deletions(-)
diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
index 311978ea6c40e0..dee110dbd79323 100644
--- a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
+++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
@@ -162,8 +162,22 @@ class FtoICastOpConversion : public OpConversionPattern<CastOp> {
return rewriter.notifyMatchFailure(castOp,
"unsupported cast destination type");
- rewriter.replaceOpWithNewOp<emitc::CastOp>(castOp, dstType,
- adaptor.getOperands());
+ // Convert to unsigned if it's the "ui" variant
+ // Signless is interpreted as signed, so no need to cast for "si"
+ Type actualResultType = dstType;
+ if (isa<arith::FPToUIOp>(castOp)) {
+ actualResultType =
+ rewriter.getIntegerType(operandType.getIntOrFloatBitWidth(),
+ /*isSigned=*/false);
+ }
+
+ Value result = rewriter.create<emitc::CastOp>(
+ castOp.getLoc(), actualResultType, adaptor.getOperands());
+
+ if (isa<arith::FPToUIOp>(castOp)) {
+ result = rewriter.create<emitc::CastOp>(castOp.getLoc(), dstType, result);
+ }
+ rewriter.replaceOp(castOp, result);
return success();
}
@@ -179,7 +193,7 @@ class ItoFCastOpConversion : public OpConversionPattern<CastOp> {
LogicalResult
matchAndRewrite(CastOp castOp, typename CastOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
-
+ // Vectors in particular are not supported
Type operandType = adaptor.getIn().getType();
if (!emitc::isSupportedIntegerType(operandType))
return rewriter.notifyMatchFailure(castOp,
@@ -193,8 +207,20 @@ class ItoFCastOpConversion : public OpConversionPattern<CastOp> {
return rewriter.notifyMatchFailure(castOp,
"unsupported cast destination type");
- rewriter.replaceOpWithNewOp<emitc::CastOp>(castOp, dstType,
- adaptor.getOperands());
+ // Convert to unsigned if it's the "ui" variant
+ // Signless is interpreted as signed, so no need to cast for "si"
+ Type actualOperandType = operandType;
+ if (isa<arith::UIToFPOp>(castOp)) {
+ actualOperandType =
+ rewriter.getIntegerType(operandType.getIntOrFloatBitWidth(),
+ /*isSigned=*/false);
+ }
+ Value fpCastOperand = adaptor.getIn();
+ if (actualOperandType != operandType) {
+ fpCastOperand = rewriter.template create<emitc::CastOp>(
+ castOp.getLoc(), actualOperandType, fpCastOperand);
+ }
+ rewriter.replaceOpWithNewOp<emitc::CastOp>(castOp, dstType, fpCastOperand);
return success();
}
diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-cast-truncate.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-cast-truncate.mlir
index f45b6306b0292b..26f9261183144e 100644
--- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-cast-truncate.mlir
+++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-cast-truncate.mlir
@@ -13,7 +13,8 @@ func.func @arith_float_to_int_cast_ops(%arg0: f32, %arg1: f64) {
// CHECK: emitc.cast %arg1 : f64 to i16
%3 = arith.fptosi %arg1 : f64 to i16
- // CHECK: emitc.cast %arg0 : f32 to i32
+ // CHECK: %[[CAST0:.*]] = emitc.cast %arg0 : f32 to ui32
+ // CHECK: emitc.cast %[[CAST0]] : ui32 to i32
%4 = arith.fptoui %arg0 : f32 to i32
return
diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
index 406aa254ecfee1..e4175f90f56fa7 100644
--- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
+++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
@@ -103,7 +103,8 @@ func.func @arith_int_to_float_cast_ops(%arg0: i8, %arg1: i64) {
// CHECK: emitc.cast %arg1 : i64 to f32
%1 = arith.sitofp %arg1 : i64 to f32
- // CHECK: emitc.cast %arg0 : i8 to f32
+ // CHECK: %[[CAST_UNS:.*]] = emitc.cast %arg0 : i8 to ui8
+ // CHECK: emitc.cast %[[CAST_UNS]] : ui8 to f32
%2 = arith.uitofp %arg0 : i8 to f32
return
More information about the Mlir-commits
mailing list