[Mlir-commits] [mlir] [mlir] Support emit fp16 and bf16 type to cpp (PR #105803)

Jianjian Guan llvmlistbot at llvm.org
Sun Aug 25 23:39:06 PDT 2024


https://github.com/jacquesguan updated https://github.com/llvm/llvm-project/pull/105803

>From 488bd1260924e4219258854e947087f196888a8b Mon Sep 17 00:00:00 2001
From: Jianjian GUAN <jacquesguan at me.com>
Date: Fri, 23 Aug 2024 16:57:00 +0800
Subject: [PATCH] [mlir] Support emit fp16 and bf16 type to cpp

---
 mlir/docs/Dialects/emitc.md                   |  4 +++
 mlir/lib/Dialect/EmitC/IR/EmitC.cpp           |  5 ++++
 mlir/lib/Target/Cpp/TranslateToCpp.cpp        | 28 +++++++++++++++----
 .../arith-to-emitc-unsupported.mlir           | 21 +++++++-------
 .../MemRefToEmitC/memref-to-emitc-failed.mlir |  8 ------
 mlir/test/Target/Cpp/const.mlir               |  8 ++++++
 mlir/test/Target/Cpp/types.mlir               |  4 +++
 7 files changed, 53 insertions(+), 25 deletions(-)

diff --git a/mlir/docs/Dialects/emitc.md b/mlir/docs/Dialects/emitc.md
index 4b0394606e4a24..d0a2ae15ec13c0 100644
--- a/mlir/docs/Dialects/emitc.md
+++ b/mlir/docs/Dialects/emitc.md
@@ -12,6 +12,10 @@ The following convention is followed:
     operation, C++20 is required.
 *   If `ssize_t` is used, then the code requires the POSIX header `sys/types.h`
     or any of the C++ headers in which the type is defined.
+*   If `_Float16` is used, the code requires the support of C additional
+    floating types.
+*   If `__bf16` is used, the code requires complier that supports it, such as:
+    GCC and Clang.
 *   Else the generated code is compatible with C99.
 
 These restrictions are neither inherent to the EmitC dialect itself nor to the
diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
index e6f1618cc26116..fdc21d6c6e24b9 100644
--- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
+++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
@@ -116,6 +116,11 @@ bool mlir::emitc::isIntegerIndexOrOpaqueType(Type type) {
 bool mlir::emitc::isSupportedFloatType(Type type) {
   if (auto floatType = llvm::dyn_cast<FloatType>(type)) {
     switch (floatType.getWidth()) {
+    case 16: {
+      if (llvm::isa<Float16Type, BFloat16Type>(type))
+        return true;
+      return false;
+    }
     case 32:
     case 64:
       return true;
diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
index c043582b7be9c6..30657d8fccb154 100644
--- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp
+++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
@@ -1258,6 +1258,12 @@ LogicalResult CppEmitter::emitAttribute(Location loc, Attribute attr) {
       val.toString(strValue, 0, 0, false);
       os << strValue;
       switch (llvm::APFloatBase::SemanticsToEnum(val.getSemantics())) {
+      case llvm::APFloatBase::S_IEEEhalf:
+        os << "f16";
+        break;
+      case llvm::APFloatBase::S_BFloat:
+        os << "bf16";
+        break;
       case llvm::APFloatBase::S_IEEEsingle:
         os << "f";
         break;
@@ -1277,17 +1283,19 @@ LogicalResult CppEmitter::emitAttribute(Location loc, Attribute attr) {
 
   // Print floating point attributes.
   if (auto fAttr = dyn_cast<FloatAttr>(attr)) {
-    if (!isa<Float32Type, Float64Type>(fAttr.getType())) {
-      return emitError(loc,
-                       "expected floating point attribute to be f32 or f64");
+    if (!isa<Float16Type, BFloat16Type, Float32Type, Float64Type>(
+            fAttr.getType())) {
+      return emitError(
+          loc, "expected floating point attribute to be f16, bf16, f32 or f64");
     }
     printFloat(fAttr.getValue());
     return success();
   }
   if (auto dense = dyn_cast<DenseFPElementsAttr>(attr)) {
-    if (!isa<Float32Type, Float64Type>(dense.getElementType())) {
-      return emitError(loc,
-                       "expected floating point attribute to be f32 or f64");
+    if (!isa<Float16Type, BFloat16Type, Float32Type, Float64Type>(
+            dense.getElementType())) {
+      return emitError(
+          loc, "expected floating point attribute to be f16, bf16, f32 or f64");
     }
     os << '{';
     interleaveComma(dense, os, [&](const APFloat &val) { printFloat(val); });
@@ -1640,6 +1648,14 @@ LogicalResult CppEmitter::emitType(Location loc, Type type) {
   }
   if (auto fType = dyn_cast<FloatType>(type)) {
     switch (fType.getWidth()) {
+    case 16: {
+      if (llvm::isa<Float16Type>(type))
+        return (os << "_Float16"), success();
+      else if (llvm::isa<BFloat16Type>(type))
+        return (os << "__bf16"), success();
+      else
+        return emitError(loc, "cannot emit float type ") << type;
+    }
     case 32:
       return (os << "float"), success();
     case 64:
diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir
index ef0e71ee8673b7..b86690461dc269 100644
--- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir
+++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir
@@ -15,36 +15,35 @@ func.func @arith_cast_vector(%arg0: vector<5xf32>) -> vector<5xi32> {
 }
 
 // -----
-
-func.func @arith_cast_bf16(%arg0: bf16) -> i32 {
+func.func @arith_cast_f80(%arg0: f80) -> i32 {
   // expected-error @+1 {{failed to legalize operation 'arith.fptosi'}}
-  %t = arith.fptosi %arg0 : bf16 to i32
+  %t = arith.fptosi %arg0 : f80 to i32
   return %t: i32
 }
 
 // -----
 
-func.func @arith_cast_f16(%arg0: f16) -> i32 {
+func.func @arith_cast_f128(%arg0: f128) -> i32 {
   // expected-error @+1 {{failed to legalize operation 'arith.fptosi'}}
-  %t = arith.fptosi %arg0 : f16 to i32
+  %t = arith.fptosi %arg0 : f128 to i32
   return %t: i32
 }
 
 
 // -----
 
-func.func @arith_cast_to_bf16(%arg0: i32) -> bf16 {
+func.func @arith_cast_to_f80(%arg0: i32) -> f80 {
   // expected-error @+1 {{failed to legalize operation 'arith.sitofp'}}
-  %t = arith.sitofp %arg0 : i32 to bf16
-  return %t: bf16
+  %t = arith.sitofp %arg0 : i32 to f80
+  return %t: f80
 }
 
 // -----
 
-func.func @arith_cast_to_f16(%arg0: i32) -> f16 {
+func.func @arith_cast_to_f128(%arg0: i32) -> f128 {
   // expected-error @+1 {{failed to legalize operation 'arith.sitofp'}}
-  %t = arith.sitofp %arg0 : i32 to f16
-  return %t: f16
+  %t = arith.sitofp %arg0 : i32 to f128
+  return %t: f128
 }
 
 // -----
diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-failed.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-failed.mlir
index 836d8aedefc1f0..14977bfb3e2fd9 100644
--- a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-failed.mlir
+++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-failed.mlir
@@ -46,14 +46,6 @@ memref.global "nested" constant @nested_global : memref<3x7xf32>
 
 // -----
 
-func.func @unsupported_type_f16() {
-  // expected-error at +1 {{failed to legalize operation 'memref.alloca'}}
-  %0 = memref.alloca() : memref<4xf16>
-  return
-}
-
-// -----
-
 func.func @unsupported_type_i4() {
   // expected-error at +1 {{failed to legalize operation 'memref.alloca'}}
   %0 = memref.alloca() : memref<4xi4>
diff --git a/mlir/test/Target/Cpp/const.mlir b/mlir/test/Target/Cpp/const.mlir
index 3658455d669438..d3656f830c48c3 100644
--- a/mlir/test/Target/Cpp/const.mlir
+++ b/mlir/test/Target/Cpp/const.mlir
@@ -11,6 +11,8 @@ func.func @emitc_constant() {
   %c6 = "emitc.constant"(){value = 2 : index} : () -> index
   %c7 = "emitc.constant"(){value = 2.0 : f32} : () -> f32
   %f64 = "emitc.constant"(){value = 4.0 : f64} : () -> f64
+  %f16 = "emitc.constant"(){value = 2.0 : f16} : () -> f16
+  %bf16 = "emitc.constant"(){value = 4.0 : bf16} : () -> bf16
   %c8 = "emitc.constant"(){value = dense<0> : tensor<i32>} : () -> tensor<i32>
   %c9 = "emitc.constant"(){value = dense<[0, 1]> : tensor<2xindex>} : () -> tensor<2xindex>
   %c10 = "emitc.constant"(){value = dense<[[0.0, 1.0], [2.0, 3.0]]> : tensor<2x2xf32>} : () -> tensor<2x2xf32>
@@ -26,6 +28,8 @@ func.func @emitc_constant() {
 // CPP-DEFAULT-NEXT: size_t [[V6:[^ ]*]] = 2;
 // CPP-DEFAULT-NEXT: float [[V7:[^ ]*]] = 2.000000000e+00f;
 // CPP-DEFAULT-NEXT: double [[F64:[^ ]*]] = 4.00000000000000000e+00;
+// CPP-DEFAULT-NEXT: _Float16 [[F16:[^ ]*]] = 2.00000e+00f16;
+// CPP-DEFAULT-NEXT: __bf16 [[BF16:[^ ]*]] = 4.0000e+00bf16;
 // CPP-DEFAULT-NEXT: Tensor<int32_t> [[V8:[^ ]*]] = {0};
 // CPP-DEFAULT-NEXT: Tensor<size_t, 2> [[V9:[^ ]*]] = {0, 1};
 // CPP-DEFAULT-NEXT: Tensor<float, 2, 2> [[V10:[^ ]*]] = {0.0e+00f, 1.000000000e+00f, 2.000000000e+00f, 3.000000000e+00f};
@@ -40,6 +44,8 @@ func.func @emitc_constant() {
 // CPP-DECLTOP-NEXT: size_t [[V6:[^ ]*]];
 // CPP-DECLTOP-NEXT: float [[V7:[^ ]*]];
 // CPP-DECLTOP-NEXT: double [[F64:[^ ]*]];
+// CPP-DECLTOP-NEXT: _Float16 [[F16:[^ ]*]];
+// CPP-DECLTOP-NEXT: __bf16 [[BF16:[^ ]*]];
 // CPP-DECLTOP-NEXT: Tensor<int32_t> [[V8:[^ ]*]];
 // CPP-DECLTOP-NEXT: Tensor<size_t, 2> [[V9:[^ ]*]];
 // CPP-DECLTOP-NEXT: Tensor<float, 2, 2> [[V10:[^ ]*]];
@@ -52,6 +58,8 @@ func.func @emitc_constant() {
 // CPP-DECLTOP-NEXT: [[V6]] = 2;
 // CPP-DECLTOP-NEXT: [[V7]] = 2.000000000e+00f;
 // CPP-DECLTOP-NEXT: [[F64]] = 4.00000000000000000e+00;
+// CPP-DECLTOP-NEXT: [[F16]] = 2.00000e+00f16;
+// CPP-DECLTOP-NEXT: [[BF16]] = 4.0000e+00bf16;
 // CPP-DECLTOP-NEXT: [[V8]] = {0};
 // CPP-DECLTOP-NEXT: [[V9]] = {0, 1};
 // CPP-DECLTOP-NEXT: [[V10]] = {0.0e+00f, 1.000000000e+00f, 2.000000000e+00f, 3.000000000e+00f};
diff --git a/mlir/test/Target/Cpp/types.mlir b/mlir/test/Target/Cpp/types.mlir
index deda383b3b0a72..e7f935c7374382 100644
--- a/mlir/test/Target/Cpp/types.mlir
+++ b/mlir/test/Target/Cpp/types.mlir
@@ -22,6 +22,10 @@ func.func @ptr_types() {
   emitc.call_opaque "f"() {template_args = [!emitc.ptr<i32>]} : () -> ()
   // CHECK-NEXT: f<int64_t*>();
   emitc.call_opaque "f"() {template_args = [!emitc.ptr<i64>]} : () -> ()
+  // CHECK-NEXT: f<_Float16*>();
+  emitc.call_opaque "f"() {template_args = [!emitc.ptr<f16>]} : () -> ()
+  // CHECK-NEXT: f<__bf16*>();
+  emitc.call_opaque "f"() {template_args = [!emitc.ptr<bf16>]} : () -> ()
   // CHECK-NEXT: f<float*>();
   emitc.call_opaque "f"() {template_args = [!emitc.ptr<f32>]} : () -> ()
   // CHECK-NEXT: f<double*>();



More information about the Mlir-commits mailing list