[Mlir-commits] [mlir] [mlir][VectorOps] Support string literals in `vector.print` (PR #68695)

Jon Roelofs llvmlistbot at llvm.org
Mon May 13 09:11:00 PDT 2024


================
@@ -2477,12 +2478,18 @@ def Vector_TransposeOp :
 }
 
 def Vector_PrintOp :
-  Vector_Op<"print", []>,
+  Vector_Op<"print", [
+    PredOpTrait<
+      "`source` or `punctuation` are not set when printing strings",
+      CPred<"!getStringLiteral() || (!getSource() && getPunctuation() == PrintPunctuation::NewLine)">
----------------
jroelofs wrote:

I'll post a complete patch for this in a bit, but I was thinking something along the lines of:
```
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index b66b55ae8d57..fd5aec1982eb 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1531,8 +1531,28 @@ public:
 
     auto punct = printOp.getPunctuation();
     if (auto stringLiteral = printOp.getStringLiteral()) {
+      std::string str;
+      llvm::raw_string_ostream punctuatedLiteral(str);
+      if (punct == PrintPunctuation::Open)
+        punctuatedLiteral << "( ";
+      punctuatedLiteral << stringLiteral->str();
+      switch (punct) {
+      case PrintPunctuation::Close:
+        punctuatedLiteral << " )";
+        break;
+      case PrintPunctuation::Comma:
+        punctuatedLiteral << ", ";
+        break;
+      case PrintPunctuation::NewLine:
+        punctuatedLiteral << '\n';
+        break;
+      case PrintPunctuation::Open:
+      case PrintPunctuation::NoPunctuation:
+        break;
+      }
       LLVM::createPrintStrCall(rewriter, loc, parent, "vector_print_str",
-                               *stringLiteral, *getTypeConverter());
+                               punctuatedLiteral.str(), *getTypeConverter(),
+                               /*addNewLine=*/false);
     } else if (punct != PrintPunctuation::NoPunctuation) {
       emitCall(rewriter, printOp->getLoc(), [&] {
         switch (punct) {
```

so that:
```
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-print-str.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-print-str.mlir
index 78d6609ccaf9..b47c5b38f783 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/test-print-str.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/test-print-str.mlir
@@ -8,6 +8,16 @@
 func.func @entry() {
    // CHECK: Hello, World!
    vector.print str "Hello, World!"
+
+   // CHECK-NEXT: Nice to meet you ( finally, today, and in this place )
+   vector.print str "Nice to meet you " punctuation <no_punctuation>
+   vector.print str "finally" punctuation <open>
+   vector.print punctuation <comma>
+   vector.print str "today" punctuation <comma>
+   vector.print str "and in " punctuation <no_punctuation>
+   vector.print str "this place" punctuation <close>
+   vector.print punctuation <newline>
+
    // CHECK-NEXT: Bye!
    vector.print str "Bye!"
    return
```



https://github.com/llvm/llvm-project/pull/68695


More information about the Mlir-commits mailing list