[Mlir-commits] [mlir] 02834e1 - [mlir][ODS] Get rid of limitations in rewriters generator
Vladislav Vinogradov
llvmlistbot at llvm.org
Thu Mar 18 02:23:14 PDT 2021
Author: Vladislav Vinogradov
Date: 2021-03-18T12:21:06+03:00
New Revision: 02834e1bd94602bb3d1c603fd9fb874eb0e75290
URL: https://github.com/llvm/llvm-project/commit/02834e1bd94602bb3d1c603fd9fb874eb0e75290
DIFF: https://github.com/llvm/llvm-project/commit/02834e1bd94602bb3d1c603fd9fb874eb0e75290.diff
LOG: [mlir][ODS] Get rid of limitations in rewriters generator
Do not limit the number of arguments in rewriter pattern.
Introduce separate `FmtStrVecObject` class to handle
format of variadic `std::string` array.
Reviewed By: mehdi_amini
Differential Revision: https://reviews.llvm.org/D97839
Added:
Modified:
mlir/include/mlir/TableGen/Format.h
mlir/lib/TableGen/Format.cpp
mlir/test/mlir-tblgen/rewriter-indexing.td
mlir/tools/mlir-tblgen/RewriterGen.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/TableGen/Format.h b/mlir/include/mlir/TableGen/Format.h
index 18a7a6f985b8..441e05c29f26 100644
--- a/mlir/include/mlir/TableGen/Format.h
+++ b/mlir/include/mlir/TableGen/Format.h
@@ -186,6 +186,20 @@ template <typename Tuple> class FmtObject : public FmtObjectBase {
}
};
+class FmtStrVecObject : public FmtObjectBase {
+public:
+ using StrFormatAdapter =
+ decltype(llvm::detail::build_format_adapter(std::declval<std::string>()));
+
+ FmtStrVecObject(StringRef fmt, const FmtContext *ctx,
+ ArrayRef<std::string> params);
+ FmtStrVecObject(FmtStrVecObject const &that) = delete;
+ FmtStrVecObject(FmtStrVecObject &&that);
+
+private:
+ SmallVector<StrFormatAdapter, 16> parameters;
+};
+
/// Formats text by substituting placeholders in format string with replacement
/// parameters.
///
@@ -234,6 +248,11 @@ inline auto tgfmt(StringRef fmt, const FmtContext *ctx, Ts &&... vals)
llvm::detail::build_format_adapter(std::forward<Ts>(vals))...));
}
+inline FmtStrVecObject tgfmt(StringRef fmt, const FmtContext *ctx,
+ ArrayRef<std::string> params) {
+ return FmtStrVecObject(fmt, ctx, params);
+}
+
} // end namespace tblgen
} // end namespace mlir
diff --git a/mlir/lib/TableGen/Format.cpp b/mlir/lib/TableGen/Format.cpp
index 7d17a0aef3f9..10834510b767 100644
--- a/mlir/lib/TableGen/Format.cpp
+++ b/mlir/lib/TableGen/Format.cpp
@@ -173,3 +173,22 @@ void FmtObjectBase::format(raw_ostream &s) const {
adapters[repl.index]->format(s, /*Options=*/"");
}
}
+
+FmtStrVecObject::FmtStrVecObject(StringRef fmt, const FmtContext *ctx,
+ ArrayRef<std::string> params)
+ : FmtObjectBase(fmt, ctx, params.size()) {
+ parameters.reserve(params.size());
+ for (std::string p : params)
+ parameters.push_back(llvm::detail::build_format_adapter(std::move(p)));
+
+ adapters.reserve(parameters.size());
+ for (auto &p : parameters)
+ adapters.push_back(&p);
+}
+
+FmtStrVecObject::FmtStrVecObject(FmtStrVecObject &&that)
+ : FmtObjectBase(std::move(that)), parameters(std::move(that.parameters)) {
+ adapters.reserve(parameters.size());
+ for (auto &p : parameters)
+ adapters.push_back(&p);
+}
diff --git a/mlir/test/mlir-tblgen/rewriter-indexing.td b/mlir/test/mlir-tblgen/rewriter-indexing.td
index a6b403285765..cbdeff9c743d 100644
--- a/mlir/test/mlir-tblgen/rewriter-indexing.td
+++ b/mlir/test/mlir-tblgen/rewriter-indexing.td
@@ -58,3 +58,30 @@ def test2 : Pat<(COp $attr1, $op1, $attr2, (AOp $op2)),
def test3 : Pat<(BOp $attr, (AOp:$a $input)),
(BOp $attr, (AOp $input), (location $a))>;
+def DOp : NS_Op<"d_op", []> {
+ let arguments = (ins
+ AnyInteger:$v1,
+ AnyInteger:$v2,
+ AnyInteger:$v3,
+ AnyInteger:$v4,
+ AnyInteger:$v5,
+ AnyInteger:$v6,
+ AnyInteger:$v7,
+ AnyInteger:$v8,
+ AnyInteger:$v9,
+ AnyInteger:$v10
+ );
+
+ let results = (outs AnyInteger);
+}
+
+def NativeBuilder :
+ NativeCodeCall<[{
+ nativeCall($_builder, $_loc, $0, $1, $2, $3, $4, $5, $6, $7, $8, $9)
+ }]>;
+
+// Check Pattern with large number of DAG arguments passed to NativeCodeCall
+// CHECK: struct test4 : public ::mlir::RewritePattern {
+// CHECK: nativeCall(rewriter, odsLoc, (*v1.begin()), (*v2.begin()), (*v3.begin()), (*v4.begin()), (*v5.begin()), (*v6.begin()), (*v7.begin()), (*v8.begin()), (*v9.begin()), (*v10.begin()))
+def test4 : Pat<(DOp $v1, $v2, $v3, $v4, $v5, $v6, $v7, $v8, $v9, $v10),
+ (NativeBuilder $v1, $v2, $v3, $v4, $v5, $v6, $v7, $v8, $v9, $v10)>;
diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp
index 5781870e0df7..7ee05f2114a6 100644
--- a/mlir/tools/mlir-tblgen/RewriterGen.cpp
+++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp
@@ -251,12 +251,8 @@ void PatternEmitter::emitNativeCodeMatch(DagNode tree, StringRef opName,
// TODO(suderman): iterate through arguments, determine their types, output
// names.
- SmallVector<std::string, 8> capture(8);
- if (tree.getNumArgs() > 8) {
- PrintFatalError(loc,
- "unsupported NativeCodeCall matcher argument numbers: " +
- Twine(tree.getNumArgs()));
- }
+ SmallVector<std::string, 8> capture;
+ capture.push_back(opName.str());
raw_indented_ostream::DelimitedScope scope(os);
@@ -274,7 +270,7 @@ void PatternEmitter::emitNativeCodeMatch(DagNode tree, StringRef opName,
}
}
- capture[i] = std::move(argName);
+ capture.push_back(std::move(argName));
}
bool hasLocationDirective;
@@ -282,21 +278,20 @@ void PatternEmitter::emitNativeCodeMatch(DagNode tree, StringRef opName,
std::tie(hasLocationDirective, locToUse) = getLocation(tree);
auto fmt = tree.getNativeCodeTemplate();
- auto nativeCodeCall = std::string(tgfmt(
- fmt, &fmtCtx.addSubst("_loc", locToUse), opName, capture[0], capture[1],
- capture[2], capture[3], capture[4], capture[5], capture[6], capture[7]));
+ auto nativeCodeCall =
+ std::string(tgfmt(fmt, &fmtCtx.addSubst("_loc", locToUse), capture));
os << "if (failed(" << nativeCodeCall << ")) return ::mlir::failure();\n";
for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
auto name = tree.getArgName(i);
if (!name.empty() && name != "_") {
- os << formatv("{0} = {1};\n", name, capture[i]);
+ os << formatv("{0} = {1};\n", name, capture[i + 1]);
}
}
for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
- std::string argName = capture[i];
+ std::string argName = capture[i + 1];
// Handle nested DAG construct first
if (DagNode argTree = tree.getArgAsNestedDag(i)) {
@@ -915,29 +910,26 @@ std::string PatternEmitter::handleReplaceWithNativeCodeCall(DagNode tree,
LLVM_DEBUG(llvm::dbgs() << '\n');
auto fmt = tree.getNativeCodeTemplate();
- // TODO: replace formatv arguments with the exact specified args.
- SmallVector<std::string, 8> attrs(8);
- if (tree.getNumArgs() > 8) {
- PrintFatalError(loc,
- "unsupported NativeCodeCall replace argument numbers: " +
- Twine(tree.getNumArgs()));
- }
+
+ SmallVector<std::string, 16> attrs;
+
bool hasLocationDirective;
std::string locToUse;
std::tie(hasLocationDirective, locToUse) = getLocation(tree);
for (int i = 0, e = tree.getNumArgs() - hasLocationDirective; i != e; ++i) {
if (tree.isNestedDagArg(i)) {
- attrs[i] = handleResultPattern(tree.getArgAsNestedDag(i), i, depth + 1);
+ attrs.push_back(
+ handleResultPattern(tree.getArgAsNestedDag(i), i, depth + 1));
} else {
- attrs[i] = handleOpArgument(tree.getArgAsLeaf(i), tree.getArgName(i));
+ attrs.push_back(
+ handleOpArgument(tree.getArgAsLeaf(i), tree.getArgName(i)));
}
LLVM_DEBUG(llvm::dbgs() << "NativeCodeCall argument #" << i
<< " replacement: " << attrs[i] << "\n");
}
- return std::string(tgfmt(fmt, &fmtCtx.addSubst("_loc", locToUse), attrs[0],
- attrs[1], attrs[2], attrs[3], attrs[4], attrs[5],
- attrs[6], attrs[7]));
+
+ return std::string(tgfmt(fmt, &fmtCtx.addSubst("_loc", locToUse), attrs));
}
int PatternEmitter::getNodeValueCount(DagNode node) {
More information about the Mlir-commits
mailing list