[Mlir-commits] [mlir] [mlir][emitc] Support convert arith.extf and arith.truncf to emitc (PR #121184)

Jianjian Guan llvmlistbot at llvm.org
Tue Dec 31 00:13:36 PST 2024


https://github.com/jacquesguan updated https://github.com/llvm/llvm-project/pull/121184

>From 920b0d653f6db2374191e31c654e72f9d3fa43d4 Mon Sep 17 00:00:00 2001
From: Jianjian GUAN <jacquesguan at me.com>
Date: Fri, 27 Dec 2024 15:19:13 +0800
Subject: [PATCH 1/2] [mlir][emitc] Support convert arith.extf and arith.truncf
 to emitc

---
 .../Conversion/ArithToEmitC/ArithToEmitC.cpp  | 35 ++++++++++++++++++-
 .../ArithToEmitC/arith-to-emitc.mlir          | 26 ++++++++++++++
 2 files changed, 60 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
index ccbc1669b7a92a..e2fbac40517e0d 100644
--- a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
+++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
@@ -733,6 +733,37 @@ class ItoFCastOpConversion : public OpConversionPattern<CastOp> {
   }
 };
 
+// Floating-point to floating-point conversions.
+template <typename CastOp>
+class FpCastOpConversion : public OpConversionPattern<CastOp> {
+public:
+  FpCastOpConversion(const TypeConverter &typeConverter, MLIRContext *context)
+      : OpConversionPattern<CastOp>(typeConverter, context) {}
+
+  LogicalResult
+  matchAndRewrite(CastOp castOp, typename CastOp::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    // Vectors in particular are not supported
+    Type operandType = adaptor.getIn().getType();
+    if (!emitc::isSupportedFloatType(operandType))
+      return rewriter.notifyMatchFailure(castOp,
+                                         "unsupported cast source type");
+
+    Type dstType = this->getTypeConverter()->convertType(castOp.getType());
+    if (!dstType)
+      return rewriter.notifyMatchFailure(castOp, "type conversion failed");
+
+    if (!emitc::isSupportedFloatType(dstType))
+      return rewriter.notifyMatchFailure(castOp,
+                                         "unsupported cast destination type");
+
+    Value fpCastOperand = adaptor.getIn();
+    rewriter.replaceOpWithNewOp<emitc::CastOp>(castOp, dstType, fpCastOperand);
+
+    return success();
+  }
+};
+
 } // namespace
 
 //===----------------------------------------------------------------------===//
@@ -778,7 +809,9 @@ void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter,
     ItoFCastOpConversion<arith::SIToFPOp>,
     ItoFCastOpConversion<arith::UIToFPOp>,
     FtoICastOpConversion<arith::FPToSIOp>,
-    FtoICastOpConversion<arith::FPToUIOp>
+    FtoICastOpConversion<arith::FPToUIOp>,
+    FpCastOpConversion<arith::ExtFOp>,
+    FpCastOpConversion<arith::TruncFOp>
   >(typeConverter, ctx);
   // clang-format on
 }
diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
index 1728c3a2557e07..434f8771d58c1e 100644
--- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
+++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
@@ -739,3 +739,29 @@ func.func @arith_divui_remui(%arg0: i32, %arg1: i32) -> i32 {
 
   return %div : i32
 }
+
+// -----
+
+func.func @arith_extf(%arg0: f16) -> f64 {
+  // CHECK-LABEL: arith_extf
+  // CHECK-SAME: (%[[Arg0:[^ ]*]]: f16)
+  // CHECK: %[[Extd0:.*]] = emitc.cast %[[Arg0]] : f16 to f32
+  %extd0 = arith.extf %arg0 : f16 to f32
+  // CHECK: %[[Extd1:.*]] = emitc.cast %[[Extd0]] : f32 to f64
+  %extd1 = arith.extf %extd0 : f32 to f64
+
+  return %extd1 : f64
+}
+
+// -----
+
+func.func @arith_truncf(%arg0: f64) -> f16 {
+  // CHECK-LABEL: arith_truncf
+  // CHECK-SAME: (%[[Arg0:[^ ]*]]: f64)
+  // CHECK: %[[Truncd0:.*]] = emitc.cast %[[Arg0]] : f64 to f32
+  %truncd0 = arith.truncf %arg0 : f64 to f32
+  // CHECK: %[[Truncd1:.*]] = emitc.cast %[[Truncd0]] : f32 to f16
+  %truncd1 = arith.truncf %truncd0 : f32 to f16
+
+  return %truncd1 : f16
+}
\ No newline at end of file

>From 2ad6ce8d281337ca5c0e5d9ab8acf18bd55b6d43 Mon Sep 17 00:00:00 2001
From: Jianjian GUAN <jacquesguan at me.com>
Date: Tue, 31 Dec 2024 16:12:27 +0800
Subject: [PATCH 2/2] Early exit with rounding mode

---
 mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp         | 6 ++++++
 .../ArithToEmitC/arith-to-emitc-unsupported.mlir          | 8 ++++++++
 mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir     | 2 +-
 3 files changed, 15 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
index e2fbac40517e0d..1eb864f68d3619 100644
--- a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
+++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
@@ -748,6 +748,12 @@ class FpCastOpConversion : public OpConversionPattern<CastOp> {
     if (!emitc::isSupportedFloatType(operandType))
       return rewriter.notifyMatchFailure(castOp,
                                          "unsupported cast source type");
+    if (auto roundingModeOp =
+            dyn_cast<arith::ArithRoundingModeInterface>(*castOp)) {
+      // Only supporting default rounding mode as of now.
+      if (roundingModeOp.getRoundingModeAttr())
+        return rewriter.notifyMatchFailure(castOp, "unsupported rounding mode");
+    }
 
     Type dstType = this->getTypeConverter()->convertType(castOp.getType());
     if (!dstType)
diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir
index b86690461dc269..b84dbf57a01b1a 100644
--- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir
+++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir
@@ -149,3 +149,11 @@ func.func @arith_remui_vector(%arg0: vector<5xi32>, %arg1: vector<5xi32>) -> vec
   %divui = arith.remui %arg0, %arg1 : vector<5xi32>
   return %divui: vector<5xi32>
 }
+
+// -----
+
+func.func @arith_truncf(%arg0: f64) -> f32 {
+  // expected-error @+1 {{failed to legalize operation 'arith.truncf'}}
+  %truncd = arith.truncf %arg0 to_nearest_away : f64 to f32
+  return %truncd : f32
+}
diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
index 434f8771d58c1e..4e3d1088beed93 100644
--- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
+++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
@@ -764,4 +764,4 @@ func.func @arith_truncf(%arg0: f64) -> f16 {
   %truncd1 = arith.truncf %truncd0 : f32 to f16
 
   return %truncd1 : f16
-}
\ No newline at end of file
+}



More information about the Mlir-commits mailing list