[Mlir-commits] [mlir] 29429d1 - [drr] Add $_loc special directive for NativeCodeCall
Jacques Pienaar
llvmlistbot at llvm.org
Tue Aug 11 14:06:42 PDT 2020
Author: Jacques Pienaar
Date: 2020-08-11T14:06:17-07:00
New Revision: 29429d1a443a51d0e1ac4ef4033a2bcc95909ba3
URL: https://github.com/llvm/llvm-project/commit/29429d1a443a51d0e1ac4ef4033a2bcc95909ba3
DIFF: https://github.com/llvm/llvm-project/commit/29429d1a443a51d0e1ac4ef4033a2bcc95909ba3.diff
LOG: [drr] Add $_loc special directive for NativeCodeCall
Allows propagating the location to ops created via NativeCodeCall.
Differential Revision: https://reviews.llvm.org/D85704
Added:
Modified:
mlir/docs/DeclarativeRewrites.md
mlir/test/lib/Dialect/Test/TestOps.td
mlir/test/lib/Dialect/Test/TestPatterns.cpp
mlir/tools/mlir-tblgen/RewriterGen.cpp
Removed:
################################################################################
diff --git a/mlir/docs/DeclarativeRewrites.md b/mlir/docs/DeclarativeRewrites.md
index 4781f9f2648e8..7e4f675166ab7 100644
--- a/mlir/docs/DeclarativeRewrites.md
+++ b/mlir/docs/DeclarativeRewrites.md
@@ -384,10 +384,12 @@ In `NativeCodeCall`, we can use placeholders like `$_builder`, `$N`. The former
is called _special placeholder_, while the latter is called _positional
placeholder_.
-`NativeCodeCall` right now only supports two special placeholders: `$_builder`
-and `$_self`:
+`NativeCodeCall` right now only supports three special placeholders:
+`$_builder`, `$_loc`, and `$_self`:
* `$_builder` will be replaced by the current `mlir::PatternRewriter`.
+* `$_loc` will be replaced by the fused location or custom location (as
+ determined by location directive).
* `$_self` will be replaced with the entity `NativeCodeCall` is attached to.
We have seen how `$_builder` can be used in the above; it allows us to pass a
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index c1bc754da804e..9a94311776da2 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -724,7 +724,8 @@ def OpNativeCodeCall3 : TEST_Op<"native_code_call3"> {
// Test that NativeCodeCall is not ignored if it is not used to directly
// replace the matched root op.
def : Pattern<(OpNativeCodeCall3 $input),
- [(NativeCodeCall<"createOpI($_builder, $0)"> $input), (OpK)]>;
+ [(NativeCodeCall<"createOpI($_builder, $_loc, $0)"> $input),
+ (OpK)]>;
// Test AllAttrConstraintsOf.
def OpAllAttrConstraint1 : TEST_Op<"all_attr_constraint_of1"> {
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index f6607a5f55246..f2a17a9f3f5fa 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -21,8 +21,8 @@ static Value chooseOperand(Value input1, Value input2, BoolAttr choice) {
return choice.getValue() ? input1 : input2;
}
-static void createOpI(PatternRewriter &rewriter, Value input) {
- rewriter.create<OpI>(rewriter.getUnknownLoc(), input);
+static void createOpI(PatternRewriter &rewriter, Location loc, Value input) {
+ rewriter.create<OpI>(loc, input);
}
static void handleNoResultOp(PatternRewriter &rewriter,
diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp
index 5025ee5216dd6..9884d1ccb077d 100644
--- a/mlir/tools/mlir-tblgen/RewriterGen.cpp
+++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp
@@ -112,6 +112,9 @@ class PatternEmitter {
// Returns the symbol of the old value serving as the replacement.
StringRef handleReplaceWithValue(DagNode tree);
+ // Returns the location value to use.
+ std::pair<bool, std::string> getLocation(DagNode tree);
+
// Returns the location value to use.
std::string handleLocationDirective(DagNode tree);
@@ -779,13 +782,18 @@ std::string PatternEmitter::handleReplaceWithNativeCodeCall(DagNode tree) {
PrintFatalError(loc, "unsupported NativeCodeCall argument numbers: " +
Twine(tree.getNumArgs()));
}
- for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
+ bool hasLocationDirective;
+ std::string locToUse;
+ std::tie(hasLocationDirective, locToUse) = getLocation(tree);
+
+ for (int i = 0, e = tree.getNumArgs() - hasLocationDirective; i != e; ++i) {
attrs[i] = handleOpArgument(tree.getArgAsLeaf(i), tree.getArgName(i));
LLVM_DEBUG(llvm::dbgs() << "NativeCodeCall argument #" << i
<< " replacement: " << attrs[i] << "\n");
}
- return std::string(tgfmt(fmt, &fmtCtx, attrs[0], attrs[1], attrs[2], attrs[3],
- attrs[4], attrs[5], attrs[6], attrs[7]));
+ return std::string(tgfmt(fmt, &fmtCtx.addSubst("_loc", locToUse), attrs[0],
+ attrs[1], attrs[2], attrs[3], attrs[4], attrs[5],
+ attrs[6], attrs[7]));
}
int PatternEmitter::getNodeValueCount(DagNode node) {
@@ -804,6 +812,20 @@ int PatternEmitter::getNodeValueCount(DagNode node) {
return 1;
}
+std::pair<bool, std::string> PatternEmitter::getLocation(DagNode tree) {
+ auto numPatArgs = tree.getNumArgs();
+
+ if (numPatArgs != 0) {
+ if (auto lastArg = tree.getArgAsNestedDag(numPatArgs - 1))
+ if (lastArg.isLocationDirective()) {
+ return std::make_pair(true, handleLocationDirective(lastArg));
+ }
+ }
+
+ // If no explicit location is given, use the default, all fused, location.
+ return std::make_pair(false, "odsLoc");
+}
+
std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
int depth) {
LLVM_DEBUG(llvm::dbgs() << "create op for pattern: ");
@@ -814,15 +836,11 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
auto numOpArgs = resultOp.getNumArgs();
auto numPatArgs = tree.getNumArgs();
- // Get the location for this operation if explicitly provided.
+ bool hasLocationDirective;
std::string locToUse;
- if (numPatArgs != 0) {
- if (auto lastArg = tree.getArgAsNestedDag(numPatArgs - 1))
- if (lastArg.isLocationDirective())
- locToUse = handleLocationDirective(lastArg);
- }
+ std::tie(hasLocationDirective, locToUse) = getLocation(tree);
- auto inPattern = numPatArgs - !locToUse.empty();
+ auto inPattern = numPatArgs - hasLocationDirective;
if (numOpArgs != inPattern) {
PrintFatalError(loc,
formatv("resultant op '{0}' argument number mismatch: "
@@ -830,10 +848,6 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
resultOp.getOperationName(), inPattern, numOpArgs));
}
- // If no explicit location is given, use the default, all fused, location.
- if (locToUse.empty())
- locToUse = "odsLoc";
-
// A map to collect all nested DAG child nodes' names, with operand index as
// the key. This includes both bound and unbound child nodes.
ChildNodeIndexNameMap childNodeNames;
More information about the Mlir-commits
mailing list