[Mlir-commits] [mlir] 02834e1 - [mlir][ODS] Get rid of limitations in rewriters generator

Vladislav Vinogradov llvmlistbot at llvm.org
Thu Mar 18 02:23:14 PDT 2021


Author: Vladislav Vinogradov
Date: 2021-03-18T12:21:06+03:00
New Revision: 02834e1bd94602bb3d1c603fd9fb874eb0e75290

URL: https://github.com/llvm/llvm-project/commit/02834e1bd94602bb3d1c603fd9fb874eb0e75290
DIFF: https://github.com/llvm/llvm-project/commit/02834e1bd94602bb3d1c603fd9fb874eb0e75290.diff

LOG: [mlir][ODS] Get rid of limitations in rewriters generator

Do not limit the number of arguments in rewriter pattern.

Introduce separate `FmtStrVecObject` class to handle
format of variadic `std::string` array.

Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D97839

Added: 
    

Modified: 
    mlir/include/mlir/TableGen/Format.h
    mlir/lib/TableGen/Format.cpp
    mlir/test/mlir-tblgen/rewriter-indexing.td
    mlir/tools/mlir-tblgen/RewriterGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/TableGen/Format.h b/mlir/include/mlir/TableGen/Format.h
index 18a7a6f985b8..441e05c29f26 100644
--- a/mlir/include/mlir/TableGen/Format.h
+++ b/mlir/include/mlir/TableGen/Format.h
@@ -186,6 +186,20 @@ template <typename Tuple> class FmtObject : public FmtObjectBase {
   }
 };
 
+class FmtStrVecObject : public FmtObjectBase {
+public:
+  using StrFormatAdapter =
+      decltype(llvm::detail::build_format_adapter(std::declval<std::string>()));
+
+  FmtStrVecObject(StringRef fmt, const FmtContext *ctx,
+                  ArrayRef<std::string> params);
+  FmtStrVecObject(FmtStrVecObject const &that) = delete;
+  FmtStrVecObject(FmtStrVecObject &&that);
+
+private:
+  SmallVector<StrFormatAdapter, 16> parameters;
+};
+
 /// Formats text by substituting placeholders in format string with replacement
 /// parameters.
 ///
@@ -234,6 +248,11 @@ inline auto tgfmt(StringRef fmt, const FmtContext *ctx, Ts &&... vals)
           llvm::detail::build_format_adapter(std::forward<Ts>(vals))...));
 }
 
+inline FmtStrVecObject tgfmt(StringRef fmt, const FmtContext *ctx,
+                             ArrayRef<std::string> params) {
+  return FmtStrVecObject(fmt, ctx, params);
+}
+
 } // end namespace tblgen
 } // end namespace mlir
 

diff  --git a/mlir/lib/TableGen/Format.cpp b/mlir/lib/TableGen/Format.cpp
index 7d17a0aef3f9..10834510b767 100644
--- a/mlir/lib/TableGen/Format.cpp
+++ b/mlir/lib/TableGen/Format.cpp
@@ -173,3 +173,22 @@ void FmtObjectBase::format(raw_ostream &s) const {
     adapters[repl.index]->format(s, /*Options=*/"");
   }
 }
+
+FmtStrVecObject::FmtStrVecObject(StringRef fmt, const FmtContext *ctx,
+                                 ArrayRef<std::string> params)
+    : FmtObjectBase(fmt, ctx, params.size()) {
+  parameters.reserve(params.size());
+  for (std::string p : params)
+    parameters.push_back(llvm::detail::build_format_adapter(std::move(p)));
+
+  adapters.reserve(parameters.size());
+  for (auto &p : parameters)
+    adapters.push_back(&p);
+}
+
+FmtStrVecObject::FmtStrVecObject(FmtStrVecObject &&that)
+    : FmtObjectBase(std::move(that)), parameters(std::move(that.parameters)) {
+  adapters.reserve(parameters.size());
+  for (auto &p : parameters)
+    adapters.push_back(&p);
+}

diff  --git a/mlir/test/mlir-tblgen/rewriter-indexing.td b/mlir/test/mlir-tblgen/rewriter-indexing.td
index a6b403285765..cbdeff9c743d 100644
--- a/mlir/test/mlir-tblgen/rewriter-indexing.td
+++ b/mlir/test/mlir-tblgen/rewriter-indexing.td
@@ -58,3 +58,30 @@ def test2 : Pat<(COp $attr1, $op1, $attr2, (AOp $op2)),
 def test3 : Pat<(BOp $attr, (AOp:$a $input)),
                 (BOp $attr, (AOp $input), (location $a))>;
 
+def DOp : NS_Op<"d_op", []> {
+  let arguments = (ins
+    AnyInteger:$v1,
+    AnyInteger:$v2,
+    AnyInteger:$v3,
+    AnyInteger:$v4,
+    AnyInteger:$v5,
+    AnyInteger:$v6,
+    AnyInteger:$v7,
+    AnyInteger:$v8,
+    AnyInteger:$v9,
+    AnyInteger:$v10
+  );
+
+  let results = (outs AnyInteger);
+}
+
+def NativeBuilder :
+  NativeCodeCall<[{
+    nativeCall($_builder, $_loc, $0, $1, $2, $3, $4, $5, $6, $7, $8, $9)
+  }]>;
+
+// Check Pattern with large number of DAG arguments passed to NativeCodeCall
+// CHECK: struct test4 : public ::mlir::RewritePattern {
+// CHECK: nativeCall(rewriter, odsLoc, (*v1.begin()), (*v2.begin()), (*v3.begin()), (*v4.begin()), (*v5.begin()), (*v6.begin()), (*v7.begin()), (*v8.begin()), (*v9.begin()), (*v10.begin()))
+def test4 : Pat<(DOp $v1, $v2, $v3, $v4, $v5, $v6, $v7, $v8, $v9, $v10),
+                (NativeBuilder $v1, $v2, $v3, $v4, $v5, $v6, $v7, $v8, $v9, $v10)>;

diff  --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp
index 5781870e0df7..7ee05f2114a6 100644
--- a/mlir/tools/mlir-tblgen/RewriterGen.cpp
+++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp
@@ -251,12 +251,8 @@ void PatternEmitter::emitNativeCodeMatch(DagNode tree, StringRef opName,
 
   // TODO(suderman): iterate through arguments, determine their types, output
   // names.
-  SmallVector<std::string, 8> capture(8);
-  if (tree.getNumArgs() > 8) {
-    PrintFatalError(loc,
-                    "unsupported NativeCodeCall matcher argument numbers: " +
-                        Twine(tree.getNumArgs()));
-  }
+  SmallVector<std::string, 8> capture;
+  capture.push_back(opName.str());
 
   raw_indented_ostream::DelimitedScope scope(os);
 
@@ -274,7 +270,7 @@ void PatternEmitter::emitNativeCodeMatch(DagNode tree, StringRef opName,
       }
     }
 
-    capture[i] = std::move(argName);
+    capture.push_back(std::move(argName));
   }
 
   bool hasLocationDirective;
@@ -282,21 +278,20 @@ void PatternEmitter::emitNativeCodeMatch(DagNode tree, StringRef opName,
   std::tie(hasLocationDirective, locToUse) = getLocation(tree);
 
   auto fmt = tree.getNativeCodeTemplate();
-  auto nativeCodeCall = std::string(tgfmt(
-      fmt, &fmtCtx.addSubst("_loc", locToUse), opName, capture[0], capture[1],
-      capture[2], capture[3], capture[4], capture[5], capture[6], capture[7]));
+  auto nativeCodeCall =
+      std::string(tgfmt(fmt, &fmtCtx.addSubst("_loc", locToUse), capture));
 
   os << "if (failed(" << nativeCodeCall << ")) return ::mlir::failure();\n";
 
   for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
     auto name = tree.getArgName(i);
     if (!name.empty() && name != "_") {
-      os << formatv("{0} = {1};\n", name, capture[i]);
+      os << formatv("{0} = {1};\n", name, capture[i + 1]);
     }
   }
 
   for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
-    std::string argName = capture[i];
+    std::string argName = capture[i + 1];
 
     // Handle nested DAG construct first
     if (DagNode argTree = tree.getArgAsNestedDag(i)) {
@@ -915,29 +910,26 @@ std::string PatternEmitter::handleReplaceWithNativeCodeCall(DagNode tree,
   LLVM_DEBUG(llvm::dbgs() << '\n');
 
   auto fmt = tree.getNativeCodeTemplate();
-  // TODO: replace formatv arguments with the exact specified args.
-  SmallVector<std::string, 8> attrs(8);
-  if (tree.getNumArgs() > 8) {
-    PrintFatalError(loc,
-                    "unsupported NativeCodeCall replace argument numbers: " +
-                        Twine(tree.getNumArgs()));
-  }
+
+  SmallVector<std::string, 16> attrs;
+
   bool hasLocationDirective;
   std::string locToUse;
   std::tie(hasLocationDirective, locToUse) = getLocation(tree);
 
   for (int i = 0, e = tree.getNumArgs() - hasLocationDirective; i != e; ++i) {
     if (tree.isNestedDagArg(i)) {
-      attrs[i] = handleResultPattern(tree.getArgAsNestedDag(i), i, depth + 1);
+      attrs.push_back(
+          handleResultPattern(tree.getArgAsNestedDag(i), i, depth + 1));
     } else {
-      attrs[i] = handleOpArgument(tree.getArgAsLeaf(i), tree.getArgName(i));
+      attrs.push_back(
+          handleOpArgument(tree.getArgAsLeaf(i), tree.getArgName(i)));
     }
     LLVM_DEBUG(llvm::dbgs() << "NativeCodeCall argument #" << i
                             << " replacement: " << attrs[i] << "\n");
   }
-  return std::string(tgfmt(fmt, &fmtCtx.addSubst("_loc", locToUse), attrs[0],
-                           attrs[1], attrs[2], attrs[3], attrs[4], attrs[5],
-                           attrs[6], attrs[7]));
+
+  return std::string(tgfmt(fmt, &fmtCtx.addSubst("_loc", locToUse), attrs));
 }
 
 int PatternEmitter::getNodeValueCount(DagNode node) {


        


More information about the Mlir-commits mailing list