[Mlir-commits] [mlir] [mlir][emitc] Arith to EmitC: handle floating-point<->integer conversions (PR #87614)

Corentin Ferry llvmlistbot at llvm.org
Mon Apr 15 06:49:50 PDT 2024


https://github.com/cferry-AMD updated https://github.com/llvm/llvm-project/pull/87614

>From 133a8ba87f3e3fa25e60542714334df935dba585 Mon Sep 17 00:00:00 2001
From: Corentin Ferry <corentin.ferry at amd.com>
Date: Wed, 27 Mar 2024 08:58:15 +0000
Subject: [PATCH 1/2] [mlir][emitc] Arith to EmitC: handle FP<->Integer
 conversions

---
 .../Conversion/ArithToEmitC/ArithToEmitC.h    |  3 +-
 mlir/include/mlir/Conversion/Passes.td        | 17 ++++
 .../Conversion/ArithToEmitC/ArithToEmitC.cpp  | 85 ++++++++++++++++++-
 .../ArithToEmitC/ArithToEmitCPass.cpp         |  4 +-
 .../arith-to-emitc-cast-truncate.mlir         | 20 +++++
 .../arith-to-emitc-cast-unsupported.mlir      | 48 +++++++++++
 .../arith-to-emitc-unsupported.mlir           |  7 ++
 .../ArithToEmitC/arith-to-emitc.mlir          | 15 ++++
 8 files changed, 194 insertions(+), 5 deletions(-)
 create mode 100644 mlir/test/Conversion/ArithToEmitC/arith-to-emitc-cast-truncate.mlir
 create mode 100644 mlir/test/Conversion/ArithToEmitC/arith-to-emitc-cast-unsupported.mlir
 create mode 100644 mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir

diff --git a/mlir/include/mlir/Conversion/ArithToEmitC/ArithToEmitC.h b/mlir/include/mlir/Conversion/ArithToEmitC/ArithToEmitC.h
index 9cb43689d1ce64..32d039e9c89185 100644
--- a/mlir/include/mlir/Conversion/ArithToEmitC/ArithToEmitC.h
+++ b/mlir/include/mlir/Conversion/ArithToEmitC/ArithToEmitC.h
@@ -14,7 +14,8 @@ class RewritePatternSet;
 class TypeConverter;
 
 void populateArithToEmitCPatterns(TypeConverter &typeConverter,
-                                  RewritePatternSet &patterns);
+                                  RewritePatternSet &patterns,
+                                  bool optionFloatToIntTruncates);
 } // namespace mlir
 
 #endif // MLIR_CONVERSION_ARITHTOEMITC_ARITHTOEMITC_H
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index d094ee3b36ab95..029cbd7aec2819 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -139,7 +139,24 @@ def ArithToAMDGPUConversionPass : Pass<"convert-arith-to-amdgpu"> {
 
 def ConvertArithToEmitC : Pass<"convert-arith-to-emitc"> {
   let summary = "Convert Arith dialect to EmitC dialect";
+  let description = [{
+    This pass converts `arith` dialect operations to `emitc`.
+
+    The semantics of floating-point to integer conversions `arith.fptosi`, 
+    `arith.fptoui` require rounding towards zero. Typical C++ implementations
+    use this behavior for float-to-integer casts, but that is not mandated by 
+    C++ and there are implementation-defined means to change the default behavior.
+    
+    If casts can be guaranteed to use round-to-zero, use the 
+    `float-to-int-truncates` flag to allow conversion of `arith.fptosi` and
+    `arith.fptoui` operations.
+  }];
   let dependentDialects = ["emitc::EmitCDialect"];
+  let options = [
+    Option<"floatToIntTruncates", "float-to-int-truncates", "bool",
+           /*default=*/"false",
+           "Whether the behavior of float-to-int cast in emitc is truncation">,
+  ];
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
index db493c1294ba2d..311978ea6c40e0 100644
--- a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
+++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
@@ -128,6 +128,78 @@ class SelectOpConversion : public OpConversionPattern<arith::SelectOp> {
   }
 };
 
+// Floating-point to integer conversions.
+template <typename CastOp>
+class FtoICastOpConversion : public OpConversionPattern<CastOp> {
+private:
+  bool floatToIntTruncates;
+
+public:
+  FtoICastOpConversion(const TypeConverter &typeConverter, MLIRContext *context,
+                       bool optionFloatToIntTruncates)
+      : OpConversionPattern<CastOp>(typeConverter, context),
+        floatToIntTruncates(optionFloatToIntTruncates) {}
+
+  LogicalResult
+  matchAndRewrite(CastOp castOp, typename CastOp::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+
+    Type operandType = adaptor.getIn().getType();
+    if (!emitc::isSupportedFloatType(operandType))
+      return rewriter.notifyMatchFailure(castOp,
+                                         "unsupported cast source type");
+
+    if (!floatToIntTruncates)
+      return rewriter.notifyMatchFailure(
+          castOp, "conversion currently requires EmitC casts to use truncation "
+                  "as rounding mode");
+
+    Type dstType = this->getTypeConverter()->convertType(castOp.getType());
+    if (!dstType)
+      return rewriter.notifyMatchFailure(castOp, "type conversion failed");
+
+    if (!emitc::isSupportedIntegerType(dstType))
+      return rewriter.notifyMatchFailure(castOp,
+                                         "unsupported cast destination type");
+
+    rewriter.replaceOpWithNewOp<emitc::CastOp>(castOp, dstType,
+                                               adaptor.getOperands());
+
+    return success();
+  }
+};
+
+// Integer to floating-point conversions.
+template <typename CastOp>
+class ItoFCastOpConversion : public OpConversionPattern<CastOp> {
+public:
+  ItoFCastOpConversion(const TypeConverter &typeConverter, MLIRContext *context)
+      : OpConversionPattern<CastOp>(typeConverter, context) {}
+
+  LogicalResult
+  matchAndRewrite(CastOp castOp, typename CastOp::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+
+    Type operandType = adaptor.getIn().getType();
+    if (!emitc::isSupportedIntegerType(operandType))
+      return rewriter.notifyMatchFailure(castOp,
+                                         "unsupported cast source type");
+
+    Type dstType = this->getTypeConverter()->convertType(castOp.getType());
+    if (!dstType)
+      return rewriter.notifyMatchFailure(castOp, "type conversion failed");
+
+    if (!emitc::isSupportedFloatType(dstType))
+      return rewriter.notifyMatchFailure(castOp,
+                                         "unsupported cast destination type");
+
+    rewriter.replaceOpWithNewOp<emitc::CastOp>(castOp, dstType,
+                                               adaptor.getOperands());
+
+    return success();
+  }
+};
+
 } // namespace
 
 //===----------------------------------------------------------------------===//
@@ -135,7 +207,8 @@ class SelectOpConversion : public OpConversionPattern<arith::SelectOp> {
 //===----------------------------------------------------------------------===//
 
 void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter,
-                                        RewritePatternSet &patterns) {
+                                        RewritePatternSet &patterns,
+                                        bool optionFloatToIntTruncates) {
   MLIRContext *ctx = patterns.getContext();
 
   // clang-format off
@@ -148,7 +221,13 @@ void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter,
     IntegerOpConversion<arith::AddIOp, emitc::AddOp>,
     IntegerOpConversion<arith::MulIOp, emitc::MulOp>,
     IntegerOpConversion<arith::SubIOp, emitc::SubOp>,
-    SelectOpConversion
-  >(typeConverter, ctx);
+    SelectOpConversion,
+    ItoFCastOpConversion<arith::SIToFPOp>,
+    ItoFCastOpConversion<arith::UIToFPOp>
+  >(typeConverter, ctx)
+  .add<
+    FtoICastOpConversion<arith::FPToSIOp>,
+    FtoICastOpConversion<arith::FPToUIOp>
+  >(typeConverter, ctx, optionFloatToIntTruncates);
   // clang-format on
 }
diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitCPass.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitCPass.cpp
index 45a088ed144f17..546bbfe2082eff 100644
--- a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitCPass.cpp
+++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitCPass.cpp
@@ -29,6 +29,8 @@ using namespace mlir;
 namespace {
 struct ConvertArithToEmitC
     : public impl::ConvertArithToEmitCBase<ConvertArithToEmitC> {
+  using Base::Base;
+
   void runOnOperation() override;
 };
 } // namespace
@@ -44,7 +46,7 @@ void ConvertArithToEmitC::runOnOperation() {
   TypeConverter typeConverter;
   typeConverter.addConversion([](Type type) { return type; });
 
-  populateArithToEmitCPatterns(typeConverter, patterns);
+  populateArithToEmitCPatterns(typeConverter, patterns, floatToIntTruncates);
 
   if (failed(
           applyPartialConversion(getOperation(), target, std::move(patterns))))
diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-cast-truncate.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-cast-truncate.mlir
new file mode 100644
index 00000000000000..f45b6306b0292b
--- /dev/null
+++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-cast-truncate.mlir
@@ -0,0 +1,20 @@
+// RUN: mlir-opt -split-input-file --pass-pipeline="builtin.module(convert-arith-to-emitc{float-to-int-truncates})" %s | FileCheck %s
+
+func.func @arith_float_to_int_cast_ops(%arg0: f32, %arg1: f64) {
+  // CHECK: emitc.cast %arg0 : f32 to i32
+  %0 = arith.fptosi %arg0 : f32 to i32
+
+  // CHECK: emitc.cast %arg1 : f64 to i32
+  %1 = arith.fptosi %arg1 : f64 to i32
+
+  // CHECK: emitc.cast %arg0 : f32 to i16
+  %2 = arith.fptosi %arg0 : f32 to i16
+
+  // CHECK: emitc.cast %arg1 : f64 to i16
+  %3 = arith.fptosi %arg1 : f64 to i16
+
+  // CHECK: emitc.cast %arg0 : f32 to i32
+  %4 = arith.fptoui %arg0 : f32 to i32
+  
+  return
+}
diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-cast-unsupported.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-cast-unsupported.mlir
new file mode 100644
index 00000000000000..34fc9f3dffc0c8
--- /dev/null
+++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-cast-unsupported.mlir
@@ -0,0 +1,48 @@
+// RUN: mlir-opt -split-input-file --pass-pipeline="builtin.module(convert-arith-to-emitc{float-to-int-truncates})" -verify-diagnostics %s
+
+func.func @arith_cast_tensor(%arg0: tensor<5xf32>) -> tensor<5xi32> {
+  // expected-error @+1 {{failed to legalize operation 'arith.fptosi'}}
+  %t = arith.fptosi %arg0 : tensor<5xf32> to tensor<5xi32>
+  return %t: tensor<5xi32>
+}
+
+// -----
+
+func.func @arith_cast_vector(%arg0: vector<5xf32>) -> vector<5xi32> {
+  // expected-error @+1 {{failed to legalize operation 'arith.fptosi'}}
+  %t = arith.fptosi %arg0 : vector<5xf32> to vector<5xi32>
+  return %t: vector<5xi32>
+}
+
+// -----
+
+func.func @arith_cast_bf16(%arg0: bf16) -> i32 {
+  // expected-error @+1 {{failed to legalize operation 'arith.fptosi'}}
+  %t = arith.fptosi %arg0 : bf16 to i32
+  return %t: i32
+}
+
+// -----
+
+func.func @arith_cast_f16(%arg0: f16) -> i32 {
+  // expected-error @+1 {{failed to legalize operation 'arith.fptosi'}}
+  %t = arith.fptosi %arg0 : f16 to i32
+  return %t: i32
+}
+
+
+// -----
+
+func.func @arith_cast_to_bf16(%arg0: i32) -> bf16 {
+  // expected-error @+1 {{failed to legalize operation 'arith.sitofp'}}
+  %t = arith.sitofp %arg0 : i32 to bf16
+  return %t: bf16
+}
+
+// -----
+
+func.func @arith_cast_to_f16(%arg0: i32) -> f16 {
+  // expected-error @+1 {{failed to legalize operation 'arith.sitofp'}}
+  %t = arith.sitofp %arg0 : i32 to f16
+  return %t: f16
+}
diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir
new file mode 100644
index 00000000000000..bbec664100564b
--- /dev/null
+++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir
@@ -0,0 +1,7 @@
+// RUN: mlir-opt -split-input-file -convert-arith-to-emitc -verify-diagnostics %s
+
+func.func @arith_cast_f32(%arg0: f32) -> i32 {
+  // expected-error @+1 {{failed to legalize operation 'arith.fptosi'}}
+  %t = arith.fptosi %arg0 : f32 to i32
+  return %t: i32
+}
diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
index 76ba518577ab8e..406aa254ecfee1 100644
--- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
+++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
@@ -93,3 +93,18 @@ func.func @arith_select(%arg0: i1, %arg1: tensor<8xi32>, %arg2: tensor<8xi32>) -
   %0 = arith.select %arg0, %arg1, %arg2 : i1, tensor<8xi32>
   return
 }
+
+// -----
+
+func.func @arith_int_to_float_cast_ops(%arg0: i8, %arg1: i64) {
+  // CHECK: emitc.cast %arg0 : i8 to f32
+  %0 = arith.sitofp %arg0 : i8 to f32
+
+  // CHECK: emitc.cast %arg1 : i64 to f32
+  %1 = arith.sitofp %arg1 : i64 to f32
+
+  // CHECK: emitc.cast %arg0 : i8 to f32
+  %2 = arith.uitofp %arg0 : i8 to f32
+
+  return
+}

>From db3765b7252ef606f73e1ccac52ed101f4961741 Mon Sep 17 00:00:00 2001
From: Corentin Ferry <corentin.ferry at amd.com>
Date: Mon, 15 Apr 2024 09:28:03 +0200
Subject: [PATCH 2/2] Merge pull request #160 from Xilinx/corentin.fix_itofp

[FXML-4281] Fix signedness behavior of unsigned integer <-> floating-point conversions
---
 .../Conversion/ArithToEmitC/ArithToEmitC.cpp  | 36 ++++++++++++++++---
 .../arith-to-emitc-cast-truncate.mlir         |  3 +-
 .../ArithToEmitC/arith-to-emitc.mlir          |  3 +-
 3 files changed, 35 insertions(+), 7 deletions(-)

diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
index 311978ea6c40e0..dee110dbd79323 100644
--- a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
+++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
@@ -162,8 +162,22 @@ class FtoICastOpConversion : public OpConversionPattern<CastOp> {
       return rewriter.notifyMatchFailure(castOp,
                                          "unsupported cast destination type");
 
-    rewriter.replaceOpWithNewOp<emitc::CastOp>(castOp, dstType,
-                                               adaptor.getOperands());
+    // Convert to unsigned if it's the "ui" variant
+    // Signless is interpreted as signed, so no need to cast for "si"
+    Type actualResultType = dstType;
+    if (isa<arith::FPToUIOp>(castOp)) {
+      actualResultType =
+          rewriter.getIntegerType(operandType.getIntOrFloatBitWidth(),
+                                  /*isSigned=*/false);
+    }
+
+    Value result = rewriter.create<emitc::CastOp>(
+        castOp.getLoc(), actualResultType, adaptor.getOperands());
+
+    if (isa<arith::FPToUIOp>(castOp)) {
+      result = rewriter.create<emitc::CastOp>(castOp.getLoc(), dstType, result);
+    }
+    rewriter.replaceOp(castOp, result);
 
     return success();
   }
@@ -179,7 +193,7 @@ class ItoFCastOpConversion : public OpConversionPattern<CastOp> {
   LogicalResult
   matchAndRewrite(CastOp castOp, typename CastOp::Adaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-
+    // Vectors in particular are not supported
     Type operandType = adaptor.getIn().getType();
     if (!emitc::isSupportedIntegerType(operandType))
       return rewriter.notifyMatchFailure(castOp,
@@ -193,8 +207,20 @@ class ItoFCastOpConversion : public OpConversionPattern<CastOp> {
       return rewriter.notifyMatchFailure(castOp,
                                          "unsupported cast destination type");
 
-    rewriter.replaceOpWithNewOp<emitc::CastOp>(castOp, dstType,
-                                               adaptor.getOperands());
+    // Convert to unsigned if it's the "ui" variant
+    // Signless is interpreted as signed, so no need to cast for "si"
+    Type actualOperandType = operandType;
+    if (isa<arith::UIToFPOp>(castOp)) {
+      actualOperandType =
+          rewriter.getIntegerType(operandType.getIntOrFloatBitWidth(),
+                                  /*isSigned=*/false);
+    }
+    Value fpCastOperand = adaptor.getIn();
+    if (actualOperandType != operandType) {
+      fpCastOperand = rewriter.template create<emitc::CastOp>(
+          castOp.getLoc(), actualOperandType, fpCastOperand);
+    }
+    rewriter.replaceOpWithNewOp<emitc::CastOp>(castOp, dstType, fpCastOperand);
 
     return success();
   }
diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-cast-truncate.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-cast-truncate.mlir
index f45b6306b0292b..26f9261183144e 100644
--- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-cast-truncate.mlir
+++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-cast-truncate.mlir
@@ -13,7 +13,8 @@ func.func @arith_float_to_int_cast_ops(%arg0: f32, %arg1: f64) {
   // CHECK: emitc.cast %arg1 : f64 to i16
   %3 = arith.fptosi %arg1 : f64 to i16
 
-  // CHECK: emitc.cast %arg0 : f32 to i32
+  // CHECK: %[[CAST0:.*]] = emitc.cast %arg0 : f32 to ui32
+  // CHECK: emitc.cast %[[CAST0]] : ui32 to i32
   %4 = arith.fptoui %arg0 : f32 to i32
   
   return
diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
index 406aa254ecfee1..e4175f90f56fa7 100644
--- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
+++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
@@ -103,7 +103,8 @@ func.func @arith_int_to_float_cast_ops(%arg0: i8, %arg1: i64) {
   // CHECK: emitc.cast %arg1 : i64 to f32
   %1 = arith.sitofp %arg1 : i64 to f32
 
-  // CHECK: emitc.cast %arg0 : i8 to f32
+  // CHECK: %[[CAST_UNS:.*]] = emitc.cast %arg0 : i8 to ui8
+  // CHECK: emitc.cast %[[CAST_UNS]] : ui8 to f32
   %2 = arith.uitofp %arg0 : i8 to f32
 
   return



More information about the Mlir-commits mailing list