[Mlir-commits] [mlir] [mlir][test] Update tests to use vector.print str (NFC) (PR #68973)

Benjamin Maxwell llvmlistbot at llvm.org
Fri Oct 13 03:42:46 PDT 2023


https://github.com/MacDue created https://github.com/llvm/llvm-project/pull/68973

This cuts down on a fair amount of boilerplate.

Depends on: #68695

>From 1aac1e383a003dfa9c922cf1065625251a107f48 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Tue, 10 Oct 2023 11:38:12 +0000
Subject: [PATCH 1/3] [mlir][VectorOps] Support string literals in
 `vector.print`

Printing strings within integration tests is currently quite annoyingly
verbose, and can't be tucked into shared helpers as the types depend on
the length of the string:

```
llvm.mlir.global internal constant @hello_world("Hello, World!\0")

func.func @entry() {
  %0 = llvm.mlir.addressof @hello_world : !llvm.ptr<array<14 x i8>>
  %1 = llvm.mlir.constant(0 : index) : i64
  %2 = llvm.getelementptr %0[%1, %1]
    : (!llvm.ptr<array<14 x i8>>, i64, i64) -> !llvm.ptr<i8>
  llvm.call @printCString(%2) : (!llvm.ptr<i8>) -> ()
  return
}
``

So this patch adds a simple extension to `vector.print` to simplify
this:
```
func.func @entry() {
   // Print a vector of characters ;)
   vector.print str "Hello, World!"
   return
}
```

Most of the logic for this is now shared with `cf.assert` which already
does something similar.
---
 .../Conversion/LLVMCommon/PrintCallHelper.h   | 36 ++++++++++
 .../mlir/Dialect/Vector/IR/VectorOps.td       | 37 +++++++++--
 .../ControlFlowToLLVM/ControlFlowToLLVM.cpp   | 49 +-------------
 mlir/lib/Conversion/LLVMCommon/CMakeLists.txt |  1 +
 .../Conversion/LLVMCommon/PrintCallHelper.cpp | 66 +++++++++++++++++++
 .../VectorToLLVM/ConvertVectorToLLVM.cpp      |  6 +-
 .../VectorToLLVM/vector-to-llvm.mlir          | 14 ++++
 .../Dialect/Vector/CPU/test-hello-world.mlir  | 10 +++
 8 files changed, 168 insertions(+), 51 deletions(-)
 create mode 100644 mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h
 create mode 100644 mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp
 create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/test-hello-world.mlir

diff --git a/mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h b/mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h
new file mode 100644
index 000000000000000..7e26858589f2756
--- /dev/null
+++ b/mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h
@@ -0,0 +1,36 @@
+
+//===- PrintCallHelper.h - LLVM Interfaces ----------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_LLVMIR_PRINTCALLHELPER_H_
+#define MLIR_DIALECT_LLVMIR_PRINTCALLHELPER_H_
+
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "llvm/ADT/StringRef.h"
+
+namespace mlir {
+
+class Location;
+class ModuleOp;
+class OpBuilder;
+class Operation;
+class Type;
+class ValueRange;
+class LLVMTypeConverter;
+
+namespace LLVM {
+
+/// Generate IR that prints the given string to stdout.
+void createPrintStrCall(OpBuilder &builder, Location loc, ModuleOp moduleOp,
+                        StringRef symbolName, StringRef string,
+                        const LLVMTypeConverter &typeConverter);
+} // namespace LLVM
+
+} // namespace mlir
+
+#endif
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 2df2fe4c5ce8e9c..2b60055ca9db94b 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -26,6 +26,7 @@ include "mlir/Interfaces/InferTypeOpInterface.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/Interfaces/VectorInterfaces.td"
 include "mlir/Interfaces/ViewLikeInterface.td"
+include "mlir/IR/BuiltinAttributes.td"
 
 // TODO: Add an attribute to specify a different algebra with operators other
 // than the current set: {*, +}.
@@ -2476,12 +2477,18 @@ def Vector_TransposeOp :
 }
 
 def Vector_PrintOp :
-  Vector_Op<"print", []>,
+  Vector_Op<"print", [
+    PredOpTrait<
+      "`source` or `punctuation` are not set printing strings",
+      CPred<"!getStringLiteral() || (!getSource() && getPunctuation() == PrintPunctuation::NewLine)">
+    >,
+  ]>,
   Arguments<(ins Optional<Type<Or<[
     AnyVectorOfAnyRank.predicate,
     AnyInteger.predicate, Index.predicate, AnyFloat.predicate
   ]>>>:$source, DefaultValuedAttr<Vector_PrintPunctuation,
-                      "::mlir::vector::PrintPunctuation::NewLine">:$punctuation)
+                      "::mlir::vector::PrintPunctuation::NewLine">:$punctuation,
+                OptionalAttr<Builtin_StringAttr>:$stringLiteral)
   > {
   let summary = "print operation (for testing and debugging)";
   let description = [{
@@ -2520,6 +2527,13 @@ def Vector_PrintOp :
     ```mlir
     vector.print punctuation <newline>
     ```
+
+    Additionally, to aid with debugging and testing `vector.print` can also
+    print constant strings:
+
+    ```mlir
+    vector.print str "Hello, World!"
+    ```
   }];
   let extraClassDeclaration = [{
     Type getPrintType() {
@@ -2528,11 +2542,26 @@ def Vector_PrintOp :
   }];
   let builders = [
     OpBuilder<(ins "PrintPunctuation":$punctuation), [{
-      build($_builder, $_state, {}, punctuation);
+      build($_builder, $_state, {}, punctuation, {});
+    }]>,
+    OpBuilder<(ins "::mlir::Value":$source), [{
+      build($_builder, $_state, source, PrintPunctuation::NewLine);
+    }]>,
+    OpBuilder<(ins "::mlir::Value":$source, "PrintPunctuation":$punctuation), [{
+      build($_builder, $_state, source, punctuation, {});
+    }]>,
+    OpBuilder<(ins "::llvm::StringRef":$string), [{
+      build($_builder, $_state, {}, PrintPunctuation::NewLine, $_builder.getStringAttr(string));
     }]>,
   ];
 
-  let assemblyFormat = "($source^ `:` type($source))? (`punctuation` $punctuation^)? attr-dict";
+  let assemblyFormat = [{
+      ($source^ `:` type($source))?
+        oilist(
+            `str` $stringLiteral
+          | `punctuation` $punctuation)
+        attr-dict
+    }];
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
index a4f146bbe475cc6..6b7647b038f1d94 100644
--- a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
+++ b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
@@ -16,6 +16,7 @@
 #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
 #include "mlir/Conversion/LLVMCommon/Pattern.h"
+#include "mlir/Conversion/LLVMCommon/PrintCallHelper.h"
 #include "mlir/Conversion/LLVMCommon/VectorPattern.h"
 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
@@ -36,51 +37,6 @@ using namespace mlir;
 
 #define PASS_NAME "convert-cf-to-llvm"
 
-static std::string generateGlobalMsgSymbolName(ModuleOp moduleOp) {
-  std::string prefix = "assert_msg_";
-  int counter = 0;
-  while (moduleOp.lookupSymbol(prefix + std::to_string(counter)))
-    ++counter;
-  return prefix + std::to_string(counter);
-}
-
-/// Generate IR that prints the given string to stderr.
-static void createPrintMsg(OpBuilder &builder, Location loc, ModuleOp moduleOp,
-                           StringRef msg,
-                           const LLVMTypeConverter &typeConverter) {
-  auto ip = builder.saveInsertionPoint();
-  builder.setInsertionPointToStart(moduleOp.getBody());
-  MLIRContext *ctx = builder.getContext();
-
-  // Create a zero-terminated byte representation and allocate global symbol.
-  SmallVector<uint8_t> elementVals;
-  elementVals.append(msg.begin(), msg.end());
-  elementVals.push_back(0);
-  auto dataAttrType = RankedTensorType::get(
-      {static_cast<int64_t>(elementVals.size())}, builder.getI8Type());
-  auto dataAttr =
-      DenseElementsAttr::get(dataAttrType, llvm::ArrayRef(elementVals));
-  auto arrayTy =
-      LLVM::LLVMArrayType::get(IntegerType::get(ctx, 8), elementVals.size());
-  std::string symbolName = generateGlobalMsgSymbolName(moduleOp);
-  auto globalOp = builder.create<LLVM::GlobalOp>(
-      loc, arrayTy, /*constant=*/true, LLVM::Linkage::Private, symbolName,
-      dataAttr);
-
-  // Emit call to `printStr` in runtime library.
-  builder.restoreInsertionPoint(ip);
-  auto msgAddr = builder.create<LLVM::AddressOfOp>(
-      loc, typeConverter.getPointerType(arrayTy), globalOp.getName());
-  SmallVector<LLVM::GEPArg> indices(1, 0);
-  Value gep = builder.create<LLVM::GEPOp>(
-      loc, typeConverter.getPointerType(builder.getI8Type()), arrayTy, msgAddr,
-      indices);
-  Operation *printer = LLVM::lookupOrCreatePrintStrFn(
-      moduleOp, typeConverter.useOpaquePointers());
-  builder.create<LLVM::CallOp>(loc, TypeRange(), SymbolRefAttr::get(printer),
-                               gep);
-}
-
 namespace {
 /// Lower `cf.assert`. The default lowering calls the `abort` function if the
 /// assertion is violated and has no effect otherwise. The failure message is
@@ -105,7 +61,8 @@ struct AssertOpLowering : public ConvertOpToLLVMPattern<cf::AssertOp> {
 
     // Failed block: Generate IR to print the message and call `abort`.
     Block *failureBlock = rewriter.createBlock(opBlock->getParent());
-    createPrintMsg(rewriter, loc, module, op.getMsg(), *getTypeConverter());
+    LLVM::createPrintStrCall(rewriter, loc, module, "assert_msg", op.getMsg(),
+                             *getTypeConverter());
     if (abortOnFailedAssert) {
       // Insert the `abort` declaration if necessary.
       auto abortFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("abort");
diff --git a/mlir/lib/Conversion/LLVMCommon/CMakeLists.txt b/mlir/lib/Conversion/LLVMCommon/CMakeLists.txt
index 091cd539f0ae014..568d9339aaabcb4 100644
--- a/mlir/lib/Conversion/LLVMCommon/CMakeLists.txt
+++ b/mlir/lib/Conversion/LLVMCommon/CMakeLists.txt
@@ -3,6 +3,7 @@ add_mlir_conversion_library(MLIRLLVMCommonConversion
   LoweringOptions.cpp
   MemRefBuilder.cpp
   Pattern.cpp
+  PrintCallHelper.cpp
   StructBuilder.cpp
   TypeConverter.cpp
   VectorPattern.cpp
diff --git a/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp b/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp
new file mode 100644
index 000000000000000..487abb435d10ad7
--- /dev/null
+++ b/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp
@@ -0,0 +1,66 @@
+
+//===- PrintCallHelper.cpp - LLVM Interfaces --------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/LLVMCommon/PrintCallHelper.h"
+#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
+#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
+#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "llvm/ADT/ArrayRef.h"
+
+using namespace mlir;
+using namespace llvm;
+
+static std::string ensureSymbolNameIsUnique(ModuleOp moduleOp,
+                                            StringRef symbolName) {
+  static int counter = 0;
+  std::string uniqueName = std::string(symbolName);
+  while (moduleOp.lookupSymbol(uniqueName)) {
+    uniqueName = std::string(symbolName) + "_" + std::to_string(counter++);
+  }
+  return uniqueName;
+}
+
+void mlir::LLVM::createPrintStrCall(OpBuilder &builder, Location loc,
+                                    ModuleOp moduleOp, StringRef symbolName,
+                                    StringRef string,
+                                    const LLVMTypeConverter &typeConverter) {
+  auto ip = builder.saveInsertionPoint();
+  builder.setInsertionPointToStart(moduleOp.getBody());
+  MLIRContext *ctx = builder.getContext();
+
+  // Create a zero-terminated byte representation and allocate global symbol.
+  SmallVector<uint8_t> elementVals;
+  elementVals.append(string.begin(), string.end());
+  elementVals.push_back(0);
+  auto dataAttrType = RankedTensorType::get(
+      {static_cast<int64_t>(elementVals.size())}, builder.getI8Type());
+  auto dataAttr =
+      DenseElementsAttr::get(dataAttrType, llvm::ArrayRef(elementVals));
+  auto arrayTy =
+      LLVM::LLVMArrayType::get(IntegerType::get(ctx, 8), elementVals.size());
+  auto globalOp = builder.create<LLVM::GlobalOp>(
+      loc, arrayTy, /*constant=*/true, LLVM::Linkage::Private,
+      ensureSymbolNameIsUnique(moduleOp, symbolName), dataAttr);
+
+  // Emit call to `printStr` in runtime library.
+  builder.restoreInsertionPoint(ip);
+  auto msgAddr = builder.create<LLVM::AddressOfOp>(
+      loc, typeConverter.getPointerType(arrayTy), globalOp.getName());
+  SmallVector<LLVM::GEPArg> indices(1, 0);
+  Value gep = builder.create<LLVM::GEPOp>(
+      loc, typeConverter.getPointerType(builder.getI8Type()), arrayTy, msgAddr,
+      indices);
+  Operation *printer = LLVM::lookupOrCreatePrintStrFn(
+      moduleOp, typeConverter.useOpaquePointers());
+  builder.create<LLVM::CallOp>(loc, TypeRange(), SymbolRefAttr::get(printer),
+                               gep);
+}
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 8427d60f14c0bcc..4af58653c8227ae 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -9,6 +9,7 @@
 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
 
 #include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h"
+#include "mlir/Conversion/LLVMCommon/PrintCallHelper.h"
 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
 #include "mlir/Conversion/LLVMCommon/VectorPattern.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
@@ -1548,7 +1549,10 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
     }
 
     auto punct = printOp.getPunctuation();
-    if (punct != PrintPunctuation::NoPunctuation) {
+    if (auto stringLiteral = printOp.getStringLiteral()) {
+      LLVM::createPrintStrCall(rewriter, loc, parent, "vector_print_str",
+                               *stringLiteral, *getTypeConverter());
+    } else if (punct != PrintPunctuation::NoPunctuation) {
       emitCall(rewriter, printOp->getLoc(), [&] {
         switch (punct) {
         case PrintPunctuation::Close:
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 9aa4d735681f576..65b3a78e295f0c4 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1068,6 +1068,20 @@ func.func @vector_print_scalar_f64(%arg0: f64) {
 
 // -----
 
+// CHECK-LABEL: module {
+// CHECK: llvm.func @puts(!llvm.ptr)
+// CHECK: llvm.mlir.global private constant @[[GLOBAL_STR:.*]](dense<[72, 101, 108, 108, 111, 44, 32, 87, 111, 114, 108, 100, 33, 0]> : tensor<14xi8>) {addr_space = 0 : i32} : !llvm.array<14 x i8>
+// CHECK: @vector_print_string
+//       CHECK-NEXT: %[[GLOBAL_ADDR:.*]] = llvm.mlir.addressof @[[GLOBAL_STR]] : !llvm.ptr
+//       CHECK-NEXT: %[[STR_PTR:.*]] = llvm.getelementptr %[[GLOBAL_ADDR]][0] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<14 x i8>
+//       CHECK-NEXT: llvm.call @puts(%[[STR_PTR]]) : (!llvm.ptr) -> ()
+func.func @vector_print_string() {
+  vector.print str "Hello, World!"
+  return
+}
+
+// -----
+
 func.func @extract_strided_slice1(%arg0: vector<4xf32>) -> vector<2xf32> {
   %0 = vector.extract_strided_slice %arg0 {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
   return %0 : vector<2xf32>
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-hello-world.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-hello-world.mlir
new file mode 100644
index 000000000000000..c4076e65151ac72
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/test-hello-world.mlir
@@ -0,0 +1,10 @@
+// RUN: mlir-opt %s -test-lower-to-llvm | \
+// RUN: mlir-cpu-runner -e entry -entry-point-result=void  \
+// RUN:   -shared-libs=%mlir_c_runner_utils | \
+// RUN: FileCheck %s
+
+func.func @entry() {
+   // CHECK: Hello, World!
+   vector.print str "Hello, World!"
+   return
+}

>From 24d9b255eb695f2b4fa35e86a26b32b2882cdb7f Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Wed, 11 Oct 2023 15:01:35 +0000
Subject: [PATCH 2/3] Fixups

---
 .../mlir/Conversion/LLVMCommon/PrintCallHelper.h |  8 +-------
 mlir/include/mlir/Dialect/Vector/IR/VectorOps.td |  2 +-
 .../Conversion/LLVMCommon/PrintCallHelper.cpp    |  4 +---
 mlir/test/Dialect/Vector/invalid.mlir            | 16 ++++++++++++++++
 ...test-hello-world.mlir => test-print-str.mlir} |  4 ++++
 5 files changed, 23 insertions(+), 11 deletions(-)
 rename mlir/test/Integration/Dialect/Vector/CPU/{test-hello-world.mlir => test-print-str.mlir} (71%)

diff --git a/mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h b/mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h
index 7e26858589f2756..457cd98ca3dc2c8 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h
@@ -1,5 +1,4 @@
-
-//===- PrintCallHelper.h - LLVM Interfaces ----------------------*- C++ -*-===//
+//===- PrintCallHelper.h - Helper to emit runtime print calls ---*- C++ -*-===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -15,12 +14,7 @@
 
 namespace mlir {
 
-class Location;
-class ModuleOp;
 class OpBuilder;
-class Operation;
-class Type;
-class ValueRange;
 class LLVMTypeConverter;
 
 namespace LLVM {
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 2b60055ca9db94b..f946d124fb2fa5e 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2479,7 +2479,7 @@ def Vector_TransposeOp :
 def Vector_PrintOp :
   Vector_Op<"print", [
     PredOpTrait<
-      "`source` or `punctuation` are not set printing strings",
+      "`source` or `punctuation` are not set when printing strings",
       CPred<"!getStringLiteral() || (!getSource() && getPunctuation() == PrintPunctuation::NewLine)">
     >,
   ]>,
diff --git a/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp b/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp
index 487abb435d10ad7..40b9382452fbb45 100644
--- a/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp
@@ -1,5 +1,4 @@
-
-//===- PrintCallHelper.cpp - LLVM Interfaces --------------------*- C++ -*-===//
+//===- PrintCallHelper.cpp - Helper to emit runtime print calls -----------===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -8,7 +7,6 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Conversion/LLVMCommon/PrintCallHelper.h"
-#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 5967a8d69bbfcc0..1664ddde7e48d76 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1016,6 +1016,22 @@ func.func private @print_needs_vector(%arg0: tensor<8xf32>) {
 
 // -----
 
+func.func @cannot_print_string_with_punctuation_set() {
+  // expected-error at +1 {{`source` or `punctuation` are not set when printing strings}}
+  vector.print str "Whoops!" punctuation <comma>
+  return
+}
+
+// -----
+
+func.func @cannot_print_string_with_source_set(%vec: vector<[4]xf32>) {
+  // expected-error at +1 {{`source` or `punctuation` are not set when printing strings}}
+  vector.print %vec: vector<[4]xf32> str "Yay!"
+  return
+}
+
+// -----
+
 func.func @reshape_bad_input_shape(%arg0 : vector<3x2x4xf32>) {
   %c2 = arith.constant 2 : index
   %c3 = arith.constant 3 : index
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-hello-world.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-print-str.mlir
similarity index 71%
rename from mlir/test/Integration/Dialect/Vector/CPU/test-hello-world.mlir
rename to mlir/test/Integration/Dialect/Vector/CPU/test-print-str.mlir
index c4076e65151ac72..4a11987121b3308 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/test-hello-world.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/test-print-str.mlir
@@ -3,8 +3,12 @@
 // RUN:   -shared-libs=%mlir_c_runner_utils | \
 // RUN: FileCheck %s
 
+/// This tests printing (multiple) string literals works.
+
 func.func @entry() {
    // CHECK: Hello, World!
    vector.print str "Hello, World!"
+   // CHECK-NEXT: Bye!
+   vector.print str "Bye!"
    return
 }

>From 148c4f62752923bbf7f64e76fad6090625cfe962 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Wed, 11 Oct 2023 15:02:07 +0000
Subject: [PATCH 3/3] [mlir][test] Update tests to use `vector.print str` (NFC)

This cuts down on a fair amount of boilerplate.

Depends on: #68695
---
 ...ide-int-emulation-compare-results-i16.mlir | 14 +--------
 .../Dialect/Linalg/CPU/ArmSME/fill-2d.mlir    | 14 +--------
 .../Dialect/Linalg/CPU/ArmSVE/fill-1d.mlir    | 14 +--------
 .../CPU/ArmSME/load-store-128-bit-tile.mlir   | 23 +++-----------
 .../Vector/CPU/ArmSME/test-load-vertical.mlir | 31 +++----------------
 .../CPU/ArmSME/test-outerproduct-f32.mlir     | 31 +++----------------
 .../CPU/ArmSME/test-outerproduct-f64.mlir     | 27 ++--------------
 .../Vector/CPU/ArmSME/test-transpose.mlir     | 31 +++----------------
 .../Dialect/Vector/CPU/ArmSME/tile_fill.mlir  | 27 ++--------------
 .../Vector/CPU/ArmSME/vector-load-store.mlir  | 30 +++---------------
 10 files changed, 27 insertions(+), 215 deletions(-)

diff --git a/mlir/test/Integration/Dialect/Arith/CPU/test-wide-int-emulation-compare-results-i16.mlir b/mlir/test/Integration/Dialect/Arith/CPU/test-wide-int-emulation-compare-results-i16.mlir
index 213cd4de1ea9313..15bafeda67403eb 100644
--- a/mlir/test/Integration/Dialect/Arith/CPU/test-wide-int-emulation-compare-results-i16.mlir
+++ b/mlir/test/Integration/Dialect/Arith/CPU/test-wide-int-emulation-compare-results-i16.mlir
@@ -15,18 +15,6 @@
 // Common Utility Functions
 //===----------------------------------------------------------------------===//
 
-llvm.mlir.global internal constant @str_mismatch("Mismatch\0A")
-func.func private @printCString(!llvm.ptr<i8>) -> ()
-// Prints 'Mismatch' to stdout.
-func.func @printMismatch() -> () {
-  %0 = llvm.mlir.addressof @str_mismatch : !llvm.ptr<array<9 x i8>>
-  %1 = llvm.mlir.constant(0 : index) : i64
-  %2 = llvm.getelementptr %0[%1, %1]
-    : (!llvm.ptr<array<9 x i8>>, i64, i64) -> !llvm.ptr<i8>
-  func.call @printCString(%2) : (!llvm.ptr<i8>) -> ()
-  return
-}
-
 // Prints both binary op operands and the first result. If the second result
 // does not match, prints the second result and a 'Mismatch' message.
 func.func @check_results(%lhs : i16, %rhs : i16, %res0 : i16, %res1 : i16) -> () {
@@ -38,7 +26,7 @@ func.func @check_results(%lhs : i16, %rhs : i16, %res0 : i16, %res1 : i16) -> ()
   %mismatch = arith.cmpi ne, %res0, %res1 : i16
   scf.if %mismatch -> () {
     vector.print %res1 : i16
-    func.call @printMismatch() : () -> ()
+    vector.print str "Mismatch"
   }
   return
 }
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/fill-2d.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/fill-2d.mlir
index 08f14dfae3249f2..fc445eed0ab3216 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/fill-2d.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/fill-2d.mlir
@@ -13,15 +13,6 @@
 // RUN:   -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils | \
 // RUN: FileCheck %s
 
-func.func @printTestEnd() {
-  %0 = llvm.mlir.addressof @str_sme_end : !llvm.ptr<array<24 x i8>>
-  %1 = llvm.mlir.constant(0 : index) : i64
-  %2 = llvm.getelementptr %0[%1, %1]
-    : (!llvm.ptr<array<24 x i8>>, i64, i64) -> !llvm.ptr<i8>
-  llvm.call @printCString(%2) : (!llvm.ptr<i8>) -> ()
-  return
-}
-
 func.func @entry() {
   %c0 = arith.constant 0 : index
   %c4 = arith.constant 4 : index
@@ -104,7 +95,7 @@ func.func @entry() {
   }
 
   // CHECK: SME: END OF TEST OUTPUT
-  func.call @printTestEnd() : () -> ()
+  vector.print str "SME: END OF TEST OUTPUT"
 
   return
 }
@@ -114,6 +105,3 @@ transform.sequence failures(propagate) {
   %0 = transform.structured.match ops{["linalg.fill"]} in %arg1 : (!transform.any_op) -> !transform.any_op
   transform.structured.vectorize %0 vector_sizes [[4], [4]] : !transform.any_op
 }
-
-llvm.func @printCString(!llvm.ptr<i8>)
-llvm.mlir.global internal constant @str_sme_end("SME: END OF TEST OUTPUT\0A")
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/fill-1d.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/fill-1d.mlir
index c3f49b2f39cf137..2907f4ef80aeb24 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/fill-1d.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/fill-1d.mlir
@@ -2,15 +2,6 @@
 // RUN: %mcr_aarch64_cmd -e=entry -entry-point-result=void --march=aarch64 --mattr="+sve" -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils | \
 // RUN: FileCheck %s
 
-func.func @printTestEnd() {
-  %0 = llvm.mlir.addressof @str_sve_end : !llvm.ptr<array<24 x i8>>
-  %1 = llvm.mlir.constant(0 : index) : i64
-  %2 = llvm.getelementptr %0[%1, %1]
-    : (!llvm.ptr<array<24 x i8>>, i64, i64) -> !llvm.ptr<i8>
-  llvm.call @printCString(%2) : (!llvm.ptr<i8>) -> ()
-  return
-}
-
 func.func @entry() {
   %c4 = arith.constant 4 : index
   %c0 = arith.constant 0 : index
@@ -41,7 +32,7 @@ func.func @entry() {
   }
 
   // CHECK: SVE: END OF TEST OUTPUT
-  func.call @printTestEnd() : () -> ()
+  vector.print str "SVE: END OF TEST OUTPUT"
 
   return
 }
@@ -51,6 +42,3 @@ transform.sequence failures(propagate) {
   %0 = transform.structured.match ops{["linalg.fill"]} in %arg1 : (!transform.any_op) -> !transform.any_op
   transform.structured.vectorize %0 vector_sizes [[4]] : !transform.any_op
 }
-
-llvm.func @printCString(!llvm.ptr<i8>)
-llvm.mlir.global internal constant @str_sve_end("SVE: END OF TEST OUTPUT\0A")
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/load-store-128-bit-tile.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/load-store-128-bit-tile.mlir
index de1fff5bea3f8b7..78f1bede5a6a529 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/load-store-128-bit-tile.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/load-store-128-bit-tile.mlir
@@ -29,16 +29,6 @@ func.func @print_i8s(%bytes: memref<?xi8>, %len: index) {
   return
 }
 
-llvm.func @printCString(!llvm.ptr<i8>)
-
-func.func @print_str(%str: !llvm.ptr<array<17 x i8>>) {
-  %c0 = llvm.mlir.constant(0 : index) : i64
-  %str_bytes = llvm.getelementptr %str[%c0, %c0]
-    : (!llvm.ptr<array<17 x i8>>, i64, i64) -> !llvm.ptr<i8>
-  llvm.call @printCString(%str_bytes) : (!llvm.ptr<i8>) -> ()
-  return
-}
-
 func.func @vector_copy_i128(%src: memref<?x?xi128>, %dst: memref<?x?xi128>) {
   %c0 = arith.constant 0 : index
   %tile = vector.load %src[%c0, %c0] : memref<?x?xi128>, vector<[1]x[1]xi128>
@@ -80,13 +70,13 @@ func.func @test_load_store_zaq0() {
 
   // CHECK-LABEL: INITIAL TILE A:
   // CHECK: ( 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7 )
-  func.call @print_str(%init_a_str) : (!llvm.ptr<array<17 x i8>>) -> ()
+  vector.print str "INITIAL TILE A:"
   func.call @print_i8s(%tile_a_bytes, %zaq_size_bytes) : (memref<?xi8>, index) -> ()
   vector.print punctuation <newline>
 
   // CHECK-LABEL: INITIAL TILE B:
   // CHECK: ( 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64 )
-  func.call @print_str(%init_b_str) : (!llvm.ptr<array<17 x i8>>) -> ()
+  vector.print str "INITIAL TILE B:"
   func.call @print_i8s(%tile_b_bytes, %zaq_size_bytes) : (memref<?xi8>, index) -> ()
   vector.print punctuation <newline>
 
@@ -95,19 +85,14 @@ func.func @test_load_store_zaq0() {
 
   // CHECK-LABEL: FINAL TILE A:
   // CHECK: ( 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7 )
-  func.call @print_str(%final_a_str) : (!llvm.ptr<array<17 x i8>>) -> ()
+  vector.print str "FINAL TILE A:"
   func.call @print_i8s(%tile_a_bytes, %zaq_size_bytes) : (memref<?xi8>, index) -> ()
   vector.print punctuation <newline>
 
   // CHECK-LABEL: FINAL TILE B:
   // CHECK: ( 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7 )
-  func.call @print_str(%final_b_str) : (!llvm.ptr<array<17 x i8>>) -> ()
+  vector.print str "FINAL TILE B:"
   func.call @print_i8s(%tile_b_bytes, %zaq_size_bytes) : (memref<?xi8>, index) -> ()
 
   return
 }
-
-llvm.mlir.global internal constant @init_tile_a ("INITIAL TILE A:\0A\00")
-llvm.mlir.global internal constant @init_tile_b ("INITIAL TILE B:\0A\00")
-llvm.mlir.global internal constant @final_tile_a("  FINAL TILE A:\0A\00")
-llvm.mlir.global internal constant @final_tile_b("  FINAL TILE B:\0A\00")
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-load-vertical.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-load-vertical.mlir
index 8c7d8c954d38475..0b9e83b28a767c1 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-load-vertical.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-load-vertical.mlir
@@ -11,26 +11,6 @@
 
 // RUN: %{compile} | %{run} | FileCheck %s
 
-llvm.func @printCString(!llvm.ptr<i8>)
-
-func.func @printTileBegin() attributes { enable_arm_streaming_ignore } {
-  %0 = llvm.mlir.addressof @str_tile_begin : !llvm.ptr<array<11 x i8>>
-  %1 = llvm.mlir.constant(0 : index) : i64
-  %2 = llvm.getelementptr %0[%1, %1]
-    : (!llvm.ptr<array<11 x i8>>, i64, i64) -> !llvm.ptr<i8>
-  llvm.call @printCString(%2) : (!llvm.ptr<i8>) -> ()
-  return
-}
-
-func.func @printTileEnd() attributes { enable_arm_streaming_ignore } {
-  %0 = llvm.mlir.addressof @str_tile_end : !llvm.ptr<array<9 x i8>>
-  %1 = llvm.mlir.constant(0 : index) : i64
-  %2 = llvm.getelementptr %0[%1, %1]
-    : (!llvm.ptr<array<9 x i8>>, i64, i64) -> !llvm.ptr<i8>
-  llvm.call @printCString(%2) : (!llvm.ptr<i8>) -> ()
-  return
-}
-
 func.func @entry() {
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index
@@ -75,12 +55,12 @@ func.func @entry() {
   // CHECK-NEXT: ( 2, 2, 2, 2
   // CHECK-NEXT: ( 3, 3, 3, 3
   // CHECK:      TILE END
-  func.call @printTileBegin() : () -> ()
+  vector.print str "TILE BEGIN"
   scf.for %i = %c0 to %za_s_size step %svl_s {
     %tileslice = vector.load %mem1[%i] : memref<?xi32>, vector<[4]xi32>
     vector.print %tileslice : vector<[4]xi32>
   }
-  func.call @printTileEnd() : () -> ()
+  vector.print str "TILE END"
 
   // 2. VERTICAL LAYOUT
   // Dump "mem2". The smallest SVL is 128-bits so the tile will be at least
@@ -92,12 +72,9 @@ func.func @entry() {
   // CHECK-NEXT: ( 0, 1, 2, 3
   // CHECK-NEXT: ( 0, 1, 2, 3
   // CHECK:      TILE END
-  func.call @printTileBegin() : () -> ()
+  vector.print str "TILE BEGIN"
   vector.print %0 : vector<[4]x[4]xi32>
-  func.call @printTileEnd() : () -> ()
+  vector.print str "TILE END"
 
   return
 }
-
-llvm.mlir.global internal constant @str_tile_begin("TILE BEGIN\0A")
-llvm.mlir.global internal constant @str_tile_end("TILE END\0A")
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir
index 4265ca0f599281c..38ba489e2fafb2c 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir
@@ -14,26 +14,6 @@
 // REDEFINE: %{entry_point} = test_outerproduct_with_accumulator_4x4xf32
 // RUN: %{compile} | %{run} | FileCheck %s --check-prefix=WITH-ACC
 
-llvm.func @printCString(!llvm.ptr<i8>)
-
-func.func @printTileBegin() attributes { enable_arm_streaming_ignore } {
-  %0 = llvm.mlir.addressof @str_tile_begin : !llvm.ptr<array<11 x i8>>
-  %1 = llvm.mlir.constant(0 : index) : i64
-  %2 = llvm.getelementptr %0[%1, %1]
-    : (!llvm.ptr<array<11 x i8>>, i64, i64) -> !llvm.ptr<i8>
-  llvm.call @printCString(%2) : (!llvm.ptr<i8>) -> ()
-  return
-}
-
-func.func @printTileEnd() attributes { enable_arm_streaming_ignore } {
-  %0 = llvm.mlir.addressof @str_tile_end : !llvm.ptr<array<9 x i8>>
-  %1 = llvm.mlir.constant(0 : index) : i64
-  %2 = llvm.getelementptr %0[%1, %1]
-    : (!llvm.ptr<array<9 x i8>>, i64, i64) -> !llvm.ptr<i8>
-  llvm.call @printCString(%2) : (!llvm.ptr<i8>) -> ()
-  return
-}
-
 func.func @test_outerproduct_no_accumulator_4x4xf32() {
   %c0 = arith.constant 0 : index
 
@@ -50,9 +30,9 @@ func.func @test_outerproduct_no_accumulator_4x4xf32() {
   // WITHOUT-ACC-NEXT: ( 0, 2, 4, 6
   // WITHOUT-ACC-NEXT: ( 0, 3, 6, 9
   // WITHOUT-ACC:      TILE END
-  func.call @printTileBegin() : () -> ()
+  vector.print str "TILE BEGIN"
   vector.print %tile : vector<[4]x[4]xf32>
-  func.call @printTileEnd() : () -> ()
+  vector.print str "TILE END"
 
   return
 }
@@ -75,12 +55,9 @@ func.func @test_outerproduct_with_accumulator_4x4xf32() {
   // WITH-ACC-NEXT: ( 10, 12, 14, 16
   // WITH-ACC-NEXT: ( 10, 13, 16, 19
   // WITH-ACC:      TILE END
-  func.call @printTileBegin() : () -> ()
+  vector.print str "TILE BEGIN"
   vector.print %tile : vector<[4]x[4]xf32>
-  func.call @printTileEnd() : () -> ()
+  vector.print str "TILE END"
 
   return
 }
-
-llvm.mlir.global internal constant @str_tile_begin("TILE BEGIN\0A")
-llvm.mlir.global internal constant @str_tile_end("TILE END\0A")
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir
index cb2c6b98a4eef3a..82f14595a24da2f 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir
@@ -11,26 +11,6 @@
 
 // RUN: %{compile} | %{run} | FileCheck %s
 
-llvm.func @printCString(!llvm.ptr<i8>)
-
-func.func @printTileBegin() attributes { enable_arm_streaming_ignore } {
-  %0 = llvm.mlir.addressof @str_tile_begin : !llvm.ptr<array<11 x i8>>
-  %1 = llvm.mlir.constant(0 : index) : i64
-  %2 = llvm.getelementptr %0[%1, %1]
-    : (!llvm.ptr<array<11 x i8>>, i64, i64) -> !llvm.ptr<i8>
-  llvm.call @printCString(%2) : (!llvm.ptr<i8>) -> ()
-  return
-}
-
-func.func @printTileEnd() attributes { enable_arm_streaming_ignore } {
-  %0 = llvm.mlir.addressof @str_tile_end : !llvm.ptr<array<9 x i8>>
-  %1 = llvm.mlir.constant(0 : index) : i64
-  %2 = llvm.getelementptr %0[%1, %1]
-    : (!llvm.ptr<array<9 x i8>>, i64, i64) -> !llvm.ptr<i8>
-  llvm.call @printCString(%2) : (!llvm.ptr<i8>) -> ()
-  return
-}
-
 func.func @test_outerproduct_with_accumulator_2x2xf64() {
   %f1 = arith.constant 1.0 : f64
   %f2 = arith.constant 2.0 : f64
@@ -50,12 +30,9 @@ func.func @test_outerproduct_with_accumulator_2x2xf64() {
   // CHECK-NEXT: ( 12, 12
   // CHECK-NEXT: ( 12, 12
   // CHECK:      TILE END
-  func.call @printTileBegin() : () -> ()
+  vector.print str "TILE BEGIN"
   vector.print %tile : vector<[2]x[2]xf64>
-  func.call @printTileEnd() : () -> ()
+  vector.print str "TILE END"
 
   return
 }
-
-llvm.mlir.global internal constant @str_tile_begin("TILE BEGIN\0A")
-llvm.mlir.global internal constant @str_tile_end("TILE END\0A")
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transpose.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transpose.mlir
index 4bb9258098d98fd..65b930115e88895 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transpose.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transpose.mlir
@@ -11,26 +11,6 @@
 
 // RUN: %{compile} | %{run} | FileCheck %s
 
-llvm.func @printCString(!llvm.ptr<i8>)
-
-func.func @printTileBegin() attributes { enable_arm_streaming_ignore }  {
-  %0 = llvm.mlir.addressof @str_tile_begin : !llvm.ptr<array<11 x i8>>
-  %1 = llvm.mlir.constant(0 : index) : i64
-  %2 = llvm.getelementptr %0[%1, %1]
-    : (!llvm.ptr<array<11 x i8>>, i64, i64) -> !llvm.ptr<i8>
-  llvm.call @printCString(%2) : (!llvm.ptr<i8>) -> ()
-  return
-}
-
-func.func @printTileEnd() attributes { enable_arm_streaming_ignore } {
-  %0 = llvm.mlir.addressof @str_tile_end : !llvm.ptr<array<9 x i8>>
-  %1 = llvm.mlir.constant(0 : index) : i64
-  %2 = llvm.getelementptr %0[%1, %1]
-    : (!llvm.ptr<array<9 x i8>>, i64, i64) -> !llvm.ptr<i8>
-  llvm.call @printCString(%2) : (!llvm.ptr<i8>) -> ()
-  return
-}
-
 func.func @entry() {
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index
@@ -77,9 +57,9 @@ func.func @entry() {
   // CHECK-NEXT: ( 2, 2, 2, 2
   // CHECK-NEXT: ( 3, 3, 3, 3
   // CHECK:      TILE END
-  func.call @printTileBegin() : () -> ()
+  vector.print str "TILE BEGIN"
   vector.print %tile : vector<[4]x[4]xi32>
-  func.call @printTileEnd() : () -> ()
+  vector.print str "TILE END"
 
   // Dump the transposed tile. The smallest SVL is 128-bits so the tile will be
   // at least 4x4xi32.
@@ -90,12 +70,9 @@ func.func @entry() {
   // CHECK-NEXT: ( 0, 1, 2, 3
   // CHECK-NEXT: ( 0, 1, 2, 3
   // CHECK:      TILE END
-  func.call @printTileBegin() : () -> ()
+  vector.print str "TILE BEGIN"
   vector.print %transposed_tile : vector<[4]x[4]xi32>
-  func.call @printTileEnd() : () -> ()
+  vector.print str "TILE END"
 
   return
 }
-
-llvm.mlir.global internal constant @str_tile_begin("TILE BEGIN\0A")
-llvm.mlir.global internal constant @str_tile_end("TILE END\0A")
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/tile_fill.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/tile_fill.mlir
index fe6ded71c1613fa..92031586b8cfc91 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/tile_fill.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/tile_fill.mlir
@@ -11,26 +11,6 @@
 // Integration test demonstrating filling a 32-bit element ZA tile with a
 // non-zero constant via vector to tile (MOVA) ops.
 
-llvm.func @printCString(!llvm.ptr<i8>)
-
-func.func @printTileBegin() attributes { enable_arm_streaming_ignore } {
-  %0 = llvm.mlir.addressof @str_tile_begin : !llvm.ptr<array<11 x i8>>
-  %1 = llvm.mlir.constant(0 : index) : i64
-  %2 = llvm.getelementptr %0[%1, %1]
-    : (!llvm.ptr<array<11 x i8>>, i64, i64) -> !llvm.ptr<i8>
-  llvm.call @printCString(%2) : (!llvm.ptr<i8>) -> ()
-  return
-}
-
-func.func @printTileEnd() attributes { enable_arm_streaming_ignore } {
-  %0 = llvm.mlir.addressof @str_tile_end : !llvm.ptr<array<9 x i8>>
-  %1 = llvm.mlir.constant(0 : index) : i64
-  %2 = llvm.getelementptr %0[%1, %1]
-    : (!llvm.ptr<array<9 x i8>>, i64, i64) -> !llvm.ptr<i8>
-  llvm.call @printCString(%2) : (!llvm.ptr<i8>) -> ()
-  return
-}
-
 func.func @entry() -> i32 {
   // Fill a tile with '123'. This will get lowered to a 1-d vector splat of
   // '123' and a loop that writes this vector to each tile slice in the ZA
@@ -46,13 +26,10 @@ func.func @entry() -> i32 {
   // CHECK-NEXT: ( 123, 123, 123, 123
   // CHECK-NEXT: ( 123, 123, 123, 123
   // CHECK:      TILE END
-  func.call @printTileBegin() : () -> ()
+  vector.print str "TILE BEGIN"
   vector.print %tile : vector<[4]x[4]xi32>
-  func.call @printTileEnd() : () -> ()
+  vector.print str "TILE END"
 
   %c0_i32 = arith.constant 0 : i32
   return %c0_i32 : i32
 }
-
-llvm.mlir.global internal constant @str_tile_begin("TILE BEGIN\0A")
-llvm.mlir.global internal constant @str_tile_end("TILE END\0A")
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir
index d5c926ebd779f2a..adf1d365cb99823 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir
@@ -16,8 +16,6 @@
 
 // Integration tests demonstrating load/store to/from SME ZA tile.
 
-llvm.func @printCString(!llvm.ptr<i8>)
-
 // This test verifies a 64-bit element ZA with FP64 data is correctly
 // loaded/stored to/from memory.
 func.func @za0_d_f64() -> i32 {
@@ -160,24 +158,6 @@ func.func @za0_d_f64() -> i32 {
   return %c0_i32 : i32
 }
 
-func.func @printTileBegin() {
-  %0 = llvm.mlir.addressof @str_tile_begin : !llvm.ptr<array<11 x i8>>
-  %1 = llvm.mlir.constant(0 : index) : i64
-  %2 = llvm.getelementptr %0[%1, %1]
-    : (!llvm.ptr<array<11 x i8>>, i64, i64) -> !llvm.ptr<i8>
-  llvm.call @printCString(%2) : (!llvm.ptr<i8>) -> ()
-  return
-}
-
-func.func @printTileEnd() {
-  %0 = llvm.mlir.addressof @str_tile_end : !llvm.ptr<array<9 x i8>>
-  %1 = llvm.mlir.constant(0 : index) : i64
-  %2 = llvm.getelementptr %0[%1, %1]
-    : (!llvm.ptr<array<9 x i8>>, i64, i64) -> !llvm.ptr<i8>
-  llvm.call @printCString(%2) : (!llvm.ptr<i8>) -> ()
-  return
-}
-
 // This test loads two 32-bit element ZA tiles from memory and stores them back
 // to memory in reverse order. This verifies the memref indices for the vector
 // load and store are correctly preserved since the second tile is offset from
@@ -285,7 +265,7 @@ func.func @load_store_two_za_s_tiles() -> i32 {
   // CHECK-NEXT: ( 1, 1, 1, 1
   // CHECK-NEXT: ( 1, 1, 1, 1
   // CHECK:      TILE END
-  func.call @printTileBegin() : () -> ()
+  vector.print str "TILE BEGIN"
   scf.for %i = %c0 to %size_of_two_tiles step %svl_s {
     %av = vector.load %mem2[%i] : memref<?xi32>, vector<[4]xi32>
     vector.print %av : vector<[4]xi32>
@@ -293,14 +273,12 @@ func.func @load_store_two_za_s_tiles() -> i32 {
     %tileSizeMinusStep = arith.subi %size_of_tile, %svl_s : index
     %isNextTile = arith.cmpi eq, %i, %tileSizeMinusStep : index
     scf.if %isNextTile {
-      func.call @printTileEnd() : () -> ()
-      func.call @printTileBegin() : () -> ()
+      vector.print str "TILE END"
+      vector.print str "TILE BEGIN"
     }
   }
-  func.call @printTileEnd() : () -> ()
+  vector.print str "TILE END"
 
   return %c0_i32 : i32
 }
 
-llvm.mlir.global internal constant @str_tile_begin("TILE BEGIN\0A")
-llvm.mlir.global internal constant @str_tile_end("TILE END\0A")



More information about the Mlir-commits mailing list