[Mlir-commits] [mlir] 18e7dcb - [mlir][emitc] Arith to EmitC: handle floating-point<->integer conversions (#87614)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri May 3 04:47:21 PDT 2024
Author: Corentin Ferry
Date: 2024-05-03T13:47:16+02:00
New Revision: 18e7dcb7c5765d89e36af9a56354525efd685b83
URL: https://github.com/llvm/llvm-project/commit/18e7dcb7c5765d89e36af9a56354525efd685b83
DIFF: https://github.com/llvm/llvm-project/commit/18e7dcb7c5765d89e36af9a56354525efd685b83.diff
LOG: [mlir][emitc] Arith to EmitC: handle floating-point<->integer conversions (#87614)
Add support for floating-point to integer, integer to floating-point
conversions. Floating point conversions to 1-bit integer types are not
handled at the moment, as these don't map directly to boolean
conversions.
Added:
mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir
Modified:
mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
index 9b2544276ce474..1447b182ccfdbc 100644
--- a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
+++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
@@ -201,6 +201,96 @@ class SelectOpConversion : public OpConversionPattern<arith::SelectOp> {
}
};
+// Floating-point to integer conversions.
+template <typename CastOp>
+class FtoICastOpConversion : public OpConversionPattern<CastOp> {
+public:
+ FtoICastOpConversion(const TypeConverter &typeConverter, MLIRContext *context)
+ : OpConversionPattern<CastOp>(typeConverter, context) {}
+
+ LogicalResult
+ matchAndRewrite(CastOp castOp, typename CastOp::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+
+ Type operandType = adaptor.getIn().getType();
+ if (!emitc::isSupportedFloatType(operandType))
+ return rewriter.notifyMatchFailure(castOp,
+ "unsupported cast source type");
+
+ Type dstType = this->getTypeConverter()->convertType(castOp.getType());
+ if (!dstType)
+ return rewriter.notifyMatchFailure(castOp, "type conversion failed");
+
+ // Float-to-i1 casts are not supported: any value with 0 < value < 1 must be
+ // truncated to 0, whereas a boolean conversion would return true.
+ if (!emitc::isSupportedIntegerType(dstType) || dstType.isInteger(1))
+ return rewriter.notifyMatchFailure(castOp,
+ "unsupported cast destination type");
+
+ // Convert to unsigned if it's the "ui" variant
+ // Signless is interpreted as signed, so no need to cast for "si"
+ Type actualResultType = dstType;
+ if (isa<arith::FPToUIOp>(castOp)) {
+ actualResultType =
+ rewriter.getIntegerType(operandType.getIntOrFloatBitWidth(),
+ /*isSigned=*/false);
+ }
+
+ Value result = rewriter.create<emitc::CastOp>(
+ castOp.getLoc(), actualResultType, adaptor.getOperands());
+
+ if (isa<arith::FPToUIOp>(castOp)) {
+ result = rewriter.create<emitc::CastOp>(castOp.getLoc(), dstType, result);
+ }
+ rewriter.replaceOp(castOp, result);
+
+ return success();
+ }
+};
+
+// Integer to floating-point conversions.
+template <typename CastOp>
+class ItoFCastOpConversion : public OpConversionPattern<CastOp> {
+public:
+ ItoFCastOpConversion(const TypeConverter &typeConverter, MLIRContext *context)
+ : OpConversionPattern<CastOp>(typeConverter, context) {}
+
+ LogicalResult
+ matchAndRewrite(CastOp castOp, typename CastOp::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ // Vectors in particular are not supported
+ Type operandType = adaptor.getIn().getType();
+ if (!emitc::isSupportedIntegerType(operandType))
+ return rewriter.notifyMatchFailure(castOp,
+ "unsupported cast source type");
+
+ Type dstType = this->getTypeConverter()->convertType(castOp.getType());
+ if (!dstType)
+ return rewriter.notifyMatchFailure(castOp, "type conversion failed");
+
+ if (!emitc::isSupportedFloatType(dstType))
+ return rewriter.notifyMatchFailure(castOp,
+ "unsupported cast destination type");
+
+ // Convert to unsigned if it's the "ui" variant
+ // Signless is interpreted as signed, so no need to cast for "si"
+ Type actualOperandType = operandType;
+ if (isa<arith::UIToFPOp>(castOp)) {
+ actualOperandType =
+ rewriter.getIntegerType(operandType.getIntOrFloatBitWidth(),
+ /*isSigned=*/false);
+ }
+ Value fpCastOperand = adaptor.getIn();
+ if (actualOperandType != operandType) {
+ fpCastOperand = rewriter.template create<emitc::CastOp>(
+ castOp.getLoc(), actualOperandType, fpCastOperand);
+ }
+ rewriter.replaceOpWithNewOp<emitc::CastOp>(castOp, dstType, fpCastOperand);
+
+ return success();
+ }
+};
+
} // namespace
//===----------------------------------------------------------------------===//
@@ -222,7 +312,11 @@ void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter,
IntegerOpConversion<arith::MulIOp, emitc::MulOp>,
IntegerOpConversion<arith::SubIOp, emitc::SubOp>,
CmpIOpConversion,
- SelectOpConversion
+ SelectOpConversion,
+ ItoFCastOpConversion<arith::SIToFPOp>,
+ ItoFCastOpConversion<arith::UIToFPOp>,
+ FtoICastOpConversion<arith::FPToSIOp>,
+ FtoICastOpConversion<arith::FPToUIOp>
>(typeConverter, ctx);
// clang-format on
}
diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir
new file mode 100644
index 00000000000000..66dfa8fa3e157e
--- /dev/null
+++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir
@@ -0,0 +1,65 @@
+// RUN: mlir-opt -split-input-file -convert-arith-to-emitc -verify-diagnostics %s
+
+func.func @arith_cast_tensor(%arg0: tensor<5xf32>) -> tensor<5xi32> {
+ // expected-error @+1 {{failed to legalize operation 'arith.fptosi'}}
+ %t = arith.fptosi %arg0 : tensor<5xf32> to tensor<5xi32>
+ return %t: tensor<5xi32>
+}
+
+// -----
+
+func.func @arith_cast_vector(%arg0: vector<5xf32>) -> vector<5xi32> {
+ // expected-error @+1 {{failed to legalize operation 'arith.fptosi'}}
+ %t = arith.fptosi %arg0 : vector<5xf32> to vector<5xi32>
+ return %t: vector<5xi32>
+}
+
+// -----
+
+func.func @arith_cast_bf16(%arg0: bf16) -> i32 {
+ // expected-error @+1 {{failed to legalize operation 'arith.fptosi'}}
+ %t = arith.fptosi %arg0 : bf16 to i32
+ return %t: i32
+}
+
+// -----
+
+func.func @arith_cast_f16(%arg0: f16) -> i32 {
+ // expected-error @+1 {{failed to legalize operation 'arith.fptosi'}}
+ %t = arith.fptosi %arg0 : f16 to i32
+ return %t: i32
+}
+
+
+// -----
+
+func.func @arith_cast_to_bf16(%arg0: i32) -> bf16 {
+ // expected-error @+1 {{failed to legalize operation 'arith.sitofp'}}
+ %t = arith.sitofp %arg0 : i32 to bf16
+ return %t: bf16
+}
+
+// -----
+
+func.func @arith_cast_to_f16(%arg0: i32) -> f16 {
+ // expected-error @+1 {{failed to legalize operation 'arith.sitofp'}}
+ %t = arith.sitofp %arg0 : i32 to f16
+ return %t: f16
+}
+
+// -----
+
+func.func @arith_cast_fptosi_i1(%arg0: f32) -> i1 {
+ // expected-error @+1 {{failed to legalize operation 'arith.fptosi'}}
+ %t = arith.fptosi %arg0 : f32 to i1
+ return %t: i1
+}
+
+// -----
+
+func.func @arith_cast_fptoui_i1(%arg0: f32) -> i1 {
+ // expected-error @+1 {{failed to legalize operation 'arith.fptoui'}}
+ %t = arith.fptoui %arg0 : f32 to i1
+ return %t: i1
+}
+
diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
index 46b407177b46aa..79fecd61494d0d 100644
--- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
+++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
@@ -141,3 +141,39 @@ func.func @arith_cmpi_predicates(%arg0: i32, %arg1: i32) {
return
}
+
+// -----
+
+func.func @arith_float_to_int_cast_ops(%arg0: f32, %arg1: f64) {
+ // CHECK: emitc.cast %arg0 : f32 to i32
+ %0 = arith.fptosi %arg0 : f32 to i32
+
+ // CHECK: emitc.cast %arg1 : f64 to i32
+ %1 = arith.fptosi %arg1 : f64 to i32
+
+ // CHECK: emitc.cast %arg0 : f32 to i16
+ %2 = arith.fptosi %arg0 : f32 to i16
+
+ // CHECK: emitc.cast %arg1 : f64 to i16
+ %3 = arith.fptosi %arg1 : f64 to i16
+
+ // CHECK: %[[CAST0:.*]] = emitc.cast %arg0 : f32 to ui32
+ // CHECK: emitc.cast %[[CAST0]] : ui32 to i32
+ %4 = arith.fptoui %arg0 : f32 to i32
+
+ return
+}
+
+func.func @arith_int_to_float_cast_ops(%arg0: i8, %arg1: i64) {
+ // CHECK: emitc.cast %arg0 : i8 to f32
+ %0 = arith.sitofp %arg0 : i8 to f32
+
+ // CHECK: emitc.cast %arg1 : i64 to f32
+ %1 = arith.sitofp %arg1 : i64 to f32
+
+ // CHECK: %[[CAST_UNS:.*]] = emitc.cast %arg0 : i8 to ui8
+ // CHECK: emitc.cast %[[CAST_UNS]] : ui8 to f32
+ %2 = arith.uitofp %arg0 : i8 to f32
+
+ return
+}
More information about the Mlir-commits
mailing list