[Mlir-commits] [mlir] 726d076 - [mlir][transform] ApplyPatternsOp: Add check to prevent modifying the transform IR

Matthias Springer llvmlistbot at llvm.org
Mon Jun 19 00:04:03 PDT 2023


Author: Matthias Springer
Date: 2023-06-19T08:58:40+02:00
New Revision: 726d076784a85a64f33f52d7714e6af8a47dae83

URL: https://github.com/llvm/llvm-project/commit/726d076784a85a64f33f52d7714e6af8a47dae83
DIFF: https://github.com/llvm/llvm-project/commit/726d076784a85a64f33f52d7714e6af8a47dae83.diff

LOG: [mlir][transform] ApplyPatternsOp: Add check to prevent modifying the transform IR

Add an extra check to make sure that transform IR is not getting modified by this op while it is being interpreted. This generally dangerous and we may want to enforce this for all transform ops that modify the payload in the future.

Users should generally try to apply patterns only to the piece of IR where it is needed (e.g., a matched function) and not the entire module (which may contain the transform IR).

This revision is in response to a crash in a downstream compiler that was caused by a dead `transform.structured.match` op that was removed by the GreedyPatternRewriteDriver's DCE while the enclosing sequence was being interpreted.

Differential Revision: https://reviews.llvm.org/D153113

Added: 
    

Modified: 
    mlir/lib/Dialect/Transform/IR/TransformOps.cpp
    mlir/test/Dialect/Tensor/fold-empty-op.mlir
    mlir/test/Dialect/Tensor/rewrite-as-constant.mlir
    mlir/test/Dialect/Transform/test-pattern-application.mlir
    mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir
    mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir
    mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir
    mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir
    mlir/test/Dialect/Vector/vector-transfer-full-partial-split-copy-transform.mlir
    mlir/test/Dialect/Vector/vector-transfer-full-partial-split.mlir
    mlir/test/Dialect/Vector/vector-transfer-tensor-slice-transforms.mlir
    mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
    mlir/test/Dialect/Vector/vector-transpose-lowering.mlir
    mlir/test/Integration/Dialect/Vector/CPU/test-shuffle16x16.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index fe07fb3160726..d4852eebf706e 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -239,6 +239,22 @@ DiagnosedSilenceableFailure
 transform::ApplyPatternsOp::applyToOne(Operation *target,
                                        ApplyToEachResultList &results,
                                        transform::TransformState &state) {
+  // Make sure that this transform is not applied to itself. Modifying the
+  // transform IR while it is being interpreted is generally dangerous. Even
+  // more so for the ApplyPatternsOp because the GreedyPatternRewriteDriver
+  // performs many additional simplifications such as dead code elimination.
+  Operation *transformAncestor = getOperation();
+  while (transformAncestor) {
+    if (transformAncestor == target) {
+      DiagnosedDefiniteFailure diag =
+          emitDefiniteFailure()
+          << "cannot apply transform to itself (or one of its ancestors)";
+      diag.attachNote(target->getLoc()) << "target payload op";
+      return diag;
+    }
+    transformAncestor = transformAncestor->getParentOp();
+  }
+
   // Gather all specified patterns.
   MLIRContext *ctx = target->getContext();
   RewritePatternSet patterns(ctx);

diff  --git a/mlir/test/Dialect/Tensor/fold-empty-op.mlir b/mlir/test/Dialect/Tensor/fold-empty-op.mlir
index c3461d250e8ba..f0a6635c0ed89 100644
--- a/mlir/test/Dialect/Tensor/fold-empty-op.mlir
+++ b/mlir/test/Dialect/Tensor/fold-empty-op.mlir
@@ -2,7 +2,9 @@
 
 transform.sequence failures(propagate) {
 ^bb1(%module_op: !transform.any_op):
-  transform.apply_patterns to %module_op {
+  %func_op = transform.structured.match ops{["func.func"]} in %module_op
+      : (!transform.any_op) -> !transform.any_op
+  transform.apply_patterns to %func_op {
     transform.apply_patterns.tensor.fold_tensor_empty
   } : !transform.any_op
 }
@@ -66,7 +68,9 @@ func.func @rank_reducing_empty_tensor_extract(%sz : index, %idx : index) -> tens
 
 transform.sequence failures(propagate) {
 ^bb1(%module_op: !transform.any_op):
-  transform.apply_patterns to %module_op {
+  %func_op = transform.structured.match ops{["func.func"]} in %module_op
+      : (!transform.any_op) -> !transform.any_op
+  transform.apply_patterns to %func_op {
     transform.apply_patterns.tensor.fold_tensor_empty
         {fold_single_use_only = true}
   } : !transform.any_op

diff  --git a/mlir/test/Dialect/Tensor/rewrite-as-constant.mlir b/mlir/test/Dialect/Tensor/rewrite-as-constant.mlir
index 60930a5637ba7..066420bb93661 100644
--- a/mlir/test/Dialect/Tensor/rewrite-as-constant.mlir
+++ b/mlir/test/Dialect/Tensor/rewrite-as-constant.mlir
@@ -2,7 +2,9 @@
 
 transform.sequence failures(propagate) {
 ^bb1(%module_op: !transform.any_op):
-  transform.apply_patterns to %module_op {
+  %func_op = transform.structured.match ops{["func.func"]} in %module_op
+      : (!transform.any_op) -> !transform.any_op
+  transform.apply_patterns to %func_op {
     transform.apply_patterns.tensor.rewrite_as_constant
   } : !transform.any_op
 }

diff  --git a/mlir/test/Dialect/Transform/test-pattern-application.mlir b/mlir/test/Dialect/Transform/test-pattern-application.mlir
index ca277abcc9a8b..5d417f0ce6572 100644
--- a/mlir/test/Dialect/Transform/test-pattern-application.mlir
+++ b/mlir/test/Dialect/Transform/test-pattern-application.mlir
@@ -155,3 +155,22 @@ transform.sequence failures(propagate) {
   } : !transform.any_op
   transform.test_print_remark_at_operand %0, "op was replaced" : !transform.any_op
 }
+
+// -----
+
+// expected-note @below{{target payload op}}
+module {
+  func.func @invalid_pattern_application_to_transform_ir() {
+    return
+  }
+
+  module {
+    transform.sequence failures(propagate) {
+    ^bb1(%arg1: !transform.any_op):
+      // expected-error @below {{cannot apply transform to itself (or one of its ancestors)}}
+      transform.apply_patterns to %arg1 {
+        transform.apply_patterns.canonicalization
+      } : !transform.any_op
+    }
+  }
+}

diff  --git a/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir
index d8365f53b43a1..f93ac8a22453f 100644
--- a/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir
@@ -210,7 +210,9 @@ func.func @redpar_vecmattrans2x2(%arg0: memref<vector<2x2xf32>>, %arg1: memref<v
 
 transform.sequence failures(propagate) {
 ^bb1(%module_op: !transform.any_op):
-  transform.apply_patterns to %module_op {
+  %func_op = transform.structured.match ops{["func.func"]} in %module_op
+      : (!transform.any_op) -> !transform.any_op
+  transform.apply_patterns to %func_op {
     transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct"
   } : !transform.any_op
 }

diff  --git a/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir
index 5a67c266ee26b..899070971c255 100644
--- a/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir
@@ -19,8 +19,6 @@ func.func @vector_multi_reduction(%arg0: vector<2x4xf32>, %acc: vector<2xf32>) -
 //       CHECK:       %[[RESULT_VEC:.+]] = vector.insertelement %[[RV1:.+]], %[[RESULT_VEC_1]][%[[C1]] : index] : vector<2xf32>
 //       CHECK:       return %[[RESULT_VEC]]
 
-// -----
-
 func.func @vector_multi_reduction_to_scalar(%arg0: vector<2x4xf32>, %acc: f32) -> f32 {
     %0 = vector.multi_reduction <mul>, %arg0, %acc [0, 1] : vector<2x4xf32> to f32
     return %0 : f32
@@ -33,8 +31,6 @@ func.func @vector_multi_reduction_to_scalar(%arg0: vector<2x4xf32>, %acc: f32) -
 //       CHECK:   %[[RES:.*]] = vector.extract %[[INSERTED]][0] : vector<1xf32>
 //       CHECK:   return %[[RES]]
 
-// -----
-
 func.func @vector_reduction_inner(%arg0: vector<2x3x4x5xi32>, %acc: vector<2x3xi32>) -> vector<2x3xi32> {
     %0 = vector.multi_reduction <add>, %arg0, %acc [2, 3] : vector<2x3x4x5xi32> to vector<2x3xi32>
     return %0 : vector<2x3xi32>
@@ -76,8 +72,6 @@ func.func @vector_reduction_inner(%arg0: vector<2x3x4x5xi32>, %acc: vector<2x3xi
 //       CHECK:       %[[RESULT:.+]] = vector.shape_cast %[[FLAT_RESULT_VEC]] : vector<6xi32> to vector<2x3xi32>
 //       CHECK:       return %[[RESULT]]
 
-// -----
-
 func.func @vector_multi_reduction_transposed(%arg0: vector<2x3x4x5xf32>, %acc: vector<2x5xf32>) -> vector<2x5xf32> {
     %0 = vector.multi_reduction <add>, %arg0, %acc [1, 2] : vector<2x3x4x5xf32> to vector<2x5xf32>
     return %0 : vector<2x5xf32>
@@ -90,8 +84,6 @@ func.func @vector_multi_reduction_transposed(%arg0: vector<2x3x4x5xf32>, %acc: v
 //       CHECK:     %[[RESULT:.+]] = vector.shape_cast %{{.*}} : vector<10xf32> to vector<2x5xf32>
 //       CHECK:       return %[[RESULT]]
 
-// -----
-
 func.func @vector_multi_reduction_ordering(%arg0: vector<3x2x4xf32>, %acc: vector<2x4xf32>) -> vector<2x4xf32> {
     %0 = vector.multi_reduction <mul>, %arg0, %acc [0] : vector<3x2x4xf32> to vector<2x4xf32>
     return %0 : vector<2x4xf32>
@@ -143,8 +135,6 @@ func.func @vector_multi_reduction_ordering(%arg0: vector<3x2x4xf32>, %acc: vecto
 //       CHECK:       %[[RESHAPED_VEC:.+]] = vector.shape_cast %[[RESULT_VEC]] : vector<8xf32> to vector<2x4xf32>
 //       CHECK:       return %[[RESHAPED_VEC]]
 
-// -----
-
 func.func @vectorize_dynamic_reduction(%arg0: tensor<?x?xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
   %c0 = arith.constant 0 : index
   %dim = tensor.dim %arg0, %c0 : tensor<?x?xf32>
@@ -187,8 +177,6 @@ func.func @vectorize_dynamic_reduction(%arg0: tensor<?x?xf32>, %arg1: tensor<?xf
 // CHECK:           %[[VAL_32:.*]] = vector.mask %[[VAL_31]] { vector.reduction <add>, %{{.*}} : vector<8xf32> into f32 } : vector<8xi1> -> f32
 // CHECK:           %[[VAL_33:.*]] = vector.insertelement
 
-// -----
-
 func.func @vectorize_1d_dynamic_reduction(%arg0: tensor<?xf32>) -> f32 {
   %c0 = arith.constant 0 : index
   %dim = tensor.dim %arg0, %c0 : tensor<?xf32>
@@ -207,8 +195,6 @@ func.func @vectorize_1d_dynamic_reduction(%arg0: tensor<?xf32>) -> f32 {
 // CHECK:           %[[VAL_5:.*]] = vector.create_mask {{.*}} : vector<8xi1>
 // CHECK:           %[[VAL_7:.*]] = vector.mask %[[VAL_5]] { vector.reduction <add>, %{{.*}} : vector<8xf32> into f32 } : vector<8xi1> -> f32
 
-// -----
-
 func.func @vectorize_dynamic_transpose_reduction(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
   %c0 = arith.constant 0 : index
   %dim = tensor.dim %arg0, %c0 : tensor<?x?x?xf32>
@@ -254,8 +240,6 @@ func.func @vectorize_dynamic_transpose_reduction(%arg0: tensor<?x?x?xf32>, %arg1
 // CHECK:           %[[VAL_159:.*]] = vector.mask %[[VAL_158]] { vector.reduction <add>
 // CHECK:           %[[VAL_160:.*]] = vector.insertelement %[[VAL_159]]
 
-// -----
-
 func.func @vector_multi_reduction_parallel_middle(%arg0: vector<3x4x5xf32>, %acc: vector<4xf32>) -> vector<4xf32> {
     %0 = vector.multi_reduction <add>, %arg0, %acc [0, 2] : vector<3x4x5xf32> to vector<4xf32>
     return %0 : vector<4xf32>
@@ -267,7 +251,9 @@ func.func @vector_multi_reduction_parallel_middle(%arg0: vector<3x4x5xf32>, %acc
 
 transform.sequence failures(propagate) {
 ^bb1(%module_op: !transform.any_op):
-  transform.apply_patterns to %module_op {
+  %func_op = transform.structured.match ops{["func.func"]} in %module_op
+      : (!transform.any_op) -> !transform.any_op
+  transform.apply_patterns to %func_op {
     transform.apply_patterns.vector.lower_multi_reduction lowering_strategy = "innerreduction"
   } : !transform.any_op
 }

diff  --git a/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir
index 9f0b3a18d67c9..9630ffc0e77d4 100644
--- a/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir
@@ -190,7 +190,9 @@ func.func @vector_multi_reduction_to_scalar(%arg0: vector<2x3xf32>, %acc: f32) -
 
 transform.sequence failures(propagate) {
 ^bb1(%module_op: !transform.any_op):
-  transform.apply_patterns to %module_op {
+  %func_op = transform.structured.match ops{["func.func"]} in %module_op
+      : (!transform.any_op) -> !transform.any_op
+  transform.apply_patterns to %func_op {
     transform.apply_patterns.vector.lower_multi_reduction lowering_strategy = "innerparallel"
   } : !transform.any_op
 }

diff  --git a/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir b/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir
index 77bbcb9bbbb84..695400de84890 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --test-transform-dialect-interpreter --split-input-file | FileCheck %s
+// RUN: mlir-opt %s --test-transform-dialect-interpreter | FileCheck %s
 
 func.func @transfer_read_rank_reducing(
       %arg : memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>) -> vector<3x2xi8> {
@@ -8,44 +8,24 @@ func.func @transfer_read_rank_reducing(
       memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>, vector<3x2xi8>
     return %v : vector<3x2xi8>
 }
-
 // CHECK-LABEL: func @transfer_read_rank_reducing
 //  CHECK-SAME:     %[[ARG:.+]]: memref<1x1x3x2xi8
 //       CHECK:   %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0, 0, 0] [1, 1, 3, 2] [1, 1, 1, 1]
 //  CHECK-SAME:     memref<1x1x3x2xi8, {{.*}}> to memref<3x2xi8, {{.*}}>
 //       CHECK:   vector.transfer_read %[[SUBVIEW]]
 
-transform.sequence failures(propagate) {
-^bb1(%module_op: !transform.any_op):
-  transform.apply_patterns to %module_op {
-    transform.apply_patterns.vector.rank_reducing_subview_patterns
-  } : !transform.any_op
-}
-
-// -----
-
 func.func @transfer_write_rank_reducing(%arg : memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>, %vec : vector<3x2xi8>) {
     %c0 = arith.constant 0 : index
     vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] :
       vector<3x2xi8>, memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>
     return
 }
-
 // CHECK-LABEL: func @transfer_write_rank_reducing
 //  CHECK-SAME:     %[[ARG:.+]]: memref<1x1x3x2xi8
 //       CHECK:   %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0, 0, 0] [1, 1, 3, 2] [1, 1, 1, 1]
 //  CHECK-SAME:     memref<1x1x3x2xi8, {{.*}}> to memref<3x2xi8, {{.*}}>
 //       CHECK:   vector.transfer_write %{{.*}}, %[[SUBVIEW]]
 
-transform.sequence failures(propagate) {
-^bb1(%module_op: !transform.any_op):
-  transform.apply_patterns to %module_op {
-    transform.apply_patterns.vector.rank_reducing_subview_patterns
-  } : !transform.any_op
-}
-
-// -----
-
 func.func @transfer_read_and_vector_rank_reducing(
       %arg : memref<1x1x3x2x1xf32>) -> vector<3x2x1xf32> {
     %c0 = arith.constant 0 : index
@@ -54,22 +34,12 @@ func.func @transfer_read_and_vector_rank_reducing(
       memref<1x1x3x2x1xf32>, vector<3x2x1xf32>
     return %v : vector<3x2x1xf32>
 }
-
 // CHECK-LABEL: func @transfer_read_and_vector_rank_reducing
 //  CHECK-SAME:     %[[ARG:.+]]: memref<1x1x3x2x1xf32>
 //       CHECK:   %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0, 0, 0, 0] [1, 1, 3, 2, 1] [1, 1, 1, 1, 1]
 //  CHECK-SAME:     memref<1x1x3x2x1xf32> to memref<3x2xf32>
 //       CHECK:   vector.transfer_read %[[SUBVIEW]]{{.*}} {in_bounds = [true, true]} : memref<3x2xf32>, vector<3x2xf32>
 
-transform.sequence failures(propagate) {
-^bb1(%module_op: !transform.any_op):
-  transform.apply_patterns to %module_op {
-    transform.apply_patterns.vector.rank_reducing_subview_patterns
-  } : !transform.any_op
-}
-
-// -----
-
 func.func @transfer_write_and_vector_rank_reducing(
       %arg : memref<1x1x3x2x1xf32>,
       %vec : vector<3x2x1xf32>) {
@@ -78,22 +48,12 @@ func.func @transfer_write_and_vector_rank_reducing(
       vector<3x2x1xf32>, memref<1x1x3x2x1xf32>
     return
 }
-
 // CHECK-LABEL: func @transfer_write_and_vector_rank_reducing
 //  CHECK-SAME:     %[[ARG:.+]]: memref<1x1x3x2x1xf32>
 //       CHECK:   %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0, 0, 0, 0] [1, 1, 3, 2, 1] [1, 1, 1, 1, 1]
 //  CHECK-SAME:     memref<1x1x3x2x1xf32> to memref<3x2xf32>
 //       CHECK:   vector.transfer_write %{{.*}}, %[[SUBVIEW]]{{.*}} {in_bounds = [true, true]} : vector<3x2xf32>, memref<3x2xf32>
 
-transform.sequence failures(propagate) {
-^bb1(%module_op: !transform.any_op):
-  transform.apply_patterns to %module_op {
-    transform.apply_patterns.vector.rank_reducing_subview_patterns
-  } : !transform.any_op
-}
-
-// -----
-
 func.func @transfer_read_and_vector_rank_reducing_to_0d(
       %arg : memref<1x1x1x1x1xf32>) -> vector<1x1x1xf32> {
     %c0 = arith.constant 0 : index
@@ -102,22 +62,12 @@ func.func @transfer_read_and_vector_rank_reducing_to_0d(
       memref<1x1x1x1x1xf32>, vector<1x1x1xf32>
     return %v : vector<1x1x1xf32>
 }
-
 // CHECK-LABEL: func @transfer_read_and_vector_rank_reducing_to_0d
 //  CHECK-SAME:     %[[MEMREF:.+]]: memref<1x1x1x1x1xf32>
 //       CHECK:   %[[SUBVIEW:.+]] = memref.subview %[[MEMREF]][0, 0, 0, 0, 0] [1, 1, 1, 1, 1] [1, 1, 1, 1, 1] : memref<1x1x1x1x1xf32> to memref<f32>
 //       CHECK:   %[[READ:.+]] = vector.transfer_read %[[SUBVIEW]]{{.*}} : memref<f32>, vector<f32>
 //       CHECK:   vector.shape_cast %[[READ]] : vector<f32> to vector<1x1x1xf32>
 
-transform.sequence failures(propagate) {
-^bb1(%module_op: !transform.any_op):
-  transform.apply_patterns to %module_op {
-    transform.apply_patterns.vector.rank_reducing_subview_patterns
-  } : !transform.any_op
-}
-
-// -----
-
 func.func @transfer_write_and_vector_rank_reducing_to_0d(
       %arg : memref<1x1x1x1x1xf32>,
       %vec : vector<1x1x1xf32>) {
@@ -126,7 +76,6 @@ func.func @transfer_write_and_vector_rank_reducing_to_0d(
       vector<1x1x1xf32>, memref<1x1x1x1x1xf32>
     return
 }
-
 // CHECK-LABEL: func @transfer_write_and_vector_rank_reducing_to_0d
 //  CHECK-SAME:     %[[MEMREF:.+]]: memref<1x1x1x1x1xf32>, %[[VECTOR:.+]]: vector<1x1x1xf32>
 //       CHECK:   %[[SUBVIEW:.+]] = memref.subview %[[MEMREF]][0, 0, 0, 0, 0] [1, 1, 1, 1, 1] [1, 1, 1, 1, 1] : memref<1x1x1x1x1xf32> to memref<f32>
@@ -135,7 +84,9 @@ func.func @transfer_write_and_vector_rank_reducing_to_0d(
 
 transform.sequence failures(propagate) {
 ^bb1(%module_op: !transform.any_op):
-  transform.apply_patterns to %module_op {
+  %func_op = transform.structured.match ops{["func.func"]} in %module_op
+      : (!transform.any_op) -> !transform.any_op
+  transform.apply_patterns to %func_op {
     transform.apply_patterns.vector.rank_reducing_subview_patterns
   } : !transform.any_op
 }

diff  --git a/mlir/test/Dialect/Vector/vector-transfer-full-partial-split-copy-transform.mlir b/mlir/test/Dialect/Vector/vector-transfer-full-partial-split-copy-transform.mlir
index ecb29266246ed..11389bbe26e26 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-full-partial-split-copy-transform.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-full-partial-split-copy-transform.mlir
@@ -108,7 +108,9 @@ func.func @split_vector_transfer_read_strided_2d(
 
 transform.sequence failures(propagate) {
 ^bb1(%module_op: !transform.any_op):
-  transform.apply_patterns to %module_op {
+  %func_op = transform.structured.match ops{["func.func"]} in %module_op
+      : (!transform.any_op) -> !transform.any_op
+  transform.apply_patterns to %func_op {
     transform.apply_patterns.vector.split_transfer_full_partial split_transfer_strategy = "linalg-copy"
   } : !transform.any_op
 }
@@ -169,7 +171,9 @@ func.func @split_vector_transfer_write_2d(%V: vector<4x8xf32>, %A: memref<?x8xf3
 
 transform.sequence failures(propagate) {
 ^bb1(%module_op: !transform.any_op):
-  transform.apply_patterns to %module_op {
+  %func_op = transform.structured.match ops{["func.func"]} in %module_op
+      : (!transform.any_op) -> !transform.any_op
+  transform.apply_patterns to %func_op {
     transform.apply_patterns.vector.split_transfer_full_partial split_transfer_strategy = "linalg-copy"
   } : !transform.any_op
 }
@@ -237,7 +241,9 @@ func.func @split_vector_transfer_write_strided_2d(
 
 transform.sequence failures(propagate) {
 ^bb1(%module_op: !transform.any_op):
-  transform.apply_patterns to %module_op {
+  %func_op = transform.structured.match ops{["func.func"]} in %module_op
+      : (!transform.any_op) -> !transform.any_op
+  transform.apply_patterns to %func_op {
     transform.apply_patterns.vector.split_transfer_full_partial split_transfer_strategy = "linalg-copy"
   } : !transform.any_op
 }

diff  --git a/mlir/test/Dialect/Vector/vector-transfer-full-partial-split.mlir b/mlir/test/Dialect/Vector/vector-transfer-full-partial-split.mlir
index a035c22d75eef..20fee63799de0 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-full-partial-split.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-full-partial-split.mlir
@@ -103,7 +103,9 @@ func.func @split_vector_transfer_read_strided_2d(
 
 transform.sequence failures(propagate) {
 ^bb1(%module_op: !transform.any_op):
-  transform.apply_patterns to %module_op {
+  %func_op = transform.structured.match ops{["func.func"]} in %module_op
+      : (!transform.any_op) -> !transform.any_op
+  transform.apply_patterns to %func_op {
     transform.apply_patterns.vector.split_transfer_full_partial split_transfer_strategy = "vector-transfer"
   } : !transform.any_op
 }
@@ -161,7 +163,9 @@ func.func @split_vector_transfer_write_2d(%V: vector<4x8xf32>, %A: memref<?x8xf3
 
 transform.sequence failures(propagate) {
 ^bb1(%module_op: !transform.any_op):
-  transform.apply_patterns to %module_op {
+  %func_op = transform.structured.match ops{["func.func"]} in %module_op
+      : (!transform.any_op) -> !transform.any_op
+  transform.apply_patterns to %func_op {
     transform.apply_patterns.vector.split_transfer_full_partial split_transfer_strategy = "vector-transfer"
   } : !transform.any_op
 }
@@ -223,7 +227,9 @@ func.func @split_vector_transfer_write_strided_2d(
 
 transform.sequence failures(propagate) {
 ^bb1(%module_op: !transform.any_op):
-  transform.apply_patterns to %module_op {
+  %func_op = transform.structured.match ops{["func.func"]} in %module_op
+      : (!transform.any_op) -> !transform.any_op
+  transform.apply_patterns to %func_op {
     transform.apply_patterns.vector.split_transfer_full_partial split_transfer_strategy = "vector-transfer"
   } : !transform.any_op
 }
@@ -265,7 +271,9 @@ func.func @transfer_read_within_scf_for(%A : memref<?x?xf32>, %lb : index, %ub :
 
 transform.sequence failures(propagate) {
 ^bb1(%module_op: !transform.any_op):
-  transform.apply_patterns to %module_op {
+  %func_op = transform.structured.match ops{["func.func"]} in %module_op
+      : (!transform.any_op) -> !transform.any_op
+  transform.apply_patterns to %func_op {
     transform.apply_patterns.vector.split_transfer_full_partial split_transfer_strategy = "vector-transfer"
   } : !transform.any_op
 }

diff  --git a/mlir/test/Dialect/Vector/vector-transfer-tensor-slice-transforms.mlir b/mlir/test/Dialect/Vector/vector-transfer-tensor-slice-transforms.mlir
index acd401704eb5b..5e36223183117 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-tensor-slice-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-tensor-slice-transforms.mlir
@@ -2,7 +2,9 @@
 
 transform.sequence failures(propagate) {
 ^bb1(%module_op: !transform.any_op):
-  transform.apply_patterns to %module_op {
+  %func_op = transform.structured.match ops{["func.func"]} in %module_op
+      : (!transform.any_op) -> !transform.any_op
+  transform.apply_patterns to %func_op {
     transform.apply_patterns.vector.fold_tensor_slice_into_transfer
   } : !transform.any_op
 }

diff  --git a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
index 992f5127c8d69..5011f22cc80bd 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
@@ -240,7 +240,9 @@ func.func @transfer_broadcasting_complex(%mem : memref<10x20x30x8x8xf32>, %i : i
 
 transform.sequence failures(propagate) {
 ^bb1(%module_op: !transform.any_op):
-  transform.apply_patterns to %module_op {
+  %func_op = transform.structured.match ops{["func.func"]} in %module_op
+      : (!transform.any_op) -> !transform.any_op
+  transform.apply_patterns to %func_op {
     transform.apply_patterns.vector.lower_transfer max_transfer_rank = 99
     transform.apply_patterns.vector.transfer_permutation_patterns
   } : !transform.any_op
@@ -362,7 +364,9 @@ func.func @transfer_write_broadcast_unit_dim(
 
 transform.sequence failures(propagate) {
 ^bb1(%module_op: !transform.any_op):
-  transform.apply_patterns to %module_op {
+  %func_op = transform.structured.match ops{["func.func"]} in %module_op
+      : (!transform.any_op) -> !transform.any_op
+  transform.apply_patterns to %func_op {
     transform.apply_patterns.vector.lower_transfer max_transfer_rank = 99
     transform.apply_patterns.vector.transfer_permutation_patterns
   } : !transform.any_op

diff  --git a/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir b/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir
index f8b36055f1002..fdf2dfd2939bf 100644
--- a/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir
@@ -76,7 +76,9 @@ func.func @transpose1023_1x1x8x8xf32(%arg0: vector<1x1x8x8xf32>) -> vector<1x1x8
 
 transform.sequence failures(propagate) {
 ^bb1(%module_op: !transform.any_op):
-  transform.apply_patterns to %module_op {
+  %func_op = transform.structured.match ops{["func.func"]} in %module_op
+      : (!transform.any_op) -> !transform.any_op
+  transform.apply_patterns to %func_op {
     transform.apply_patterns.vector.lower_transpose lowering_strategy = "eltwise"
   } : !transform.any_op
 }
@@ -99,7 +101,9 @@ func.func @transpose(%arg0: vector<2x4xf32>) -> vector<4x2xf32> {
 
 transform.sequence failures(propagate) {
 ^bb1(%module_op: !transform.any_op):
-  transform.apply_patterns to %module_op {
+  %func_op = transform.structured.match ops{["func.func"]} in %module_op
+      : (!transform.any_op) -> !transform.any_op
+  transform.apply_patterns to %func_op {
     transform.apply_patterns.vector.lower_transpose lowering_strategy = "shuffle_1d"
   } : !transform.any_op
 }
@@ -118,7 +122,9 @@ func.func @transpose(%arg0: vector<2x4xf32>) -> vector<4x2xf32> {
 
 transform.sequence failures(propagate) {
 ^bb1(%module_op: !transform.any_op):
-  transform.apply_patterns to %module_op {
+  %func_op = transform.structured.match ops{["func.func"]} in %module_op
+      : (!transform.any_op) -> !transform.any_op
+  transform.apply_patterns to %func_op {
     transform.apply_patterns.vector.lower_transpose lowering_strategy = "flat_transpose"
   } : !transform.any_op
 }
@@ -605,7 +611,9 @@ func.func @transpose210_1x8x8xf32(%arg0: vector<1x8x8xf32>) -> vector<8x8x1xf32>
 
 transform.sequence failures(propagate) {
 ^bb1(%module_op: !transform.any_op):
-  transform.apply_patterns to %module_op {
+  %func_op = transform.structured.match ops{["func.func"]} in %module_op
+      : (!transform.any_op) -> !transform.any_op
+  transform.apply_patterns to %func_op {
     transform.apply_patterns.vector.lower_transpose avx2_lowering_strategy = true
   } : !transform.any_op
 }
@@ -683,7 +691,9 @@ func.func @transpose_shuffle16x16xf32(%arg0: vector<16x16xf32>) -> vector<16x16x
 
 transform.sequence failures(propagate) {
 ^bb1(%module_op: !transform.any_op):
-  transform.apply_patterns to %module_op {
+  %func_op = transform.structured.match ops{["func.func"]} in %module_op
+      : (!transform.any_op) -> !transform.any_op
+  transform.apply_patterns to %func_op {
     transform.apply_patterns.vector.lower_transpose lowering_strategy = "shuffle_16x16"
   } : !transform.any_op
 }
@@ -762,7 +772,9 @@ func.func @transpose021_shuffle16x16xf32(%arg0: vector<1x16x16xf32>) -> vector<1
 
 transform.sequence failures(propagate) {
 ^bb1(%module_op: !transform.any_op):
-  transform.apply_patterns to %module_op {
+  %func_op = transform.structured.match ops{["func.func"]} in %module_op
+      : (!transform.any_op) -> !transform.any_op
+  transform.apply_patterns to %func_op {
     transform.apply_patterns.vector.lower_transpose lowering_strategy = "shuffle_16x16"
   } : !transform.any_op
 }

diff  --git a/mlir/test/Integration/Dialect/Vector/CPU/test-shuffle16x16.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-shuffle16x16.mlir
index 74d47de05b706..f9b584920d0d2 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/test-shuffle16x16.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/test-shuffle16x16.mlir
@@ -31,7 +31,9 @@ func.func @entry() {
 
 transform.sequence failures(propagate) {
 ^bb1(%module_op: !transform.any_op):
-  transform.apply_patterns to %module_op {
+  %func_op = transform.structured.match ops{["func.func"]} in %module_op
+      : (!transform.any_op) -> !transform.any_op
+  transform.apply_patterns to %func_op {
     transform.apply_patterns.vector.lower_transpose lowering_strategy = "shuffle_16x16"
   } : !transform.any_op
 }


        


More information about the Mlir-commits mailing list