[Mlir-commits] [mlir] [mlir] Support emit fp16 and bf16 type to cpp (PR #105803)
Jianjian Guan
llvmlistbot at llvm.org
Fri Aug 23 01:58:54 PDT 2024
https://github.com/jacquesguan created https://github.com/llvm/llvm-project/pull/105803
None
>From 19f7d82a59be6511e3a00a2d355d5348faea6e7d Mon Sep 17 00:00:00 2001
From: Jianjian GUAN <jacquesguan at me.com>
Date: Fri, 23 Aug 2024 16:57:00 +0800
Subject: [PATCH] [mlir] Support emit fp16 and bf16 type to cpp
---
mlir/lib/Dialect/EmitC/IR/EmitC.cpp | 1 +
mlir/lib/Target/Cpp/TranslateToCpp.cpp | 5 +++
.../arith-to-emitc-unsupported.mlir | 33 -------------------
.../MemRefToEmitC/memref-to-emitc-failed.mlir | 8 -----
mlir/test/Target/Cpp/types.mlir | 4 +++
5 files changed, 10 insertions(+), 41 deletions(-)
diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
index e6f1618cc26116..8555e82002d56b 100644
--- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
+++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
@@ -116,6 +116,7 @@ bool mlir::emitc::isIntegerIndexOrOpaqueType(Type type) {
bool mlir::emitc::isSupportedFloatType(Type type) {
if (auto floatType = llvm::dyn_cast<FloatType>(type)) {
switch (floatType.getWidth()) {
+ case 16:
case 32:
case 64:
return true;
diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
index c043582b7be9c6..aa45e7c9d7f757 100644
--- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp
+++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
@@ -1640,6 +1640,11 @@ LogicalResult CppEmitter::emitType(Location loc, Type type) {
}
if (auto fType = dyn_cast<FloatType>(type)) {
switch (fType.getWidth()) {
+ case 16:
+ if (llvm::isa<Float16Type>(type))
+ return (os << "_Float16"), success();
+ else
+ return (os << "__bf16"), success();
case 32:
return (os << "float"), success();
case 64:
diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir
index ef0e71ee8673b7..b3eebaf8a1ef1e 100644
--- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir
+++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir
@@ -16,39 +16,6 @@ func.func @arith_cast_vector(%arg0: vector<5xf32>) -> vector<5xi32> {
// -----
-func.func @arith_cast_bf16(%arg0: bf16) -> i32 {
- // expected-error @+1 {{failed to legalize operation 'arith.fptosi'}}
- %t = arith.fptosi %arg0 : bf16 to i32
- return %t: i32
-}
-
-// -----
-
-func.func @arith_cast_f16(%arg0: f16) -> i32 {
- // expected-error @+1 {{failed to legalize operation 'arith.fptosi'}}
- %t = arith.fptosi %arg0 : f16 to i32
- return %t: i32
-}
-
-
-// -----
-
-func.func @arith_cast_to_bf16(%arg0: i32) -> bf16 {
- // expected-error @+1 {{failed to legalize operation 'arith.sitofp'}}
- %t = arith.sitofp %arg0 : i32 to bf16
- return %t: bf16
-}
-
-// -----
-
-func.func @arith_cast_to_f16(%arg0: i32) -> f16 {
- // expected-error @+1 {{failed to legalize operation 'arith.sitofp'}}
- %t = arith.sitofp %arg0 : i32 to f16
- return %t: f16
-}
-
-// -----
-
func.func @arith_cast_fptosi_i1(%arg0: f32) -> i1 {
// expected-error @+1 {{failed to legalize operation 'arith.fptosi'}}
%t = arith.fptosi %arg0 : f32 to i1
diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-failed.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-failed.mlir
index 836d8aedefc1f0..14977bfb3e2fd9 100644
--- a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-failed.mlir
+++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-failed.mlir
@@ -46,14 +46,6 @@ memref.global "nested" constant @nested_global : memref<3x7xf32>
// -----
-func.func @unsupported_type_f16() {
- // expected-error at +1 {{failed to legalize operation 'memref.alloca'}}
- %0 = memref.alloca() : memref<4xf16>
- return
-}
-
-// -----
-
func.func @unsupported_type_i4() {
// expected-error at +1 {{failed to legalize operation 'memref.alloca'}}
%0 = memref.alloca() : memref<4xi4>
diff --git a/mlir/test/Target/Cpp/types.mlir b/mlir/test/Target/Cpp/types.mlir
index deda383b3b0a72..e7f935c7374382 100644
--- a/mlir/test/Target/Cpp/types.mlir
+++ b/mlir/test/Target/Cpp/types.mlir
@@ -22,6 +22,10 @@ func.func @ptr_types() {
emitc.call_opaque "f"() {template_args = [!emitc.ptr<i32>]} : () -> ()
// CHECK-NEXT: f<int64_t*>();
emitc.call_opaque "f"() {template_args = [!emitc.ptr<i64>]} : () -> ()
+ // CHECK-NEXT: f<_Float16*>();
+ emitc.call_opaque "f"() {template_args = [!emitc.ptr<f16>]} : () -> ()
+ // CHECK-NEXT: f<__bf16*>();
+ emitc.call_opaque "f"() {template_args = [!emitc.ptr<bf16>]} : () -> ()
// CHECK-NEXT: f<float*>();
emitc.call_opaque "f"() {template_args = [!emitc.ptr<f32>]} : () -> ()
// CHECK-NEXT: f<double*>();
More information about the Mlir-commits
mailing list