[Mlir-commits] [mlir] 657f60a - [mlir][vector] add support for printing f16 and bf16
Aart Bik
llvmlistbot at llvm.org
Fri Mar 3 08:58:34 PST 2023
Author: Aart Bik
Date: 2023-03-03T08:58:25-08:00
New Revision: 657f60a07b2d382dd8580dd8a6111ea5c6e2d889
URL: https://github.com/llvm/llvm-project/commit/657f60a07b2d382dd8580dd8a6111ea5c6e2d889
DIFF: https://github.com/llvm/llvm-project/commit/657f60a07b2d382dd8580dd8a6111ea5c6e2d889.diff
LOG: [mlir][vector] add support for printing f16 and bf16
Love or hate it, but the vector.print operation was the very
first operation that actually made "end-to-end" CHECK integration
testing possible for MLIR. This revision adds support for
the -until recently- less common but important floating-point
types f16 and bf16.
This will become useful for accelerator specific testing (e.g. NVidia GPUs)
Reviewed By: wrengr
Differential Revision: https://reviews.llvm.org/D145207
Added:
mlir/test/Integration/Dialect/Vector/CPU/test-print-fp.mlir
Modified:
mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
mlir/include/mlir/ExecutionEngine/CRunnerUtils.h
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
mlir/lib/ExecutionEngine/Float16bits.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
index 39b35fedcac68..17aa9a3c831c2 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
@@ -34,6 +34,8 @@ class LLVMFuncOp;
/// of the libc).
LLVM::LLVMFuncOp lookupOrCreatePrintI64Fn(ModuleOp moduleOp);
LLVM::LLVMFuncOp lookupOrCreatePrintU64Fn(ModuleOp moduleOp);
+LLVM::LLVMFuncOp lookupOrCreatePrintF16Fn(ModuleOp moduleOp);
+LLVM::LLVMFuncOp lookupOrCreatePrintBF16Fn(ModuleOp moduleOp);
LLVM::LLVMFuncOp lookupOrCreatePrintF32Fn(ModuleOp moduleOp);
LLVM::LLVMFuncOp lookupOrCreatePrintF64Fn(ModuleOp moduleOp);
LLVM::LLVMFuncOp lookupOrCreatePrintStrFn(ModuleOp moduleOp,
diff --git a/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h b/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h
index e7798b2136af0..7b7e894421b40 100644
--- a/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h
+++ b/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h
@@ -469,6 +469,8 @@ extern "C" MLIR_CRUNNERUTILS_EXPORT void printOpen();
extern "C" MLIR_CRUNNERUTILS_EXPORT void printClose();
extern "C" MLIR_CRUNNERUTILS_EXPORT void printComma();
extern "C" MLIR_CRUNNERUTILS_EXPORT void printNewline();
+extern "C" MLIR_CRUNNERUTILS_EXPORT void printF16(uint16_t bits); // bits!
+extern "C" MLIR_CRUNNERUTILS_EXPORT void printBF16(uint16_t bits); // bits!
//===----------------------------------------------------------------------===//
// Small runtime support library for timing execution and printing GFLOPS
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index d1b78bf626973..f705284845a2d 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1466,16 +1466,20 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
PrintConversion conversion = PrintConversion::None;
VectorType vectorType = printType.dyn_cast<VectorType>();
Type eltType = vectorType ? vectorType.getElementType() : printType;
+ auto parent = printOp->getParentOfType<ModuleOp>();
Operation *printer;
if (eltType.isF32()) {
- printer =
- LLVM::lookupOrCreatePrintF32Fn(printOp->getParentOfType<ModuleOp>());
+ printer = LLVM::lookupOrCreatePrintF32Fn(parent);
} else if (eltType.isF64()) {
- printer =
- LLVM::lookupOrCreatePrintF64Fn(printOp->getParentOfType<ModuleOp>());
+ printer = LLVM::lookupOrCreatePrintF64Fn(parent);
+ } else if (eltType.isF16()) {
+ conversion = PrintConversion::Bitcast16; // bits!
+ printer = LLVM::lookupOrCreatePrintF16Fn(parent);
+ } else if (eltType.isBF16()) {
+ conversion = PrintConversion::Bitcast16; // bits!
+ printer = LLVM::lookupOrCreatePrintBF16Fn(parent);
} else if (eltType.isIndex()) {
- printer =
- LLVM::lookupOrCreatePrintU64Fn(printOp->getParentOfType<ModuleOp>());
+ printer = LLVM::lookupOrCreatePrintU64Fn(parent);
} else if (auto intTy = eltType.dyn_cast<IntegerType>()) {
// Integers need a zero or sign extension on the operand
// (depending on the source type) as well as a signed or
@@ -1485,8 +1489,7 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
if (width <= 64) {
if (width < 64)
conversion = PrintConversion::ZeroExt64;
- printer = LLVM::lookupOrCreatePrintU64Fn(
- printOp->getParentOfType<ModuleOp>());
+ printer = LLVM::lookupOrCreatePrintU64Fn(parent);
} else {
return failure();
}
@@ -1499,8 +1502,7 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
conversion = PrintConversion::ZeroExt64;
else if (width < 64)
conversion = PrintConversion::SignExt64;
- printer = LLVM::lookupOrCreatePrintI64Fn(
- printOp->getParentOfType<ModuleOp>());
+ printer = LLVM::lookupOrCreatePrintI64Fn(parent);
} else {
return failure();
}
@@ -1515,8 +1517,7 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
emitRanks(rewriter, printOp, adaptor.getSource(), type, printer, rank,
conversion);
emitCall(rewriter, printOp->getLoc(),
- LLVM::lookupOrCreatePrintNewlineFn(
- printOp->getParentOfType<ModuleOp>()));
+ LLVM::lookupOrCreatePrintNewlineFn(parent));
rewriter.eraseOp(printOp);
return success();
}
@@ -1526,7 +1527,8 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
// clang-format off
None,
ZeroExt64,
- SignExt64
+ SignExt64,
+ Bitcast16
// clang-format on
};
@@ -1546,6 +1548,10 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
value = rewriter.create<arith::ExtSIOp>(
loc, IntegerType::get(rewriter.getContext(), 64), value);
break;
+ case PrintConversion::Bitcast16:
+ value = rewriter.create<LLVM::BitcastOp>(
+ loc, IntegerType::get(rewriter.getContext(), 16), value);
+ break;
case PrintConversion::None:
break;
}
@@ -1553,10 +1559,9 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
return;
}
- emitCall(rewriter, loc,
- LLVM::lookupOrCreatePrintOpenFn(op->getParentOfType<ModuleOp>()));
- Operation *printComma =
- LLVM::lookupOrCreatePrintCommaFn(op->getParentOfType<ModuleOp>());
+ auto parent = op->getParentOfType<ModuleOp>();
+ emitCall(rewriter, loc, LLVM::lookupOrCreatePrintOpenFn(parent));
+ Operation *printComma = LLVM::lookupOrCreatePrintCommaFn(parent);
if (rank <= 1) {
auto reducedType = vectorType.getElementType();
@@ -1570,9 +1575,7 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
if (d != dim - 1)
emitCall(rewriter, loc, printComma);
}
- emitCall(
- rewriter, loc,
- LLVM::lookupOrCreatePrintCloseFn(op->getParentOfType<ModuleOp>()));
+ emitCall(rewriter, loc, LLVM::lookupOrCreatePrintCloseFn(parent));
return;
}
@@ -1587,8 +1590,7 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
if (d != dim - 1)
emitCall(rewriter, loc, printComma);
}
- emitCall(rewriter, loc,
- LLVM::lookupOrCreatePrintCloseFn(op->getParentOfType<ModuleOp>()));
+ emitCall(rewriter, loc, LLVM::lookupOrCreatePrintCloseFn(parent));
}
// Helper to emit a call.
diff --git a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
index 237e576a1d9ef..aef3a5a87e9bf 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
@@ -26,6 +26,8 @@ using namespace mlir::LLVM;
/// part of the libc).
static constexpr llvm::StringRef kPrintI64 = "printI64";
static constexpr llvm::StringRef kPrintU64 = "printU64";
+static constexpr llvm::StringRef kPrintF16 = "printF16";
+static constexpr llvm::StringRef kPrintBF16 = "printBF16";
static constexpr llvm::StringRef kPrintF32 = "printF32";
static constexpr llvm::StringRef kPrintF64 = "printF64";
static constexpr llvm::StringRef kPrintStr = "puts";
@@ -67,6 +69,18 @@ LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintU64Fn(ModuleOp moduleOp) {
LLVM::LLVMVoidType::get(moduleOp->getContext()));
}
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF16Fn(ModuleOp moduleOp) {
+ return lookupOrCreateFn(moduleOp, kPrintF16,
+ IntegerType::get(moduleOp->getContext(), 16), // bits!
+ LLVM::LLVMVoidType::get(moduleOp->getContext()));
+}
+
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintBF16Fn(ModuleOp moduleOp) {
+ return lookupOrCreateFn(moduleOp, kPrintBF16,
+ IntegerType::get(moduleOp->getContext(), 16), // bits!
+ LLVM::LLVMVoidType::get(moduleOp->getContext()));
+}
+
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF32Fn(ModuleOp moduleOp) {
return lookupOrCreateFn(moduleOp, kPrintF32,
Float32Type::get(moduleOp->getContext()),
diff --git a/mlir/lib/ExecutionEngine/Float16bits.cpp b/mlir/lib/ExecutionEngine/Float16bits.cpp
index 189286b9f13a8..38a05fe86bbdd 100644
--- a/mlir/lib/ExecutionEngine/Float16bits.cpp
+++ b/mlir/lib/ExecutionEngine/Float16bits.cpp
@@ -192,4 +192,16 @@ extern "C" BF16ABIType ATTR_WEAK __truncdfbf2(double d) {
return __truncsfbf2(static_cast<float>(d));
}
+// Provide these to the CRunner with the local float16 knowledge.
+extern "C" void printF16(uint16_t bits) {
+ f16 f;
+ std::memcpy(&f, &bits, sizeof(f16));
+ std::cout << f;
+}
+extern "C" void printBF16(uint16_t bits) {
+ bf16 f;
+ std::memcpy(&f, &bits, sizeof(bf16));
+ std::cout << f;
+}
+
#endif // MLIR_FLOAT16_DEFINE_FUNCTIONS
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-print-fp.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-print-fp.mlir
new file mode 100644
index 0000000000000..eeee363d246f3
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/test-print-fp.mlir
@@ -0,0 +1,27 @@
+// RUN: mlir-opt %s -convert-scf-to-cf -convert-vector-to-llvm -convert-func-to-llvm -reconcile-unrealized-casts | \
+// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
+// RUN: -shared-libs=%mlir_c_runner_utils | \
+// RUN: FileCheck %s
+
+//
+// Test various floating-point types.
+//
+func.func @entry() {
+ %0 = arith.constant dense<[-1000.0, -1.1, 0.0, 1.1, 1000.0]> : vector<5xf64>
+ vector.print %0 : vector<5xf64>
+ // CHECK: ( -1000, -1.1, 0, 1.1, 1000 )
+
+ %1 = arith.constant dense<[-1000.0, -1.1, 0.0, 1.1, 1000.0]> : vector<5xf32>
+ vector.print %1 : vector<5xf32>
+ // CHECK: ( -1000, -1.1, 0, 1.1, 1000 )
+
+ %2 = arith.constant dense<[-1000.0, -1.1, 0.0, 1.1, 1000.0]> : vector<5xf16>
+ vector.print %2 : vector<5xf16>
+ // CHECK: ( -1000, -1.09961, 0, 1.09961, 1000 )
+
+ %3 = arith.constant dense<[-1000.0, -1.1, 0.0, 1.1, 1000.0]> : vector<5xbf16>
+ vector.print %3 : vector<5xbf16>
+ // CHECK: ( -1000, -1.10156, 0, 1.10156, 1000 )
+
+ return
+}
More information about the Mlir-commits
mailing list