[Mlir-commits] [mlir] 29429d1 - [drr] Add $_loc special directive for NativeCodeCall

Jacques Pienaar llvmlistbot at llvm.org
Tue Aug 11 14:06:42 PDT 2020


Author: Jacques Pienaar
Date: 2020-08-11T14:06:17-07:00
New Revision: 29429d1a443a51d0e1ac4ef4033a2bcc95909ba3

URL: https://github.com/llvm/llvm-project/commit/29429d1a443a51d0e1ac4ef4033a2bcc95909ba3
DIFF: https://github.com/llvm/llvm-project/commit/29429d1a443a51d0e1ac4ef4033a2bcc95909ba3.diff

LOG: [drr] Add $_loc special directive for NativeCodeCall

Allows propagating the location to ops created via NativeCodeCall.

Differential Revision: https://reviews.llvm.org/D85704

Added: 
    

Modified: 
    mlir/docs/DeclarativeRewrites.md
    mlir/test/lib/Dialect/Test/TestOps.td
    mlir/test/lib/Dialect/Test/TestPatterns.cpp
    mlir/tools/mlir-tblgen/RewriterGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/docs/DeclarativeRewrites.md b/mlir/docs/DeclarativeRewrites.md
index 4781f9f2648e8..7e4f675166ab7 100644
--- a/mlir/docs/DeclarativeRewrites.md
+++ b/mlir/docs/DeclarativeRewrites.md
@@ -384,10 +384,12 @@ In `NativeCodeCall`, we can use placeholders like `$_builder`, `$N`. The former
 is called _special placeholder_, while the latter is called _positional
 placeholder_.
 
-`NativeCodeCall` right now only supports two special placeholders: `$_builder`
-and `$_self`:
+`NativeCodeCall` right now only supports three special placeholders:
+`$_builder`, `$_loc`, and `$_self`:
 
 *   `$_builder` will be replaced by the current `mlir::PatternRewriter`.
+*   `$_loc` will be replaced by the fused location or custom location (as
+    determined by location directive).
 *   `$_self` will be replaced with the entity `NativeCodeCall` is attached to.
 
 We have seen how `$_builder` can be used in the above; it allows us to pass a

diff  --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index c1bc754da804e..9a94311776da2 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -724,7 +724,8 @@ def OpNativeCodeCall3 : TEST_Op<"native_code_call3"> {
 // Test that NativeCodeCall is not ignored if it is not used to directly
 // replace the matched root op.
 def : Pattern<(OpNativeCodeCall3 $input),
-              [(NativeCodeCall<"createOpI($_builder, $0)"> $input), (OpK)]>;
+              [(NativeCodeCall<"createOpI($_builder, $_loc, $0)"> $input),
+               (OpK)]>;
 
 // Test AllAttrConstraintsOf.
 def OpAllAttrConstraint1 : TEST_Op<"all_attr_constraint_of1"> {

diff  --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index f6607a5f55246..f2a17a9f3f5fa 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -21,8 +21,8 @@ static Value chooseOperand(Value input1, Value input2, BoolAttr choice) {
   return choice.getValue() ? input1 : input2;
 }
 
-static void createOpI(PatternRewriter &rewriter, Value input) {
-  rewriter.create<OpI>(rewriter.getUnknownLoc(), input);
+static void createOpI(PatternRewriter &rewriter, Location loc, Value input) {
+  rewriter.create<OpI>(loc, input);
 }
 
 static void handleNoResultOp(PatternRewriter &rewriter,

diff  --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp
index 5025ee5216dd6..9884d1ccb077d 100644
--- a/mlir/tools/mlir-tblgen/RewriterGen.cpp
+++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp
@@ -112,6 +112,9 @@ class PatternEmitter {
   // Returns the symbol of the old value serving as the replacement.
   StringRef handleReplaceWithValue(DagNode tree);
 
+  // Returns the location value to use.
+  std::pair<bool, std::string> getLocation(DagNode tree);
+
   // Returns the location value to use.
   std::string handleLocationDirective(DagNode tree);
 
@@ -779,13 +782,18 @@ std::string PatternEmitter::handleReplaceWithNativeCodeCall(DagNode tree) {
     PrintFatalError(loc, "unsupported NativeCodeCall argument numbers: " +
                              Twine(tree.getNumArgs()));
   }
-  for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
+  bool hasLocationDirective;
+  std::string locToUse;
+  std::tie(hasLocationDirective, locToUse) = getLocation(tree);
+
+  for (int i = 0, e = tree.getNumArgs() - hasLocationDirective; i != e; ++i) {
     attrs[i] = handleOpArgument(tree.getArgAsLeaf(i), tree.getArgName(i));
     LLVM_DEBUG(llvm::dbgs() << "NativeCodeCall argument #" << i
                             << " replacement: " << attrs[i] << "\n");
   }
-  return std::string(tgfmt(fmt, &fmtCtx, attrs[0], attrs[1], attrs[2], attrs[3],
-                           attrs[4], attrs[5], attrs[6], attrs[7]));
+  return std::string(tgfmt(fmt, &fmtCtx.addSubst("_loc", locToUse), attrs[0],
+                           attrs[1], attrs[2], attrs[3], attrs[4], attrs[5],
+                           attrs[6], attrs[7]));
 }
 
 int PatternEmitter::getNodeValueCount(DagNode node) {
@@ -804,6 +812,20 @@ int PatternEmitter::getNodeValueCount(DagNode node) {
   return 1;
 }
 
+std::pair<bool, std::string> PatternEmitter::getLocation(DagNode tree) {
+  auto numPatArgs = tree.getNumArgs();
+
+  if (numPatArgs != 0) {
+    if (auto lastArg = tree.getArgAsNestedDag(numPatArgs - 1))
+      if (lastArg.isLocationDirective()) {
+        return std::make_pair(true, handleLocationDirective(lastArg));
+      }
+  }
+
+  // If no explicit location is given, use the default, all fused, location.
+  return std::make_pair(false, "odsLoc");
+}
+
 std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
                                              int depth) {
   LLVM_DEBUG(llvm::dbgs() << "create op for pattern: ");
@@ -814,15 +836,11 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
   auto numOpArgs = resultOp.getNumArgs();
   auto numPatArgs = tree.getNumArgs();
 
-  // Get the location for this operation if explicitly provided.
+  bool hasLocationDirective;
   std::string locToUse;
-  if (numPatArgs != 0) {
-    if (auto lastArg = tree.getArgAsNestedDag(numPatArgs - 1))
-      if (lastArg.isLocationDirective())
-        locToUse = handleLocationDirective(lastArg);
-  }
+  std::tie(hasLocationDirective, locToUse) = getLocation(tree);
 
-  auto inPattern = numPatArgs - !locToUse.empty();
+  auto inPattern = numPatArgs - hasLocationDirective;
   if (numOpArgs != inPattern) {
     PrintFatalError(loc,
                     formatv("resultant op '{0}' argument number mismatch: "
@@ -830,10 +848,6 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
                             resultOp.getOperationName(), inPattern, numOpArgs));
   }
 
-  // If no explicit location is given, use the default, all fused, location.
-  if (locToUse.empty())
-    locToUse = "odsLoc";
-
   // A map to collect all nested DAG child nodes' names, with operand index as
   // the key. This includes both bound and unbound child nodes.
   ChildNodeIndexNameMap childNodeNames;


        


More information about the Mlir-commits mailing list