[Mlir-commits] [mlir] 616c86a - [mlir][drr] Set operand segment in rewrite

Jacques Pienaar llvmlistbot at llvm.org
Thu Oct 19 13:06:25 PDT 2023


Author: Jacques Pienaar
Date: 2023-10-19T13:06:17-07:00
New Revision: 616c86accbf4c9ada37da6fb6b04554dec0fffee

URL: https://github.com/llvm/llvm-project/commit/616c86accbf4c9ada37da6fb6b04554dec0fffee
DIFF: https://github.com/llvm/llvm-project/commit/616c86accbf4c9ada37da6fb6b04554dec0fffee.diff

LOG: [mlir][drr] Set operand segment in rewrite

This allows some basic variadic operands in rewrites. There were some workarounds employed (like "aliasing" the attribute). Couldn't find a way to do this directly with properties.

Added: 
    

Modified: 
    mlir/test/lib/Dialect/Test/TestOps.td
    mlir/test/mlir-tblgen/pattern.mlir
    mlir/tools/mlir-tblgen/RewriterGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index edb63924b3553f2..1add9bd3c329438 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -2495,6 +2495,29 @@ def TestDefaultStrAttrHasValueOp : TEST_Op<"has_str_value"> {
 def : Pat<(TestDefaultStrAttrNoValueOp $value),
           (TestDefaultStrAttrHasValueOp ConstantStrAttr<StrAttr, "foo">)>;
 
+//===----------------------------------------------------------------------===//
+// Test Ops with variadics
+//===----------------------------------------------------------------------===//
+
+def TestVariadicRewriteSrcOp : TEST_Op<"variadic_rewrite_src_op", [AttrSizedOperandSegments]> {
+  let arguments = (ins
+    Variadic<AnyType>:$arg,
+    AnyType:$brg,
+    Variadic<AnyType>:$crg
+  );
+}
+
+def TestVariadicRewriteDstOp : TEST_Op<"variadic_rewrite_dst_op", [AttrSizedOperandSegments]> {
+  let arguments = (ins
+    AnyType:$brg,
+    Variadic<AnyType>:$crg,
+    Variadic<AnyType>:$arg
+  );
+}
+
+def : Pat<(TestVariadicRewriteSrcOp $arg, $brg, $crg),
+          (TestVariadicRewriteDstOp $brg, $crg, $arg)>;
+
 //===----------------------------------------------------------------------===//
 // Test Ops with Default-Valued Attributes and Differing Print Settings
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/mlir-tblgen/pattern.mlir b/mlir/test/mlir-tblgen/pattern.mlir
index 5f776338bd40be8..7f9c450f15b21e8 100644
--- a/mlir/test/mlir-tblgen/pattern.mlir
+++ b/mlir/test/mlir-tblgen/pattern.mlir
@@ -683,3 +683,16 @@ func.func @testConstantStrAttr() -> () {
   test.no_str_value {value = "bar"}
   return
 }
+
+//===----------------------------------------------------------------------===//
+// Test that patterns with variadics propagate sizes
+//===----------------------------------------------------------------------===//
+
+func.func @testVariadic(%arg_0: i32, %arg_1: i32, %brg: i64,
+    %crg_0: f32, %crg_1: f32, %crg_2: f32, %crg_3: f32) -> () {
+  // CHECK: "test.variadic_rewrite_dst_op"(%arg2, %arg3, %arg4, %arg5, %arg6, %arg0, %arg1) <{operandSegmentSizes = array<i32: 1, 4, 2>}> : (i64, f32, f32, f32, f32, i32, i32) -> ()
+  "test.variadic_rewrite_src_op"(%arg_0, %arg_1, %brg,
+    %crg_0, %crg_1, %crg_2, %crg_3) {operandSegmentSizes = array<i32: 2, 1, 4>} :
+    (i32, i32, i64, f32, f32, f32, f32) -> ()
+  return
+}

diff  --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp
index 9f36a3b430274a3..77c34cb03e987ea 100644
--- a/mlir/tools/mlir-tblgen/RewriterGen.cpp
+++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp
@@ -1743,10 +1743,15 @@ void PatternEmitter::createAggregateLocalVarsForOpArgs(
       "if (auto tmpAttr = {1}) {\n"
       "  tblgen_attrs.emplace_back(rewriter.getStringAttr(\"{0}\"), "
       "tmpAttr);\n}\n";
+  int numVariadic = 0;
+  bool hasOperandSegmentSizes = false;
+  std::vector<std::string> sizes;
   for (int argIndex = 0, e = resultOp.getNumArgs(); argIndex < e; ++argIndex) {
     if (resultOp.getArg(argIndex).is<NamedAttribute *>()) {
       // The argument in the op definition.
       auto opArgName = resultOp.getArgName(argIndex);
+      hasOperandSegmentSizes =
+          hasOperandSegmentSizes || opArgName == "operandSegmentSizes";
       if (auto subTree = node.getArgAsNestedDag(argIndex)) {
         if (!subTree.isNativeCodeCall())
           PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node "
@@ -1766,6 +1771,7 @@ void PatternEmitter::createAggregateLocalVarsForOpArgs(
         resultOp.getArg(argIndex).get<NamedTypeConstraint *>();
     std::string varName;
     if (operand->isVariadic()) {
+      ++numVariadic;
       std::string range;
       if (node.isNestedDagArg(argIndex)) {
         range = childNodeNames.lookup(argIndex);
@@ -1777,7 +1783,9 @@ void PatternEmitter::createAggregateLocalVarsForOpArgs(
       range = symbolInfoMap.getValueAndRangeUse(range);
       os << formatv("for (auto v: {0}) {{\n  tblgen_values.push_back(v);\n}\n",
                     range);
+      sizes.push_back(formatv("static_cast<int32_t>({0}.size())", range));
     } else {
+      sizes.push_back("1");
       os << formatv("tblgen_values.push_back(");
       if (node.isNestedDagArg(argIndex)) {
         os << symbolInfoMap.getValueAndRangeUse(
@@ -1804,6 +1812,19 @@ void PatternEmitter::createAggregateLocalVarsForOpArgs(
       os << ");\n";
     }
   }
+
+  if (numVariadic > 1 && !hasOperandSegmentSizes) {
+    // Only set size if it can't be computed.
+    const auto *sameVariadicSize =
+        resultOp.getTrait("::mlir::OpTrait::SameVariadicOperandSize");
+    if (!sameVariadicSize) {
+      const char *setSizes = R"(
+        tblgen_attrs.emplace_back(rewriter.getStringAttr("operandSegmentSizes"),
+          rewriter.getDenseI32ArrayAttr({{ {0} }));
+          )";
+      os.printReindented(formatv(setSizes, llvm::join(sizes, ", ")).str());
+    }
+  }
 }
 
 StaticMatcherHelper::StaticMatcherHelper(raw_ostream &os,


        


More information about the Mlir-commits mailing list