[Mlir-commits] [mlir] 78530ce - Add indented raw_ostream class

Jacques Pienaar llvmlistbot at llvm.org
Sat Oct 3 08:53:57 PDT 2020


Author: Jacques Pienaar
Date: 2020-10-03T08:53:43-07:00
New Revision: 78530ce65375fa02bc96019e5cc9d73db8adaca4

URL: https://github.com/llvm/llvm-project/commit/78530ce65375fa02bc96019e5cc9d73db8adaca4
DIFF: https://github.com/llvm/llvm-project/commit/78530ce65375fa02bc96019e5cc9d73db8adaca4.diff

LOG: Add indented raw_ostream class

Class simplifies keeping track of the indentation while emitting. For every new line the current indentation is simply prefixed (if not at start of line, then it just emits as normal). Add a simple Region helper that makes it easy to have the C++ scope match the emitted scope.

Use this in op doc generator and rewrite generator.

Differential Revision: https://reviews.llvm.org/D84107

Added: 
    mlir/include/mlir/Support/IndentedOstream.h
    mlir/lib/Support/IndentedOstream.cpp
    mlir/unittests/Support/CMakeLists.txt
    mlir/unittests/Support/IndentedOstreamTest.cpp

Modified: 
    mlir/lib/Support/CMakeLists.txt
    mlir/tools/mlir-tblgen/CMakeLists.txt
    mlir/tools/mlir-tblgen/OpDocGen.cpp
    mlir/tools/mlir-tblgen/RewriterGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Support/IndentedOstream.h b/mlir/include/mlir/Support/IndentedOstream.h
new file mode 100644
index 000000000000..20161c1f3898
--- /dev/null
+++ b/mlir/include/mlir/Support/IndentedOstream.h
@@ -0,0 +1,102 @@
+//===- IndentedOstream.h - raw ostream wrapper to indent --------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// raw_ostream subclass that keeps track of indentation for textual output
+// where indentation helps readability.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_SUPPORT_INDENTEDOSTREAM_H_
+#define MLIR_SUPPORT_INDENTEDOSTREAM_H_
+
+#include "mlir/Support/LLVM.h"
+#include "llvm/Support/raw_ostream.h"
+
+namespace mlir {
+
+/// raw_ostream subclass that simplifies indention a sequence of code.
+class raw_indented_ostream : public raw_ostream {
+public:
+  explicit raw_indented_ostream(llvm::raw_ostream &os) : os(os) {
+    SetUnbuffered();
+  }
+
+  /// Simple RAII struct to use to indentation around entering/exiting region.
+  struct DelimitedScope {
+    explicit DelimitedScope(raw_indented_ostream &os, StringRef open = "",
+                            StringRef close = "")
+        : os(os), open(open), close(close) {
+      os << open;
+      os.indent();
+    }
+    ~DelimitedScope() {
+      os.unindent();
+      os << close;
+    }
+
+    raw_indented_ostream &os;
+
+  private:
+    llvm::StringRef open, close;
+  };
+
+  /// Returns DelimitedScope.
+  DelimitedScope scope(StringRef open = "", StringRef close = "") {
+    return DelimitedScope(*this, open, close);
+  }
+
+  /// Re-indents by removing the leading whitespace from the first non-empty
+  /// line from every line of the the string, skipping over empty lines at the
+  /// start.
+  raw_indented_ostream &reindent(StringRef str);
+
+  /// Increases the indent and returning this raw_indented_ostream.
+  raw_indented_ostream &indent() {
+    currentIndent += indentSize;
+    return *this;
+  }
+
+  /// Decreases the indent and returning this raw_indented_ostream.
+  raw_indented_ostream &unindent() {
+    currentIndent = std::max(0, currentIndent - indentSize);
+    return *this;
+  }
+
+  /// Emits whitespace and sets the indendation for the stream.
+  raw_indented_ostream &indent(int with) {
+    os.indent(with);
+    atStartOfLine = false;
+    currentIndent = with;
+    return *this;
+  }
+
+private:
+  void write_impl(const char *ptr, size_t size) override;
+
+  /// Return the current position within the stream, not counting the bytes
+  /// currently in the buffer.
+  uint64_t current_pos() const override { return os.tell(); }
+
+  /// Constant indent added/removed.
+  static constexpr int indentSize = 2;
+
+  // Tracker for current indentation.
+  int currentIndent = 0;
+
+  // The leading whitespace of the string being printed, if reindent is used.
+  int leadingWs = 0;
+
+  // Tracks whether at start of line and so indent is required or not.
+  bool atStartOfLine = true;
+
+  // The underlying raw_ostream.
+  raw_ostream &os;
+};
+
+} // namespace mlir
+#endif // MLIR_SUPPORT_INDENTEDOSTREAM_H_

diff  --git a/mlir/lib/Support/CMakeLists.txt b/mlir/lib/Support/CMakeLists.txt
index bdba99057172..16584e082109 100644
--- a/mlir/lib/Support/CMakeLists.txt
+++ b/mlir/lib/Support/CMakeLists.txt
@@ -1,5 +1,6 @@
 set(LLVM_OPTIONAL_SOURCES
   FileUtilities.cpp
+  IndentedOstream.cpp
   MlirOptMain.cpp
   StorageUniquer.cpp
   ToolUtilities.cpp
@@ -27,3 +28,10 @@ add_mlir_library(MLIROptLib
   MLIRParser
   MLIRSupport
   )
+
+# This doesn't use add_mlir_library as it is used in mlir-tblgen and else
+# mlir-tblgen ends up depending on mlir-generic-headers.
+add_llvm_library(MLIRSupportIdentedOstream
+  IndentedOstream.cpp
+
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Support)

diff  --git a/mlir/lib/Support/IndentedOstream.cpp b/mlir/lib/Support/IndentedOstream.cpp
new file mode 100644
index 000000000000..bb3feef6c445
--- /dev/null
+++ b/mlir/lib/Support/IndentedOstream.cpp
@@ -0,0 +1,65 @@
+//===- IndentedOstream.cpp - raw ostream wrapper to indent ----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// raw_ostream subclass that keeps track of indentation for textual output
+// where indentation helps readability.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Support/IndentedOstream.h"
+
+using namespace mlir;
+
+raw_indented_ostream &mlir::raw_indented_ostream::reindent(StringRef str) {
+  StringRef remaining = str;
+  // Find leading whitespace indent.
+  while (!remaining.empty()) {
+    auto split = remaining.split('\n');
+    size_t indent = split.first.find_first_not_of(" \t");
+    if (indent != StringRef::npos) {
+      leadingWs = indent;
+      break;
+    }
+    remaining = split.second;
+  }
+  // Print, skipping the empty lines.
+  *this << remaining;
+  leadingWs = 0;
+  return *this;
+}
+
+void mlir::raw_indented_ostream::write_impl(const char *ptr, size_t size) {
+  StringRef str(ptr, size);
+  // Print out indented.
+  auto print = [this](StringRef str) {
+    if (atStartOfLine)
+      os.indent(currentIndent) << str.substr(leadingWs);
+    else
+      os << str.substr(leadingWs);
+  };
+
+  while (!str.empty()) {
+    size_t idx = str.find('\n');
+    if (idx == StringRef::npos) {
+      if (!str.substr(leadingWs).empty()) {
+        print(str);
+        atStartOfLine = false;
+      }
+      break;
+    }
+
+    auto split =
+        std::make_pair(str.slice(0, idx), str.slice(idx + 1, StringRef::npos));
+    // Print empty new line without spaces if line only has spaces.
+    if (!split.first.ltrim().empty())
+      print(split.first);
+    os << '\n';
+    atStartOfLine = true;
+    str = split.second;
+  }
+}

diff  --git a/mlir/tools/mlir-tblgen/CMakeLists.txt b/mlir/tools/mlir-tblgen/CMakeLists.txt
index 46b9d81115c9..df004adb1bed 100644
--- a/mlir/tools/mlir-tblgen/CMakeLists.txt
+++ b/mlir/tools/mlir-tblgen/CMakeLists.txt
@@ -25,6 +25,7 @@ add_tablegen(mlir-tblgen MLIR
 set_target_properties(mlir-tblgen PROPERTIES FOLDER "Tablegenning")
 target_link_libraries(mlir-tblgen
   PRIVATE
+  MLIRSupportIdentedOstream
   MLIRTableGen)
 
 mlir_check_all_link_libraries(mlir-tblgen)

diff  --git a/mlir/tools/mlir-tblgen/OpDocGen.cpp b/mlir/tools/mlir-tblgen/OpDocGen.cpp
index df78556c1c77..ff6a29039763 100644
--- a/mlir/tools/mlir-tblgen/OpDocGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDocGen.cpp
@@ -12,6 +12,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "DocGenUtilities.h"
+#include "mlir/Support/IndentedOstream.h"
 #include "mlir/TableGen/GenInfo.h"
 #include "mlir/TableGen/Operator.h"
 #include "llvm/ADT/DenseMap.h"
@@ -35,39 +36,8 @@ using mlir::tblgen::Operator;
 // in a way the user wanted but has some additional indenting due to being
 // nested in the op definition.
 void mlir::tblgen::emitDescription(StringRef description, raw_ostream &os) {
-  // Determine the minimum number of spaces in a line.
-  size_t min_indent = -1;
-  StringRef remaining = description;
-  while (!remaining.empty()) {
-    auto split = remaining.split('\n');
-    size_t indent = split.first.find_first_not_of(" \t");
-    if (indent != StringRef::npos)
-      min_indent = std::min(indent, min_indent);
-    remaining = split.second;
-  }
-
-  // Print out the description indented.
-  os << "\n";
-  remaining = description;
-  bool printed = false;
-  while (!remaining.empty()) {
-    auto split = remaining.split('\n');
-    if (split.second.empty()) {
-      // Skip last line with just spaces.
-      if (split.first.ltrim().empty())
-        break;
-    }
-    // Print empty new line without spaces if line only has spaces, unless no
-    // text has been emitted before.
-    if (split.first.ltrim().empty()) {
-      if (printed)
-        os << "\n";
-    } else {
-      os << split.first.substr(min_indent) << "\n";
-      printed = true;
-    }
-    remaining = split.second;
-  }
+  raw_indented_ostream ros(os);
+  ros.reindent(description.rtrim(" \t"));
 }
 
 // Emits `str` with trailing newline if not empty.
@@ -116,7 +86,7 @@ static void emitOpDoc(Operator op, raw_ostream &os) {
 
   // Emit the summary, syntax, and description if present.
   if (op.hasSummary())
-    os << "\n" << op.getSummary() << "\n";
+    os << "\n" << op.getSummary() << "\n\n";
   if (op.hasAssemblyFormat())
     emitAssemblyFormat(op.getOperationName(), op.getAssemblyFormat().trim(),
                        os);
@@ -228,7 +198,7 @@ static void emitDialectDoc(const RecordKeeper &recordKeeper, raw_ostream &os) {
   }
 
   os << "<!-- Autogenerated by mlir-tblgen; don't manually edit -->\n";
-  for (auto dialectWithOps : dialectOps)
+  for (const auto &dialectWithOps : dialectOps)
     emitDialectDoc(dialectWithOps.first, dialectWithOps.second,
                    dialectTypes[dialectWithOps.first], os);
 }

diff  --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp
index 9b2f35f56624..e16900227759 100644
--- a/mlir/tools/mlir-tblgen/RewriterGen.cpp
+++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp
@@ -10,6 +10,7 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Support/IndentedOstream.h"
 #include "mlir/TableGen/Attribute.h"
 #include "mlir/TableGen/Format.h"
 #include "mlir/TableGen/GenInfo.h"
@@ -77,11 +78,11 @@ class PatternEmitter {
 
   // Emits C++ statements for matching the `argIndex`-th argument of the given
   // DAG `tree` as an operand.
-  void emitOperandMatch(DagNode tree, int argIndex, int depth, int indent);
+  void emitOperandMatch(DagNode tree, int argIndex, int depth);
 
   // Emits C++ statements for matching the `argIndex`-th argument of the given
   // DAG `tree` as an attribute.
-  void emitAttributeMatch(DagNode tree, int argIndex, int depth, int indent);
+  void emitAttributeMatch(DagNode tree, int argIndex, int depth);
 
   // Emits C++ for checking a match with a corresponding match failure
   // diagnostic.
@@ -184,7 +185,7 @@ class PatternEmitter {
   // The next unused ID for newly created values.
   unsigned nextValueId;
 
-  raw_ostream &os;
+  raw_indented_ostream os;
 
   // Format contexts containing placeholder substitutions.
   FmtContext fmtCtx;
@@ -225,8 +226,7 @@ void PatternEmitter::emitOpMatch(DagNode tree, int depth) {
   // Skip the operand matching at depth 0 as the pattern rewriter already does.
   if (depth != 0) {
     // Skip if there is no defining operation (e.g., arguments to function).
-    os.indent(indent) << formatv("if (!castedOp{0}) return failure();\n",
-                                 depth);
+    os << formatv("if (!castedOp{0})\n  return failure();\n", depth);
   }
   if (tree.getNumArgs() != op.getNumArgs()) {
     PrintFatalError(loc, formatv("op '{0}' argument number mismatch: {1} in "
@@ -238,7 +238,7 @@ void PatternEmitter::emitOpMatch(DagNode tree, int depth) {
   // If the operand's name is set, set to that variable.
   auto name = tree.getSymbol();
   if (!name.empty())
-    os.indent(indent) << formatv("{0} = castedOp{1};\n", name, depth);
+    os << formatv("{0} = castedOp{1};\n", name, depth);
 
   for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
     auto opArg = op.getArg(i);
@@ -253,24 +253,23 @@ void PatternEmitter::emitOpMatch(DagNode tree, int depth) {
           PrintFatalError(loc, error);
         }
       }
-      os.indent(indent) << "{\n";
+      os << "{\n";
 
-      os.indent(indent + 2) << formatv(
+      os.indent() << formatv(
           "auto *op{0} = "
           "(*castedOp{1}.getODSOperands({2}).begin()).getDefiningOp();\n",
           depth + 1, depth, i);
       emitOpMatch(argTree, depth + 1);
-      os.indent(indent + 2)
-          << formatv("tblgen_ops[{0}] = op{1};\n", ++opCounter, depth + 1);
-      os.indent(indent) << "}\n";
+      os << formatv("tblgen_ops[{0}] = op{1};\n", ++opCounter, depth + 1);
+      os.unindent() << "}\n";
       continue;
     }
 
     // Next handle DAG leaf: operand or attribute
     if (opArg.is<NamedTypeConstraint *>()) {
-      emitOperandMatch(tree, i, depth, indent);
+      emitOperandMatch(tree, i, depth);
     } else if (opArg.is<NamedAttribute *>()) {
-      emitAttributeMatch(tree, i, depth, indent);
+      emitAttributeMatch(tree, i, depth);
     } else {
       PrintFatalError(loc, "unhandled case when matching op");
     }
@@ -280,8 +279,7 @@ void PatternEmitter::emitOpMatch(DagNode tree, int depth) {
                           << '\n');
 }
 
-void PatternEmitter::emitOperandMatch(DagNode tree, int argIndex, int depth,
-                                      int indent) {
+void PatternEmitter::emitOperandMatch(DagNode tree, int argIndex, int depth) {
   Operator &op = tree.getDialectOp(opMap);
   auto *operand = op.getArg(argIndex).get<NamedTypeConstraint *>();
   auto matcher = tree.getArgAsLeaf(argIndex);
@@ -328,30 +326,28 @@ void PatternEmitter::emitOperandMatch(DagNode tree, int argIndex, int depth,
         op.arg_begin(), op.arg_begin() + argIndex,
         [](const Argument &arg) { return arg.is<NamedAttribute *>(); });
 
-    os.indent(indent) << formatv("{0} = castedOp{1}.getODSOperands({2});\n",
-                                 name, depth, argIndex - numPrevAttrs);
+    os << formatv("{0} = castedOp{1}.getODSOperands({2});\n", name, depth,
+                  argIndex - numPrevAttrs);
   }
 }
 
-void PatternEmitter::emitAttributeMatch(DagNode tree, int argIndex, int depth,
-                                        int indent) {
+void PatternEmitter::emitAttributeMatch(DagNode tree, int argIndex, int depth) {
   Operator &op = tree.getDialectOp(opMap);
   auto *namedAttr = op.getArg(argIndex).get<NamedAttribute *>();
   const auto &attr = namedAttr->attr;
 
-  os.indent(indent) << "{\n";
-  indent += 2;
-  os.indent(indent) << formatv(
-      "auto tblgen_attr = op{0}->getAttrOfType<{1}>(\"{2}\");"
+  os << "{\n";
+  os.indent() << formatv(
+      "auto tblgen_attr = op{0}->getAttrOfType<{1}>(\"{2}\"); "
       "(void)tblgen_attr;\n",
       depth, attr.getStorageType(), namedAttr->name);
 
   // TODO: This should use getter method to avoid duplication.
   if (attr.hasDefaultValue()) {
-    os.indent(indent) << "if (!tblgen_attr) tblgen_attr = "
-                      << std::string(tgfmt(attr.getConstBuilderTemplate(),
-                                           &fmtCtx, attr.getDefaultValue()))
-                      << ";\n";
+    os << "if (!tblgen_attr) tblgen_attr = "
+       << std::string(tgfmt(attr.getConstBuilderTemplate(), &fmtCtx,
+                            attr.getDefaultValue()))
+       << ";\n";
   } else if (attr.isOptional()) {
     // For a missing attribute that is optional according to definition, we
     // should just capture a mlir::Attribute() to signal the missing state.
@@ -387,27 +383,20 @@ void PatternEmitter::emitAttributeMatch(DagNode tree, int argIndex, int depth,
   auto name = tree.getArgName(argIndex);
   // `$_` is a special symbol to ignore op argument matching.
   if (!name.empty() && name != "_") {
-    os.indent(indent) << formatv("{0} = tblgen_attr;\n", name);
+    os << formatv("{0} = tblgen_attr;\n", name);
   }
 
-  indent -= 2;
-  os.indent(indent) << "}\n";
+  os.unindent() << "}\n";
 }
 
 void PatternEmitter::emitMatchCheck(
     int depth, const FmtObjectBase &matchFmt,
     const llvm::formatv_object_base &failureFmt) {
-  // {0} The match depth (used to get the operation that failed to match).
-  // {1} The format for the match string.
-  // {2} The format for the failure string.
-  const char *matchStr = R"(
-    if (!({1})) {
-      return rewriter.notifyMatchFailure(op{0}, [&](::mlir::Diagnostic &diag) {
-        diag << {2};
-      });
-    })";
-  os << llvm::formatv(matchStr, depth, matchFmt.str(), failureFmt.str())
-     << "\n";
+  os << "if (!(" << matchFmt.str() << "))";
+  os.scope("{\n", "\n}\n").os
+      << "return rewriter.notifyMatchFailure(op" << depth
+      << ", [&](::mlir::Diagnostic &diag) {\n  diag << " << failureFmt.str()
+      << ";\n});";
 }
 
 void PatternEmitter::emitMatchLogic(DagNode tree) {
@@ -491,7 +480,7 @@ void PatternEmitter::emit(StringRef rewriteName) {
 
   // Emit RewritePattern for Pattern.
   auto locs = pattern.getLocation();
-  os << formatv("/* Generated from:\n\t{0:$[ instantiating\n\t]}\n*/\n",
+  os << formatv("/* Generated from:\n    {0:$[ instantiating\n    ]}\n*/\n",
                 make_range(locs.rbegin(), locs.rend()));
   os << formatv(R"(struct {0} : public ::mlir::RewritePattern {
   {0}(::mlir::MLIRContext *context)
@@ -509,44 +498,48 @@ void PatternEmitter::emit(StringRef rewriteName) {
   os << formatv(R"(}, {0}, context) {{})", pattern.getBenefit()) << "\n";
 
   // Emit matchAndRewrite() function.
-  os << R"(
-  ::mlir::LogicalResult
-  matchAndRewrite(::mlir::Operation *op0,
-                  ::mlir::PatternRewriter &rewriter) const override {
-)";
-
-  // Register all symbols bound in the source pattern.
-  pattern.collectSourcePatternBoundSymbols(symbolInfoMap);
-
-  LLVM_DEBUG(
-      llvm::dbgs() << "start creating local variables for capturing matches\n");
-  os.indent(4) << "// Variables for capturing values and attributes used for "
-                  "creating ops\n";
-  // Create local variables for storing the arguments and results bound
-  // to symbols.
-  for (const auto &symbolInfoPair : symbolInfoMap) {
-    StringRef symbol = symbolInfoPair.getKey();
-    auto &info = symbolInfoPair.getValue();
-    os.indent(4) << info.getVarDecl(symbol);
+  {
+    auto classScope = os.scope();
+    os.reindent(R"(
+    ::mlir::LogicalResult matchAndRewrite(::mlir::Operation *op0,
+        ::mlir::PatternRewriter &rewriter) const override {)")
+        << '\n';
+    {
+      auto functionScope = os.scope();
+
+      // Register all symbols bound in the source pattern.
+      pattern.collectSourcePatternBoundSymbols(symbolInfoMap);
+
+      LLVM_DEBUG(llvm::dbgs()
+                 << "start creating local variables for capturing matches\n");
+      os << "// Variables for capturing values and attributes used while "
+            "creating ops\n";
+      // Create local variables for storing the arguments and results bound
+      // to symbols.
+      for (const auto &symbolInfoPair : symbolInfoMap) {
+        StringRef symbol = symbolInfoPair.getKey();
+        auto &info = symbolInfoPair.getValue();
+        os << info.getVarDecl(symbol);
+      }
+      // TODO: capture ops with consistent numbering so that it can be
+      // reused for fused loc.
+      os << formatv("::mlir::Operation *tblgen_ops[{0}];\n\n",
+                    pattern.getSourcePattern().getNumOps());
+      LLVM_DEBUG(llvm::dbgs()
+                 << "done creating local variables for capturing matches\n");
+
+      os << "// Match\n";
+      os << "tblgen_ops[0] = op0;\n";
+      emitMatchLogic(sourceTree);
+
+      os << "\n// Rewrite\n";
+      emitRewriteLogic();
+
+      os << "return success();\n";
+    }
+    os << "};\n";
   }
-  // TODO: capture ops with consistent numbering so that it can be
-  // reused for fused loc.
-  os.indent(4) << formatv("::mlir::Operation *tblgen_ops[{0}];\n\n",
-                          pattern.getSourcePattern().getNumOps());
-  LLVM_DEBUG(
-      llvm::dbgs() << "done creating local variables for capturing matches\n");
-
-  os.indent(4) << "// Match\n";
-  os.indent(4) << "tblgen_ops[0] = op0;\n";
-  emitMatchLogic(sourceTree);
-  os << "\n";
-
-  os.indent(4) << "// Rewrite\n";
-  emitRewriteLogic();
-
-  os.indent(4) << "return success();\n";
-  os << "  };\n";
-  os << "};\n";
+  os << "};\n\n";
 }
 
 void PatternEmitter::emitRewriteLogic() {
@@ -586,7 +579,7 @@ void PatternEmitter::emitRewriteLogic() {
     PrintFatalError(loc, error);
   }
 
-  os.indent(4) << "auto odsLoc = rewriter.getFusedLoc({";
+  os << "auto odsLoc = rewriter.getFusedLoc({";
   for (int i = 0, e = pattern.getSourcePattern().getNumOps(); i != e; ++i) {
     os << (i ? ", " : "") << "tblgen_ops[" << i << "]->getLoc()";
   }
@@ -601,22 +594,21 @@ void PatternEmitter::emitRewriteLogic() {
     // we are handling auxiliary patterns so we want the side effect even if
     // NativeCodeCall is not replacing matched root op's results.
     if (resultTree.isNativeCodeCall())
-      os.indent(4) << val << ";\n";
+      os << val << ";\n";
   }
 
   if (numExpectedResults == 0) {
     assert(replStartIndex >= numResultPatterns &&
            "invalid auxiliary vs. replacement pattern division!");
     // No result to replace. Just erase the op.
-    os.indent(4) << "rewriter.eraseOp(op0);\n";
+    os << "rewriter.eraseOp(op0);\n";
   } else {
     // Process replacement result patterns.
-    os.indent(4)
-        << "::llvm::SmallVector<::mlir::Value, 4> tblgen_repl_values;\n";
+    os << "::llvm::SmallVector<::mlir::Value, 4> tblgen_repl_values;\n";
     for (int i = replStartIndex; i < numResultPatterns; ++i) {
       DagNode resultTree = pattern.getResultPattern(i);
       auto val = handleResultPattern(resultTree, offsets[i], 0);
-      os.indent(4) << "\n";
+      os << "\n";
       // Resolve each symbol for all range use so that we can loop over them.
       // We need an explicit cast to `SmallVector` to capture the cases where
       // `{0}` resolves to an `Operation::result_range` as well as cases that
@@ -625,12 +617,11 @@ void PatternEmitter::emitRewriteLogic() {
       // TODO: Revisit the need for materializing a vector.
       os << symbolInfoMap.getAllRangeUse(
           val,
-          "    for (auto v : ::llvm::SmallVector<::mlir::Value, 4>{ {0} }) {{ "
-          "tblgen_repl_values.push_back(v); }",
+          "for (auto v: ::llvm::SmallVector<::mlir::Value, 4>{ {0} }) {{\n"
+          "  tblgen_repl_values.push_back(v);\n}\n",
           "\n");
     }
-    os.indent(4) << "\n";
-    os.indent(4) << "rewriter.replaceOp(op0, tblgen_repl_values);\n";
+    os << "\nrewriter.replaceOp(op0, tblgen_repl_values);\n";
   }
 
   LLVM_DEBUG(llvm::dbgs() << "--- done emitting rewrite logic ---\n");
@@ -879,9 +870,8 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
   }
 
   // Create the local variable for this op.
-  os.indent(4) << formatv("{0} {1};\n", resultOp.getQualCppClassName(),
-                          valuePackName);
-  os.indent(4) << "{\n";
+  os << formatv("{0} {1};\n{{\n", resultOp.getQualCppClassName(),
+                valuePackName);
 
   // Right now ODS don't have general type inference support. Except a few
   // special cases listed below, DRR needs to supply types for all results
@@ -900,10 +890,9 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
     createAggregateLocalVarsForOpArgs(tree, childNodeNames);
 
     // Then create the op.
-    os.indent(6) << formatv(
-        "{0} = rewriter.create<{1}>({2}, tblgen_values, tblgen_attrs);\n",
+    os.scope("", "\n}\n").os << formatv(
+        "{0} = rewriter.create<{1}>({2}, tblgen_values, tblgen_attrs);",
         valuePackName, resultOp.getQualCppClassName(), locToUse);
-    os.indent(4) << "}\n";
     return resultValue;
   }
 
@@ -920,11 +909,10 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
     // aggregate-parameter builders.
     createSeparateLocalVarsForOpArgs(tree, childNodeNames);
 
-    os.indent(6) << formatv("{0} = rewriter.create<{1}>({2}", valuePackName,
-                            resultOp.getQualCppClassName(), locToUse);
+    os.scope().os << formatv("{0} = rewriter.create<{1}>({2}", valuePackName,
+                             resultOp.getQualCppClassName(), locToUse);
     supplyValuesForOpArgs(tree, childNodeNames);
-    os << "\n      );\n";
-    os.indent(4) << "}\n";
+    os << "\n  );\n}\n";
     return resultValue;
   }
 
@@ -938,20 +926,19 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
 
   // Then prepare the result types. We need to specify the types for all
   // results.
-  os.indent(6) << formatv("::mlir::SmallVector<::mlir::Type, 4> tblgen_types; "
-                          "(void)tblgen_types;\n");
+  os.indent() << formatv("::mlir::SmallVector<::mlir::Type, 4> tblgen_types; "
+                         "(void)tblgen_types;\n");
   int numResults = resultOp.getNumResults();
   if (numResults != 0) {
     for (int i = 0; i < numResults; ++i)
-      os.indent(6) << formatv("for (auto v : castedOp0.getODSResults({0})) {{"
-                              "tblgen_types.push_back(v.getType()); }\n",
-                              resultIndex + i);
+      os << formatv("for (auto v: castedOp0.getODSResults({0})) {{\n"
+                    "  tblgen_types.push_back(v.getType());\n}\n",
+                    resultIndex + i);
   }
-  os.indent(6) << formatv("{0} = rewriter.create<{1}>({2}, tblgen_types, "
-                          "tblgen_values, tblgen_attrs);\n",
-                          valuePackName, resultOp.getQualCppClassName(),
-                          locToUse);
-  os.indent(4) << "}\n";
+  os << formatv("{0} = rewriter.create<{1}>({2}, tblgen_types, "
+                "tblgen_values, tblgen_attrs);\n",
+                valuePackName, resultOp.getQualCppClassName(), locToUse);
+  os.unindent() << "}\n";
   return resultValue;
 }
 
@@ -968,16 +955,15 @@ void PatternEmitter::createSeparateLocalVarsForOpArgs(
   for (int argIndex = 0, e = resultOp.getNumArgs(); argIndex < e; ++argIndex) {
     const auto *operand =
         resultOp.getArg(argIndex).dyn_cast<NamedTypeConstraint *>();
-    if (!operand) {
-      // We do not need special handling for attributes.
+    // We do not need special handling for attributes.
+    if (!operand)
       continue;
-    }
 
+    raw_indented_ostream::DelimitedScope scope(os);
     std::string varName;
     if (operand->isVariadic()) {
       varName = std::string(formatv("tblgen_values_{0}", valueIndex++));
-      os.indent(6) << formatv("::mlir::SmallVector<::mlir::Value, 4> {0};\n",
-                              varName);
+      os << formatv("::mlir::SmallVector<::mlir::Value, 4> {0};\n", varName);
       std::string range;
       if (node.isNestedDagArg(argIndex)) {
         range = childNodeNames[argIndex];
@@ -987,11 +973,11 @@ void PatternEmitter::createSeparateLocalVarsForOpArgs(
       // Resolve the symbol for all range use so that we have a uniform way of
       // capturing the values.
       range = symbolInfoMap.getValueAndRangeUse(range);
-      os.indent(6) << formatv("for (auto v : {0}) {1}.push_back(v);\n", range,
-                              varName);
+      os << formatv("for (auto v: {0}) {{\n  {1}.push_back(v);\n}\n", range,
+                    varName);
     } else {
       varName = std::string(formatv("tblgen_value_{0}", valueIndex++));
-      os.indent(6) << formatv("::mlir::Value {0} = ", varName);
+      os << formatv("::mlir::Value {0} = ", varName);
       if (node.isNestedDagArg(argIndex)) {
         os << symbolInfoMap.getValueAndRangeUse(childNodeNames[argIndex]);
       } else {
@@ -1019,7 +1005,7 @@ void PatternEmitter::supplyValuesForOpArgs(
   for (int argIndex = 0, numOpArgs = resultOp.getNumArgs();
        argIndex != numOpArgs; ++argIndex) {
     // Start each argument on its own line.
-    (os << ",\n").indent(8);
+    os << ",\n    ";
 
     Argument opArg = resultOp.getArg(argIndex);
     // Handle the case of operand first.
@@ -1060,14 +1046,16 @@ void PatternEmitter::createAggregateLocalVarsForOpArgs(
     DagNode node, const ChildNodeIndexNameMap &childNodeNames) {
   Operator &resultOp = node.getDialectOp(opMap);
 
-  os.indent(6) << formatv("::mlir::SmallVector<::mlir::Value, 4> "
-                          "tblgen_values; (void)tblgen_values;\n");
-  os.indent(6) << formatv("::mlir::SmallVector<::mlir::NamedAttribute, 4> "
-                          "tblgen_attrs; (void)tblgen_attrs;\n");
+  auto scope = os.scope();
+  os << formatv("::mlir::SmallVector<::mlir::Value, 4> "
+                "tblgen_values; (void)tblgen_values;\n");
+  os << formatv("::mlir::SmallVector<::mlir::NamedAttribute, 4> "
+                "tblgen_attrs; (void)tblgen_attrs;\n");
 
   const char *addAttrCmd =
-      "if (auto tmpAttr = {1}) "
-      "tblgen_attrs.emplace_back(rewriter.getIdentifier(\"{0}\"), tmpAttr);\n";
+      "if (auto tmpAttr = {1}) {\n"
+      "  tblgen_attrs.emplace_back(rewriter.getIdentifier(\"{0}\"), "
+      "tmpAttr);\n}\n";
   for (int argIndex = 0, e = resultOp.getNumArgs(); argIndex < e; ++argIndex) {
     if (resultOp.getArg(argIndex).is<NamedAttribute *>()) {
       // The argument in the op definition.
@@ -1076,14 +1064,14 @@ void PatternEmitter::createAggregateLocalVarsForOpArgs(
         if (!subTree.isNativeCodeCall())
           PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node "
                                "for creating attribute");
-        os.indent(6) << formatv(addAttrCmd, opArgName,
-                                handleReplaceWithNativeCodeCall(subTree));
+        os << formatv(addAttrCmd, opArgName,
+                      handleReplaceWithNativeCodeCall(subTree));
       } else {
         auto leaf = node.getArgAsLeaf(argIndex);
         // The argument in the result DAG pattern.
         auto patArgName = node.getArgName(argIndex);
-        os.indent(6) << formatv(addAttrCmd, opArgName,
-                                handleOpArgument(leaf, patArgName));
+        os << formatv(addAttrCmd, opArgName,
+                      handleOpArgument(leaf, patArgName));
       }
       continue;
     }
@@ -1101,10 +1089,10 @@ void PatternEmitter::createAggregateLocalVarsForOpArgs(
       // Resolve the symbol for all range use so that we have a uniform way of
       // capturing the values.
       range = symbolInfoMap.getValueAndRangeUse(range);
-      os.indent(6) << formatv(
-          "for (auto v : {0}) tblgen_values.push_back(v);\n", range);
+      os << formatv("for (auto v: {0}) {{\n  tblgen_values.push_back(v);\n}\n",
+                    range);
     } else {
-      os.indent(6) << formatv("tblgen_values.push_back(", varName);
+      os << formatv("tblgen_values.push_back(", varName);
       if (node.isNestedDagArg(argIndex)) {
         os << symbolInfoMap.getValueAndRangeUse(
             childNodeNames.lookup(argIndex));

diff  --git a/mlir/unittests/Support/CMakeLists.txt b/mlir/unittests/Support/CMakeLists.txt
new file mode 100644
index 000000000000..42a1c21261c4
--- /dev/null
+++ b/mlir/unittests/Support/CMakeLists.txt
@@ -0,0 +1,6 @@
+add_mlir_unittest(MLIRSupportTests
+  IndentedOstreamTest.cpp
+)
+
+target_link_libraries(MLIRSupportTests
+  PRIVATE MLIRSupportIdentedOstream MLIRSupport)

diff  --git a/mlir/unittests/Support/IndentedOstreamTest.cpp b/mlir/unittests/Support/IndentedOstreamTest.cpp
new file mode 100644
index 000000000000..0271eb73e889
--- /dev/null
+++ b/mlir/unittests/Support/IndentedOstreamTest.cpp
@@ -0,0 +1,110 @@
+//===- IndentedOstreamTest.cpp - Indented raw ostream Tests ---------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Support/IndentedOstream.h"
+#include "gmock/gmock.h"
+
+using namespace mlir;
+using ::testing::StrEq;
+
+TEST(FormatTest, SingleLine) {
+  std::string str;
+  llvm::raw_string_ostream os(str);
+  raw_indented_ostream ros(os);
+  ros << 10;
+  ros.flush();
+  EXPECT_THAT(os.str(), StrEq("10"));
+}
+
+TEST(FormatTest, SimpleMultiLine) {
+  std::string str;
+  llvm::raw_string_ostream os(str);
+  raw_indented_ostream ros(os);
+  ros << "a";
+  ros << "b";
+  ros << "\n";
+  ros << "c";
+  ros << "\n";
+  ros.flush();
+  EXPECT_THAT(os.str(), StrEq("ab\nc\n"));
+}
+
+TEST(FormatTest, SimpleMultiLineIndent) {
+  std::string str;
+  llvm::raw_string_ostream os(str);
+  raw_indented_ostream ros(os);
+  ros.indent(2) << "a";
+  ros.indent(4) << "b";
+  ros << "\n";
+  ros << "c";
+  ros << "\n";
+  ros.flush();
+  EXPECT_THAT(os.str(), StrEq("  a    b\n    c\n"));
+}
+
+TEST(FormatTest, SingleRegion) {
+  std::string str;
+  llvm::raw_string_ostream os(str);
+  raw_indented_ostream ros(os);
+  ros << "before\n";
+  {
+    raw_indented_ostream::DelimitedScope scope(ros);
+    ros << "inside " << 10;
+    ros << "\n   two\n";
+    {
+      raw_indented_ostream::DelimitedScope scope(ros, "{\n", "\n}\n");
+      ros << "inner inner";
+    }
+  }
+  ros << "after";
+  ros.flush();
+  const auto *expected =
+      R"(before
+  inside 10
+     two
+  {
+    inner inner
+  }
+after)";
+  EXPECT_THAT(os.str(), StrEq(expected));
+
+  // Repeat the above with inline form.
+  str.clear();
+  ros << "before\n";
+  ros.scope().os << "inside " << 10 << "\n   two\n";
+  ros.scope().os.scope("{\n", "\n}\n").os << "inner inner";
+  ros << "after";
+  ros.flush();
+  EXPECT_THAT(os.str(), StrEq(expected));
+}
+
+TEST(FormatTest, Reindent) {
+  std::string str;
+  llvm::raw_string_ostream os(str);
+  raw_indented_ostream ros(os);
+
+  // String to print with some additional empty lines at the start and lines
+  // with just spaces.
+  const auto *desc = R"(
+       
+       
+         First line
+                 second line
+                 
+                 
+  )";
+  ros.reindent(desc);
+  ros.flush();
+  const auto *expected =
+      R"(First line
+        second line
+
+
+)";
+  EXPECT_THAT(os.str(), StrEq(expected));
+}


        


More information about the Mlir-commits mailing list