[Mlir-commits] [mlir] d6b32e3 - [mlir][drr] Allow specifying string in location
Jacques Pienaar
llvmlistbot at llvm.org
Fri Apr 10 12:44:02 PDT 2020
Author: Jacques Pienaar
Date: 2020-04-10T12:43:22-07:00
New Revision: d6b32e39ae27e029f9e73a4d8b61755e369e428e
URL: https://github.com/llvm/llvm-project/commit/d6b32e39ae27e029f9e73a4d8b61755e369e428e
DIFF: https://github.com/llvm/llvm-project/commit/d6b32e39ae27e029f9e73a4d8b61755e369e428e.diff
LOG: [mlir][drr] Allow specifying string in location
Summary:
The string in the location is used to provide metadata for the fused location
or create a NamedLoc. This allows tagging individual locations to convey
additional rewrite information.
Differential Revision: https://reviews.llvm.org/D77840
Added:
Modified:
mlir/docs/DeclarativeRewrites.md
mlir/include/mlir/TableGen/Pattern.h
mlir/lib/TableGen/Pattern.cpp
mlir/test/lib/Dialect/Test/TestOps.td
mlir/test/mlir-tblgen/pattern.mlir
mlir/tools/mlir-tblgen/RewriterGen.cpp
Removed:
################################################################################
diff --git a/mlir/docs/DeclarativeRewrites.md b/mlir/docs/DeclarativeRewrites.md
index ca759f44d42c..2f2299f13c65 100644
--- a/mlir/docs/DeclarativeRewrites.md
+++ b/mlir/docs/DeclarativeRewrites.md
@@ -672,18 +672,24 @@ directive to provide finer control.
(location $symbol0, $symbol1, ...)
```
-where all `$symbol` should be bound previously in the pattern.
+where all `$symbol` should be bound previously in the pattern and one optional
+string may be specified as an attribute. The following locations are creted:
+
+* If only 1 symbol is specified then that symbol's location is used,
+* If multiple are specified then a fused location is created;
+* If no symbol is specified then string must be specified and a NamedLoc is
+ created instead;
`location` must be used as the last argument to an op creation. For example,
```tablegen
def : Pat<(LocSrc1Op:$src1 (LocSrc2Op:$src2 ...),
- (LocDst1Op (LocDst2Op ..., (location $src2)))>;
+ (LocDst1Op (LocDst2Op ..., (location $src2)), (location "outer"))>;
```
In the above pattern, the generated `LocDst2Op` will use the matched location
-of `LocSrc2Op` while the root `LocDst1Op` node will still se the fused location
-of all source Ops.
+of `LocSrc2Op` while the root `LocDst1Op` node will used the named location
+`outer`.
### `replaceWithValue`
diff --git a/mlir/include/mlir/TableGen/Pattern.h b/mlir/include/mlir/TableGen/Pattern.h
index e7fa48dc8829..94b9cde9332a 100644
--- a/mlir/include/mlir/TableGen/Pattern.h
+++ b/mlir/include/mlir/TableGen/Pattern.h
@@ -77,6 +77,9 @@ class DagLeaf {
// Returns true if this DAG leaf is specifying an enum attribute case.
bool isEnumAttrCase() const;
+ // Returns true if this DAG leaf is specifying a string attribute.
+ bool isStringAttr() const;
+
// Returns this DAG leaf as a constraint. Asserts if fails.
Constraint getAsConstraint() const;
@@ -95,6 +98,10 @@ class DagLeaf {
// Precondition: isNativeCodeCall()
StringRef getNativeCodeTemplate() const;
+ // Returns the string associated with the leaf.
+ // Precondition: isStringAttr()
+ std::string getStringAttr() const;
+
void print(raw_ostream &os) const;
private:
diff --git a/mlir/lib/TableGen/Pattern.cpp b/mlir/lib/TableGen/Pattern.cpp
index 5b547089efca..d832ea809247 100644
--- a/mlir/lib/TableGen/Pattern.cpp
+++ b/mlir/lib/TableGen/Pattern.cpp
@@ -56,6 +56,10 @@ bool tblgen::DagLeaf::isEnumAttrCase() const {
return isSubClassOf("EnumAttrCaseInfo");
}
+bool tblgen::DagLeaf::isStringAttr() const {
+ return isa<llvm::StringInit>(def) || isa<llvm::CodeInit>(def);
+}
+
tblgen::Constraint tblgen::DagLeaf::getAsConstraint() const {
assert((isOperandMatcher() || isAttrMatcher()) &&
"the DAG leaf must be operand or attribute");
@@ -81,6 +85,10 @@ llvm::StringRef tblgen::DagLeaf::getNativeCodeTemplate() const {
return cast<llvm::DefInit>(def)->getDef()->getValueAsString("expression");
}
+std::string tblgen::DagLeaf::getStringAttr() const {
+ assert(isStringAttr() && "the DAG leaf must be string attribute");
+ return def->getAsUnquotedString();
+}
bool tblgen::DagLeaf::isSubClassOf(StringRef superclass) const {
if (auto *defInit = dyn_cast_or_null<llvm::DefInit>(def))
return defInit->getDef()->isSubClassOf(superclass);
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 8eedd1ff6bb8..86ebfe4108a2 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -1018,8 +1018,9 @@ def : Pat<(TestLocationSrcOp:$res1
(TestLocationSrcOp:$res3 $input))),
(TestLocationDstOp
(TestLocationDstOp
- (TestLocationDstOp $input, (location $res1))),
- (location $res2, $res3))>;
+ (TestLocationDstOp $input, (location $res1)),
+ (location "named")),
+ (location "fused", $res2, $res3))>;
//===----------------------------------------------------------------------===//
// Test Legalization
diff --git a/mlir/test/mlir-tblgen/pattern.mlir b/mlir/test/mlir-tblgen/pattern.mlir
index d06cdba3ae42..50ec1688ddcc 100644
--- a/mlir/test/mlir-tblgen/pattern.mlir
+++ b/mlir/test/mlir-tblgen/pattern.mlir
@@ -17,11 +17,8 @@ func @verifyDesignatedLoc(%arg0 : i32) -> i32 {
%2 = "test.loc_src"(%1) : (i32) -> i32 loc("loc1")
// CHECK: "test.loc_dst"({{.*}}) : (i32) -> i32 loc("loc1")
- // CHECK: "test.loc_dst"({{.*}}) : (i32) -> i32 loc(fused[
- // CHECK-SAME: "loc1"
- // CHECK-SAME: "loc3"
- // CHECK-SAME: "loc2"
- // CHECK: "test.loc_dst"({{.*}}) : (i32) -> i32 loc(fused["loc2", "loc3"])
+ // CHECK: "test.loc_dst"({{.*}}) : (i32) -> i32 loc("named")
+ // CHECK: "test.loc_dst"({{.*}}) : (i32) -> i32 loc(fused<"fused">["loc2", "loc3"])
return %1 : i32
}
diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp
index 73a85259e9ed..c61a99e53b9d 100644
--- a/mlir/tools/mlir-tblgen/RewriterGen.cpp
+++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp
@@ -112,8 +112,8 @@ class PatternEmitter {
// Returns the symbol of the old value serving as the replacement.
StringRef handleReplaceWithValue(DagNode tree);
- // Returns the symbol of the value whose location to use.
- std::string handleUseLocationOf(DagNode tree);
+ // Returns the location value to use.
+ std::string handleLocationDirective(DagNode tree);
// Emits the C++ statement to build a new op out of the given DAG `tree` and
// returns the variable name that this op is assigned to. If the root op in
@@ -681,27 +681,53 @@ StringRef PatternEmitter::handleReplaceWithValue(DagNode tree) {
return tree.getArgName(0);
}
-std::string PatternEmitter::handleUseLocationOf(DagNode tree) {
+std::string PatternEmitter::handleLocationDirective(DagNode tree) {
assert(tree.isLocationDirective());
auto lookUpArgLoc = [this, &tree](int idx) {
const auto *const lookupFmt = "(*{0}.begin()).getLoc()";
return symbolInfoMap.getAllRangeUse(tree.getArgName(idx), lookupFmt);
};
- if (tree.getNumArgs() != 1) {
- std::string ret;
- llvm::raw_string_ostream os(ret);
- os << "rewriter.getFusedLoc({";
- for (int i = 0, e = tree.getNumArgs(); i != e; ++i)
- os << (i ? ", " : "") << lookUpArgLoc(i);
- os << "})";
- return os.str();
- }
+ if (tree.getNumArgs() == 0)
+ llvm::PrintFatalError(
+ "At least one argument to location directive required");
if (!tree.getSymbol().empty())
PrintFatalError(loc, "cannot bind symbol to location");
- return lookUpArgLoc(0);
+ if (tree.getNumArgs() == 1) {
+ DagLeaf leaf = tree.getArgAsLeaf(0);
+ if (leaf.isStringAttr())
+ return formatv("mlir::NameLoc::get(rewriter.getIdentifier(\"{0}\"), "
+ "rewriter.getContext())",
+ leaf.getStringAttr())
+ .str();
+ return lookUpArgLoc(0);
+ }
+
+ std::string ret;
+ llvm::raw_string_ostream os(ret);
+ std::string strAttr;
+ os << "rewriter.getFusedLoc({";
+ bool first = true;
+ for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
+ DagLeaf leaf = tree.getArgAsLeaf(i);
+ // Handle the optional string value.
+ if (leaf.isStringAttr()) {
+ if (!strAttr.empty())
+ llvm::PrintFatalError("Only one string attribute may be specified");
+ strAttr = leaf.getStringAttr();
+ continue;
+ }
+ os << (first ? "" : ", ") << lookUpArgLoc(i);
+ first = false;
+ }
+ os << "}";
+ if (!strAttr.empty()) {
+ os << ", rewriter.getStringAttr(\"" << strAttr << "\")";
+ }
+ os << ")";
+ return os.str();
}
std::string PatternEmitter::handleOpArgument(DagLeaf leaf,
@@ -789,7 +815,7 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
if (numPatArgs != 0) {
if (auto lastArg = tree.getArgAsNestedDag(numPatArgs - 1))
if (lastArg.isLocationDirective())
- locToUse = handleUseLocationOf(lastArg);
+ locToUse = handleLocationDirective(lastArg);
}
auto inPattern = numPatArgs - !locToUse.empty();
More information about the Mlir-commits
mailing list