[Mlir-commits] [mlir] [mlir][VectorOps] Support string literals in `vector.print` (PR #68695)
Benjamin Maxwell
llvmlistbot at llvm.org
Thu Oct 19 10:40:12 PDT 2023
https://github.com/MacDue updated https://github.com/llvm/llvm-project/pull/68695
>From 0a7043b1532c1dba476d795c9ba1b46015bb9eb8 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Tue, 10 Oct 2023 11:38:12 +0000
Subject: [PATCH 1/5] [mlir][VectorOps] Support string literals in
`vector.print`
Printing strings within integration tests is currently quite annoyingly
verbose, and can't be tucked into shared helpers as the types depend on
the length of the string:
```
llvm.mlir.global internal constant @hello_world("Hello, World!\0")
func.func @entry() {
%0 = llvm.mlir.addressof @hello_world : !llvm.ptr<array<14 x i8>>
%1 = llvm.mlir.constant(0 : index) : i64
%2 = llvm.getelementptr %0[%1, %1]
: (!llvm.ptr<array<14 x i8>>, i64, i64) -> !llvm.ptr<i8>
llvm.call @printCString(%2) : (!llvm.ptr<i8>) -> ()
return
}
``
So this patch adds a simple extension to `vector.print` to simplify
this:
```
func.func @entry() {
// Print a vector of characters ;)
vector.print str "Hello, World!"
return
}
```
Most of the logic for this is now shared with `cf.assert` which already
does something similar.
---
.../Conversion/LLVMCommon/PrintCallHelper.h | 36 ++++++++++
.../mlir/Dialect/Vector/IR/VectorOps.td | 37 +++++++++--
.../ControlFlowToLLVM/ControlFlowToLLVM.cpp | 49 +-------------
mlir/lib/Conversion/LLVMCommon/CMakeLists.txt | 1 +
.../Conversion/LLVMCommon/PrintCallHelper.cpp | 66 +++++++++++++++++++
.../VectorToLLVM/ConvertVectorToLLVM.cpp | 6 +-
.../VectorToLLVM/vector-to-llvm.mlir | 14 ++++
.../Dialect/Vector/CPU/test-hello-world.mlir | 10 +++
8 files changed, 168 insertions(+), 51 deletions(-)
create mode 100644 mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h
create mode 100644 mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp
create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/test-hello-world.mlir
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h b/mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h
new file mode 100644
index 000000000000000..7e26858589f2756
--- /dev/null
+++ b/mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h
@@ -0,0 +1,36 @@
+
+//===- PrintCallHelper.h - LLVM Interfaces ----------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_LLVMIR_PRINTCALLHELPER_H_
+#define MLIR_DIALECT_LLVMIR_PRINTCALLHELPER_H_
+
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "llvm/ADT/StringRef.h"
+
+namespace mlir {
+
+class Location;
+class ModuleOp;
+class OpBuilder;
+class Operation;
+class Type;
+class ValueRange;
+class LLVMTypeConverter;
+
+namespace LLVM {
+
+/// Generate IR that prints the given string to stdout.
+void createPrintStrCall(OpBuilder &builder, Location loc, ModuleOp moduleOp,
+ StringRef symbolName, StringRef string,
+ const LLVMTypeConverter &typeConverter);
+} // namespace LLVM
+
+} // namespace mlir
+
+#endif
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 2df2fe4c5ce8e9c..2b60055ca9db94b 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -26,6 +26,7 @@ include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/VectorInterfaces.td"
include "mlir/Interfaces/ViewLikeInterface.td"
+include "mlir/IR/BuiltinAttributes.td"
// TODO: Add an attribute to specify a different algebra with operators other
// than the current set: {*, +}.
@@ -2476,12 +2477,18 @@ def Vector_TransposeOp :
}
def Vector_PrintOp :
- Vector_Op<"print", []>,
+ Vector_Op<"print", [
+ PredOpTrait<
+ "`source` or `punctuation` are not set printing strings",
+ CPred<"!getStringLiteral() || (!getSource() && getPunctuation() == PrintPunctuation::NewLine)">
+ >,
+ ]>,
Arguments<(ins Optional<Type<Or<[
AnyVectorOfAnyRank.predicate,
AnyInteger.predicate, Index.predicate, AnyFloat.predicate
]>>>:$source, DefaultValuedAttr<Vector_PrintPunctuation,
- "::mlir::vector::PrintPunctuation::NewLine">:$punctuation)
+ "::mlir::vector::PrintPunctuation::NewLine">:$punctuation,
+ OptionalAttr<Builtin_StringAttr>:$stringLiteral)
> {
let summary = "print operation (for testing and debugging)";
let description = [{
@@ -2520,6 +2527,13 @@ def Vector_PrintOp :
```mlir
vector.print punctuation <newline>
```
+
+ Additionally, to aid with debugging and testing `vector.print` can also
+ print constant strings:
+
+ ```mlir
+ vector.print str "Hello, World!"
+ ```
}];
let extraClassDeclaration = [{
Type getPrintType() {
@@ -2528,11 +2542,26 @@ def Vector_PrintOp :
}];
let builders = [
OpBuilder<(ins "PrintPunctuation":$punctuation), [{
- build($_builder, $_state, {}, punctuation);
+ build($_builder, $_state, {}, punctuation, {});
+ }]>,
+ OpBuilder<(ins "::mlir::Value":$source), [{
+ build($_builder, $_state, source, PrintPunctuation::NewLine);
+ }]>,
+ OpBuilder<(ins "::mlir::Value":$source, "PrintPunctuation":$punctuation), [{
+ build($_builder, $_state, source, punctuation, {});
+ }]>,
+ OpBuilder<(ins "::llvm::StringRef":$string), [{
+ build($_builder, $_state, {}, PrintPunctuation::NewLine, $_builder.getStringAttr(string));
}]>,
];
- let assemblyFormat = "($source^ `:` type($source))? (`punctuation` $punctuation^)? attr-dict";
+ let assemblyFormat = [{
+ ($source^ `:` type($source))?
+ oilist(
+ `str` $stringLiteral
+ | `punctuation` $punctuation)
+ attr-dict
+ }];
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
index a4f146bbe475cc6..6b7647b038f1d94 100644
--- a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
+++ b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
@@ -16,6 +16,7 @@
#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
+#include "mlir/Conversion/LLVMCommon/PrintCallHelper.h"
#include "mlir/Conversion/LLVMCommon/VectorPattern.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
@@ -36,51 +37,6 @@ using namespace mlir;
#define PASS_NAME "convert-cf-to-llvm"
-static std::string generateGlobalMsgSymbolName(ModuleOp moduleOp) {
- std::string prefix = "assert_msg_";
- int counter = 0;
- while (moduleOp.lookupSymbol(prefix + std::to_string(counter)))
- ++counter;
- return prefix + std::to_string(counter);
-}
-
-/// Generate IR that prints the given string to stderr.
-static void createPrintMsg(OpBuilder &builder, Location loc, ModuleOp moduleOp,
- StringRef msg,
- const LLVMTypeConverter &typeConverter) {
- auto ip = builder.saveInsertionPoint();
- builder.setInsertionPointToStart(moduleOp.getBody());
- MLIRContext *ctx = builder.getContext();
-
- // Create a zero-terminated byte representation and allocate global symbol.
- SmallVector<uint8_t> elementVals;
- elementVals.append(msg.begin(), msg.end());
- elementVals.push_back(0);
- auto dataAttrType = RankedTensorType::get(
- {static_cast<int64_t>(elementVals.size())}, builder.getI8Type());
- auto dataAttr =
- DenseElementsAttr::get(dataAttrType, llvm::ArrayRef(elementVals));
- auto arrayTy =
- LLVM::LLVMArrayType::get(IntegerType::get(ctx, 8), elementVals.size());
- std::string symbolName = generateGlobalMsgSymbolName(moduleOp);
- auto globalOp = builder.create<LLVM::GlobalOp>(
- loc, arrayTy, /*constant=*/true, LLVM::Linkage::Private, symbolName,
- dataAttr);
-
- // Emit call to `printStr` in runtime library.
- builder.restoreInsertionPoint(ip);
- auto msgAddr = builder.create<LLVM::AddressOfOp>(
- loc, typeConverter.getPointerType(arrayTy), globalOp.getName());
- SmallVector<LLVM::GEPArg> indices(1, 0);
- Value gep = builder.create<LLVM::GEPOp>(
- loc, typeConverter.getPointerType(builder.getI8Type()), arrayTy, msgAddr,
- indices);
- Operation *printer = LLVM::lookupOrCreatePrintStrFn(
- moduleOp, typeConverter.useOpaquePointers());
- builder.create<LLVM::CallOp>(loc, TypeRange(), SymbolRefAttr::get(printer),
- gep);
-}
-
namespace {
/// Lower `cf.assert`. The default lowering calls the `abort` function if the
/// assertion is violated and has no effect otherwise. The failure message is
@@ -105,7 +61,8 @@ struct AssertOpLowering : public ConvertOpToLLVMPattern<cf::AssertOp> {
// Failed block: Generate IR to print the message and call `abort`.
Block *failureBlock = rewriter.createBlock(opBlock->getParent());
- createPrintMsg(rewriter, loc, module, op.getMsg(), *getTypeConverter());
+ LLVM::createPrintStrCall(rewriter, loc, module, "assert_msg", op.getMsg(),
+ *getTypeConverter());
if (abortOnFailedAssert) {
// Insert the `abort` declaration if necessary.
auto abortFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("abort");
diff --git a/mlir/lib/Conversion/LLVMCommon/CMakeLists.txt b/mlir/lib/Conversion/LLVMCommon/CMakeLists.txt
index 091cd539f0ae014..568d9339aaabcb4 100644
--- a/mlir/lib/Conversion/LLVMCommon/CMakeLists.txt
+++ b/mlir/lib/Conversion/LLVMCommon/CMakeLists.txt
@@ -3,6 +3,7 @@ add_mlir_conversion_library(MLIRLLVMCommonConversion
LoweringOptions.cpp
MemRefBuilder.cpp
Pattern.cpp
+ PrintCallHelper.cpp
StructBuilder.cpp
TypeConverter.cpp
VectorPattern.cpp
diff --git a/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp b/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp
new file mode 100644
index 000000000000000..487abb435d10ad7
--- /dev/null
+++ b/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp
@@ -0,0 +1,66 @@
+
+//===- PrintCallHelper.cpp - LLVM Interfaces --------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/LLVMCommon/PrintCallHelper.h"
+#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
+#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
+#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "llvm/ADT/ArrayRef.h"
+
+using namespace mlir;
+using namespace llvm;
+
+static std::string ensureSymbolNameIsUnique(ModuleOp moduleOp,
+ StringRef symbolName) {
+ static int counter = 0;
+ std::string uniqueName = std::string(symbolName);
+ while (moduleOp.lookupSymbol(uniqueName)) {
+ uniqueName = std::string(symbolName) + "_" + std::to_string(counter++);
+ }
+ return uniqueName;
+}
+
+void mlir::LLVM::createPrintStrCall(OpBuilder &builder, Location loc,
+ ModuleOp moduleOp, StringRef symbolName,
+ StringRef string,
+ const LLVMTypeConverter &typeConverter) {
+ auto ip = builder.saveInsertionPoint();
+ builder.setInsertionPointToStart(moduleOp.getBody());
+ MLIRContext *ctx = builder.getContext();
+
+ // Create a zero-terminated byte representation and allocate global symbol.
+ SmallVector<uint8_t> elementVals;
+ elementVals.append(string.begin(), string.end());
+ elementVals.push_back(0);
+ auto dataAttrType = RankedTensorType::get(
+ {static_cast<int64_t>(elementVals.size())}, builder.getI8Type());
+ auto dataAttr =
+ DenseElementsAttr::get(dataAttrType, llvm::ArrayRef(elementVals));
+ auto arrayTy =
+ LLVM::LLVMArrayType::get(IntegerType::get(ctx, 8), elementVals.size());
+ auto globalOp = builder.create<LLVM::GlobalOp>(
+ loc, arrayTy, /*constant=*/true, LLVM::Linkage::Private,
+ ensureSymbolNameIsUnique(moduleOp, symbolName), dataAttr);
+
+ // Emit call to `printStr` in runtime library.
+ builder.restoreInsertionPoint(ip);
+ auto msgAddr = builder.create<LLVM::AddressOfOp>(
+ loc, typeConverter.getPointerType(arrayTy), globalOp.getName());
+ SmallVector<LLVM::GEPArg> indices(1, 0);
+ Value gep = builder.create<LLVM::GEPOp>(
+ loc, typeConverter.getPointerType(builder.getI8Type()), arrayTy, msgAddr,
+ indices);
+ Operation *printer = LLVM::lookupOrCreatePrintStrFn(
+ moduleOp, typeConverter.useOpaquePointers());
+ builder.create<LLVM::CallOp>(loc, TypeRange(), SymbolRefAttr::get(printer),
+ gep);
+}
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 8427d60f14c0bcc..4af58653c8227ae 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -9,6 +9,7 @@
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
#include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h"
+#include "mlir/Conversion/LLVMCommon/PrintCallHelper.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Conversion/LLVMCommon/VectorPattern.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
@@ -1548,7 +1549,10 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
}
auto punct = printOp.getPunctuation();
- if (punct != PrintPunctuation::NoPunctuation) {
+ if (auto stringLiteral = printOp.getStringLiteral()) {
+ LLVM::createPrintStrCall(rewriter, loc, parent, "vector_print_str",
+ *stringLiteral, *getTypeConverter());
+ } else if (punct != PrintPunctuation::NoPunctuation) {
emitCall(rewriter, printOp->getLoc(), [&] {
switch (punct) {
case PrintPunctuation::Close:
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 9aa4d735681f576..65b3a78e295f0c4 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1068,6 +1068,20 @@ func.func @vector_print_scalar_f64(%arg0: f64) {
// -----
+// CHECK-LABEL: module {
+// CHECK: llvm.func @puts(!llvm.ptr)
+// CHECK: llvm.mlir.global private constant @[[GLOBAL_STR:.*]](dense<[72, 101, 108, 108, 111, 44, 32, 87, 111, 114, 108, 100, 33, 0]> : tensor<14xi8>) {addr_space = 0 : i32} : !llvm.array<14 x i8>
+// CHECK: @vector_print_string
+// CHECK-NEXT: %[[GLOBAL_ADDR:.*]] = llvm.mlir.addressof @[[GLOBAL_STR]] : !llvm.ptr
+// CHECK-NEXT: %[[STR_PTR:.*]] = llvm.getelementptr %[[GLOBAL_ADDR]][0] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<14 x i8>
+// CHECK-NEXT: llvm.call @puts(%[[STR_PTR]]) : (!llvm.ptr) -> ()
+func.func @vector_print_string() {
+ vector.print str "Hello, World!"
+ return
+}
+
+// -----
+
func.func @extract_strided_slice1(%arg0: vector<4xf32>) -> vector<2xf32> {
%0 = vector.extract_strided_slice %arg0 {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
return %0 : vector<2xf32>
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-hello-world.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-hello-world.mlir
new file mode 100644
index 000000000000000..c4076e65151ac72
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/test-hello-world.mlir
@@ -0,0 +1,10 @@
+// RUN: mlir-opt %s -test-lower-to-llvm | \
+// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
+// RUN: -shared-libs=%mlir_c_runner_utils | \
+// RUN: FileCheck %s
+
+func.func @entry() {
+ // CHECK: Hello, World!
+ vector.print str "Hello, World!"
+ return
+}
>From c1594360dc1b4dc03e94e9b2bcbd49893fa2a146 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Wed, 11 Oct 2023 15:01:35 +0000
Subject: [PATCH 2/5] Fixups
---
.../mlir/Conversion/LLVMCommon/PrintCallHelper.h | 8 +-------
mlir/include/mlir/Dialect/Vector/IR/VectorOps.td | 2 +-
.../Conversion/LLVMCommon/PrintCallHelper.cpp | 4 +---
mlir/test/Dialect/Vector/invalid.mlir | 16 ++++++++++++++++
...test-hello-world.mlir => test-print-str.mlir} | 4 ++++
5 files changed, 23 insertions(+), 11 deletions(-)
rename mlir/test/Integration/Dialect/Vector/CPU/{test-hello-world.mlir => test-print-str.mlir} (71%)
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h b/mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h
index 7e26858589f2756..457cd98ca3dc2c8 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h
@@ -1,5 +1,4 @@
-
-//===- PrintCallHelper.h - LLVM Interfaces ----------------------*- C++ -*-===//
+//===- PrintCallHelper.h - Helper to emit runtime print calls ---*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -15,12 +14,7 @@
namespace mlir {
-class Location;
-class ModuleOp;
class OpBuilder;
-class Operation;
-class Type;
-class ValueRange;
class LLVMTypeConverter;
namespace LLVM {
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 2b60055ca9db94b..f946d124fb2fa5e 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2479,7 +2479,7 @@ def Vector_TransposeOp :
def Vector_PrintOp :
Vector_Op<"print", [
PredOpTrait<
- "`source` or `punctuation` are not set printing strings",
+ "`source` or `punctuation` are not set when printing strings",
CPred<"!getStringLiteral() || (!getSource() && getPunctuation() == PrintPunctuation::NewLine)">
>,
]>,
diff --git a/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp b/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp
index 487abb435d10ad7..40b9382452fbb45 100644
--- a/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp
@@ -1,5 +1,4 @@
-
-//===- PrintCallHelper.cpp - LLVM Interfaces --------------------*- C++ -*-===//
+//===- PrintCallHelper.cpp - Helper to emit runtime print calls -----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -8,7 +7,6 @@
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/LLVMCommon/PrintCallHelper.h"
-#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 5967a8d69bbfcc0..1664ddde7e48d76 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1016,6 +1016,22 @@ func.func private @print_needs_vector(%arg0: tensor<8xf32>) {
// -----
+func.func @cannot_print_string_with_punctuation_set() {
+ // expected-error at +1 {{`source` or `punctuation` are not set when printing strings}}
+ vector.print str "Whoops!" punctuation <comma>
+ return
+}
+
+// -----
+
+func.func @cannot_print_string_with_source_set(%vec: vector<[4]xf32>) {
+ // expected-error at +1 {{`source` or `punctuation` are not set when printing strings}}
+ vector.print %vec: vector<[4]xf32> str "Yay!"
+ return
+}
+
+// -----
+
func.func @reshape_bad_input_shape(%arg0 : vector<3x2x4xf32>) {
%c2 = arith.constant 2 : index
%c3 = arith.constant 3 : index
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-hello-world.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-print-str.mlir
similarity index 71%
rename from mlir/test/Integration/Dialect/Vector/CPU/test-hello-world.mlir
rename to mlir/test/Integration/Dialect/Vector/CPU/test-print-str.mlir
index c4076e65151ac72..4a11987121b3308 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/test-hello-world.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/test-print-str.mlir
@@ -3,8 +3,12 @@
// RUN: -shared-libs=%mlir_c_runner_utils | \
// RUN: FileCheck %s
+/// This tests printing (multiple) string literals works.
+
func.func @entry() {
// CHECK: Hello, World!
vector.print str "Hello, World!"
+ // CHECK-NEXT: Bye!
+ vector.print str "Bye!"
return
}
>From cfd10aaa9ee8ed95b72ef65a42205482b459978f Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Thu, 19 Oct 2023 07:59:32 +0000
Subject: [PATCH 3/5] Use `printCStr()` rather than `puts()`
PrintCallHelper is the only use of this, so we can safely switch.
---
mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h | 3 ++-
mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp | 7 +++++--
mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp | 2 +-
.../Integration/Dialect/Vector/CPU/test-print-str.mlir | 2 +-
4 files changed, 9 insertions(+), 5 deletions(-)
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h b/mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h
index 457cd98ca3dc2c8..ca30553e5de3806 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h
@@ -22,7 +22,8 @@ namespace LLVM {
/// Generate IR that prints the given string to stdout.
void createPrintStrCall(OpBuilder &builder, Location loc, ModuleOp moduleOp,
StringRef symbolName, StringRef string,
- const LLVMTypeConverter &typeConverter);
+ const LLVMTypeConverter &typeConverter,
+ bool addNewline = true);
} // namespace LLVM
} // namespace mlir
diff --git a/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp b/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp
index 40b9382452fbb45..03dbc65240a1a54 100644
--- a/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp
@@ -30,7 +30,8 @@ static std::string ensureSymbolNameIsUnique(ModuleOp moduleOp,
void mlir::LLVM::createPrintStrCall(OpBuilder &builder, Location loc,
ModuleOp moduleOp, StringRef symbolName,
StringRef string,
- const LLVMTypeConverter &typeConverter) {
+ const LLVMTypeConverter &typeConverter,
+ bool addNewline) {
auto ip = builder.saveInsertionPoint();
builder.setInsertionPointToStart(moduleOp.getBody());
MLIRContext *ctx = builder.getContext();
@@ -38,7 +39,9 @@ void mlir::LLVM::createPrintStrCall(OpBuilder &builder, Location loc,
// Create a zero-terminated byte representation and allocate global symbol.
SmallVector<uint8_t> elementVals;
elementVals.append(string.begin(), string.end());
- elementVals.push_back(0);
+ if (addNewline)
+ elementVals.push_back('\n');
+ elementVals.push_back('\0');
auto dataAttrType = RankedTensorType::get(
{static_cast<int64_t>(elementVals.size())}, builder.getI8Type());
auto dataAttr =
diff --git a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
index aef3a5a87e9bfe3..55a644bca31733d 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
@@ -30,7 +30,7 @@ static constexpr llvm::StringRef kPrintF16 = "printF16";
static constexpr llvm::StringRef kPrintBF16 = "printBF16";
static constexpr llvm::StringRef kPrintF32 = "printF32";
static constexpr llvm::StringRef kPrintF64 = "printF64";
-static constexpr llvm::StringRef kPrintStr = "puts";
+static constexpr llvm::StringRef kPrintStr = "printCString";
static constexpr llvm::StringRef kPrintOpen = "printOpen";
static constexpr llvm::StringRef kPrintClose = "printClose";
static constexpr llvm::StringRef kPrintComma = "printComma";
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-print-str.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-print-str.mlir
index 4a11987121b3308..78d6609ccaf9a9d 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/test-print-str.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/test-print-str.mlir
@@ -1,6 +1,6 @@
// RUN: mlir-opt %s -test-lower-to-llvm | \
// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
-// RUN: -shared-libs=%mlir_c_runner_utils | \
+// RUN: -shared-libs=%mlir_c_runner_utils,%mlir_runner_utils | \
// RUN: FileCheck %s
/// This tests printing (multiple) string literals works.
>From 6baba389c2e50c3cf17791df079d2199c4db72fe Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Thu, 19 Oct 2023 08:19:50 +0000
Subject: [PATCH 4/5] Fixup test checks and naming
---
mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h | 4 ++--
mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp | 2 +-
mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp | 8 ++++----
mlir/test/Conversion/ControlFlowToLLVM/assert.mlir | 4 ++--
mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir | 8 ++++----
mlir/test/Integration/Dialect/ControlFlow/assert.mlir | 2 +-
6 files changed, 14 insertions(+), 14 deletions(-)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
index 17aa9a3c831c2e0..4a86edfdf8e1a02 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
@@ -38,8 +38,8 @@ LLVM::LLVMFuncOp lookupOrCreatePrintF16Fn(ModuleOp moduleOp);
LLVM::LLVMFuncOp lookupOrCreatePrintBF16Fn(ModuleOp moduleOp);
LLVM::LLVMFuncOp lookupOrCreatePrintF32Fn(ModuleOp moduleOp);
LLVM::LLVMFuncOp lookupOrCreatePrintF64Fn(ModuleOp moduleOp);
-LLVM::LLVMFuncOp lookupOrCreatePrintStrFn(ModuleOp moduleOp,
- bool opaquePointers);
+LLVM::LLVMFuncOp lookupOrCreatePrintCStringFn(ModuleOp moduleOp,
+ bool opaquePointers);
LLVM::LLVMFuncOp lookupOrCreatePrintOpenFn(ModuleOp moduleOp);
LLVM::LLVMFuncOp lookupOrCreatePrintCloseFn(ModuleOp moduleOp);
LLVM::LLVMFuncOp lookupOrCreatePrintCommaFn(ModuleOp moduleOp);
diff --git a/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp b/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp
index 03dbc65240a1a54..4017fd9ad8c017a 100644
--- a/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp
@@ -60,7 +60,7 @@ void mlir::LLVM::createPrintStrCall(OpBuilder &builder, Location loc,
Value gep = builder.create<LLVM::GEPOp>(
loc, typeConverter.getPointerType(builder.getI8Type()), arrayTy, msgAddr,
indices);
- Operation *printer = LLVM::lookupOrCreatePrintStrFn(
+ Operation *printer = LLVM::lookupOrCreatePrintCStringFn(
moduleOp, typeConverter.useOpaquePointers());
builder.create<LLVM::CallOp>(loc, TypeRange(), SymbolRefAttr::get(printer),
gep);
diff --git a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
index 55a644bca31733d..228d85d96cd4fc5 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
@@ -30,7 +30,7 @@ static constexpr llvm::StringRef kPrintF16 = "printF16";
static constexpr llvm::StringRef kPrintBF16 = "printBF16";
static constexpr llvm::StringRef kPrintF32 = "printF32";
static constexpr llvm::StringRef kPrintF64 = "printF64";
-static constexpr llvm::StringRef kPrintStr = "printCString";
+static constexpr llvm::StringRef kPrintCString = "printCString";
static constexpr llvm::StringRef kPrintOpen = "printOpen";
static constexpr llvm::StringRef kPrintClose = "printClose";
static constexpr llvm::StringRef kPrintComma = "printComma";
@@ -107,9 +107,9 @@ static LLVM::LLVMPointerType getVoidPtr(MLIRContext *context,
return getCharPtr(context, opaquePointers);
}
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintStrFn(ModuleOp moduleOp,
- bool opaquePointers) {
- return lookupOrCreateFn(moduleOp, kPrintStr,
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintCStringFn(ModuleOp moduleOp,
+ bool opaquePointers) {
+ return lookupOrCreateFn(moduleOp, kPrintCString,
getCharPtr(moduleOp->getContext(), opaquePointers),
LLVM::LLVMVoidType::get(moduleOp->getContext()));
}
diff --git a/mlir/test/Conversion/ControlFlowToLLVM/assert.mlir b/mlir/test/Conversion/ControlFlowToLLVM/assert.mlir
index dc5ba0680acb2e1..1642a6fb5bb9bdb 100644
--- a/mlir/test/Conversion/ControlFlowToLLVM/assert.mlir
+++ b/mlir/test/Conversion/ControlFlowToLLVM/assert.mlir
@@ -10,7 +10,7 @@ func.func @main() {
return
}
-// CHECK: llvm.func @puts(!llvm.ptr)
+// CHECK: llvm.func @printCString(!llvm.ptr)
// CHECK-LABEL: @main
// CHECK: llvm.cond_br %{{.*}}, ^{{.*}}, ^[[FALSE_BRANCH:[[:alnum:]]+]]
@@ -18,4 +18,4 @@ func.func @main() {
// CHECK: ^[[FALSE_BRANCH]]:
// CHECK: %[[ADDRESS_OF:.*]] = llvm.mlir.addressof @{{.*}} : !llvm.ptr{{$}}
// CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ADDRESS_OF]][0] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<{{[0-9]+}} x i8>
-// CHECK: llvm.call @puts(%[[GEP]]) : (!llvm.ptr) -> ()
+// CHECK: llvm.call @printCString(%[[GEP]]) : (!llvm.ptr) -> ()
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 65b3a78e295f0c4..ef7260c5bb57ab7 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1069,12 +1069,12 @@ func.func @vector_print_scalar_f64(%arg0: f64) {
// -----
// CHECK-LABEL: module {
-// CHECK: llvm.func @puts(!llvm.ptr)
-// CHECK: llvm.mlir.global private constant @[[GLOBAL_STR:.*]](dense<[72, 101, 108, 108, 111, 44, 32, 87, 111, 114, 108, 100, 33, 0]> : tensor<14xi8>) {addr_space = 0 : i32} : !llvm.array<14 x i8>
+// CHECK: llvm.func @printCString(!llvm.ptr)
+// CHECK: llvm.mlir.global private constant @[[GLOBAL_STR:.*]]({{.*}})
// CHECK: @vector_print_string
// CHECK-NEXT: %[[GLOBAL_ADDR:.*]] = llvm.mlir.addressof @[[GLOBAL_STR]] : !llvm.ptr
-// CHECK-NEXT: %[[STR_PTR:.*]] = llvm.getelementptr %[[GLOBAL_ADDR]][0] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<14 x i8>
-// CHECK-NEXT: llvm.call @puts(%[[STR_PTR]]) : (!llvm.ptr) -> ()
+// CHECK-NEXT: %[[STR_PTR:.*]] = llvm.getelementptr %[[GLOBAL_ADDR]][0] : (!llvm.ptr) -> !llvm.ptr
+// CHECK-NEXT: llvm.call @printCString(%[[STR_PTR]]) : (!llvm.ptr) -> ()
func.func @vector_print_string() {
vector.print str "Hello, World!"
return
diff --git a/mlir/test/Integration/Dialect/ControlFlow/assert.mlir b/mlir/test/Integration/Dialect/ControlFlow/assert.mlir
index 42130250daf1b6a..63ce092818627d9 100644
--- a/mlir/test/Integration/Dialect/ControlFlow/assert.mlir
+++ b/mlir/test/Integration/Dialect/ControlFlow/assert.mlir
@@ -1,6 +1,6 @@
// RUN: mlir-opt %s -test-cf-assert \
// RUN: -convert-func-to-llvm | \
-// RUN: mlir-cpu-runner -e main -entry-point-result=void | \
+// RUN: mlir-cpu-runner -e main -entry-point-result=void -shared-libs=%mlir_runner_utils | \
// RUN: FileCheck %s
func.func @main() {
>From 92b2ce2c03642761da9ed219c16b789c12de51dd Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Thu, 19 Oct 2023 17:36:12 +0000
Subject: [PATCH 5/5] Add printString to CRunnerUtils and use that instead
I've also weakly defined this in RunnerUtils so linking either or both
gives you printString, this avoids the need to update a bunch of tests
that use cf.assert.
---
mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h | 4 ++--
mlir/include/mlir/ExecutionEngine/CRunnerUtils.h | 1 +
mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp | 2 +-
mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp | 8 ++++----
mlir/lib/ExecutionEngine/CRunnerUtils.cpp | 1 +
mlir/lib/ExecutionEngine/RunnerUtils.cpp | 6 +++++-
mlir/test/Conversion/ControlFlowToLLVM/assert.mlir | 4 ++--
mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir | 4 ++--
8 files changed, 18 insertions(+), 12 deletions(-)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
index 4a86edfdf8e1a02..c0806b64d25f3a6 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
@@ -38,8 +38,8 @@ LLVM::LLVMFuncOp lookupOrCreatePrintF16Fn(ModuleOp moduleOp);
LLVM::LLVMFuncOp lookupOrCreatePrintBF16Fn(ModuleOp moduleOp);
LLVM::LLVMFuncOp lookupOrCreatePrintF32Fn(ModuleOp moduleOp);
LLVM::LLVMFuncOp lookupOrCreatePrintF64Fn(ModuleOp moduleOp);
-LLVM::LLVMFuncOp lookupOrCreatePrintCStringFn(ModuleOp moduleOp,
- bool opaquePointers);
+LLVM::LLVMFuncOp lookupOrCreatePrintStringFn(ModuleOp moduleOp,
+ bool opaquePointers);
LLVM::LLVMFuncOp lookupOrCreatePrintOpenFn(ModuleOp moduleOp);
LLVM::LLVMFuncOp lookupOrCreatePrintCloseFn(ModuleOp moduleOp);
LLVM::LLVMFuncOp lookupOrCreatePrintCommaFn(ModuleOp moduleOp);
diff --git a/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h b/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h
index e8f429463cb0b9b..76b04145b482e4a 100644
--- a/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h
+++ b/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h
@@ -465,6 +465,7 @@ extern "C" MLIR_CRUNNERUTILS_EXPORT void printI64(int64_t i);
extern "C" MLIR_CRUNNERUTILS_EXPORT void printU64(uint64_t u);
extern "C" MLIR_CRUNNERUTILS_EXPORT void printF32(float f);
extern "C" MLIR_CRUNNERUTILS_EXPORT void printF64(double d);
+extern "C" MLIR_CRUNNERUTILS_EXPORT void printString(char const *s);
extern "C" MLIR_CRUNNERUTILS_EXPORT void printOpen();
extern "C" MLIR_CRUNNERUTILS_EXPORT void printClose();
extern "C" MLIR_CRUNNERUTILS_EXPORT void printComma();
diff --git a/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp b/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp
index 4017fd9ad8c017a..8fecd4ca6c298d6 100644
--- a/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp
@@ -60,7 +60,7 @@ void mlir::LLVM::createPrintStrCall(OpBuilder &builder, Location loc,
Value gep = builder.create<LLVM::GEPOp>(
loc, typeConverter.getPointerType(builder.getI8Type()), arrayTy, msgAddr,
indices);
- Operation *printer = LLVM::lookupOrCreatePrintCStringFn(
+ Operation *printer = LLVM::lookupOrCreatePrintStringFn(
moduleOp, typeConverter.useOpaquePointers());
builder.create<LLVM::CallOp>(loc, TypeRange(), SymbolRefAttr::get(printer),
gep);
diff --git a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
index 228d85d96cd4fc5..83540c83df3d1a9 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
@@ -30,7 +30,7 @@ static constexpr llvm::StringRef kPrintF16 = "printF16";
static constexpr llvm::StringRef kPrintBF16 = "printBF16";
static constexpr llvm::StringRef kPrintF32 = "printF32";
static constexpr llvm::StringRef kPrintF64 = "printF64";
-static constexpr llvm::StringRef kPrintCString = "printCString";
+static constexpr llvm::StringRef kPrintString = "printString";
static constexpr llvm::StringRef kPrintOpen = "printOpen";
static constexpr llvm::StringRef kPrintClose = "printClose";
static constexpr llvm::StringRef kPrintComma = "printComma";
@@ -107,9 +107,9 @@ static LLVM::LLVMPointerType getVoidPtr(MLIRContext *context,
return getCharPtr(context, opaquePointers);
}
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintCStringFn(ModuleOp moduleOp,
- bool opaquePointers) {
- return lookupOrCreateFn(moduleOp, kPrintCString,
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintStringFn(ModuleOp moduleOp,
+ bool opaquePointers) {
+ return lookupOrCreateFn(moduleOp, kPrintString,
getCharPtr(moduleOp->getContext(), opaquePointers),
LLVM::LLVMVoidType::get(moduleOp->getContext()));
}
diff --git a/mlir/lib/ExecutionEngine/CRunnerUtils.cpp b/mlir/lib/ExecutionEngine/CRunnerUtils.cpp
index 7f800e8ea96ff4f..fcfd89ecaf488fb 100644
--- a/mlir/lib/ExecutionEngine/CRunnerUtils.cpp
+++ b/mlir/lib/ExecutionEngine/CRunnerUtils.cpp
@@ -51,6 +51,7 @@ extern "C" void printI64(int64_t i) { fprintf(stdout, "%" PRId64, i); }
extern "C" void printU64(uint64_t u) { fprintf(stdout, "%" PRIu64, u); }
extern "C" void printF32(float f) { fprintf(stdout, "%g", f); }
extern "C" void printF64(double d) { fprintf(stdout, "%lg", d); }
+extern "C" void printString(char const *s) { fputs(s, stdout); }
extern "C" void printOpen() { fputs("( ", stdout); }
extern "C" void printClose() { fputs(" )", stdout); }
extern "C" void printComma() { fputs(", ", stdout); }
diff --git a/mlir/lib/ExecutionEngine/RunnerUtils.cpp b/mlir/lib/ExecutionEngine/RunnerUtils.cpp
index ccf5309487637e7..b0cc31ec926545f 100644
--- a/mlir/lib/ExecutionEngine/RunnerUtils.cpp
+++ b/mlir/lib/ExecutionEngine/RunnerUtils.cpp
@@ -158,7 +158,11 @@ extern "C" void printMemrefC64(int64_t rank, void *ptr) {
_mlir_ciface_printMemrefC64(&descriptor);
}
-extern "C" void printCString(char *str) { printf("%s", str); }
+extern "C" void printCString(char *str) { fputs(str, stdout); }
+// Weakly defined so both RunnerUtils and CRunnerUtils can provide printString.
+extern "C" void __attribute__((weak)) printString(char const *str) {
+ printCString(const_cast<char *>(str));
+}
extern "C" void _mlir_ciface_printMemref0dF32(StridedMemRefType<float, 0> *M) {
impl::printMemRef(*M);
diff --git a/mlir/test/Conversion/ControlFlowToLLVM/assert.mlir b/mlir/test/Conversion/ControlFlowToLLVM/assert.mlir
index 1642a6fb5bb9bdb..a432cdfee2e691b 100644
--- a/mlir/test/Conversion/ControlFlowToLLVM/assert.mlir
+++ b/mlir/test/Conversion/ControlFlowToLLVM/assert.mlir
@@ -10,7 +10,7 @@ func.func @main() {
return
}
-// CHECK: llvm.func @printCString(!llvm.ptr)
+// CHECK: llvm.func @printString(!llvm.ptr)
// CHECK-LABEL: @main
// CHECK: llvm.cond_br %{{.*}}, ^{{.*}}, ^[[FALSE_BRANCH:[[:alnum:]]+]]
@@ -18,4 +18,4 @@ func.func @main() {
// CHECK: ^[[FALSE_BRANCH]]:
// CHECK: %[[ADDRESS_OF:.*]] = llvm.mlir.addressof @{{.*}} : !llvm.ptr{{$}}
// CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ADDRESS_OF]][0] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<{{[0-9]+}} x i8>
-// CHECK: llvm.call @printCString(%[[GEP]]) : (!llvm.ptr) -> ()
+// CHECK: llvm.call @printString(%[[GEP]]) : (!llvm.ptr) -> ()
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index ef7260c5bb57ab7..05733214bc3ae80 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1069,12 +1069,12 @@ func.func @vector_print_scalar_f64(%arg0: f64) {
// -----
// CHECK-LABEL: module {
-// CHECK: llvm.func @printCString(!llvm.ptr)
+// CHECK: llvm.func @printString(!llvm.ptr)
// CHECK: llvm.mlir.global private constant @[[GLOBAL_STR:.*]]({{.*}})
// CHECK: @vector_print_string
// CHECK-NEXT: %[[GLOBAL_ADDR:.*]] = llvm.mlir.addressof @[[GLOBAL_STR]] : !llvm.ptr
// CHECK-NEXT: %[[STR_PTR:.*]] = llvm.getelementptr %[[GLOBAL_ADDR]][0] : (!llvm.ptr) -> !llvm.ptr
-// CHECK-NEXT: llvm.call @printCString(%[[STR_PTR]]) : (!llvm.ptr) -> ()
+// CHECK-NEXT: llvm.call @printString(%[[STR_PTR]]) : (!llvm.ptr) -> ()
func.func @vector_print_string() {
vector.print str "Hello, World!"
return
More information about the Mlir-commits
mailing list