[Mlir-commits] [mlir] 5bc9cc1 - [drr] Enable specifying range in NativeCodeCall replacement.

Jacques Pienaar llvmlistbot at llvm.org
Mon Jun 28 13:42:31 PDT 2021


Author: Jacques Pienaar
Date: 2021-06-28T13:42:16-07:00
New Revision: 5bc9cc1332aa042b68fb5efa9fb50eaaf2d54f79

URL: https://github.com/llvm/llvm-project/commit/5bc9cc1332aa042b68fb5efa9fb50eaaf2d54f79
DIFF: https://github.com/llvm/llvm-project/commit/5bc9cc1332aa042b68fb5efa9fb50eaaf2d54f79.diff

LOG: [drr] Enable specifying range in NativeCodeCall replacement.

This enables creating a replacement rule where range of positional replacements
need not be spelled out, or are not known (e.g., enable having a rewrite that
forward all operands to a call generically).

Differential Revision: https://reviews.llvm.org/D104955

Added: 
    

Modified: 
    mlir/docs/DeclarativeRewrites.md
    mlir/include/mlir/TableGen/Format.h
    mlir/lib/TableGen/Format.cpp
    mlir/test/mlir-tblgen/rewriter-indexing.td

Removed: 
    


################################################################################
diff  --git a/mlir/docs/DeclarativeRewrites.md b/mlir/docs/DeclarativeRewrites.md
index 27ae161978c8a..5815035ca77e8 100644
--- a/mlir/docs/DeclarativeRewrites.md
+++ b/mlir/docs/DeclarativeRewrites.md
@@ -11,8 +11,8 @@ compiler build time.
 This manual explains in detail all of the available mechanisms for defining
 rewrite rules in such a declarative manner. It aims to be a specification
 instead of a tutorial. Please refer to
-[Quickstart tutorial to adding MLIR graph
-rewrite](Tutorials/QuickstartRewrites.md) for the latter.
+[Quickstart tutorial to adding MLIR graph rewrite](Tutorials/QuickstartRewrites.md)
+for the latter.
 
 Given that declarative rewrite rules depend on op definition specification, this
 manual assumes knowledge of the [ODS](OpDefinitions.md) doc.
@@ -51,8 +51,8 @@ features:
 *   Matching multi-result ops in nested patterns.
 *   Matching and generating variadic operand/result ops in nested patterns.
 *   Packing and unpacking variadic operands/results during generation.
-*   [`NativeCodeCall`](#nativecodecall-transforming-the-generated-op)
-    returning more than one results.
+*   [`NativeCodeCall`](#nativecodecall-transforming-the-generated-op) returning
+    more than one results.
 
 ## Rule Definition
 
@@ -93,9 +93,9 @@ Each pattern is specified as a TableGen `dag` object with the syntax of
 [directives](#rewrite-directives). `argN` is for matching (if used in source
 pattern) or generating (if used in result pattern) the `N`-th argument for
 `operator`. If the `operator` is some MLIR operation, it means the `N`-th
-argument as specified in the `arguments` list of the op's definition.
-Therefore, we say op argument specification in pattern is **position-based**:
-the position where they appear matters.
+argument as specified in the `arguments` list of the op's definition. Therefore,
+we say op argument specification in pattern is **position-based**: the position
+where they appear matters.
 
 `argN` can be a `dag` object itself, thus we can have nested `dag` tree to model
 the def-use relationship between ops.
@@ -245,15 +245,15 @@ the pattern by following the exact same order as the ODS `arguments` definition.
 Otherwise, a custom `build()` method that matches the argument list is required.
 
 Right now all ODS-generated `build()` methods require specifying the result
-type(s), unless the op has known traits like `SameOperandsAndResultType` that
-we can use to auto-generate a `build()` method with result type deduction.
-When generating an op to replace the result of the matched root op, we can use
-the matched root op's result type when calling the ODS-generated builder.
-Otherwise (e.g., generating an [auxiliary op](#supporting-auxiliary-ops) or
-generating an op with a nested result pattern), DRR will not be able to deduce
-the result type(s). The pattern author will need to define a custom builder
-that has result type deduction ability via `OpBuilder` in ODS. For example,
-in the following pattern
+type(s), unless the op has known traits like `SameOperandsAndResultType` that we
+can use to auto-generate a `build()` method with result type deduction. When
+generating an op to replace the result of the matched root op, we can use the
+matched root op's result type when calling the ODS-generated builder. Otherwise
+(e.g., generating an [auxiliary op](#supporting-auxiliary-ops) or generating an
+op with a nested result pattern), DRR will not be able to deduce the result
+type(s). The pattern author will need to define a custom builder that has result
+type deduction ability via `OpBuilder` in ODS. For example, in the following
+pattern
 
 ```tablegen
 def : Pat<(AOp $input, $attr), (COp (AOp $input, $attr) $attr)>;
@@ -295,8 +295,8 @@ to replace the matched `AOp`.
 
 In the result pattern, we can bind to the result(s) of a newly built op by
 attaching symbols to the op. (But we **cannot** bind to op arguments given that
-they are referencing previously bound symbols.) This is useful for reusing
-newly created results where suitable. For example,
+they are referencing previously bound symbols.) This is useful for reusing newly
+created results where suitable. For example,
 
 ```tablegen
 def DOp : Op<"d_op"> {
@@ -373,18 +373,18 @@ And make sure the generated C++ code from the above pattern has access to the
 definition of the C++ helper function.
 
 In the above example, we are using a string to specialize the `NativeCodeCall`
-template. The string can be an arbitrary C++ expression that evaluates into
-some C++ object expected at the `NativeCodeCall` site (here it would be
-expecting an array attribute). Typically the string should be a function call.
+template. The string can be an arbitrary C++ expression that evaluates into some
+C++ object expected at the `NativeCodeCall` site (here it would be expecting an
+array attribute). Typically the string should be a function call.
 
 Note that currently `NativeCodeCall` must return no more than one value or
 attribute. This might change in the future.
 
 ##### `NativeCodeCall` placeholders
 
-In `NativeCodeCall`, we can use placeholders like `$_builder`, `$N`. The former
-is called _special placeholder_, while the latter is called _positional
-placeholder_.
+In `NativeCodeCall`, we can use placeholders like `$_builder`, `$N` and `$N...`.
+The former is called _special placeholder_, while the latter is called
+_positional placeholder_ and _positional range placeholder_.
 
 `NativeCodeCall` right now only supports three special placeholders:
 `$_builder`, `$_loc`, and `$_self`:
@@ -423,6 +423,11 @@ the `NativeCodeCall` use site. For example, if we define `SomeCall :
 NativeCodeCall<"someFn($1, $2, $0)">` and use it like `(SomeCall $in0, $in1,
 $in2)`, then this will be translated into C++ call `someFn($in1, $in2, $in0)`.
 
+Positional range placeholders will be substituted by multiple `dag` object
+parameters at the `NativeCodeCall` use site. For example, if we define
+`SomeCall : NativeCodeCall<"someFn($1...)">` and use it like `(SomeCall $in0,
+$in1, $in2)`, then this will be translated into C++ call `someFn($in1, $in2)`.
+
 ##### Customizing entire op building
 
 `NativeCodeCall` is not only limited to transforming arguments for building an
@@ -490,8 +495,8 @@ matched op.
 
 Multi-result ops bring extra complexity to declarative rewrite rules. We use
 TableGen `dag` objects to represent ops in patterns; there is no native way to
-indicate that an op generates multiple results. The approach adopted is based
-on **naming convention**: a `__N` suffix is added to a symbol to indicate the
+indicate that an op generates multiple results. The approach adopted is based on
+**naming convention**: a `__N` suffix is added to a symbol to indicate the
 `N`-th result.
 
 #### `__N` suffix
@@ -541,12 +546,12 @@ The above example also shows how to replace a matched multi-result op.
 
 To replace an `N`-result op, the result patterns must generate at least `N`
 declared values (see [Declared vs. actual value](#declared-vs-actual-value) for
-definition). If there are more than `N` declared values generated, only the
-last `N` declared values will be used to replace the matched op. Note that
-because of the existence of multi-result op, one result pattern **may** generate
-multiple declared values. So it means we do not necessarily need `N` result
-patterns to replace an `N`-result op. For example, to replace an op with three
-results, you can have
+definition). If there are more than `N` declared values generated, only the last
+`N` declared values will be used to replace the matched op. Note that because of
+the existence of multi-result op, one result pattern **may** generate multiple
+declared values. So it means we do not necessarily need `N` result patterns to
+replace an `N`-result op. For example, to replace an op with three results, you
+can have
 
 ```tablegen
 // ThreeResultOp/TwoResultOp/OneResultOp generates three/two/one result(s),
@@ -590,8 +595,8 @@ regarding an op's values.
 *   _Actual operand/result/value_: an operand/result/value of an op instance at
     runtime
 
-The above terms are needed because ops can have multiple results, and some of the
-results can also be variadic. For example,
+The above terms are needed because ops can have multiple results, and some of
+the results can also be variadic. For example,
 
 ```tablegen
 def MultiVariadicOp : Op<"multi_variadic_op"> {
@@ -611,8 +616,8 @@ def MultiVariadicOp : Op<"multi_variadic_op"> {
 
 We say the above op has 3 declared operands and 3 declared results. But at
 runtime, an instance can have 3 values corresponding to `$input2` and 2 values
-correspond to `$output2`; we say it has 5 actual operands and 4 actual
-results. A variadic operand/result is a considered as a declared value that can
+correspond to `$output2`; we say it has 5 actual operands and 4 actual results.
+A variadic operand/result is a considered as a declared value that can
 correspond to multiple actual values.
 
 [TODO]
@@ -651,10 +656,10 @@ You can
 
 ### Adjusting benefits
 
-The benefit of a `Pattern` is an integer value indicating the benefit of matching
-the pattern. It determines the priorities of patterns inside the pattern rewrite
-driver. A pattern with a higher benefit is applied before one with a lower
-benefit.
+The benefit of a `Pattern` is an integer value indicating the benefit of
+matching the pattern. It determines the priorities of patterns inside the
+pattern rewrite driver. A pattern with a higher benefit is applied before one
+with a lower benefit.
 
 In DRR, a rule is set to have a benefit of the number of ops in the source
 pattern. This is based on the heuristics and assumptions that:
@@ -662,7 +667,6 @@ pattern. This is based on the heuristics and assumptions that:
 *   Larger matches are more beneficial than smaller ones.
 *   If a smaller one is applied first the larger one may not apply anymore.
 
-
 The fourth parameter to `Pattern` (and `Pat`) allows to manually tweak a
 pattern's benefit. Just supply `(addBenefit N)` to add `N` to the benefit value.
 
@@ -696,8 +700,8 @@ def : Pat<(LocSrc1Op:$src1 (LocSrc2Op:$src2 ...),
           (LocDst1Op (LocDst2Op ..., (location $src2)), (location "outer"))>;
 ```
 
-In the above pattern, the generated `LocDst2Op` will use the matched location
-of `LocSrc2Op` while the root `LocDst1Op` node will used the named location
+In the above pattern, the generated `LocDst2Op` will use the matched location of
+`LocSrc2Op` while the root `LocDst1Op` node will used the named location
 `outer`.
 
 ### `replaceWithValue`
@@ -724,8 +728,8 @@ The above pattern removes the `Foo` and replaces all uses of `Foo` with
 
 ### Run `mlir-tblgen` to see the generated content
 
-TableGen syntax sometimes can be obscure; reading the generated content can be
-a very helpful way to understand and debug issues. To build `mlir-tblgen`, run
+TableGen syntax sometimes can be obscure; reading the generated content can be a
+very helpful way to understand and debug issues. To build `mlir-tblgen`, run
 `cmake --build . --target mlir-tblgen` in your build directory and find the
 `mlir-tblgen` binary in the `bin/` subdirectory. All the supported generators
 can be found via `mlir-tblgen --help`.

diff  --git a/mlir/include/mlir/TableGen/Format.h b/mlir/include/mlir/TableGen/Format.h
index 441e05c29f264..3120f6ef5766c 100644
--- a/mlir/include/mlir/TableGen/Format.h
+++ b/mlir/include/mlir/TableGen/Format.h
@@ -88,22 +88,33 @@ class FmtContext {
 
 /// Struct representing a replacement segment for the formatted string. It can
 /// be a segment of the formatting template (for `Literal`) or a replacement
-/// parameter (for `PositionalPH` and `SpecialPH`).
+/// parameter (for `PositionalPH`, `PositionalRangePH` and `SpecialPH`).
 struct FmtReplacement {
-  enum class Type { Empty, Literal, PositionalPH, SpecialPH };
+  enum class Type {
+    Empty,
+    Literal,
+    PositionalPH,
+    PositionalRangePH,
+    SpecialPH
+  };
 
   FmtReplacement() = default;
   explicit FmtReplacement(StringRef literal)
       : type(Type::Literal), spec(literal) {}
   FmtReplacement(StringRef spec, size_t index)
       : type(Type::PositionalPH), spec(spec), index(index) {}
+  FmtReplacement(StringRef spec, size_t index, size_t end)
+      : type(Type::PositionalRangePH), spec(spec), index(index), end(end) {}
   FmtReplacement(StringRef spec, FmtContext::PHKind placeholder)
       : type(Type::SpecialPH), spec(spec), placeholder(placeholder) {}
 
   Type type = Type::Empty;
   StringRef spec;
   size_t index = 0;
+  size_t end = kUnset;
   FmtContext::PHKind placeholder = FmtContext::PHKind::None;
+
+  static constexpr size_t kUnset = -1;
 };
 
 class FmtObjectBase {
@@ -121,7 +132,7 @@ class FmtObjectBase {
   // std::vector<Base*>.
   struct CreateAdapters {
     template <typename... Ts>
-    std::vector<llvm::detail::format_adapter *> operator()(Ts &... items) {
+    std::vector<llvm::detail::format_adapter *> operator()(Ts &...items) {
       return std::vector<llvm::detail::format_adapter *>{&items...};
     }
   };
@@ -205,7 +216,8 @@ class FmtStrVecObject : public FmtObjectBase {
 ///
 /// There are two categories of placeholders accepted, both led by a '$' sign:
 ///
-/// 1. Positional placeholder: $[0-9]+
+/// 1.a Positional placeholder: $[0-9]+
+/// 1.b Positional range placeholder: $[0-9]+...
 /// 2. Special placeholder:    $[a-zA-Z_][a-zA-Z0-9_]*
 ///
 /// Replacement parameters for positional placeholders are supplied as the
@@ -214,6 +226,9 @@ class FmtStrVecObject : public FmtObjectBase {
 /// can use the positional placeholders in any order and repeat any times, for
 /// example, "$2 $1 $1 $0" is accepted.
 ///
+/// Replace parameters for positional range placeholders are supplied as if
+/// positional placeholders were specified with commas separating them.
+///
 /// Replacement parameters for special placeholders are supplied using the `ctx`
 /// format context.
 ///
@@ -237,7 +252,7 @@ class FmtStrVecObject : public FmtObjectBase {
 /// 2. This utility does not support format layout because it is rarely needed
 ///    in C++ code generation.
 template <typename... Ts>
-inline auto tgfmt(StringRef fmt, const FmtContext *ctx, Ts &&... vals)
+inline auto tgfmt(StringRef fmt, const FmtContext *ctx, Ts &&...vals)
     -> FmtObject<decltype(std::make_tuple(
         llvm::detail::build_format_adapter(std::forward<Ts>(vals))...))> {
   using ParamTuple = decltype(std::make_tuple(

diff  --git a/mlir/lib/TableGen/Format.cpp b/mlir/lib/TableGen/Format.cpp
index 10834510b7674..4a0bbdf7f346c 100644
--- a/mlir/lib/TableGen/Format.cpp
+++ b/mlir/lib/TableGen/Format.cpp
@@ -97,7 +97,8 @@ FmtObjectBase::splitFmtSegment(StringRef fmt) {
   // First try to see if it's a positional placeholder, and then handle special
   // placeholders.
 
-  size_t end = fmt.find_if_not([](char c) { return std::isdigit(c); }, 1);
+  size_t end =
+      fmt.find_if_not([](char c) { return std::isdigit(c); }, /*From=*/1);
   if (end != 1) {
     // We have a positional placeholder. Parse the index.
     size_t index = 0;
@@ -105,6 +106,14 @@ FmtObjectBase::splitFmtSegment(StringRef fmt) {
       llvm_unreachable("invalid replacement sequence index");
     }
 
+    // Check if this is the part of a range specification.
+    if (fmt.substr(end, 3) == "...") {
+      // Currently only ranges without upper bound are supported.
+      return {
+          FmtReplacement{fmt.substr(0, end + 3), index, FmtReplacement::kUnset},
+          fmt.substr(end + 3)};
+    }
+
     if (end == StringRef::npos) {
       // All the remaining characters are part of the positional placeholder.
       return {FmtReplacement{fmt, index}, StringRef()};
@@ -164,6 +173,20 @@ void FmtObjectBase::format(raw_ostream &s) const {
       continue;
     }
 
+    if (repl.type == FmtReplacement::Type::PositionalRangePH) {
+      if (repl.index >= adapters.size()) {
+        s << repl.spec << kMarkerForNoSubst;
+        continue;
+      }
+      auto range = llvm::makeArrayRef(adapters);
+      range = range.drop_front(repl.index);
+      if (repl.end != FmtReplacement::kUnset)
+        range = range.drop_back(adapters.size() - repl.end);
+      llvm::interleaveComma(range, s,
+                            [&](auto &x) { x->format(s, /*Options=*/""); });
+      continue;
+    }
+
     assert(repl.type == FmtReplacement::Type::PositionalPH);
 
     if (repl.index >= adapters.size()) {

diff  --git a/mlir/test/mlir-tblgen/rewriter-indexing.td b/mlir/test/mlir-tblgen/rewriter-indexing.td
index cbdeff9c743da..f4f055e1c0c40 100644
--- a/mlir/test/mlir-tblgen/rewriter-indexing.td
+++ b/mlir/test/mlir-tblgen/rewriter-indexing.td
@@ -85,3 +85,8 @@ def NativeBuilder :
 // CHECK: nativeCall(rewriter, odsLoc, (*v1.begin()), (*v2.begin()), (*v3.begin()), (*v4.begin()), (*v5.begin()), (*v6.begin()), (*v7.begin()), (*v8.begin()), (*v9.begin()), (*v10.begin()))
 def test4 : Pat<(DOp $v1, $v2, $v3, $v4, $v5, $v6, $v7, $v8, $v9, $v10),
                 (NativeBuilder $v1, $v2, $v3, $v4, $v5, $v6, $v7, $v8, $v9, $v10)>;
+
+// CHECK: struct test5 : public ::mlir::RewritePattern {
+// CHECK: foo(rewriter, (*v4.begin()), (*v5.begin()), (*v6.begin()), (*v7.begin()), (*v8.begin()), (*v9.begin()), (*v10.begin()))
+def test5 : Pat<(DOp $v1, $v2, $v3, $v4, $v5, $v6, $v7, $v8, $v9, $v10),
+                (NativeCodeCall<[{ foo($_builder, $3...) }]> $v1, $v2, $v3, $v4, $v5, $v6, $v7, $v8, $v9, $v10)>;


        


More information about the Mlir-commits mailing list