[Mlir-commits] [mlir] 657f60a - [mlir][vector] add support for printing f16 and bf16

Aart Bik llvmlistbot at llvm.org
Fri Mar 3 08:58:34 PST 2023


Author: Aart Bik
Date: 2023-03-03T08:58:25-08:00
New Revision: 657f60a07b2d382dd8580dd8a6111ea5c6e2d889

URL: https://github.com/llvm/llvm-project/commit/657f60a07b2d382dd8580dd8a6111ea5c6e2d889
DIFF: https://github.com/llvm/llvm-project/commit/657f60a07b2d382dd8580dd8a6111ea5c6e2d889.diff

LOG: [mlir][vector] add support for printing f16 and bf16

Love or hate it, but the vector.print operation was the very
first operation that actually made "end-to-end" CHECK integration
testing possible for MLIR. This revision adds support for
the -until recently- less common but important floating-point
types f16 and bf16.

This will become useful for accelerator specific testing (e.g. NVidia GPUs)

Reviewed By: wrengr

Differential Revision: https://reviews.llvm.org/D145207

Added: 
    mlir/test/Integration/Dialect/Vector/CPU/test-print-fp.mlir

Modified: 
    mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
    mlir/include/mlir/ExecutionEngine/CRunnerUtils.h
    mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
    mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
    mlir/lib/ExecutionEngine/Float16bits.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
index 39b35fedcac68..17aa9a3c831c2 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
@@ -34,6 +34,8 @@ class LLVMFuncOp;
 /// of the libc).
 LLVM::LLVMFuncOp lookupOrCreatePrintI64Fn(ModuleOp moduleOp);
 LLVM::LLVMFuncOp lookupOrCreatePrintU64Fn(ModuleOp moduleOp);
+LLVM::LLVMFuncOp lookupOrCreatePrintF16Fn(ModuleOp moduleOp);
+LLVM::LLVMFuncOp lookupOrCreatePrintBF16Fn(ModuleOp moduleOp);
 LLVM::LLVMFuncOp lookupOrCreatePrintF32Fn(ModuleOp moduleOp);
 LLVM::LLVMFuncOp lookupOrCreatePrintF64Fn(ModuleOp moduleOp);
 LLVM::LLVMFuncOp lookupOrCreatePrintStrFn(ModuleOp moduleOp,

diff  --git a/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h b/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h
index e7798b2136af0..7b7e894421b40 100644
--- a/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h
+++ b/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h
@@ -469,6 +469,8 @@ extern "C" MLIR_CRUNNERUTILS_EXPORT void printOpen();
 extern "C" MLIR_CRUNNERUTILS_EXPORT void printClose();
 extern "C" MLIR_CRUNNERUTILS_EXPORT void printComma();
 extern "C" MLIR_CRUNNERUTILS_EXPORT void printNewline();
+extern "C" MLIR_CRUNNERUTILS_EXPORT void printF16(uint16_t bits);  // bits!
+extern "C" MLIR_CRUNNERUTILS_EXPORT void printBF16(uint16_t bits); // bits!
 
 //===----------------------------------------------------------------------===//
 // Small runtime support library for timing execution and printing GFLOPS

diff  --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index d1b78bf626973..f705284845a2d 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1466,16 +1466,20 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
     PrintConversion conversion = PrintConversion::None;
     VectorType vectorType = printType.dyn_cast<VectorType>();
     Type eltType = vectorType ? vectorType.getElementType() : printType;
+    auto parent = printOp->getParentOfType<ModuleOp>();
     Operation *printer;
     if (eltType.isF32()) {
-      printer =
-          LLVM::lookupOrCreatePrintF32Fn(printOp->getParentOfType<ModuleOp>());
+      printer = LLVM::lookupOrCreatePrintF32Fn(parent);
     } else if (eltType.isF64()) {
-      printer =
-          LLVM::lookupOrCreatePrintF64Fn(printOp->getParentOfType<ModuleOp>());
+      printer = LLVM::lookupOrCreatePrintF64Fn(parent);
+    } else if (eltType.isF16()) {
+      conversion = PrintConversion::Bitcast16; // bits!
+      printer = LLVM::lookupOrCreatePrintF16Fn(parent);
+    } else if (eltType.isBF16()) {
+      conversion = PrintConversion::Bitcast16; // bits!
+      printer = LLVM::lookupOrCreatePrintBF16Fn(parent);
     } else if (eltType.isIndex()) {
-      printer =
-          LLVM::lookupOrCreatePrintU64Fn(printOp->getParentOfType<ModuleOp>());
+      printer = LLVM::lookupOrCreatePrintU64Fn(parent);
     } else if (auto intTy = eltType.dyn_cast<IntegerType>()) {
       // Integers need a zero or sign extension on the operand
       // (depending on the source type) as well as a signed or
@@ -1485,8 +1489,7 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
         if (width <= 64) {
           if (width < 64)
             conversion = PrintConversion::ZeroExt64;
-          printer = LLVM::lookupOrCreatePrintU64Fn(
-              printOp->getParentOfType<ModuleOp>());
+          printer = LLVM::lookupOrCreatePrintU64Fn(parent);
         } else {
           return failure();
         }
@@ -1499,8 +1502,7 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
             conversion = PrintConversion::ZeroExt64;
           else if (width < 64)
             conversion = PrintConversion::SignExt64;
-          printer = LLVM::lookupOrCreatePrintI64Fn(
-              printOp->getParentOfType<ModuleOp>());
+          printer = LLVM::lookupOrCreatePrintI64Fn(parent);
         } else {
           return failure();
         }
@@ -1515,8 +1517,7 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
     emitRanks(rewriter, printOp, adaptor.getSource(), type, printer, rank,
               conversion);
     emitCall(rewriter, printOp->getLoc(),
-             LLVM::lookupOrCreatePrintNewlineFn(
-                 printOp->getParentOfType<ModuleOp>()));
+             LLVM::lookupOrCreatePrintNewlineFn(parent));
     rewriter.eraseOp(printOp);
     return success();
   }
@@ -1526,7 +1527,8 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
     // clang-format off
     None,
     ZeroExt64,
-    SignExt64
+    SignExt64,
+    Bitcast16
     // clang-format on
   };
 
@@ -1546,6 +1548,10 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
         value = rewriter.create<arith::ExtSIOp>(
             loc, IntegerType::get(rewriter.getContext(), 64), value);
         break;
+      case PrintConversion::Bitcast16:
+        value = rewriter.create<LLVM::BitcastOp>(
+            loc, IntegerType::get(rewriter.getContext(), 16), value);
+        break;
       case PrintConversion::None:
         break;
       }
@@ -1553,10 +1559,9 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
       return;
     }
 
-    emitCall(rewriter, loc,
-             LLVM::lookupOrCreatePrintOpenFn(op->getParentOfType<ModuleOp>()));
-    Operation *printComma =
-        LLVM::lookupOrCreatePrintCommaFn(op->getParentOfType<ModuleOp>());
+    auto parent = op->getParentOfType<ModuleOp>();
+    emitCall(rewriter, loc, LLVM::lookupOrCreatePrintOpenFn(parent));
+    Operation *printComma = LLVM::lookupOrCreatePrintCommaFn(parent);
 
     if (rank <= 1) {
       auto reducedType = vectorType.getElementType();
@@ -1570,9 +1575,7 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
         if (d != dim - 1)
           emitCall(rewriter, loc, printComma);
       }
-      emitCall(
-          rewriter, loc,
-          LLVM::lookupOrCreatePrintCloseFn(op->getParentOfType<ModuleOp>()));
+      emitCall(rewriter, loc, LLVM::lookupOrCreatePrintCloseFn(parent));
       return;
     }
 
@@ -1587,8 +1590,7 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
       if (d != dim - 1)
         emitCall(rewriter, loc, printComma);
     }
-    emitCall(rewriter, loc,
-             LLVM::lookupOrCreatePrintCloseFn(op->getParentOfType<ModuleOp>()));
+    emitCall(rewriter, loc, LLVM::lookupOrCreatePrintCloseFn(parent));
   }
 
   // Helper to emit a call.

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
index 237e576a1d9ef..aef3a5a87e9bf 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
@@ -26,6 +26,8 @@ using namespace mlir::LLVM;
 /// part of the libc).
 static constexpr llvm::StringRef kPrintI64 = "printI64";
 static constexpr llvm::StringRef kPrintU64 = "printU64";
+static constexpr llvm::StringRef kPrintF16 = "printF16";
+static constexpr llvm::StringRef kPrintBF16 = "printBF16";
 static constexpr llvm::StringRef kPrintF32 = "printF32";
 static constexpr llvm::StringRef kPrintF64 = "printF64";
 static constexpr llvm::StringRef kPrintStr = "puts";
@@ -67,6 +69,18 @@ LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintU64Fn(ModuleOp moduleOp) {
                           LLVM::LLVMVoidType::get(moduleOp->getContext()));
 }
 
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF16Fn(ModuleOp moduleOp) {
+  return lookupOrCreateFn(moduleOp, kPrintF16,
+                          IntegerType::get(moduleOp->getContext(), 16), // bits!
+                          LLVM::LLVMVoidType::get(moduleOp->getContext()));
+}
+
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintBF16Fn(ModuleOp moduleOp) {
+  return lookupOrCreateFn(moduleOp, kPrintBF16,
+                          IntegerType::get(moduleOp->getContext(), 16), // bits!
+                          LLVM::LLVMVoidType::get(moduleOp->getContext()));
+}
+
 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF32Fn(ModuleOp moduleOp) {
   return lookupOrCreateFn(moduleOp, kPrintF32,
                           Float32Type::get(moduleOp->getContext()),

diff  --git a/mlir/lib/ExecutionEngine/Float16bits.cpp b/mlir/lib/ExecutionEngine/Float16bits.cpp
index 189286b9f13a8..38a05fe86bbdd 100644
--- a/mlir/lib/ExecutionEngine/Float16bits.cpp
+++ b/mlir/lib/ExecutionEngine/Float16bits.cpp
@@ -192,4 +192,16 @@ extern "C" BF16ABIType ATTR_WEAK __truncdfbf2(double d) {
   return __truncsfbf2(static_cast<float>(d));
 }
 
+// Provide these to the CRunner with the local float16 knowledge.
+extern "C" void printF16(uint16_t bits) {
+  f16 f;
+  std::memcpy(&f, &bits, sizeof(f16));
+  std::cout << f;
+}
+extern "C" void printBF16(uint16_t bits) {
+  bf16 f;
+  std::memcpy(&f, &bits, sizeof(bf16));
+  std::cout << f;
+}
+
 #endif // MLIR_FLOAT16_DEFINE_FUNCTIONS

diff  --git a/mlir/test/Integration/Dialect/Vector/CPU/test-print-fp.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-print-fp.mlir
new file mode 100644
index 0000000000000..eeee363d246f3
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/test-print-fp.mlir
@@ -0,0 +1,27 @@
+// RUN: mlir-opt %s -convert-scf-to-cf -convert-vector-to-llvm -convert-func-to-llvm -reconcile-unrealized-casts | \
+// RUN: mlir-cpu-runner -e entry -entry-point-result=void  \
+// RUN:   -shared-libs=%mlir_c_runner_utils | \
+// RUN: FileCheck %s
+
+//
+// Test various floating-point types.
+//
+func.func @entry() {
+  %0 = arith.constant dense<[-1000.0, -1.1, 0.0, 1.1, 1000.0]> : vector<5xf64>
+  vector.print %0 : vector<5xf64>
+  // CHECK: ( -1000, -1.1, 0, 1.1, 1000 )
+
+  %1 = arith.constant dense<[-1000.0, -1.1, 0.0, 1.1, 1000.0]> : vector<5xf32>
+  vector.print %1 : vector<5xf32>
+  // CHECK: ( -1000, -1.1, 0, 1.1, 1000 )
+
+  %2 = arith.constant dense<[-1000.0, -1.1, 0.0, 1.1, 1000.0]> : vector<5xf16>
+  vector.print %2 : vector<5xf16>
+  // CHECK: ( -1000, -1.09961, 0, 1.09961, 1000 )
+
+  %3 = arith.constant dense<[-1000.0, -1.1, 0.0, 1.1, 1000.0]> : vector<5xbf16>
+  vector.print %3 : vector<5xbf16>
+  // CHECK: ( -1000, -1.10156, 0, 1.10156, 1000 )
+
+  return
+}


        


More information about the Mlir-commits mailing list