[Mlir-commits] [mlir] [mlir] Support emit fp16 and bf16 type to cpp (PR #105803)

Jianjian Guan llvmlistbot at llvm.org
Fri Aug 23 01:58:54 PDT 2024


https://github.com/jacquesguan created https://github.com/llvm/llvm-project/pull/105803

None

>From 19f7d82a59be6511e3a00a2d355d5348faea6e7d Mon Sep 17 00:00:00 2001
From: Jianjian GUAN <jacquesguan at me.com>
Date: Fri, 23 Aug 2024 16:57:00 +0800
Subject: [PATCH] [mlir] Support emit fp16 and bf16 type to cpp

---
 mlir/lib/Dialect/EmitC/IR/EmitC.cpp           |  1 +
 mlir/lib/Target/Cpp/TranslateToCpp.cpp        |  5 +++
 .../arith-to-emitc-unsupported.mlir           | 33 -------------------
 .../MemRefToEmitC/memref-to-emitc-failed.mlir |  8 -----
 mlir/test/Target/Cpp/types.mlir               |  4 +++
 5 files changed, 10 insertions(+), 41 deletions(-)

diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
index e6f1618cc26116..8555e82002d56b 100644
--- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
+++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
@@ -116,6 +116,7 @@ bool mlir::emitc::isIntegerIndexOrOpaqueType(Type type) {
 bool mlir::emitc::isSupportedFloatType(Type type) {
   if (auto floatType = llvm::dyn_cast<FloatType>(type)) {
     switch (floatType.getWidth()) {
+    case 16:
     case 32:
     case 64:
       return true;
diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
index c043582b7be9c6..aa45e7c9d7f757 100644
--- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp
+++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
@@ -1640,6 +1640,11 @@ LogicalResult CppEmitter::emitType(Location loc, Type type) {
   }
   if (auto fType = dyn_cast<FloatType>(type)) {
     switch (fType.getWidth()) {
+    case 16:
+      if (llvm::isa<Float16Type>(type))
+        return (os << "_Float16"), success();
+      else
+        return (os << "__bf16"), success();
     case 32:
       return (os << "float"), success();
     case 64:
diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir
index ef0e71ee8673b7..b3eebaf8a1ef1e 100644
--- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir
+++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir
@@ -16,39 +16,6 @@ func.func @arith_cast_vector(%arg0: vector<5xf32>) -> vector<5xi32> {
 
 // -----
 
-func.func @arith_cast_bf16(%arg0: bf16) -> i32 {
-  // expected-error @+1 {{failed to legalize operation 'arith.fptosi'}}
-  %t = arith.fptosi %arg0 : bf16 to i32
-  return %t: i32
-}
-
-// -----
-
-func.func @arith_cast_f16(%arg0: f16) -> i32 {
-  // expected-error @+1 {{failed to legalize operation 'arith.fptosi'}}
-  %t = arith.fptosi %arg0 : f16 to i32
-  return %t: i32
-}
-
-
-// -----
-
-func.func @arith_cast_to_bf16(%arg0: i32) -> bf16 {
-  // expected-error @+1 {{failed to legalize operation 'arith.sitofp'}}
-  %t = arith.sitofp %arg0 : i32 to bf16
-  return %t: bf16
-}
-
-// -----
-
-func.func @arith_cast_to_f16(%arg0: i32) -> f16 {
-  // expected-error @+1 {{failed to legalize operation 'arith.sitofp'}}
-  %t = arith.sitofp %arg0 : i32 to f16
-  return %t: f16
-}
-
-// -----
-
 func.func @arith_cast_fptosi_i1(%arg0: f32) -> i1 {
   // expected-error @+1 {{failed to legalize operation 'arith.fptosi'}}
   %t = arith.fptosi %arg0 : f32 to i1
diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-failed.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-failed.mlir
index 836d8aedefc1f0..14977bfb3e2fd9 100644
--- a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-failed.mlir
+++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-failed.mlir
@@ -46,14 +46,6 @@ memref.global "nested" constant @nested_global : memref<3x7xf32>
 
 // -----
 
-func.func @unsupported_type_f16() {
-  // expected-error at +1 {{failed to legalize operation 'memref.alloca'}}
-  %0 = memref.alloca() : memref<4xf16>
-  return
-}
-
-// -----
-
 func.func @unsupported_type_i4() {
   // expected-error at +1 {{failed to legalize operation 'memref.alloca'}}
   %0 = memref.alloca() : memref<4xi4>
diff --git a/mlir/test/Target/Cpp/types.mlir b/mlir/test/Target/Cpp/types.mlir
index deda383b3b0a72..e7f935c7374382 100644
--- a/mlir/test/Target/Cpp/types.mlir
+++ b/mlir/test/Target/Cpp/types.mlir
@@ -22,6 +22,10 @@ func.func @ptr_types() {
   emitc.call_opaque "f"() {template_args = [!emitc.ptr<i32>]} : () -> ()
   // CHECK-NEXT: f<int64_t*>();
   emitc.call_opaque "f"() {template_args = [!emitc.ptr<i64>]} : () -> ()
+  // CHECK-NEXT: f<_Float16*>();
+  emitc.call_opaque "f"() {template_args = [!emitc.ptr<f16>]} : () -> ()
+  // CHECK-NEXT: f<__bf16*>();
+  emitc.call_opaque "f"() {template_args = [!emitc.ptr<bf16>]} : () -> ()
   // CHECK-NEXT: f<float*>();
   emitc.call_opaque "f"() {template_args = [!emitc.ptr<f32>]} : () -> ()
   // CHECK-NEXT: f<double*>();



More information about the Mlir-commits mailing list