[Mlir-commits] [mlir] 3f7439b - [mlir][DRR] Add location directive

Jacques Pienaar llvmlistbot at llvm.org
Tue Apr 7 13:38:59 PDT 2020


Author: Jacques Pienaar
Date: 2020-04-07T13:38:25-07:00
New Revision: 3f7439b28063c284975b49ebdc9c5645cedae7a0

URL: https://github.com/llvm/llvm-project/commit/3f7439b28063c284975b49ebdc9c5645cedae7a0
DIFF: https://github.com/llvm/llvm-project/commit/3f7439b28063c284975b49ebdc9c5645cedae7a0.diff

LOG: [mlir][DRR] Add location directive

Summary:
Add directive to indicate the location to give to op being created. This
directive is optional and if unused the location will still be the fused
location of all source operations.

Currently this directive only works with other op locations, reusing an
existing op location or a fusion of op locations. But doesn't yet support
supplying metadata for the FusedLoc.

Based off initial revision by antiagainst@ and effectively mirrors GlobalIsel
debug_locations directive.

Differential Revision: https://reviews.llvm.org/D77649

Added: 
    

Modified: 
    mlir/docs/DeclarativeRewrites.md
    mlir/include/mlir/IR/OpBase.td
    mlir/include/mlir/TableGen/Pattern.h
    mlir/lib/TableGen/Pattern.cpp
    mlir/test/lib/Dialect/Test/TestOps.td
    mlir/test/mlir-tblgen/pattern.mlir
    mlir/tools/mlir-tblgen/RewriterGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/docs/DeclarativeRewrites.md b/mlir/docs/DeclarativeRewrites.md
index 82b97e04a8c0..ca759f44d42c 100644
--- a/mlir/docs/DeclarativeRewrites.md
+++ b/mlir/docs/DeclarativeRewrites.md
@@ -657,9 +657,53 @@ pattern. This is based on the heuristics and assumptions that:
 The fourth parameter to `Pattern` (and `Pat`) allows to manually tweak a
 pattern's benefit. Just supply `(addBenefit N)` to add `N` to the benefit value.
 
-## Special directives
+## Rewrite directives
 
-[TODO]
+### `location`
+
+By default the C++ pattern expanded from a DRR pattern uses the fused location
+of all source ops as the location for all generated ops. This is not always the
+best location mapping relationship. For such cases, DRR provides the `location`
+directive to provide finer control.
+
+`location` is of the following syntax:
+
+```tablgen
+(location $symbol0, $symbol1, ...)
+```
+
+where all `$symbol` should be bound previously in the pattern.
+
+`location` must be used as the last argument to an op creation. For example,
+
+```tablegen
+def : Pat<(LocSrc1Op:$src1 (LocSrc2Op:$src2 ...),
+          (LocDst1Op (LocDst2Op ..., (location $src2)))>;
+```
+
+In the above pattern, the generated `LocDst2Op` will use the matched location
+of `LocSrc2Op` while the root `LocDst1Op` node will still se the fused location
+of all source Ops.
+
+### `replaceWithValue`
+
+The `replaceWithValue` directive is used to eliminate a matched op by replacing
+all of it uses with a captured value. It is of the following syntax:
+
+```tablegen
+(replaceWithValue $symbol)
+```
+
+where `$symbol` should be a symbol bound previously in the pattern.
+
+For example,
+
+```tablegen
+def : Pat<(Foo $input), (replaceWithValue $input)>;
+```
+
+The above pattern removes the `Foo` and replaces all uses of `Foo` with
+`$input`.
 
 ## Debugging Tips
 

diff  --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index 25f062b02d15..09cea1b1f5ea 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -2179,9 +2179,14 @@ class NativeCodeCall<string expr> {
 }
 
 //===----------------------------------------------------------------------===//
-// Common directives
+// Rewrite directives
 //===----------------------------------------------------------------------===//
 
+// Directive used in result pattern to specify the location of the generated
+// op. This directive must be used as the last argument to the op creation
+// DAG construct. The arguments to location must be previously captured symbol.
+def location;
+
 // Directive used in result pattern to indicate that no new op are generated,
 // so to replace the matched DAG with an existing SSA value.
 def replaceWithValue;

diff  --git a/mlir/include/mlir/TableGen/Pattern.h b/mlir/include/mlir/TableGen/Pattern.h
index 0ed413368fa0..e7fa48dc8829 100644
--- a/mlir/include/mlir/TableGen/Pattern.h
+++ b/mlir/include/mlir/TableGen/Pattern.h
@@ -159,6 +159,9 @@ class DagNode {
   // value.
   bool isReplaceWithValue() const;
 
+  // Returns whether this DAG represents the location of an op creation.
+  bool isLocationDirective() const;
+
   // Returns true if this DAG node is wrapping native code call.
   bool isNativeCodeCall() const;
 

diff  --git a/mlir/lib/TableGen/Pattern.cpp b/mlir/lib/TableGen/Pattern.cpp
index 6d4b03bd0f43..5b547089efca 100644
--- a/mlir/lib/TableGen/Pattern.cpp
+++ b/mlir/lib/TableGen/Pattern.cpp
@@ -103,7 +103,7 @@ bool tblgen::DagNode::isNativeCodeCall() const {
 }
 
 bool tblgen::DagNode::isOperation() const {
-  return !(isNativeCodeCall() || isReplaceWithValue());
+  return !isNativeCodeCall() && !isReplaceWithValue() && !isLocationDirective();
 }
 
 llvm::StringRef tblgen::DagNode::getNativeCodeTemplate() const {
@@ -159,6 +159,11 @@ bool tblgen::DagNode::isReplaceWithValue() const {
   return dagOpDef->getName() == "replaceWithValue";
 }
 
+bool tblgen::DagNode::isLocationDirective() const {
+  auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef();
+  return dagOpDef->getName() == "location";
+}
+
 void tblgen::DagNode::print(raw_ostream &os) const {
   if (node)
     node->print(os);
@@ -533,7 +538,14 @@ void tblgen::Pattern::collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap,
   auto numOpArgs = op.getNumArgs();
   auto numTreeArgs = tree.getNumArgs();
 
-  if (numOpArgs != numTreeArgs) {
+  // The pattern might have the last argument specifying the location.
+  bool hasLocDirective = false;
+  if (numTreeArgs != 0) {
+    if (auto lastArg = tree.getArgAsNestedDag(numTreeArgs - 1))
+      hasLocDirective = lastArg.isLocationDirective();
+  }
+
+  if (numOpArgs != numTreeArgs - hasLocDirective) {
     auto err = formatv("op '{0}' argument number mismatch: "
                        "{1} in pattern vs. {2} in definition",
                        op.getOperationName(), numTreeArgs, numOpArgs);

diff  --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 061960959cd8..8859d50342af 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -501,6 +501,20 @@ def StringAttrPrettyNameOp
   let parser = [{ return ::parse$cppClass(parser, result); }];
 }
 
+//===----------------------------------------------------------------------===//
+// Test Locations
+//===----------------------------------------------------------------------===//
+
+def TestLocationSrcOp : TEST_Op<"loc_src"> {
+  let arguments = (ins I32:$input);
+  let results = (outs I32:$output);
+}
+
+def TestLocationDstOp : TEST_Op<"loc_dst", [SameOperandsAndResultType]> {
+  let arguments = (ins I32:$input);
+  let results = (outs I32:$output);
+}
+
 //===----------------------------------------------------------------------===//
 // Test Patterns
 //===----------------------------------------------------------------------===//
@@ -995,6 +1009,18 @@ def : Pat<(OneI32ResultOp),
               (replaceWithValue $results__2),
               ConstantAttr<I32Attr, "2">)>;
 
+//===----------------------------------------------------------------------===//
+// Test Patterns (Location)
+
+// Test that we can specify locations for generated ops.
+def : Pat<(TestLocationSrcOp:$res1
+           (TestLocationSrcOp:$res2
+            (TestLocationSrcOp:$res3 $input))),
+          (TestLocationDstOp
+            (TestLocationDstOp
+              (TestLocationDstOp $input, (location $res1))),
+	    (location $res2, $res3))>;
+
 //===----------------------------------------------------------------------===//
 // Test Legalization
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/mlir-tblgen/pattern.mlir b/mlir/test/mlir-tblgen/pattern.mlir
index 67ea2fd91809..a96d90f4ed2c 100644
--- a/mlir/test/mlir-tblgen/pattern.mlir
+++ b/mlir/test/mlir-tblgen/pattern.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -test-patterns -mlir-print-debuginfo %s | FileCheck %s
+// RUN: mlir-opt -test-patterns -mlir-print-debuginfo %s | FileCheck %s --dump-input-on-failure
 
 // CHECK-LABEL: verifyFusedLocs
 func @verifyFusedLocs(%arg0 : i32) -> i32 {
@@ -10,6 +10,21 @@ func @verifyFusedLocs(%arg0 : i32) -> i32 {
   return %result : i32
 }
 
+// CHECK-LABEL: verifyDesignatedLoc
+func @verifyDesignatedLoc(%arg0 : i32) -> i32 {
+  %0 = "test.loc_src"(%arg0) : (i32) -> i32 loc("loc3")
+  %1 = "test.loc_src"(%0) : (i32) -> i32 loc("loc2")
+  %2 = "test.loc_src"(%1) : (i32) -> i32 loc("loc1")
+
+  // CHECK: "test.loc_dst"({{.*}}) : (i32) -> i32 loc("loc1")
+  // CHECK: "test.loc_dst"({{.*}}) : (i32) -> i32 loc(fused[
+  // CHECK-SAME: "loc1"
+  // CHECK-SAME: "loc3"
+  // CHECK-SAME: "loc2"
+  // CHECK: "test.loc_dst"({{.*}}) : (i32) -> i32 loc(fused["loc2", "loc3"])
+  return %1 : i32
+}
+
 // CHECK-LABEL: verifyZeroResult
 func @verifyZeroResult(%arg0 : i32) {
   // CHECK: "test.op_i"(%arg0) : (i32) -> ()

diff  --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp
index a4843167c83c..73a85259e9ed 100644
--- a/mlir/tools/mlir-tblgen/RewriterGen.cpp
+++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp
@@ -109,9 +109,11 @@ class PatternEmitter {
   // calling native C++ code.
   std::string handleReplaceWithNativeCodeCall(DagNode resultTree);
 
-  // Returns the C++ expression referencing the old value serving as the
-  // replacement.
-  std::string handleReplaceWithValue(DagNode tree);
+  // Returns the symbol of the old value serving as the replacement.
+  StringRef handleReplaceWithValue(DagNode tree);
+
+  // Returns the symbol of the value whose location to use.
+  std::string handleUseLocationOf(DagNode tree);
 
   // Emits the C++ statement to build a new op out of the given DAG `tree` and
   // returns the variable name that this op is assigned to. If the root op in
@@ -580,11 +582,11 @@ void PatternEmitter::emitRewriteLogic() {
     PrintFatalError(loc, error);
   }
 
-  os.indent(4) << "auto loc = rewriter.getFusedLoc({";
+  os.indent(4) << "auto odsLoc = rewriter.getFusedLoc({";
   for (int i = 0, e = pattern.getSourcePattern().getNumOps(); i != e; ++i) {
     os << (i ? ", " : "") << "tblgen_ops[" << i << "]->getLoc()";
   }
-  os << "}); (void)loc;\n";
+  os << "}); (void)odsLoc;\n";
 
   // Process auxiliary result patterns.
   for (int i = 0; i < replStartIndex; ++i) {
@@ -640,15 +642,19 @@ std::string PatternEmitter::handleResultPattern(DagNode resultTree,
   LLVM_DEBUG(resultTree.print(llvm::dbgs()));
   LLVM_DEBUG(llvm::dbgs() << '\n');
 
+  if (resultTree.isLocationDirective()) {
+    PrintFatalError(loc,
+                    "location directive can only be used with op creation");
+  }
+
   if (resultTree.isNativeCodeCall()) {
     auto symbol = handleReplaceWithNativeCodeCall(resultTree);
     symbolInfoMap.bindValue(symbol);
     return symbol;
   }
 
-  if (resultTree.isReplaceWithValue()) {
-    return handleReplaceWithValue(resultTree);
-  }
+  if (resultTree.isReplaceWithValue())
+    return handleReplaceWithValue(resultTree).str();
 
   // Normal op creation.
   auto symbol = handleOpCreation(resultTree, resultIndex, depth);
@@ -660,7 +666,7 @@ std::string PatternEmitter::handleResultPattern(DagNode resultTree,
   return symbol;
 }
 
-std::string PatternEmitter::handleReplaceWithValue(DagNode tree) {
+StringRef PatternEmitter::handleReplaceWithValue(DagNode tree) {
   assert(tree.isReplaceWithValue());
 
   if (tree.getNumArgs() != 1) {
@@ -672,7 +678,30 @@ std::string PatternEmitter::handleReplaceWithValue(DagNode tree) {
     PrintFatalError(loc, "cannot bind symbol to replaceWithValue");
   }
 
-  return std::string(tree.getArgName(0));
+  return tree.getArgName(0);
+}
+
+std::string PatternEmitter::handleUseLocationOf(DagNode tree) {
+  assert(tree.isLocationDirective());
+  auto lookUpArgLoc = [this, &tree](int idx) {
+    const auto *const lookupFmt = "(*{0}.begin()).getLoc()";
+    return symbolInfoMap.getAllRangeUse(tree.getArgName(idx), lookupFmt);
+  };
+
+  if (tree.getNumArgs() != 1) {
+    std::string ret;
+    llvm::raw_string_ostream os(ret);
+    os << "rewriter.getFusedLoc({";
+    for (int i = 0, e = tree.getNumArgs(); i != e; ++i)
+      os << (i ? ", " : "") << lookUpArgLoc(i);
+    os << "})";
+    return os.str();
+  }
+
+  if (!tree.getSymbol().empty())
+    PrintFatalError(loc, "cannot bind symbol to location");
+
+  return lookUpArgLoc(0);
 }
 
 std::string PatternEmitter::handleOpArgument(DagLeaf leaf,
@@ -753,14 +782,28 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
 
   Operator &resultOp = tree.getDialectOp(opMap);
   auto numOpArgs = resultOp.getNumArgs();
+  auto numPatArgs = tree.getNumArgs();
+
+  // Get the location for this operation if explicitly provided.
+  std::string locToUse;
+  if (numPatArgs != 0) {
+    if (auto lastArg = tree.getArgAsNestedDag(numPatArgs - 1))
+      if (lastArg.isLocationDirective())
+        locToUse = handleUseLocationOf(lastArg);
+  }
 
-  if (numOpArgs != tree.getNumArgs()) {
-    PrintFatalError(loc, formatv("resultant op '{0}' argument number mismatch: "
-                                 "{1} in pattern vs. {2} in definition",
-                                 resultOp.getOperationName(), tree.getNumArgs(),
-                                 numOpArgs));
+  auto inPattern = numPatArgs - !locToUse.empty();
+  if (numOpArgs != inPattern) {
+    PrintFatalError(loc,
+                    formatv("resultant op '{0}' argument number mismatch: "
+                            "{1} in pattern vs. {2} in definition",
+                            resultOp.getOperationName(), inPattern, numOpArgs));
   }
 
+  // If no explicit location is given, use the default, all fused, location.
+  if (locToUse.empty())
+    locToUse = "odsLoc";
+
   // A map to collect all nested DAG child nodes' names, with operand index as
   // the key. This includes both bound and unbound child nodes.
   ChildNodeIndexNameMap childNodeNames;
@@ -769,9 +812,8 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
   // create ops for them and remember the symbol names for them, so that we can
   // use the results in the current node. This happens in a recursive manner.
   for (int i = 0, e = resultOp.getNumOperands(); i != e; ++i) {
-    if (auto child = tree.getArgAsNestedDag(i)) {
+    if (auto child = tree.getArgAsNestedDag(i))
       childNodeNames[i] = handleResultPattern(child, i, depth + 1);
-    }
   }
 
   // The name of the local variable holding this op.
@@ -811,10 +853,11 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
 
     // First prepare local variables for op arguments used in builder call.
     createAggregateLocalVarsForOpArgs(tree, childNodeNames);
+
     // Then create the op.
     os.indent(6) << formatv(
-        "{0} = rewriter.create<{1}>(loc, tblgen_values, tblgen_attrs);\n",
-        valuePackName, resultOp.getQualCppClassName());
+        "{0} = rewriter.create<{1}>({2}, tblgen_values, tblgen_attrs);\n",
+        valuePackName, resultOp.getQualCppClassName(), locToUse);
     os.indent(4) << "}\n";
     return resultValue;
   }
@@ -831,8 +874,9 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
     // here given that it's easier for developers to write compared to
     // aggregate-parameter builders.
     createSeparateLocalVarsForOpArgs(tree, childNodeNames);
-    os.indent(6) << formatv("{0} = rewriter.create<{1}>(loc", valuePackName,
-                            resultOp.getQualCppClassName());
+
+    os.indent(6) << formatv("{0} = rewriter.create<{1}>({2}", valuePackName,
+                            resultOp.getQualCppClassName(), locToUse);
     supplyValuesForOpArgs(tree, childNodeNames);
     os << "\n      );\n";
     os.indent(4) << "}\n";
@@ -858,9 +902,10 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
                               "tblgen_types.push_back(v.getType()); }\n",
                               resultIndex + i);
   }
-  os.indent(6) << formatv("{0} = rewriter.create<{1}>(loc, tblgen_types, "
+  os.indent(6) << formatv("{0} = rewriter.create<{1}>({2}, tblgen_types, "
                           "tblgen_values, tblgen_attrs);\n",
-                          valuePackName, resultOp.getQualCppClassName());
+                          valuePackName, resultOp.getQualCppClassName(),
+                          locToUse);
   os.indent(4) << "}\n";
   return resultValue;
 }


        


More information about the Mlir-commits mailing list