[Mlir-commits] [mlir] 0581ab6 - [mlir][linalg][transform] Support matching of attributes (and their values)

Matthias Springer llvmlistbot at llvm.org
Fri Aug 12 05:55:09 PDT 2022


Author: Matthias Springer
Date: 2022-08-12T14:55:00+02:00
New Revision: 0581ab65ea049069faaa103226948f033fb3fda6

URL: https://github.com/llvm/llvm-project/commit/0581ab65ea049069faaa103226948f033fb3fda6
DIFF: https://github.com/llvm/llvm-project/commit/0581ab65ea049069faaa103226948f033fb3fda6.diff

LOG: [mlir][linalg][transform] Support matching of attributes (and their values)

Do not just check if an attribute exists on the payload op. Also check its value.

Differential Revision: https://reviews.llvm.org/D131760

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
    mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
    mlir/test/Dialect/Linalg/transform-op-match.mlir
    mlir/test/Dialect/Linalg/transform-tile-and-fuse.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 0ad37ada457d7..11c9b898ef58d 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -197,21 +197,21 @@ def MatchOp : Op<Transform_Dialect, "structured.match",
       - ops: an optional StrArrayAttr specifying the concrete name of an op.
         Multiple names can be specified. Matched ops must have one of specified
         names.
-      - attribute: an optional Str specifying the name of an attribute that
-        matched ops must have.
+      - attribute: the matched op must have all specified attributes (with their
+        specified values).
       
     Note: Only ops that satisfy all specified constraints are matched.
 
     TODO: Extend with regions to allow a limited form of constraints.
 
     #### Return modes
-    
+
     This op traverses the ops nested under `target` and returns the handles to
     all the operations that match the requirements.
 
     This op fails if the target is not a handle to exactly one operation. 
     Otherwise it succeeds.
-  
+
     This operation does not consume the target handle and produces new handles:
     it is a navigation op.
   }];
@@ -219,14 +219,14 @@ def MatchOp : Op<Transform_Dialect, "structured.match",
   let arguments = (ins PDL_Operation:$target,
                        OptionalAttr<StrArrayAttr>:$ops,
                        OptionalAttr<MatchInterfaceEnum>:$interface,
-                       OptionalAttr<StrAttr>:$attribute);
+                       OptionalAttr<DictionaryAttr>:$op_attrs);
   // TODO: variadic results when needed.
   let results = (outs PDL_Operation:$results);
 
   let assemblyFormat = [{
-    (`ops` `{` $ops^ `}`)? 
-    (`interface` `{` $interface^ `}`)? 
-    (`attribute` `{` $attribute^ `}`)? 
+    (`ops` `{` $ops^ `}`)?
+    (`interface` `{` $interface^ `}`)?
+    (`attributes` $op_attrs^)?
     `in` $target attr-dict
   }];
 }

diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 8f2fb306030bd..328c4141e50fa 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -467,8 +467,19 @@ transform::MatchOp::apply(transform::TransformResults &results,
         return WalkResult::advance();
     }
 
-    if (getAttribute().has_value() && !op->hasAttr(getAttribute().value()))
-      return WalkResult::advance();
+    // Check if all specified attributes match.
+    if (getOpAttrs().has_value()) {
+      DictionaryAttr opAttrs = getOpAttrs().value();
+      for (NamedAttribute attr : opAttrs) {
+        if (attr.getName() == getInterfaceAttrName() ||
+            attr.getName() == getOpsAttrName())
+          continue;
+        if (!op->hasAttr(attr.getName()))
+          return WalkResult::advance();
+        if (op->getAttr(attr.getName()) != attr.getValue())
+          return WalkResult::advance();
+      }
+    }
 
     // All constraints are satisfied.
     res.push_back(op);

diff  --git a/mlir/test/Dialect/Linalg/transform-op-match.mlir b/mlir/test/Dialect/Linalg/transform-op-match.mlir
index 06af6c8dad96a..002f142f55d1b 100644
--- a/mlir/test/Dialect/Linalg/transform-op-match.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-match.mlir
@@ -17,8 +17,45 @@ transform.with_pdl_patterns {
     transform.test_print_remark_at_operand %match_name, "matched op name"
     transform.test_consume_operand %match_name
 
-    %match_attr = transform.structured.match ops{["arith.constant"]} attribute{"my_attr"} in %arg1
+    %match_attr = transform.structured.match ops{["arith.constant"]} attributes{my_attr} in %arg1
     transform.test_print_remark_at_operand %match_attr, "matched attr name"
     transform.test_consume_operand %match_attr
   }
 }
+
+// -----
+
+#map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d1, d0, d2)>
+func.func @match_complex_attribute(%arg0: tensor<12x128x32xf32>)
+    -> tensor<128x12x32xf32> {
+  %0 = linalg.init_tensor [128, 12, 32] : tensor<128x12x32xf32>
+  // expected-remark @below {{matched complex attr}}
+  %1 = linalg.generic {indexing_maps = [#map0, #map1],
+                       iterator_types = ["parallel", "parallel", "parallel"]}
+    ins(%arg0 : tensor<12x128x32xf32>)
+    outs(%0 : tensor<128x12x32xf32>) {
+  ^bb0(%arg1: f32, %arg2: f32):
+    linalg.yield %arg1 : f32
+  } -> tensor<128x12x32xf32>
+  return %1 : tensor<128x12x32xf32>
+}
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+  transform.sequence %arg0 {
+  ^bb1(%arg1: !pdl.operation):
+    %match_attr = transform.structured.match
+        ops{["linalg.generic"]}
+        attributes{iterator_types = ["parallel", "parallel", "parallel"]}
+        in %arg1
+    transform.test_print_remark_at_operand %match_attr, "matched complex attr"
+    transform.test_consume_operand %match_attr
+
+    %no_match = transform.structured.match
+        attributes{iterator_types = ["parallel", "parallel", "reduction"]}
+        in %arg1
+  // expected-remark @below {{0}}
+    transform.test_print_number_of_associated_payload_ir_ops %no_match
+  }
+}

diff  --git a/mlir/test/Dialect/Linalg/transform-tile-and-fuse.mlir b/mlir/test/Dialect/Linalg/transform-tile-and-fuse.mlir
index 1109950916ed8..1f6185f664661 100644
--- a/mlir/test/Dialect/Linalg/transform-tile-and-fuse.mlir
+++ b/mlir/test/Dialect/Linalg/transform-tile-and-fuse.mlir
@@ -45,8 +45,8 @@ module {
     transform.sequence %arg0 {
     ^bb1(%arg1: !pdl.operation):
       // Find the root and all producers.
-      %root = transform.structured.match attribute{"__root__"} in %arg1
-      %producers = transform.structured.match attribute{"__producer__"} in %arg1
+      %root = transform.structured.match attributes{"__root__"} in %arg1
+      %producers = transform.structured.match attributes{"__producer__"} in %arg1
 
       // Tile the root.
       %foreach_thread_op, %tiled_op = transform.structured.tile_to_foreach_thread_op %root num_threads [10, 20]
@@ -105,8 +105,8 @@ module {
     transform.sequence %arg0 {
     ^bb1(%arg1: !pdl.operation):
       // Find the root and all producers.
-      %root = transform.structured.match attribute{"__root__"} in %arg1
-      %producers = transform.structured.match attribute{"__producer__"} in %arg1
+      %root = transform.structured.match attributes{"__root__"} in %arg1
+      %producers = transform.structured.match attributes{"__producer__"} in %arg1
       %reversed_producers = transform.test_reverse_payload_ops %producers
 
       // Tile the root.


        


More information about the Mlir-commits mailing list