[Mlir-commits] [mlir] [mlir] Support emit fp16 and bf16 type to cpp (PR #105803)
Jianjian Guan
llvmlistbot at llvm.org
Tue Aug 27 00:23:28 PDT 2024
https://github.com/jacquesguan updated https://github.com/llvm/llvm-project/pull/105803
>From 5cd62855aa76472505f02912f1bd2c88ebc6e01d Mon Sep 17 00:00:00 2001
From: Jianjian GUAN <jacquesguan at me.com>
Date: Fri, 23 Aug 2024 16:57:00 +0800
Subject: [PATCH] [mlir] Support emit fp16 and bf16 type to cpp
---
mlir/docs/Dialects/emitc.md | 4 +++
mlir/lib/Dialect/EmitC/IR/EmitC.cpp | 5 ++++
mlir/lib/Target/Cpp/TranslateToCpp.cpp | 28 +++++++++++++++----
.../arith-to-emitc-unsupported.mlir | 21 +++++++-------
.../MemRefToEmitC/memref-to-emitc-failed.mlir | 4 +--
mlir/test/Target/Cpp/const.mlir | 8 ++++++
mlir/test/Target/Cpp/types.mlir | 4 +++
7 files changed, 55 insertions(+), 19 deletions(-)
diff --git a/mlir/docs/Dialects/emitc.md b/mlir/docs/Dialects/emitc.md
index 4b0394606e4a24..743d70959f3d8e 100644
--- a/mlir/docs/Dialects/emitc.md
+++ b/mlir/docs/Dialects/emitc.md
@@ -12,6 +12,10 @@ The following convention is followed:
operation, C++20 is required.
* If `ssize_t` is used, then the code requires the POSIX header `sys/types.h`
or any of the C++ headers in which the type is defined.
+* If `_Float16` is used, the code requires the support of C additional
+ floating types.
+* If `__bf16` is used, the code requires a compiler that supports it, such as
+ GCC or Clang.
* Else the generated code is compatible with C99.
These restrictions are neither inherent to the EmitC dialect itself nor to the
diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
index e6f1618cc26116..fdc21d6c6e24b9 100644
--- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
+++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
@@ -116,6 +116,11 @@ bool mlir::emitc::isIntegerIndexOrOpaqueType(Type type) {
bool mlir::emitc::isSupportedFloatType(Type type) {
if (auto floatType = llvm::dyn_cast<FloatType>(type)) {
switch (floatType.getWidth()) {
+ case 16: {
+ if (llvm::isa<Float16Type, BFloat16Type>(type))
+ return true;
+ return false;
+ }
case 32:
case 64:
return true;
diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
index c043582b7be9c6..30657d8fccb154 100644
--- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp
+++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
@@ -1258,6 +1258,12 @@ LogicalResult CppEmitter::emitAttribute(Location loc, Attribute attr) {
val.toString(strValue, 0, 0, false);
os << strValue;
switch (llvm::APFloatBase::SemanticsToEnum(val.getSemantics())) {
+ case llvm::APFloatBase::S_IEEEhalf:
+ os << "f16";
+ break;
+ case llvm::APFloatBase::S_BFloat:
+ os << "bf16";
+ break;
case llvm::APFloatBase::S_IEEEsingle:
os << "f";
break;
@@ -1277,17 +1283,19 @@ LogicalResult CppEmitter::emitAttribute(Location loc, Attribute attr) {
// Print floating point attributes.
if (auto fAttr = dyn_cast<FloatAttr>(attr)) {
- if (!isa<Float32Type, Float64Type>(fAttr.getType())) {
- return emitError(loc,
- "expected floating point attribute to be f32 or f64");
+ if (!isa<Float16Type, BFloat16Type, Float32Type, Float64Type>(
+ fAttr.getType())) {
+ return emitError(
+ loc, "expected floating point attribute to be f16, bf16, f32 or f64");
}
printFloat(fAttr.getValue());
return success();
}
if (auto dense = dyn_cast<DenseFPElementsAttr>(attr)) {
- if (!isa<Float32Type, Float64Type>(dense.getElementType())) {
- return emitError(loc,
- "expected floating point attribute to be f32 or f64");
+ if (!isa<Float16Type, BFloat16Type, Float32Type, Float64Type>(
+ dense.getElementType())) {
+ return emitError(
+ loc, "expected floating point attribute to be f16, bf16, f32 or f64");
}
os << '{';
interleaveComma(dense, os, [&](const APFloat &val) { printFloat(val); });
@@ -1640,6 +1648,14 @@ LogicalResult CppEmitter::emitType(Location loc, Type type) {
}
if (auto fType = dyn_cast<FloatType>(type)) {
switch (fType.getWidth()) {
+ case 16: {
+ if (llvm::isa<Float16Type>(type))
+ return (os << "_Float16"), success();
+ else if (llvm::isa<BFloat16Type>(type))
+ return (os << "__bf16"), success();
+ else
+ return emitError(loc, "cannot emit float type ") << type;
+ }
case 32:
return (os << "float"), success();
case 64:
diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir
index ef0e71ee8673b7..b86690461dc269 100644
--- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir
+++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir
@@ -15,36 +15,35 @@ func.func @arith_cast_vector(%arg0: vector<5xf32>) -> vector<5xi32> {
}
// -----
-
-func.func @arith_cast_bf16(%arg0: bf16) -> i32 {
+func.func @arith_cast_f80(%arg0: f80) -> i32 {
// expected-error @+1 {{failed to legalize operation 'arith.fptosi'}}
- %t = arith.fptosi %arg0 : bf16 to i32
+ %t = arith.fptosi %arg0 : f80 to i32
return %t: i32
}
// -----
-func.func @arith_cast_f16(%arg0: f16) -> i32 {
+func.func @arith_cast_f128(%arg0: f128) -> i32 {
// expected-error @+1 {{failed to legalize operation 'arith.fptosi'}}
- %t = arith.fptosi %arg0 : f16 to i32
+ %t = arith.fptosi %arg0 : f128 to i32
return %t: i32
}
// -----
-func.func @arith_cast_to_bf16(%arg0: i32) -> bf16 {
+func.func @arith_cast_to_f80(%arg0: i32) -> f80 {
// expected-error @+1 {{failed to legalize operation 'arith.sitofp'}}
- %t = arith.sitofp %arg0 : i32 to bf16
- return %t: bf16
+ %t = arith.sitofp %arg0 : i32 to f80
+ return %t: f80
}
// -----
-func.func @arith_cast_to_f16(%arg0: i32) -> f16 {
+func.func @arith_cast_to_f128(%arg0: i32) -> f128 {
// expected-error @+1 {{failed to legalize operation 'arith.sitofp'}}
- %t = arith.sitofp %arg0 : i32 to f16
- return %t: f16
+ %t = arith.sitofp %arg0 : i32 to f128
+ return %t: f128
}
// -----
diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-failed.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-failed.mlir
index 836d8aedefc1f0..dee9cc97a14493 100644
--- a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-failed.mlir
+++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-failed.mlir
@@ -46,9 +46,9 @@ memref.global "nested" constant @nested_global : memref<3x7xf32>
// -----
-func.func @unsupported_type_f16() {
+func.func @unsupported_type_f128() {
// expected-error at +1 {{failed to legalize operation 'memref.alloca'}}
- %0 = memref.alloca() : memref<4xf16>
+ %0 = memref.alloca() : memref<4xf128>
return
}
diff --git a/mlir/test/Target/Cpp/const.mlir b/mlir/test/Target/Cpp/const.mlir
index 3658455d669438..d3656f830c48c3 100644
--- a/mlir/test/Target/Cpp/const.mlir
+++ b/mlir/test/Target/Cpp/const.mlir
@@ -11,6 +11,8 @@ func.func @emitc_constant() {
%c6 = "emitc.constant"(){value = 2 : index} : () -> index
%c7 = "emitc.constant"(){value = 2.0 : f32} : () -> f32
%f64 = "emitc.constant"(){value = 4.0 : f64} : () -> f64
+ %f16 = "emitc.constant"(){value = 2.0 : f16} : () -> f16
+ %bf16 = "emitc.constant"(){value = 4.0 : bf16} : () -> bf16
%c8 = "emitc.constant"(){value = dense<0> : tensor<i32>} : () -> tensor<i32>
%c9 = "emitc.constant"(){value = dense<[0, 1]> : tensor<2xindex>} : () -> tensor<2xindex>
%c10 = "emitc.constant"(){value = dense<[[0.0, 1.0], [2.0, 3.0]]> : tensor<2x2xf32>} : () -> tensor<2x2xf32>
@@ -26,6 +28,8 @@ func.func @emitc_constant() {
// CPP-DEFAULT-NEXT: size_t [[V6:[^ ]*]] = 2;
// CPP-DEFAULT-NEXT: float [[V7:[^ ]*]] = 2.000000000e+00f;
// CPP-DEFAULT-NEXT: double [[F64:[^ ]*]] = 4.00000000000000000e+00;
+// CPP-DEFAULT-NEXT: _Float16 [[F16:[^ ]*]] = 2.00000e+00f16;
+// CPP-DEFAULT-NEXT: __bf16 [[BF16:[^ ]*]] = 4.0000e+00bf16;
// CPP-DEFAULT-NEXT: Tensor<int32_t> [[V8:[^ ]*]] = {0};
// CPP-DEFAULT-NEXT: Tensor<size_t, 2> [[V9:[^ ]*]] = {0, 1};
// CPP-DEFAULT-NEXT: Tensor<float, 2, 2> [[V10:[^ ]*]] = {0.0e+00f, 1.000000000e+00f, 2.000000000e+00f, 3.000000000e+00f};
@@ -40,6 +44,8 @@ func.func @emitc_constant() {
// CPP-DECLTOP-NEXT: size_t [[V6:[^ ]*]];
// CPP-DECLTOP-NEXT: float [[V7:[^ ]*]];
// CPP-DECLTOP-NEXT: double [[F64:[^ ]*]];
+// CPP-DECLTOP-NEXT: _Float16 [[F16:[^ ]*]];
+// CPP-DECLTOP-NEXT: __bf16 [[BF16:[^ ]*]];
// CPP-DECLTOP-NEXT: Tensor<int32_t> [[V8:[^ ]*]];
// CPP-DECLTOP-NEXT: Tensor<size_t, 2> [[V9:[^ ]*]];
// CPP-DECLTOP-NEXT: Tensor<float, 2, 2> [[V10:[^ ]*]];
@@ -52,6 +58,8 @@ func.func @emitc_constant() {
// CPP-DECLTOP-NEXT: [[V6]] = 2;
// CPP-DECLTOP-NEXT: [[V7]] = 2.000000000e+00f;
// CPP-DECLTOP-NEXT: [[F64]] = 4.00000000000000000e+00;
+// CPP-DECLTOP-NEXT: [[F16]] = 2.00000e+00f16;
+// CPP-DECLTOP-NEXT: [[BF16]] = 4.0000e+00bf16;
// CPP-DECLTOP-NEXT: [[V8]] = {0};
// CPP-DECLTOP-NEXT: [[V9]] = {0, 1};
// CPP-DECLTOP-NEXT: [[V10]] = {0.0e+00f, 1.000000000e+00f, 2.000000000e+00f, 3.000000000e+00f};
diff --git a/mlir/test/Target/Cpp/types.mlir b/mlir/test/Target/Cpp/types.mlir
index deda383b3b0a72..e7f935c7374382 100644
--- a/mlir/test/Target/Cpp/types.mlir
+++ b/mlir/test/Target/Cpp/types.mlir
@@ -22,6 +22,10 @@ func.func @ptr_types() {
emitc.call_opaque "f"() {template_args = [!emitc.ptr<i32>]} : () -> ()
// CHECK-NEXT: f<int64_t*>();
emitc.call_opaque "f"() {template_args = [!emitc.ptr<i64>]} : () -> ()
+ // CHECK-NEXT: f<_Float16*>();
+ emitc.call_opaque "f"() {template_args = [!emitc.ptr<f16>]} : () -> ()
+ // CHECK-NEXT: f<__bf16*>();
+ emitc.call_opaque "f"() {template_args = [!emitc.ptr<bf16>]} : () -> ()
// CHECK-NEXT: f<float*>();
emitc.call_opaque "f"() {template_args = [!emitc.ptr<f32>]} : () -> ()
// CHECK-NEXT: f<double*>();
More information about the Mlir-commits
mailing list