[flang-commits] [flang] [mlir][VectorOps] Support string literals in `vector.print` (PR #68695)

Benjamin Maxwell via flang-commits flang-commits at lists.llvm.org
Tue Oct 10 08:29:06 PDT 2023


https://github.com/MacDue updated https://github.com/llvm/llvm-project/pull/68695

>From e521859869a75e9ff7117c2a37b3d1cb55ef7654 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Tue, 10 Oct 2023 11:11:33 +0000
Subject: [PATCH 1/2] [mlir][ODS] Omit printing default-valued attributes in
 oilists

This makes these match the behaviour of optional attributes (which are
omitted when they are their default value of `none`). This allows for
concise assembly formats without a custom printer.

An extra print of " " is also removed, this does change any existing
uses of oilists, but if the parameter before the oilist is optional,
that would previously add an extra space.
---
 flang/test/Lower/OpenMP/FIR/atomic-read.f90   |  2 +-
 flang/test/Lower/OpenMP/FIR/critical.f90      |  2 +-
 flang/test/Lower/OpenMP/critical.f90          |  2 +-
 .../OpenMPToLLVM/convert-to-llvmir.mlir       |  4 +--
 mlir/test/Dialect/OpenMP/ops.mlir             | 10 ++++----
 mlir/tools/mlir-tblgen/OpFormatGen.cpp        | 25 +++++++++++++------
 6 files changed, 27 insertions(+), 18 deletions(-)

diff --git a/flang/test/Lower/OpenMP/FIR/atomic-read.f90 b/flang/test/Lower/OpenMP/FIR/atomic-read.f90
index ff2b651953f2abc..0079b347fac8de6 100644
--- a/flang/test/Lower/OpenMP/FIR/atomic-read.f90
+++ b/flang/test/Lower/OpenMP/FIR/atomic-read.f90
@@ -14,7 +14,7 @@
 !CHECK: %[[VAR_X:.*]] = fir.alloca i32 {bindc_name = "x", uniq_name = "_QFEx"}
 !CHECK: %[[VAR_Y:.*]] = fir.alloca i32 {bindc_name = "y", uniq_name = "_QFEy"}
 !CHECK: omp.atomic.read %[[VAR_X]] = %[[VAR_Y]] memory_order(acquire)  hint(uncontended) : !fir.ref<i32>, i32
-!CHECK: omp.atomic.read %[[VAR_A]] = %[[VAR_B]] memory_order(relaxed) hint(none)  : !fir.ref<!fir.char<1>>, !fir.char<1>
+!CHECK: omp.atomic.read %[[VAR_A]] = %[[VAR_B]] memory_order(relaxed) : !fir.ref<!fir.char<1>>, !fir.char<1>
 !CHECK: omp.atomic.read %[[VAR_C]] = %[[VAR_D]] memory_order(seq_cst)  hint(contended) : !fir.ref<!fir.logical<4>>, !fir.logical<4>
 !CHECK: omp.atomic.read %[[VAR_E]] = %[[VAR_F]] hint(speculative) : !fir.ref<!fir.char<1,8>>, !fir.char<1,8>
 !CHECK: omp.atomic.read %[[VAR_G]] = %[[VAR_H]] hint(nonspeculative) : !fir.ref<f32>, f32
diff --git a/flang/test/Lower/OpenMP/FIR/critical.f90 b/flang/test/Lower/OpenMP/FIR/critical.f90
index c6ac818fe21aa6e..b86729f8a98e370 100644
--- a/flang/test/Lower/OpenMP/FIR/critical.f90
+++ b/flang/test/Lower/OpenMP/FIR/critical.f90
@@ -2,7 +2,7 @@
 !RUN: %flang_fc1 -emit-fir -fopenmp %s -o - | fir-opt --fir-to-llvm-ir | FileCheck %s --check-prefix="OMPDialect"
 !RUN: %flang_fc1 -emit-fir -fopenmp %s -o - | fir-opt --fir-to-llvm-ir | tco | FileCheck %s --check-prefix="LLVMIR"
 
-!OMPDialect: omp.critical.declare @help2 hint(none)
+!OMPDialect: omp.critical.declare @help2
 !OMPDialect: omp.critical.declare @help1 hint(contended)
 
 subroutine omp_critical()
diff --git a/flang/test/Lower/OpenMP/critical.f90 b/flang/test/Lower/OpenMP/critical.f90
index 9fbd172df96421c..5a4d2e4815df49e 100644
--- a/flang/test/Lower/OpenMP/critical.f90
+++ b/flang/test/Lower/OpenMP/critical.f90
@@ -1,6 +1,6 @@
 !RUN: %flang_fc1 -emit-hlfir -fopenmp %s -o - | FileCheck %s
 
-!CHECK: omp.critical.declare @help2 hint(none)
+!CHECK: omp.critical.declare @help2
 !CHECK: omp.critical.declare @help1 hint(contended)
 
 subroutine omp_critical()
diff --git a/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir b/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir
index 1df27dd9957e594..881d738b413ef15 100644
--- a/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir
+++ b/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir
@@ -90,7 +90,7 @@ func.func @wsloop(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4:
 // CHECK-LABEL: @atomic_write
 // CHECK: (%[[ARG0:.*]]: !llvm.ptr<i32>)
 // CHECK: %[[VAL0:.*]] = llvm.mlir.constant(1 : i32) : i32
-// CHECK: omp.atomic.write %[[ARG0]] = %[[VAL0]] hint(none) memory_order(relaxed) : !llvm.ptr<i32>, i32
+// CHECK: omp.atomic.write %[[ARG0]] = %[[VAL0]] memory_order(relaxed) : !llvm.ptr<i32>, i32
 func.func @atomic_write(%a: !llvm.ptr<i32>) -> () {
   %1 = arith.constant 1 : i32
   omp.atomic.write %a = %1 hint(none) memory_order(relaxed) : !llvm.ptr<i32>, i32
@@ -474,4 +474,4 @@ llvm.func @_QPtarget_map_with_bounds(%arg0: !llvm.ptr<i32>, %arg1: !llvm.ptr<arr
     omp.terminator
   }
   llvm.return
-}
\ No newline at end of file
+}
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index 13cbea6c9923c22..27c31824c0506df 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -831,7 +831,7 @@ omp.critical.declare @mutex6 hint(contended, nonspeculative)
 omp.critical.declare @mutex7 hint(uncontended, speculative)
 // CHECK: omp.critical.declare @mutex8 hint(contended, speculative)
 omp.critical.declare @mutex8 hint(contended, speculative)
-// CHECK: omp.critical.declare @mutex9 hint(none)
+// CHECK: omp.critical.declare @mutex9
 omp.critical.declare @mutex9 hint(none)
 // CHECK: omp.critical.declare @mutex10
 omp.critical.declare @mutex10
@@ -909,7 +909,7 @@ func.func @omp_atomic_read(%v: memref<i32>, %x: memref<i32>) {
   omp.atomic.read %v = %x hint(nonspeculative, contended) : memref<i32>, i32
   // CHECK: omp.atomic.read %[[v]] = %[[x]] memory_order(seq_cst) hint(contended, speculative) : memref<i32>, i32
   omp.atomic.read %v = %x hint(speculative, contended) memory_order(seq_cst) : memref<i32>, i32
-  // CHECK: omp.atomic.read %[[v]] = %[[x]] memory_order(seq_cst) hint(none) : memref<i32>, i32
+  // CHECK: omp.atomic.read %[[v]] = %[[x]] memory_order(seq_cst) : memref<i32>, i32
   omp.atomic.read %v = %x hint(none) memory_order(seq_cst) : memref<i32>, i32
   return
 }
@@ -927,7 +927,7 @@ func.func @omp_atomic_write(%addr : memref<i32>, %val : i32) {
   omp.atomic.write %addr = %val memory_order(relaxed) : memref<i32>, i32
   // CHECK: omp.atomic.write %[[ADDR]] = %[[VAL]] hint(uncontended, speculative) : memref<i32>, i32
   omp.atomic.write %addr = %val hint(speculative, uncontended) : memref<i32>, i32
-  // CHECK: omp.atomic.write %[[ADDR]] = %[[VAL]] hint(none) : memref<i32>, i32
+  // CHECK: omp.atomic.write %[[ADDR]] = %[[VAL]] : memref<i32>, i32
   omp.atomic.write %addr = %val hint(none) : memref<i32>, i32
   return
 }
@@ -1004,7 +1004,7 @@ func.func @omp_atomic_update(%x : memref<i32>, %expr : i32, %xBool : memref<i1>,
     omp.yield(%const:i32)
   }
 
-  // CHECK: omp.atomic.update hint(none) %[[X]] : memref<i32>
+  // CHECK: omp.atomic.update %[[X]] : memref<i32>
   // CHECK-NEXT: (%[[XVAL:.*]]: i32):
   // CHECK-NEXT:   %[[NEWVAL:.*]] = llvm.add %[[XVAL]], %[[EXPR]] : i32
   // CHECK-NEXT:   omp.yield(%[[NEWVAL]] : i32)
@@ -1181,7 +1181,7 @@ func.func @omp_atomic_capture(%v: memref<i32>, %x: memref<i32>, %expr: i32) {
     omp.atomic.write %x = %expr : memref<i32>, i32
   }
 
-  // CHECK: omp.atomic.capture hint(none) {
+  // CHECK: omp.atomic.capture {
   // CHECK-NEXT: omp.atomic.update %[[x]] : memref<i32>
   // CHECK-NEXT: (%[[xval:.*]]: i32):
   // CHECK-NEXT:   %[[newval:.*]] = llvm.add %[[xval]], %[[expr]] : i32
diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
index bdb97866a47fc9d..240780b0abf1156 100644
--- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
@@ -2009,6 +2009,15 @@ static void genEnumAttrPrinter(const NamedAttribute *var, const Operator &op,
           "  }\n";
 }
 
+static void genNonDefaultValueCheck(MethodBody &body, const Operator &op,
+                                    AttributeVariable &attrElement) {
+  FmtContext fctx;
+  Attribute attr = attrElement.getVar()->attr;
+  fctx.withBuilder("::mlir::OpBuilder((*this)->getContext())");
+  body << " && " << op.getGetterName(attrElement.getVar()->name) << "Attr() != "
+       << tgfmt(attr.getConstBuilderTemplate(), &fctx, attr.getDefaultValue());
+}
+
 /// Generate the check for the anchor of an optional group.
 static void genOptionalGroupPrinterAnchor(FormatElement *anchor,
                                           const Operator &op,
@@ -2042,12 +2051,7 @@ static void genOptionalGroupPrinterAnchor(FormatElement *anchor,
         if (attr.hasDefaultValue()) {
           // Consider a default-valued attribute as present if it's not the
           // default value.
-          FmtContext fctx;
-          fctx.withBuilder("::mlir::OpBuilder((*this)->getContext())");
-          body << " && " << op.getGetterName(element->getVar()->name)
-               << "Attr() != "
-               << tgfmt(attr.getConstBuilderTemplate(), &fctx,
-                        attr.getDefaultValue());
+          genNonDefaultValueCheck(body, op, *element);
           return;
         }
         llvm_unreachable("attribute must be optional or default-valued");
@@ -2158,7 +2162,6 @@ void OperationFormat::genElementPrinter(FormatElement *element,
 
   // Emit the OIList
   if (auto *oilist = dyn_cast<OIListElement>(element)) {
-    genLiteralPrinter(" ", body, shouldEmitSpace, lastWasPunctuation);
     for (auto clause : oilist->getClauses()) {
       LiteralElement *lelement = std::get<0>(clause);
       ArrayRef<FormatElement *> pelement = std::get<1>(clause);
@@ -2170,8 +2173,14 @@ void OperationFormat::genElementPrinter(FormatElement *element,
       for (VariableElement *var : vars) {
         TypeSwitch<FormatElement *>(var)
             .Case([&](AttributeVariable *attrEle) {
-              body << " || " << op.getGetterName(attrEle->getVar()->name)
+              body << " || (" << op.getGetterName(attrEle->getVar()->name)
                    << "Attr()";
+              Attribute attr = attrEle->getVar()->attr;
+              if (attr.hasDefaultValue()) {
+                // Don't print default-valued attributes.
+                genNonDefaultValueCheck(body, op, *attrEle);
+              }
+              body << ")";
             })
             .Case([&](OperandVariable *ele) {
               if (ele->getVar()->isVariadic()) {

>From 6d96a5527d59c9262701ea98bee223488a84bb37 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Tue, 10 Oct 2023 11:38:12 +0000
Subject: [PATCH 2/2] [mlir][VectorOps] Support string literals in
 `vector.print`

Printing strings within integration tests is currently quite annoyingly
verbose, and can't be tucked into shared helpers as the types depend on
the length of the string:

```
llvm.mlir.global internal constant @hello_world("Hello, World!\0")

func.func @entry() {
  %0 = llvm.mlir.addressof @hello_world : !llvm.ptr<array<14 x i8>>
  %1 = llvm.mlir.constant(0 : index) : i64
  %2 = llvm.getelementptr %0[%1, %1]
    : (!llvm.ptr<array<14 x i8>>, i64, i64) -> !llvm.ptr<i8>
  llvm.call @printCString(%2) : (!llvm.ptr<i8>) -> ()
  return
}
``

So this patch adds a simple extension to `vector.print` to simplify
this:
```
func.func @entry() {
   // Print a vector of characters ;)
   vector.print str "Hello, World!"
   return
}
```

Most of the logic for this is now shared with `cf.assert` which already
does something similar.
---
 .../Conversion/LLVMCommon/PrintCallHelper.h   | 36 ++++++++++
 .../mlir/Dialect/Vector/IR/VectorOps.td       | 37 +++++++++--
 .../ControlFlowToLLVM/ControlFlowToLLVM.cpp   | 49 +-------------
 mlir/lib/Conversion/LLVMCommon/CMakeLists.txt |  1 +
 .../Conversion/LLVMCommon/PrintCallHelper.cpp | 66 +++++++++++++++++++
 .../VectorToLLVM/ConvertVectorToLLVM.cpp      |  6 +-
 .../VectorToLLVM/vector-to-llvm.mlir          | 14 ++++
 .../Dialect/Vector/CPU/test-hello-world.mlir  | 10 +++
 8 files changed, 168 insertions(+), 51 deletions(-)
 create mode 100644 mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h
 create mode 100644 mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp
 create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/test-hello-world.mlir

diff --git a/mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h b/mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h
new file mode 100644
index 000000000000000..7e26858589f2756
--- /dev/null
+++ b/mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h
@@ -0,0 +1,36 @@
+
+//===- PrintCallHelper.h - LLVM Interfaces ----------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_LLVMIR_PRINTCALLHELPER_H_
+#define MLIR_DIALECT_LLVMIR_PRINTCALLHELPER_H_
+
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "llvm/ADT/StringRef.h"
+
+namespace mlir {
+
+class Location;
+class ModuleOp;
+class OpBuilder;
+class Operation;
+class Type;
+class ValueRange;
+class LLVMTypeConverter;
+
+namespace LLVM {
+
+/// Generate IR that prints the given string to stdout.
+void createPrintStrCall(OpBuilder &builder, Location loc, ModuleOp moduleOp,
+                        StringRef symbolName, StringRef string,
+                        const LLVMTypeConverter &typeConverter);
+} // namespace LLVM
+
+} // namespace mlir
+
+#endif
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 2df2fe4c5ce8e9c..2b60055ca9db94b 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -26,6 +26,7 @@ include "mlir/Interfaces/InferTypeOpInterface.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/Interfaces/VectorInterfaces.td"
 include "mlir/Interfaces/ViewLikeInterface.td"
+include "mlir/IR/BuiltinAttributes.td"
 
 // TODO: Add an attribute to specify a different algebra with operators other
 // than the current set: {*, +}.
@@ -2476,12 +2477,18 @@ def Vector_TransposeOp :
 }
 
 def Vector_PrintOp :
-  Vector_Op<"print", []>,
+  Vector_Op<"print", [
+    PredOpTrait<
+      "`source` or `punctuation` are not set printing strings",
+      CPred<"!getStringLiteral() || (!getSource() && getPunctuation() == PrintPunctuation::NewLine)">
+    >,
+  ]>,
   Arguments<(ins Optional<Type<Or<[
     AnyVectorOfAnyRank.predicate,
     AnyInteger.predicate, Index.predicate, AnyFloat.predicate
   ]>>>:$source, DefaultValuedAttr<Vector_PrintPunctuation,
-                      "::mlir::vector::PrintPunctuation::NewLine">:$punctuation)
+                      "::mlir::vector::PrintPunctuation::NewLine">:$punctuation,
+                OptionalAttr<Builtin_StringAttr>:$stringLiteral)
   > {
   let summary = "print operation (for testing and debugging)";
   let description = [{
@@ -2520,6 +2527,13 @@ def Vector_PrintOp :
     ```mlir
     vector.print punctuation <newline>
     ```
+
+    Additionally, to aid with debugging and testing `vector.print` can also
+    print constant strings:
+
+    ```mlir
+    vector.print str "Hello, World!"
+    ```
   }];
   let extraClassDeclaration = [{
     Type getPrintType() {
@@ -2528,11 +2542,26 @@ def Vector_PrintOp :
   }];
   let builders = [
     OpBuilder<(ins "PrintPunctuation":$punctuation), [{
-      build($_builder, $_state, {}, punctuation);
+      build($_builder, $_state, {}, punctuation, {});
+    }]>,
+    OpBuilder<(ins "::mlir::Value":$source), [{
+      build($_builder, $_state, source, PrintPunctuation::NewLine);
+    }]>,
+    OpBuilder<(ins "::mlir::Value":$source, "PrintPunctuation":$punctuation), [{
+      build($_builder, $_state, source, punctuation, {});
+    }]>,
+    OpBuilder<(ins "::llvm::StringRef":$string), [{
+      build($_builder, $_state, {}, PrintPunctuation::NewLine, $_builder.getStringAttr(string));
     }]>,
   ];
 
-  let assemblyFormat = "($source^ `:` type($source))? (`punctuation` $punctuation^)? attr-dict";
+  let assemblyFormat = [{
+      ($source^ `:` type($source))?
+        oilist(
+            `str` $stringLiteral
+          | `punctuation` $punctuation)
+        attr-dict
+    }];
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
index a4f146bbe475cc6..6b7647b038f1d94 100644
--- a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
+++ b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
@@ -16,6 +16,7 @@
 #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
 #include "mlir/Conversion/LLVMCommon/Pattern.h"
+#include "mlir/Conversion/LLVMCommon/PrintCallHelper.h"
 #include "mlir/Conversion/LLVMCommon/VectorPattern.h"
 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
@@ -36,51 +37,6 @@ using namespace mlir;
 
 #define PASS_NAME "convert-cf-to-llvm"
 
-static std::string generateGlobalMsgSymbolName(ModuleOp moduleOp) {
-  std::string prefix = "assert_msg_";
-  int counter = 0;
-  while (moduleOp.lookupSymbol(prefix + std::to_string(counter)))
-    ++counter;
-  return prefix + std::to_string(counter);
-}
-
-/// Generate IR that prints the given string to stderr.
-static void createPrintMsg(OpBuilder &builder, Location loc, ModuleOp moduleOp,
-                           StringRef msg,
-                           const LLVMTypeConverter &typeConverter) {
-  auto ip = builder.saveInsertionPoint();
-  builder.setInsertionPointToStart(moduleOp.getBody());
-  MLIRContext *ctx = builder.getContext();
-
-  // Create a zero-terminated byte representation and allocate global symbol.
-  SmallVector<uint8_t> elementVals;
-  elementVals.append(msg.begin(), msg.end());
-  elementVals.push_back(0);
-  auto dataAttrType = RankedTensorType::get(
-      {static_cast<int64_t>(elementVals.size())}, builder.getI8Type());
-  auto dataAttr =
-      DenseElementsAttr::get(dataAttrType, llvm::ArrayRef(elementVals));
-  auto arrayTy =
-      LLVM::LLVMArrayType::get(IntegerType::get(ctx, 8), elementVals.size());
-  std::string symbolName = generateGlobalMsgSymbolName(moduleOp);
-  auto globalOp = builder.create<LLVM::GlobalOp>(
-      loc, arrayTy, /*constant=*/true, LLVM::Linkage::Private, symbolName,
-      dataAttr);
-
-  // Emit call to `printStr` in runtime library.
-  builder.restoreInsertionPoint(ip);
-  auto msgAddr = builder.create<LLVM::AddressOfOp>(
-      loc, typeConverter.getPointerType(arrayTy), globalOp.getName());
-  SmallVector<LLVM::GEPArg> indices(1, 0);
-  Value gep = builder.create<LLVM::GEPOp>(
-      loc, typeConverter.getPointerType(builder.getI8Type()), arrayTy, msgAddr,
-      indices);
-  Operation *printer = LLVM::lookupOrCreatePrintStrFn(
-      moduleOp, typeConverter.useOpaquePointers());
-  builder.create<LLVM::CallOp>(loc, TypeRange(), SymbolRefAttr::get(printer),
-                               gep);
-}
-
 namespace {
 /// Lower `cf.assert`. The default lowering calls the `abort` function if the
 /// assertion is violated and has no effect otherwise. The failure message is
@@ -105,7 +61,8 @@ struct AssertOpLowering : public ConvertOpToLLVMPattern<cf::AssertOp> {
 
     // Failed block: Generate IR to print the message and call `abort`.
     Block *failureBlock = rewriter.createBlock(opBlock->getParent());
-    createPrintMsg(rewriter, loc, module, op.getMsg(), *getTypeConverter());
+    LLVM::createPrintStrCall(rewriter, loc, module, "assert_msg", op.getMsg(),
+                             *getTypeConverter());
     if (abortOnFailedAssert) {
       // Insert the `abort` declaration if necessary.
       auto abortFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("abort");
diff --git a/mlir/lib/Conversion/LLVMCommon/CMakeLists.txt b/mlir/lib/Conversion/LLVMCommon/CMakeLists.txt
index 091cd539f0ae014..568d9339aaabcb4 100644
--- a/mlir/lib/Conversion/LLVMCommon/CMakeLists.txt
+++ b/mlir/lib/Conversion/LLVMCommon/CMakeLists.txt
@@ -3,6 +3,7 @@ add_mlir_conversion_library(MLIRLLVMCommonConversion
   LoweringOptions.cpp
   MemRefBuilder.cpp
   Pattern.cpp
+  PrintCallHelper.cpp
   StructBuilder.cpp
   TypeConverter.cpp
   VectorPattern.cpp
diff --git a/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp b/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp
new file mode 100644
index 000000000000000..487abb435d10ad7
--- /dev/null
+++ b/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp
@@ -0,0 +1,66 @@
+
+//===- PrintCallHelper.cpp - LLVM Interfaces --------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/LLVMCommon/PrintCallHelper.h"
+#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
+#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
+#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "llvm/ADT/ArrayRef.h"
+
+using namespace mlir;
+using namespace llvm;
+
+static std::string ensureSymbolNameIsUnique(ModuleOp moduleOp,
+                                            StringRef symbolName) {
+  static int counter = 0;
+  std::string uniqueName = std::string(symbolName);
+  while (moduleOp.lookupSymbol(uniqueName)) {
+    uniqueName = std::string(symbolName) + "_" + std::to_string(counter++);
+  }
+  return uniqueName;
+}
+
+void mlir::LLVM::createPrintStrCall(OpBuilder &builder, Location loc,
+                                    ModuleOp moduleOp, StringRef symbolName,
+                                    StringRef string,
+                                    const LLVMTypeConverter &typeConverter) {
+  auto ip = builder.saveInsertionPoint();
+  builder.setInsertionPointToStart(moduleOp.getBody());
+  MLIRContext *ctx = builder.getContext();
+
+  // Create a zero-terminated byte representation and allocate global symbol.
+  SmallVector<uint8_t> elementVals;
+  elementVals.append(string.begin(), string.end());
+  elementVals.push_back(0);
+  auto dataAttrType = RankedTensorType::get(
+      {static_cast<int64_t>(elementVals.size())}, builder.getI8Type());
+  auto dataAttr =
+      DenseElementsAttr::get(dataAttrType, llvm::ArrayRef(elementVals));
+  auto arrayTy =
+      LLVM::LLVMArrayType::get(IntegerType::get(ctx, 8), elementVals.size());
+  auto globalOp = builder.create<LLVM::GlobalOp>(
+      loc, arrayTy, /*constant=*/true, LLVM::Linkage::Private,
+      ensureSymbolNameIsUnique(moduleOp, symbolName), dataAttr);
+
+  // Emit call to `printStr` in runtime library.
+  builder.restoreInsertionPoint(ip);
+  auto msgAddr = builder.create<LLVM::AddressOfOp>(
+      loc, typeConverter.getPointerType(arrayTy), globalOp.getName());
+  SmallVector<LLVM::GEPArg> indices(1, 0);
+  Value gep = builder.create<LLVM::GEPOp>(
+      loc, typeConverter.getPointerType(builder.getI8Type()), arrayTy, msgAddr,
+      indices);
+  Operation *printer = LLVM::lookupOrCreatePrintStrFn(
+      moduleOp, typeConverter.useOpaquePointers());
+  builder.create<LLVM::CallOp>(loc, TypeRange(), SymbolRefAttr::get(printer),
+                               gep);
+}
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 8427d60f14c0bcc..4af58653c8227ae 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -9,6 +9,7 @@
 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
 
 #include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h"
+#include "mlir/Conversion/LLVMCommon/PrintCallHelper.h"
 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
 #include "mlir/Conversion/LLVMCommon/VectorPattern.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
@@ -1548,7 +1549,10 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
     }
 
     auto punct = printOp.getPunctuation();
-    if (punct != PrintPunctuation::NoPunctuation) {
+    if (auto stringLiteral = printOp.getStringLiteral()) {
+      LLVM::createPrintStrCall(rewriter, loc, parent, "vector_print_str",
+                               *stringLiteral, *getTypeConverter());
+    } else if (punct != PrintPunctuation::NoPunctuation) {
       emitCall(rewriter, printOp->getLoc(), [&] {
         switch (punct) {
         case PrintPunctuation::Close:
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 9aa4d735681f576..65b3a78e295f0c4 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1068,6 +1068,20 @@ func.func @vector_print_scalar_f64(%arg0: f64) {
 
 // -----
 
+// CHECK-LABEL: module {
+// CHECK: llvm.func @puts(!llvm.ptr)
+// CHECK: llvm.mlir.global private constant @[[GLOBAL_STR:.*]](dense<[72, 101, 108, 108, 111, 44, 32, 87, 111, 114, 108, 100, 33, 0]> : tensor<14xi8>) {addr_space = 0 : i32} : !llvm.array<14 x i8>
+// CHECK: @vector_print_string
+//       CHECK-NEXT: %[[GLOBAL_ADDR:.*]] = llvm.mlir.addressof @[[GLOBAL_STR]] : !llvm.ptr
+//       CHECK-NEXT: %[[STR_PTR:.*]] = llvm.getelementptr %[[GLOBAL_ADDR]][0] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<14 x i8>
+//       CHECK-NEXT: llvm.call @puts(%[[STR_PTR]]) : (!llvm.ptr) -> ()
+func.func @vector_print_string() {
+  vector.print str "Hello, World!"
+  return
+}
+
+// -----
+
 func.func @extract_strided_slice1(%arg0: vector<4xf32>) -> vector<2xf32> {
   %0 = vector.extract_strided_slice %arg0 {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
   return %0 : vector<2xf32>
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-hello-world.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-hello-world.mlir
new file mode 100644
index 000000000000000..c4076e65151ac72
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/test-hello-world.mlir
@@ -0,0 +1,10 @@
+// RUN: mlir-opt %s -test-lower-to-llvm | \
+// RUN: mlir-cpu-runner -e entry -entry-point-result=void  \
+// RUN:   -shared-libs=%mlir_c_runner_utils | \
+// RUN: FileCheck %s
+
+func.func @entry() {
+   // CHECK: Hello, World!
+   vector.print str "Hello, World!"
+   return
+}



More information about the flang-commits mailing list