[Mlir-commits] [mlir] [mlir][emitc] Arith to EmitC: handle floating-point<->integer conversions (PR #87614)

Corentin Ferry llvmlistbot at llvm.org
Fri Apr 19 04:58:53 PDT 2024


https://github.com/cferry-AMD updated https://github.com/llvm/llvm-project/pull/87614

>From e7e14b24e9a5cb96be0507be11c31f126aa7c9e2 Mon Sep 17 00:00:00 2001
From: Corentin Ferry <corentin.ferry at amd.com>
Date: Wed, 27 Mar 2024 08:58:15 +0000
Subject: [PATCH] [mlir][emitc] Arith to EmitC: handle FP<->Integer conversions

---
 .../Conversion/ArithToEmitC/ArithToEmitC.cpp  | 94 ++++++++++++++++++-
 .../arith-to-emitc-unsupported.mlir           | 48 ++++++++++
 .../ArithToEmitC/arith-to-emitc.mlir          | 36 +++++++
 3 files changed, 177 insertions(+), 1 deletion(-)
 create mode 100644 mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir

diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
index 9b2544276ce474..195d4d39cbdbe7 100644
--- a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
+++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
@@ -201,6 +201,94 @@ class SelectOpConversion : public OpConversionPattern<arith::SelectOp> {
   }
 };
 
+// Floating-point to integer conversions.
+template <typename CastOp>
+class FtoICastOpConversion : public OpConversionPattern<CastOp> {
+public:
+  FtoICastOpConversion(const TypeConverter &typeConverter, MLIRContext *context)
+      : OpConversionPattern<CastOp>(typeConverter, context) {}
+
+  LogicalResult
+  matchAndRewrite(CastOp castOp, typename CastOp::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+
+    Type operandType = adaptor.getIn().getType();
+    if (!emitc::isSupportedFloatType(operandType))
+      return rewriter.notifyMatchFailure(castOp,
+                                         "unsupported cast source type");
+
+    Type dstType = this->getTypeConverter()->convertType(castOp.getType());
+    if (!dstType)
+      return rewriter.notifyMatchFailure(castOp, "type conversion failed");
+
+    if (!emitc::isSupportedIntegerType(dstType))
+      return rewriter.notifyMatchFailure(castOp,
+                                         "unsupported cast destination type");
+
+    // Convert to unsigned if it's the "ui" variant
+    // Signless is interpreted as signed, so no need to cast for "si"
+    Type actualResultType = dstType;
+    if (isa<arith::FPToUIOp>(castOp)) {
+      actualResultType =
+          rewriter.getIntegerType(operandType.getIntOrFloatBitWidth(),
+                                  /*isSigned=*/false);
+    }
+
+    Value result = rewriter.create<emitc::CastOp>(
+        castOp.getLoc(), actualResultType, adaptor.getOperands());
+
+    if (isa<arith::FPToUIOp>(castOp)) {
+      result = rewriter.create<emitc::CastOp>(castOp.getLoc(), dstType, result);
+    }
+    rewriter.replaceOp(castOp, result);
+
+    return success();
+  }
+};
+
+// Integer to floating-point conversions.
+template <typename CastOp>
+class ItoFCastOpConversion : public OpConversionPattern<CastOp> {
+public:
+  ItoFCastOpConversion(const TypeConverter &typeConverter, MLIRContext *context)
+      : OpConversionPattern<CastOp>(typeConverter, context) {}
+
+  LogicalResult
+  matchAndRewrite(CastOp castOp, typename CastOp::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    // Vectors in particular are not supported
+    Type operandType = adaptor.getIn().getType();
+    if (!emitc::isSupportedIntegerType(operandType))
+      return rewriter.notifyMatchFailure(castOp,
+                                         "unsupported cast source type");
+
+    Type dstType = this->getTypeConverter()->convertType(castOp.getType());
+    if (!dstType)
+      return rewriter.notifyMatchFailure(castOp, "type conversion failed");
+
+    if (!emitc::isSupportedFloatType(dstType))
+      return rewriter.notifyMatchFailure(castOp,
+                                         "unsupported cast destination type");
+
+    // Convert to unsigned if it's the "ui" variant
+    // Signless is interpreted as signed, so no need to cast for "si"
+    Type actualOperandType = operandType;
+    if (isa<arith::UIToFPOp>(castOp)) {
+      actualOperandType =
+          rewriter.getIntegerType(operandType.getIntOrFloatBitWidth(),
+                                  /*isSigned=*/false);
+    }
+    Value fpCastOperand = adaptor.getIn();
+    if (actualOperandType != operandType) {
+      fpCastOperand = rewriter.template create<emitc::CastOp>(
+          castOp.getLoc(), actualOperandType, fpCastOperand);
+    }
+    rewriter.replaceOpWithNewOp<emitc::CastOp>(castOp, dstType, fpCastOperand);
+
+    return success();
+  }
+};
+
 } // namespace
 
 //===----------------------------------------------------------------------===//
@@ -222,7 +310,11 @@ void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter,
     IntegerOpConversion<arith::MulIOp, emitc::MulOp>,
     IntegerOpConversion<arith::SubIOp, emitc::SubOp>,
     CmpIOpConversion,
-    SelectOpConversion
+    SelectOpConversion,
+    ItoFCastOpConversion<arith::SIToFPOp>,
+    ItoFCastOpConversion<arith::UIToFPOp>,
+    FtoICastOpConversion<arith::FPToSIOp>,
+    FtoICastOpConversion<arith::FPToUIOp>
   >(typeConverter, ctx);
   // clang-format on
 }
diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir
new file mode 100644
index 00000000000000..39b56882853a77
--- /dev/null
+++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir
@@ -0,0 +1,48 @@
+// RUN: mlir-opt -split-input-file -convert-arith-to-emitc -verify-diagnostics %s
+
+func.func @arith_cast_tensor(%arg0: tensor<5xf32>) -> tensor<5xi32> {
+  // expected-error @+1 {{failed to legalize operation 'arith.fptosi'}}
+  %t = arith.fptosi %arg0 : tensor<5xf32> to tensor<5xi32>
+  return %t: tensor<5xi32>
+}
+
+// -----
+
+func.func @arith_cast_vector(%arg0: vector<5xf32>) -> vector<5xi32> {
+  // expected-error @+1 {{failed to legalize operation 'arith.fptosi'}}
+  %t = arith.fptosi %arg0 : vector<5xf32> to vector<5xi32>
+  return %t: vector<5xi32>
+}
+
+// -----
+
+func.func @arith_cast_bf16(%arg0: bf16) -> i32 {
+  // expected-error @+1 {{failed to legalize operation 'arith.fptosi'}}
+  %t = arith.fptosi %arg0 : bf16 to i32
+  return %t: i32
+}
+
+// -----
+
+func.func @arith_cast_f16(%arg0: f16) -> i32 {
+  // expected-error @+1 {{failed to legalize operation 'arith.fptosi'}}
+  %t = arith.fptosi %arg0 : f16 to i32
+  return %t: i32
+}
+
+
+// -----
+
+func.func @arith_cast_to_bf16(%arg0: i32) -> bf16 {
+  // expected-error @+1 {{failed to legalize operation 'arith.sitofp'}}
+  %t = arith.sitofp %arg0 : i32 to bf16
+  return %t: bf16
+}
+
+// -----
+
+func.func @arith_cast_to_f16(%arg0: i32) -> f16 {
+  // expected-error @+1 {{failed to legalize operation 'arith.sitofp'}}
+  %t = arith.sitofp %arg0 : i32 to f16
+  return %t: f16
+}
diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
index 46b407177b46aa..79fecd61494d0d 100644
--- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
+++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
@@ -141,3 +141,39 @@ func.func @arith_cmpi_predicates(%arg0: i32, %arg1: i32) {
   
   return
 }
+
+// -----
+
+func.func @arith_float_to_int_cast_ops(%arg0: f32, %arg1: f64) {
+  // CHECK: emitc.cast %arg0 : f32 to i32
+  %0 = arith.fptosi %arg0 : f32 to i32
+
+  // CHECK: emitc.cast %arg1 : f64 to i32
+  %1 = arith.fptosi %arg1 : f64 to i32
+
+  // CHECK: emitc.cast %arg0 : f32 to i16
+  %2 = arith.fptosi %arg0 : f32 to i16
+
+  // CHECK: emitc.cast %arg1 : f64 to i16
+  %3 = arith.fptosi %arg1 : f64 to i16
+
+  // CHECK: %[[CAST0:.*]] = emitc.cast %arg0 : f32 to ui32
+  // CHECK: emitc.cast %[[CAST0]] : ui32 to i32
+  %4 = arith.fptoui %arg0 : f32 to i32
+
+  return
+}
+
+func.func @arith_int_to_float_cast_ops(%arg0: i8, %arg1: i64) {
+  // CHECK: emitc.cast %arg0 : i8 to f32
+  %0 = arith.sitofp %arg0 : i8 to f32
+
+  // CHECK: emitc.cast %arg1 : i64 to f32
+  %1 = arith.sitofp %arg1 : i64 to f32
+
+  // CHECK: %[[CAST_UNS:.*]] = emitc.cast %arg0 : i8 to ui8
+  // CHECK: emitc.cast %[[CAST_UNS]] : ui8 to f32
+  %2 = arith.uitofp %arg0 : i8 to f32
+
+  return
+}



More information about the Mlir-commits mailing list