[Mlir-commits] [mlir] [mlir][vector] Group tests for re-order patterns (PR #102856)

Andrzej WarzyƄski llvmlistbot at llvm.org
Tue Aug 13 09:45:58 PDT 2024


https://github.com/banach-space updated https://github.com/llvm/llvm-project/pull/102856

>From 9db316c581057b6fc4e0543d284a1bd01907effa Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Wed, 7 Aug 2024 09:40:02 +0100
Subject: [PATCH 1/4] [mlir][vector] Add tests for
 `populateSinkVectorBroadcastPatterns` (1/n)

Adds tests for scalable vectors in:
  * sink-vector-broadcast.mlir

This test file excercises patterns grouped under
`populateSinkVectorBroadcastPatterns`, which includes:
  * `ReorderElementwiseOpsOnBroadcast`,
  * `ReorderCastOpsOnBroadcast`.

Right now there are only tests for the former. However, I've noticed
that "vector-reduce-to-contract.mlir" contains tests for the latter and
I've left a few TODOs to group these tests back together in one file.

Additionally, added some helpful `notifyMatchFailure` messages in
`ReorderElementwiseOpsOnBroadcast`.
---
 .../Vector/Transforms/VectorTransforms.cpp    |  17 ++-
 .../Dialect/Vector/sink-vector-broadcast.mlir | 120 +++++++++++++++---
 .../Vector/vector-reduce-to-contract.mlir     |  10 ++
 3 files changed, 121 insertions(+), 26 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 7f59a378e03512..ccbaa3e9759975 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -979,15 +979,18 @@ struct ReorderElementwiseOpsOnBroadcast final
     if (!llvm::isa<ShapedType>(op->getResults()[0].getType()))
       return failure();
     if (!OpTrait::hasElementwiseMappableTraits(op))
+      return rewriter.notifyMatchFailure(
+          op, "Op doesn't have ElementwiseMappableTraits");
+    if (op->getNumOperands() == 0)
       return failure();
-    if (op->getNumOperands() == 0 ||
-        op->getResults()[0].getType() != op->getOperand(0).getType()) {
-      return failure();
-    }
-    // Avoid operations that only accept vector types, since broadcast
-    // source might be scalar types.
+    if (op->getResults()[0].getType() != op->getOperand(0).getType())
+      return rewriter.notifyMatchFailure(op,
+                                         "result and operand type mismatch");
     if (isa<vector::FMAOp>(op)) {
-      return failure();
+      return rewriter.notifyMatchFailure(
+          op,
+          "Op only accepts vector types - not supported as broadcast source "
+          "might be a scalar");
     }
 
     // Get the type of the lhs operand
diff --git a/mlir/test/Dialect/Vector/sink-vector-broadcast.mlir b/mlir/test/Dialect/Vector/sink-vector-broadcast.mlir
index e7863a9e8b7b78..ae55b6696c676c 100644
--- a/mlir/test/Dialect/Vector/sink-vector-broadcast.mlir
+++ b/mlir/test/Dialect/Vector/sink-vector-broadcast.mlir
@@ -1,16 +1,32 @@
 // RUN: mlir-opt %s -test-sink-vector-broadcast -split-input-file | FileCheck %s
 
+//-----------------------------------------------------------------------------
+// [Pattern: ReorderElementwiseOpsOnBroadcast]
+//-----------------------------------------------------------------------------
+
 // CHECK-LABEL:   func.func @broadcast_scalar_with_bcast(
 // CHECK-SAME:     %[[ARG_0:.*]]: index, %[[ARG_1:.*]]: index) -> vector<1x4xindex> {
 // CHECK:           %[[ADD:.*]] = arith.addi %[[ARG_0]], %[[ARG_1]] : index
 // CHECK:           %[[BCAST:.*]] = vector.broadcast %[[ADD]] : index to vector<1x4xindex>
 // CHECK:           return %[[BCAST]] : vector<1x4xindex>
 
-func.func @broadcast_scalar_with_bcast( %arg1: index, %arg2: index) -> vector<1x4xindex> {
+func.func @broadcast_scalar_with_bcast(%arg1: index, %arg2: index) -> vector<1x4xindex> {
   %0 = vector.broadcast %arg1 : index to vector<1x4xindex>
   %1 = vector.broadcast %arg2 : index to vector<1x4xindex>
-  %2 = arith.addi %0, %1 : vector<1x4xindex>
-  return %2 : vector<1x4xindex>
+  %2 = arith.addi %0, %1 : vector<1x4xindex> return %2 : vector<1x4xindex>
+}
+
+// CHECK-LABEL:   func.func @broadcast_scalar_with_bcast_scalable(
+// CHECK-SAME:     %[[ARG_0:.*]]: index, %[[ARG_1:.*]]: index) -> vector<1x[4]xindex> {
+// CHECK:           %[[ADD:.*]] = arith.addi %[[ARG_0]], %[[ARG_1]] : index
+// CHECK:           %[[BCAST:.*]] = vector.broadcast %[[ADD]] : index to vector<1x[4]xindex>
+// CHECK:           return %[[BCAST]] : vector<1x[4]xindex>
+
+func.func @broadcast_scalar_with_bcast_scalable(%arg1: index, %arg2: index) -> vector<1x[4]xindex> {
+  %0 = vector.broadcast %arg1 : index to vector<1x[4]xindex>
+  %1 = vector.broadcast %arg2 : index to vector<1x[4]xindex>
+  %2 = arith.addi %0, %1 : vector<1x[4]xindex>
+  return %2 : vector<1x[4]xindex>
 }
 
 // -----
@@ -21,13 +37,26 @@ func.func @broadcast_scalar_with_bcast( %arg1: index, %arg2: index) -> vector<1x
 // CHECK:           %[[ADD:.*]] = arith.addi %[[ARG1]], %[[ARG2]] : index
 // CHECK:           %[[BCAST:.*]] = vector.broadcast %[[ADD]] : index to vector<1x4xindex>
 // CHECK:           return %[[BCAST]] : vector<1x4xindex>
-func.func @broadcast_scalar_with_bcast_and_splat( %arg1: index, %arg2: index) -> vector<1x4xindex> {
+func.func @broadcast_scalar_with_bcast_and_splat(%arg1: index, %arg2: index) -> vector<1x4xindex> {
   %0 = vector.splat %arg1 : vector<1x4xindex>
   %1 = vector.broadcast %arg2 : index to vector<1x4xindex>
   %2 = arith.addi %0, %1 : vector<1x4xindex>
   return %2 : vector<1x4xindex>
 }
 
+// CHECK-LABEL:   func.func @broadcast_scalar_with_bcast_and_splat_scalable(
+// CHECK-SAME:      %[[ARG1:.*]]: index,
+// CHECK-SAME:      %[[ARG2:.*]]: index) -> vector<1x[4]xindex> {
+// CHECK:           %[[ADD:.*]] = arith.addi %[[ARG1]], %[[ARG2]] : index
+// CHECK:           %[[BCAST:.*]] = vector.broadcast %[[ADD]] : index to vector<1x[4]xindex>
+// CHECK:           return %[[BCAST]] : vector<1x[4]xindex>
+func.func @broadcast_scalar_with_bcast_and_splat_scalable(%arg1: index, %arg2: index) -> vector<1x[4]xindex> {
+  %0 = vector.splat %arg1 : vector<1x[4]xindex>
+  %1 = vector.broadcast %arg2 : index to vector<1x[4]xindex>
+  %2 = arith.addi %0, %1 : vector<1x[4]xindex>
+  return %2 : vector<1x[4]xindex>
+}
+
 // -----
 
 // CHECK-LABEL:   func.func @broadcast_vector(
@@ -37,13 +66,27 @@ func.func @broadcast_scalar_with_bcast_and_splat( %arg1: index, %arg2: index) ->
 // CHECK:           %[[BCAST:.*]] = vector.broadcast %[[ADDF]] : vector<4xf32> to vector<3x4xf32>
 // CHECK:           return %[[BCAST]] : vector<3x4xf32>
 
-func.func @broadcast_vector( %arg1: vector<4xf32>, %arg2: vector<4xf32>) -> vector<3x4xf32> {
+func.func @broadcast_vector(%arg1: vector<4xf32>, %arg2: vector<4xf32>) -> vector<3x4xf32> {
   %arg1_bcast = vector.broadcast %arg1 : vector<4xf32> to vector<3x4xf32>
   %arg2_bcast = vector.broadcast %arg2 : vector<4xf32> to vector<3x4xf32>
   %2 = arith.addf %arg1_bcast, %arg2_bcast : vector<3x4xf32>
   return %2 : vector<3x4xf32>
 }
 
+// CHECK-LABEL:   func.func @broadcast_vector_scalable(
+// CHECK-SAME:      %[[ARG_0:.*]]: vector<[4]xf32>,
+// CHECK-SAME:      %[[ARG_1:.*]]: vector<[4]xf32>) -> vector<3x[4]xf32> {
+// CHECK:           %[[ADDF:.*]] = arith.addf %[[ARG_0]], %[[ARG_1]] : vector<[4]xf32>
+// CHECK:           %[[BCAST:.*]] = vector.broadcast %[[ADDF]] : vector<[4]xf32> to vector<3x[4]xf32>
+// CHECK:           return %[[BCAST]] : vector<3x[4]xf32>
+
+func.func @broadcast_vector_scalable(%arg1: vector<[4]xf32>, %arg2: vector<[4]xf32>) -> vector<3x[4]xf32> {
+  %arg1_bcast = vector.broadcast %arg1 : vector<[4]xf32> to vector<3x[4]xf32>
+  %arg2_bcast = vector.broadcast %arg2 : vector<[4]xf32> to vector<3x[4]xf32>
+  %2 = arith.addf %arg1_bcast, %arg2_bcast : vector<3x[4]xf32>
+  return %2 : vector<3x[4]xf32>
+}
+
 // -----
 
 // CHECK-LABEL:   func.func @broadcast_scalar_and_vec(
@@ -53,13 +96,27 @@ func.func @broadcast_vector( %arg1: vector<4xf32>, %arg2: vector<4xf32>) -> vect
 // CHECK:            %[[BCAST:.*]] = vector.broadcast %[[ARG2]] : vector<4xindex> to vector<1x4xindex>
 // CHECK:            %[[ADD:.*]] = arith.addi %[[SPLAT]], %[[BCAST]] : vector<1x4xindex>
 // CHECK:            return %[[ADD]] : vector<1x4xindex>
-func.func @broadcast_scalar_and_vec( %arg1: index, %arg2: vector<4xindex>) -> vector<1x4xindex> {
+func.func @broadcast_scalar_and_vec(%arg1: index, %arg2: vector<4xindex>) -> vector<1x4xindex> {
   %0 = vector.splat %arg1 : vector<1x4xindex>
   %1 = vector.broadcast %arg2 : vector<4xindex> to vector<1x4xindex>
   %2 = arith.addi %0, %1 : vector<1x4xindex>
   return %2 : vector<1x4xindex>
 }
 
+// CHECK-LABEL:   func.func @broadcast_scalar_and_vec_scalable(
+// CHECK-SAME:       %[[ARG1:.*]]: index,
+// CHECK-SAME:       %[[ARG2:.*]]: vector<[4]xindex>) -> vector<1x[4]xindex> {
+// CHECK:            %[[SPLAT:.*]] = vector.splat %[[ARG1]] : vector<1x[4]xindex>
+// CHECK:            %[[BCAST:.*]] = vector.broadcast %[[ARG2]] : vector<[4]xindex> to vector<1x[4]xindex>
+// CHECK:            %[[ADD:.*]] = arith.addi %[[SPLAT]], %[[BCAST]] : vector<1x[4]xindex>
+// CHECK:            return %[[ADD]] : vector<1x[4]xindex>
+func.func @broadcast_scalar_and_vec_scalable(%arg1: index, %arg2: vector<[4]xindex>) -> vector<1x[4]xindex> {
+  %0 = vector.splat %arg1 : vector<1x[4]xindex>
+  %1 = vector.broadcast %arg2 : vector<[4]xindex> to vector<1x[4]xindex>
+  %2 = arith.addi %0, %1 : vector<1x[4]xindex>
+  return %2 : vector<1x[4]xindex>
+}
+
 // -----
 
 // CHECK-LABEL:   func.func @broadcast_vector_and_scalar(
@@ -69,12 +126,25 @@ func.func @broadcast_scalar_and_vec( %arg1: index, %arg2: vector<4xindex>) -> ve
 // CHECK:           %[[ADD:.*]] = arith.addi %[[BCAST]], %[[ARG_1]] : vector<4xi32>
 // CHECK:           return %[[ADD]] : vector<4xi32>
 
-func.func @broadcast_vector_and_scalar( %arg1: i32, %arg2: vector<4xi32>) -> vector<4xi32> {
+func.func @broadcast_vector_and_scalar(%arg1: i32, %arg2: vector<4xi32>) -> vector<4xi32> {
   %arg1_bcast = vector.broadcast %arg1 : i32 to vector<4xi32>
   %2 = arith.addi %arg1_bcast, %arg2 : vector<4xi32>
   return %2 : vector<4xi32>
 }
 
+// CHECK-LABEL:   func.func @broadcast_vector_and_scalar_scalable(
+// CHECK-SAME:      %[[ARG_0:.*]]: i32,
+// CHECK-SAME:      %[[ARG_1:.*]]: vector<[4]xi32>) -> vector<[4]xi32> {
+// CHECK:           %[[BCAST:.*]] = vector.broadcast %[[ARG_0]] : i32 to vector<[4]xi32>
+// CHECK:           %[[ADD:.*]] = arith.addi %[[BCAST]], %[[ARG_1]] : vector<[4]xi32>
+// CHECK:           return %[[ADD]] : vector<[4]xi32>
+
+func.func @broadcast_vector_and_scalar_scalable(%arg1: i32, %arg2: vector<[4]xi32>) -> vector<[4]xi32> {
+  %arg1_bcast = vector.broadcast %arg1 : i32 to vector<[4]xi32>
+  %2 = arith.addi %arg1_bcast, %arg2 : vector<[4]xi32>
+  return %2 : vector<[4]xi32>
+}
+
 // -----
 
 #matmat_accesses = [
@@ -87,12 +157,12 @@ func.func @broadcast_vector_and_scalar( %arg1: i32, %arg2: vector<4xi32>) -> vec
   iterator_types = ["parallel", "parallel", "reduction"]
 }
 
-// CHECK-LABEL:   func.func @broadcast_not_elementwise() -> vector<2x2xf32> {
-// CHECK-DAG:       %[[VAL_0:.*]] = arith.constant dense<1.000000e+00> : vector<2x2xf32>
-// CHECK-DAG:       %[[VAL_1:.*]] = arith.constant dense<2.000000e+00> : vector<2x2xf32>
-// CHECK-DAG:       %[[VAL_2:.*]] = arith.constant dense<3.000000e+00> : vector<2x2xf32>
-// CHECK:           %[[VAL_3:.*]] = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[VAL_0]], %[[VAL_1]], %[[VAL_2]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
-func.func @broadcast_not_elementwise() -> vector<2x2xf32> {
+// CHECK-LABEL:   func.func @negative_not_elementwise
+// CHECK-DAG:       %[[F1:.*]] = arith.constant dense<1.000000e+00> : vector<2x2xf32>
+// CHECK-DAG:       %[[F2:.*]] = arith.constant dense<2.000000e+00> : vector<2x2xf32>
+// CHECK-DAG:       %[[F3:.*]] = arith.constant dense<3.000000e+00> : vector<2x2xf32>
+// CHECK:           %[[RES:.*]] = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[F1]], %[[F2]], %[[F3]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
+func.func @negative_not_elementwise() -> vector<2x2xf32> {
   %f1 = arith.constant 1.0: f32
   %f2 = arith.constant 2.0: f32
   %f3 = arith.constant 3.0: f32
@@ -100,27 +170,39 @@ func.func @broadcast_not_elementwise() -> vector<2x2xf32> {
   %A = vector.broadcast %f1 : f32 to vector<2x2xf32>
   %B = vector.broadcast %f2 : f32 to vector<2x2xf32>
   %C = vector.broadcast %f3 : f32 to vector<2x2xf32>
-  %mm1 = vector.contract #matmat_trait %A, %B, %C
+  %res = vector.contract #matmat_trait %A, %B, %C
     : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
 
-  return %mm1 : vector<2x2xf32>
+  return %res : vector<2x2xf32>
 }
 
-// CHECK-LABEL: func.func @dont_sink_cmp(
+// -----
+
+// The source and the result for arith.cmp have different types - not supported
+
+// CHECK-LABEL: func.func @negative_source_and_result_mismatch
 //       CHECK:   %[[BROADCAST:.+]] = vector.broadcast
 //       CHECK:   %[[RETURN:.+]] = arith.cmpf uno, %[[BROADCAST]], %[[BROADCAST]]
 //       CHECK:   return %[[RETURN]]
-func.func @dont_sink_cmp(%arg0 : f32, %arg1 : vector<1xf32>) -> vector<1xi1> {
+func.func @negative_source_and_result_mismatch(%arg0 : f32, %arg1 : vector<1xf32>) -> vector<1xi1> {
   %0 = vector.broadcast %arg0 : f32 to vector<1xf32>
   %1 = arith.cmpf uno, %0, %0 : vector<1xf32>
   return %1 : vector<1xi1>
 }
 
-// CHECK-LABEL: func.func @dont_sink_fma(
+// -----
+
+// vector.fma only supports vectors - currently it's not possible to replace this with e.g.:
+//    %scalar_res = vector.fma %scalar_1, %scalar2
+//    %vec_res = vector.broadcast %scalar_res
+//
+// TODO: It should be possible to support this case
+
+// CHECK-LABEL: func.func @negative_op_only_supports_vectors
   //     CHECK:   %[[BROADCAST:.+]] = vector.broadcast
   //     CHECK:   %[[RESULT:.+]] = vector.fma %[[BROADCAST]]
   //     CHECK:   return %[[RESULT]]
-func.func @dont_sink_fma(%arg0 : f32) -> vector<1xf32> {
+func.func @negative_op_only_supports_vectors(%arg0 : f32) -> vector<1xf32> {
   %0 = vector.broadcast %arg0 : f32 to vector<1xf32>
   %1 = vector.fma %0, %0, %0 : vector<1xf32>
   return %1 : vector<1xf32>
diff --git a/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir b/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir
index 23a44b7c03f8f4..c0dbea81df892a 100644
--- a/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir
+++ b/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir
@@ -246,8 +246,12 @@ func.func @contract_broadcast_would_have_no_reduction_dim_pair(%arg0 : vector<1x
 
 
 //===----------------------------------------------------------------------===//
+// [Pattern: ReorderCastOpsOnBroadcast]
+//
 // Reorder casting ops and vector ops. The casting ops have almost identical
 // pattern, so only arith.extsi op is tested.
+//
+// TODO: Potential duplication with sink-vector-broadcast.mlir
 //===----------------------------------------------------------------------===//
 
 // -----
@@ -272,6 +276,11 @@ func.func @broadcast_scalar_extsi(%a : i8) -> vector<2x4xi32> {
 
 // -----
 
+//===----------------------------------------------------------------------===//
+// [Pattern: ReorderElementwiseOpsOnTranspose]
+//
+// TODO: Potential duplication with sink-vector-broadcast.mlir
+//===----------------------------------------------------------------------===//
 func.func @transpose_extsi(%a : vector<4x2xi8>) -> vector<2x4xi32> {
   // CHECK: %[[EXT:.+]] = arith.extsi %{{.+}} : vector<4x2xi8> to vector<4x2xi32>
   // CHECK: vector.transpose %[[EXT]], [1, 0] : vector<4x2xi32> to vector<2x4xi32>
@@ -282,6 +291,7 @@ func.func @transpose_extsi(%a : vector<4x2xi8>) -> vector<2x4xi32> {
 
 //===----------------------------------------------------------------------===//
 // Reorder elementwise ops and vector ops.
+// TODO: Potential duplication with sink-vector-broadcast.mlir
 //===----------------------------------------------------------------------===//
 
 // -----

>From b65f36dfbdf32a0eebea3a73d791876a4778aefd Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Sun, 11 Aug 2024 17:38:33 +0100
Subject: [PATCH 2/4] [mlir][vector] Group tests for re-order patterns

Moves all tests for patterns that re-order vector.transpose and
vector.broadcast Ops (*) under one test-flag:
  * `test-vector-reorder-patterns`.

To facilitate this,
  * `-test-sink-vector-broadcast` is renamed as
    `test-vector-reorder-patterns`,
  * "sink-vector-broadcast.mlir" is renamed as "vector-reorder.mlir",
  * tests for `ReorderCastOpsOnBroadcast` and
    `ReorderElementwiseOpsOnTranspose` patterns are moved from
    "vector-reduce-to-contract.mlir" to "vector-reorder.mlir",
  * `ReorderElementwiseOpsOnTranspose` patterns are removed from
    `populateVectorReductionToContractPatterns` and added to (newly
    created) `populateReoderVectorTransposePatterns`.
  * `ReorderCastOpsOnBroadcast` patterns are removed from
    `populateVectorReductionToContractPatterns` - these are already
    present in `populateSinkVectorBroadcastPatter`.

This should allow us better layering and more straightforward testing.
For the latter, the goal is to be able to easily identify which pattern
a particular test is exercising (especially when it's a specific
pattern).

Note for downstream users: in order to preserve the current
functionality, please make sure to add
  * `populateReoderVectorTransposePatterns` and
    `populateSinkVectorBroadcastPatter`,

wherever you are using `populateVectorReductionToContractPatterns`

(*) I didn't notice any other re-order patterns.
---
 .../Vector/Transforms/VectorRewritePatterns.h |   4 +
 .../TransformOps/LinalgTransformOps.cpp       |   1 +
 .../Vector/Transforms/VectorTransforms.cpp    |   9 +-
 .../Vector/vector-reduce-to-contract.mlir     | 122 ------------------
 ...tor-broadcast.mlir => vector-reorder.mlir} | 114 +++++++++++++++-
 .../Dialect/Vector/TestVectorTransforms.cpp   |   5 +-
 6 files changed, 128 insertions(+), 127 deletions(-)
 rename mlir/test/Dialect/Vector/{sink-vector-broadcast.mlir => vector-reorder.mlir} (65%)

diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index 10970fd03e6eb2..2c17ffd49d1d41 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -148,6 +148,10 @@ void populateVectorTransferCollapseInnerMostContiguousDimsPatterns(
 void populateSinkVectorBroadcastPatterns(RewritePatternSet &patterns,
                                          PatternBenefit benefit = 1);
 
+/// Patterns that re-order transpose Ops.
+void populateReoderVectorTransposePatterns(RewritePatternSet &patterns,
+                                           PatternBenefit benefit = 1);
+
 /// Patterns that fold chained vector reductions. These patterns assume that
 /// elementwise operations (e.g., `arith.addf` with vector operands) are
 /// cheaper than vector reduction.
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 48b3abbeee7010..e2f9ca1fc75027 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3452,6 +3452,7 @@ transform::VectorizeChildrenAndApplyPatternsOp::applyToOne(
   if (!getDisableMultiReductionToContractPatterns())
     vector::populateVectorReductionToContractPatterns(patterns);
 
+  vector::populateReoderVectorTransposePatterns(patterns);
   vector::populateSinkVectorBroadcastPatterns(patterns);
 
   patterns.add<linalg::LinalgCopyVTRForwardingPattern,
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index ccbaa3e9759975..922e3c61a77310 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -2030,8 +2030,7 @@ void mlir::vector::populateVectorContractCanonicalizeMatmulToMMT(
 void mlir::vector::populateVectorReductionToContractPatterns(
     RewritePatternSet &patterns, PatternBenefit benefit) {
   patterns.add<MultiReduceToContract, CombineContractBroadcast,
-               CombineContractABTranspose, CombineContractResultTranspose,
-               ReorderCastOpsOnBroadcast, ReorderElementwiseOpsOnTranspose>(
+               CombineContractABTranspose, CombineContractResultTranspose>(
       patterns.getContext(), benefit);
 }
 
@@ -2049,6 +2048,12 @@ void mlir::vector::populateSinkVectorBroadcastPatterns(
       patterns.getContext(), benefit);
 }
 
+void mlir::vector::populateReoderVectorTransposePatterns(
+    RewritePatternSet &patterns, PatternBenefit benefit) {
+  patterns.add<ReorderElementwiseOpsOnTranspose>(patterns.getContext(),
+                                                 benefit);
+}
+
 void mlir::vector::populateChainedVectorReductionFoldingPatterns(
     RewritePatternSet &patterns, PatternBenefit benefit) {
   patterns.add<ChainedReduction>(patterns.getContext(), benefit);
diff --git a/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir b/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir
index c0dbea81df892a..24070dbf017a58 100644
--- a/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir
+++ b/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir
@@ -245,128 +245,6 @@ func.func @contract_broadcast_would_have_no_reduction_dim_pair(%arg0 : vector<1x
 }
 
 
-//===----------------------------------------------------------------------===//
-// [Pattern: ReorderCastOpsOnBroadcast]
-//
-// Reorder casting ops and vector ops. The casting ops have almost identical
-// pattern, so only arith.extsi op is tested.
-//
-// TODO: Potential duplication with sink-vector-broadcast.mlir
-//===----------------------------------------------------------------------===//
-
-// -----
-
-func.func @broadcast_vector_extsi(%a : vector<4xi8>) -> vector<2x4xi32> {
-  // CHECK: %[[EXT:.+]] = arith.extsi %{{.+}} : vector<4xi8> to vector<4xi32>
-  // CHECK: vector.broadcast %[[EXT:.+]] : vector<4xi32> to vector<2x4xi32>
-  %b = vector.broadcast %a : vector<4xi8> to vector<2x4xi8>
-  %r = arith.extsi %b : vector<2x4xi8> to vector<2x4xi32>
-  return %r : vector<2x4xi32>
-}
-
-// -----
-
-func.func @broadcast_scalar_extsi(%a : i8) -> vector<2x4xi32> {
-  // CHECK: %[[EXT:.+]] = arith.extsi %{{.+}} : i8 to i32
-  // CHECK: vector.broadcast %[[EXT]] : i32 to vector<2x4xi32>
-  %b = vector.broadcast %a : i8 to vector<2x4xi8>
-  %r = arith.extsi %b : vector<2x4xi8> to vector<2x4xi32>
-  return %r : vector<2x4xi32>
-}
-
-// -----
-
-//===----------------------------------------------------------------------===//
-// [Pattern: ReorderElementwiseOpsOnTranspose]
-//
-// TODO: Potential duplication with sink-vector-broadcast.mlir
-//===----------------------------------------------------------------------===//
-func.func @transpose_extsi(%a : vector<4x2xi8>) -> vector<2x4xi32> {
-  // CHECK: %[[EXT:.+]] = arith.extsi %{{.+}} : vector<4x2xi8> to vector<4x2xi32>
-  // CHECK: vector.transpose %[[EXT]], [1, 0] : vector<4x2xi32> to vector<2x4xi32>
-  %b = vector.transpose %a, [1, 0]: vector<4x2xi8> to vector<2x4xi8>
-  %r = arith.extsi %b : vector<2x4xi8> to vector<2x4xi32>
-  return %r : vector<2x4xi32>
-}
-
-//===----------------------------------------------------------------------===//
-// Reorder elementwise ops and vector ops.
-// TODO: Potential duplication with sink-vector-broadcast.mlir
-//===----------------------------------------------------------------------===//
-
-// -----
-
-// CHECK-LABEL: func @transpose_elementwise_same_type
-//  CHECK-SAME: (%[[A:.+]]: vector<4x2xf32>, %[[B:.+]]: vector<4x2xf32>)
-//       CHECK:   %[[ADD:.+]] = arith.addf %[[A]], %[[B]] : vector<4x2xf32>
-//       CHECK:   %[[T:.+]] = vector.transpose %[[ADD]], [1, 0]
-//       CHECK:   return %[[T]]
-
-func.func @transpose_elementwise_same_type(%a : vector<4x2xf32>, %b : vector<4x2xf32>) -> vector<2x4xf32> {
-  %at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
-  %bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
-  %r = arith.addf %at, %bt : vector<2x4xf32>
-  return %r : vector<2x4xf32>
-}
-
-// -----
-
-// CHECK-LABEL: func @transpose_elementwise_diff_operand_types
-//  CHECK-SAME: (%[[COND:.+]]: vector<4x2xi1>, %[[A:.+]]: vector<4x2xf32>, %[[B:.+]]: vector<4x2xf32>)
-//       CHECK:   %[[S:.+]] = arith.select %[[COND]], %[[A]], %[[B]] : vector<4x2xi1>, vector<4x2xf32>
-//       CHECK:   %[[T:.+]] = vector.transpose %[[S]], [1, 0] : vector<4x2xf32> to vector<2x4xf32>
-//       CHECK:   return %[[T]]
-func.func @transpose_elementwise_diff_operand_types(%cond: vector<4x2xi1>, %a : vector<4x2xf32>, %b : vector<4x2xf32>) -> vector<2x4xf32> {
-  %condt = vector.transpose %cond, [1, 0]: vector<4x2xi1> to vector<2x4xi1>
-  %at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
-  %bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
-  %r = arith.select %condt, %at, %bt : vector<2x4xi1>, vector<2x4xf32>
-  return %r : vector<2x4xf32>
-}
-
-// -----
-
-// CHECK-LABEL: func @transpose_elementwise_diff_operand_result_type
-//  CHECK-SAME: (%[[A:.+]]: vector<4x2xf32>, %[[B:.+]]: vector<4x2xf32>)
-//       CHECK:   %[[CMP:.+]] = arith.cmpf olt, %[[A]], %[[B]] : vector<4x2xf32>
-//       CHECK:   %[[T:.+]] = vector.transpose %[[CMP]], [1, 0] : vector<4x2xi1> to vector<2x4xi1>
-//       CHECK:   return %[[T]]
-func.func @transpose_elementwise_diff_operand_result_type(%a : vector<4x2xf32>, %b : vector<4x2xf32>) -> vector<2x4xi1> {
-  %at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
-  %bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
-  %r = arith.cmpf olt, %at, %bt : vector<2x4xf32>
-  return %r : vector<2x4xi1>
-}
-
-// -----
-
-// CHECK-LABEL: func @transpose_elementwise_splat_constant
-//  CHECK-SAME: (%[[A:.+]]: vector<4x6x3x2xf32>)
-//       CHECK:   %[[B:.+]] = arith.constant dense<5.000000e+00> : vector<4x6x3x2xf32>
-//       CHECK:   %[[ADD:.+]] = arith.addf %[[A]], %[[B]] : vector<4x6x3x2xf32>
-//       CHECK:   %[[T:.+]] = vector.transpose %[[ADD]], [1, 0, 3, 2] : vector<4x6x3x2xf32> to vector<6x4x2x3xf32>
-//       CHECK:   return %[[T:.+]] : vector<6x4x2x3xf32>
-
-func.func @transpose_elementwise_splat_constant(%a : vector<4x6x3x2xf32>) -> vector<6x4x2x3xf32> {
-  %b = arith.constant dense<5.0> : vector<6x4x2x3xf32>
-  %at = vector.transpose %a, [1, 0, 3, 2]: vector<4x6x3x2xf32> to vector<6x4x2x3xf32>
-  %r = arith.addf %at, %b : vector<6x4x2x3xf32>
-  return %r : vector<6x4x2x3xf32>
-}
-
-// -----
-
-// CHECK-LABEL: func @transpose_elementwise_diff_map
-//       CHECK:   vector.transpose
-//       CHECK:   vector.transpose
-//       CHECK:   arith.addf
-func.func @transpose_elementwise_diff_map(%a : vector<4x6x3x2xf32>, %b: vector<6x2x4x3xf32>) -> vector<6x4x2x3xf32> {
-  %at = vector.transpose %a, [1, 0, 3, 2]: vector<4x6x3x2xf32> to vector<6x4x2x3xf32>
-  %bt = vector.transpose %b, [0, 2, 1, 3]: vector<6x2x4x3xf32> to vector<6x4x2x3xf32>
-  %r = arith.addf %at, %bt : vector<6x4x2x3xf32>
-  return %r : vector<6x4x2x3xf32>
-}
-
 // -----
 
 // CHECK-DAG: #[[$LHS_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>
diff --git a/mlir/test/Dialect/Vector/sink-vector-broadcast.mlir b/mlir/test/Dialect/Vector/vector-reorder.mlir
similarity index 65%
rename from mlir/test/Dialect/Vector/sink-vector-broadcast.mlir
rename to mlir/test/Dialect/Vector/vector-reorder.mlir
index ae55b6696c676c..d7669ec2b54037 100644
--- a/mlir/test/Dialect/Vector/sink-vector-broadcast.mlir
+++ b/mlir/test/Dialect/Vector/vector-reorder.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -test-sink-vector-broadcast -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -test-vector-reorder-patterns -split-input-file | FileCheck %s
 
 //-----------------------------------------------------------------------------
 // [Pattern: ReorderElementwiseOpsOnBroadcast]
@@ -207,3 +207,115 @@ func.func @negative_op_only_supports_vectors(%arg0 : f32) -> vector<1xf32> {
   %1 = vector.fma %0, %0, %0 : vector<1xf32>
   return %1 : vector<1xf32>
 }
+
+//===----------------------------------------------------------------------===//
+// [Pattern: ReorderCastOpsOnBroadcast]
+//
+// Reorder casting ops and vector ops. The casting ops have almost identical
+// pattern, so only arith.extsi op is tested.
+//===----------------------------------------------------------------------===//
+
+// -----
+
+func.func @broadcast_vector_extsi(%a : vector<4xi8>) -> vector<2x4xi32> {
+  // CHECK: %[[EXT:.+]] = arith.extsi %{{.+}} : vector<4xi8> to vector<4xi32>
+  // CHECK: vector.broadcast %[[EXT:.+]] : vector<4xi32> to vector<2x4xi32>
+  %b = vector.broadcast %a : vector<4xi8> to vector<2x4xi8>
+  %r = arith.extsi %b : vector<2x4xi8> to vector<2x4xi32>
+  return %r : vector<2x4xi32>
+}
+
+// -----
+
+func.func @broadcast_scalar_extsi(%a : i8) -> vector<2x4xi32> {
+  // CHECK: %[[EXT:.+]] = arith.extsi %{{.+}} : i8 to i32
+  // CHECK: vector.broadcast %[[EXT]] : i32 to vector<2x4xi32>
+  %b = vector.broadcast %a : i8 to vector<2x4xi8>
+  %r = arith.extsi %b : vector<2x4xi8> to vector<2x4xi32>
+  return %r : vector<2x4xi32>
+}
+
+//===----------------------------------------------------------------------===//
+// [Pattern: ReorderElementwiseOpsOnTranspose]
+//===----------------------------------------------------------------------===//
+
+func.func @transpose_extsi(%a : vector<4x2xi8>) -> vector<2x4xi32> {
+  // CHECK: %[[EXT:.+]] = arith.extsi %{{.+}} : vector<4x2xi8> to vector<4x2xi32>
+  // CHECK: vector.transpose %[[EXT]], [1, 0] : vector<4x2xi32> to vector<2x4xi32>
+  %b = vector.transpose %a, [1, 0]: vector<4x2xi8> to vector<2x4xi8>
+  %r = arith.extsi %b : vector<2x4xi8> to vector<2x4xi32>
+  return %r : vector<2x4xi32>
+}
+
+// -----
+
+// CHECK-LABEL: func @transpose_elementwise_same_type
+//  CHECK-SAME: (%[[A:.+]]: vector<4x2xf32>, %[[B:.+]]: vector<4x2xf32>)
+//       CHECK:   %[[ADD:.+]] = arith.addf %[[A]], %[[B]] : vector<4x2xf32>
+//       CHECK:   %[[T:.+]] = vector.transpose %[[ADD]], [1, 0]
+//       CHECK:   return %[[T]]
+
+func.func @transpose_elementwise_same_type(%a : vector<4x2xf32>, %b : vector<4x2xf32>) -> vector<2x4xf32> {
+  %at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
+  %bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
+  %r = arith.addf %at, %bt : vector<2x4xf32>
+  return %r : vector<2x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @transpose_elementwise_diff_operand_types
+//  CHECK-SAME: (%[[COND:.+]]: vector<4x2xi1>, %[[A:.+]]: vector<4x2xf32>, %[[B:.+]]: vector<4x2xf32>)
+//       CHECK:   %[[S:.+]] = arith.select %[[COND]], %[[A]], %[[B]] : vector<4x2xi1>, vector<4x2xf32>
+//       CHECK:   %[[T:.+]] = vector.transpose %[[S]], [1, 0] : vector<4x2xf32> to vector<2x4xf32>
+//       CHECK:   return %[[T]]
+func.func @transpose_elementwise_diff_operand_types(%cond: vector<4x2xi1>, %a : vector<4x2xf32>, %b : vector<4x2xf32>) -> vector<2x4xf32> {
+  %condt = vector.transpose %cond, [1, 0]: vector<4x2xi1> to vector<2x4xi1>
+  %at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
+  %bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
+  %r = arith.select %condt, %at, %bt : vector<2x4xi1>, vector<2x4xf32>
+  return %r : vector<2x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @transpose_elementwise_diff_operand_result_type
+//  CHECK-SAME: (%[[A:.+]]: vector<4x2xf32>, %[[B:.+]]: vector<4x2xf32>)
+//       CHECK:   %[[CMP:.+]] = arith.cmpf olt, %[[A]], %[[B]] : vector<4x2xf32>
+//       CHECK:   %[[T:.+]] = vector.transpose %[[CMP]], [1, 0] : vector<4x2xi1> to vector<2x4xi1>
+//       CHECK:   return %[[T]]
+func.func @transpose_elementwise_diff_operand_result_type(%a : vector<4x2xf32>, %b : vector<4x2xf32>) -> vector<2x4xi1> {
+  %at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
+  %bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
+  %r = arith.cmpf olt, %at, %bt : vector<2x4xf32>
+  return %r : vector<2x4xi1>
+}
+
+// -----
+
+// CHECK-LABEL: func @transpose_elementwise_splat_constant
+//  CHECK-SAME: (%[[A:.+]]: vector<4x6x3x2xf32>)
+//       CHECK:   %[[B:.+]] = arith.constant dense<5.000000e+00> : vector<4x6x3x2xf32>
+//       CHECK:   %[[ADD:.+]] = arith.addf %[[A]], %[[B]] : vector<4x6x3x2xf32>
+//       CHECK:   %[[T:.+]] = vector.transpose %[[ADD]], [1, 0, 3, 2] : vector<4x6x3x2xf32> to vector<6x4x2x3xf32>
+//       CHECK:   return %[[T:.+]] : vector<6x4x2x3xf32>
+
+func.func @transpose_elementwise_splat_constant(%a : vector<4x6x3x2xf32>) -> vector<6x4x2x3xf32> {
+  %b = arith.constant dense<5.0> : vector<6x4x2x3xf32>
+  %at = vector.transpose %a, [1, 0, 3, 2]: vector<4x6x3x2xf32> to vector<6x4x2x3xf32>
+  %r = arith.addf %at, %b : vector<6x4x2x3xf32>
+  return %r : vector<6x4x2x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @transpose_elementwise_diff_map
+//       CHECK:   vector.transpose
+//       CHECK:   vector.transpose
+//       CHECK:   arith.addf
+func.func @transpose_elementwise_diff_map(%a : vector<4x6x3x2xf32>, %b: vector<6x2x4x3xf32>) -> vector<6x4x2x3xf32> {
+  %at = vector.transpose %a, [1, 0, 3, 2]: vector<4x6x3x2xf32> to vector<6x4x2x3xf32>
+  %bt = vector.transpose %b, [0, 2, 1, 3]: vector<6x2x4x3xf32> to vector<6x4x2x3xf32>
+  %r = arith.addf %at, %bt : vector<6x4x2x3xf32>
+  return %r : vector<6x4x2x3xf32>
+}
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 29c763b622e877..c4479bf77d6c5e 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -385,16 +385,17 @@ struct TestSinkVectorBroadcast
     registry.insert<memref::MemRefDialect, affine::AffineDialect>();
   }
 
-  StringRef getArgument() const final { return "test-sink-vector-broadcast"; }
+  StringRef getArgument() const final { return "test-vector-reorder-patterns"; }
 
   StringRef getDescription() const final {
     return "Test lowering patterns that eliminate redundant brodacast "
-           "operations.";
+           "and transpose operations.";
   }
 
   void runOnOperation() override {
     RewritePatternSet patterns(&getContext());
     populateSinkVectorBroadcastPatterns(patterns);
+    populateReoderVectorTransposePatterns(patterns);
     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
   }
 };

>From eeef857854167796ca43c34b058a8628a54bffb9 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Mon, 12 Aug 2024 10:45:56 +0100
Subject: [PATCH 3/4] fixup! [mlir][vector] Group tests for re-order patterns

Rename TestSinkVectorBroadcast as TestVectorReorderPatterns
---
 .../lib/Dialect/Vector/TestVectorTransforms.cpp     | 13 +++++++------
 1 file changed, 7 insertions(+), 6 deletions(-)

diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index c4479bf77d6c5e..f468873aff859d 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -374,12 +374,13 @@ struct TestVectorTransferCollapseInnerMostContiguousDims
   }
 };
 
-struct TestSinkVectorBroadcast
-    : public PassWrapper<TestSinkVectorBroadcast, OperationPass<func::FuncOp>> {
-  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestSinkVectorBroadcast)
+struct TestVectorReorderPatterns
+    : public PassWrapper<TestVectorReorderPatterns,
+                         OperationPass<func::FuncOp>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorReorderPatterns)
 
-  TestSinkVectorBroadcast() = default;
-  TestSinkVectorBroadcast(const TestSinkVectorBroadcast &pass) = default;
+  TestVectorReorderPatterns() = default;
+  TestVectorReorderPatterns(const TestVectorReorderPatterns &pass) = default;
 
   void getDependentDialects(DialectRegistry &registry) const override {
     registry.insert<memref::MemRefDialect, affine::AffineDialect>();
@@ -920,7 +921,7 @@ void registerTestVectorLowerings() {
 
   PassRegistration<TestVectorTransferCollapseInnerMostContiguousDims>();
 
-  PassRegistration<TestSinkVectorBroadcast>();
+  PassRegistration<TestVectorReorderPatterns>();
 
   PassRegistration<TestVectorReduceToContractPatternsPatterns>();
 

>From e1ef3b703bf19b3819725a7566fedae4759fbe1a Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Tue, 13 Aug 2024 17:38:50 +0100
Subject: [PATCH 4/4] fixup! fixup! [mlir][vector] Group tests for re-order
 patterns

Group all re-order patterns under `populateSinkVectorOpsPatterns`.
Rename test flag and test pass accordingly.
---
 .../Vector/Transforms/VectorRewritePatterns.h | 23 +++++++++++++------
 .../TransformOps/LinalgTransformOps.cpp       |  3 +--
 .../Vector/Transforms/VectorTransforms.cpp    | 13 ++++-------
 mlir/test/Dialect/Vector/vector-reorder.mlir  |  2 +-
 .../Dialect/Vector/TestVectorTransforms.cpp   | 18 +++++++--------
 5 files changed, 30 insertions(+), 29 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index 2c17ffd49d1d41..456d021bc79358 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -144,13 +144,22 @@ void populateVectorTransferFullPartialPatterns(
 void populateVectorTransferCollapseInnerMostContiguousDimsPatterns(
     RewritePatternSet &patterns, PatternBenefit benefit = 1);
 
-/// Patterns that remove redundant vector broadcasts.
-void populateSinkVectorBroadcastPatterns(RewritePatternSet &patterns,
-                                         PatternBenefit benefit = 1);
-
-/// Patterns that re-order transpose Ops.
-void populateReoderVectorTransposePatterns(RewritePatternSet &patterns,
-                                           PatternBenefit benefit = 1);
+/// Patterns that remove redundant Vector Ops by re-ordering them with
+/// e.g. elementwise Ops:
+/// ```
+/// %at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
+/// %bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
+/// %r = arith.addf %at, %bt : vector<2x4xf32>
+/// ```
+/// gets converted to:
+/// ```
+/// %0 = arith.addf %a, %b : vector<4x2xf32>
+/// %r = vector.transpose %0, [1, 0] : vector<2x4xf32>
+/// ```
+/// At the moment, these patterns are limited to vector.broadcast and
+/// vector.transpose.
+void populateSinkVectorOpsPatterns(RewritePatternSet &patterns,
+                                   PatternBenefit benefit = 1);
 
 /// Patterns that fold chained vector reductions. These patterns assume that
 /// elementwise operations (e.g., `arith.addf` with vector operands) are
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index e2f9ca1fc75027..248d255d8076d6 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3452,8 +3452,7 @@ transform::VectorizeChildrenAndApplyPatternsOp::applyToOne(
   if (!getDisableMultiReductionToContractPatterns())
     vector::populateVectorReductionToContractPatterns(patterns);
 
-  vector::populateReoderVectorTransposePatterns(patterns);
-  vector::populateSinkVectorBroadcastPatterns(patterns);
+  vector::populateSinkVectorOpsPatterns(patterns);
 
   patterns.add<linalg::LinalgCopyVTRForwardingPattern,
                linalg::LinalgCopyVTWForwardingPattern>(ctx,
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 922e3c61a77310..ad4e42b31962e1 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -2042,15 +2042,10 @@ void mlir::vector::
                                                    benefit);
 }
 
-void mlir::vector::populateSinkVectorBroadcastPatterns(
-    RewritePatternSet &patterns, PatternBenefit benefit) {
-  patterns.add<ReorderCastOpsOnBroadcast, ReorderElementwiseOpsOnBroadcast>(
-      patterns.getContext(), benefit);
-}
-
-void mlir::vector::populateReoderVectorTransposePatterns(
-    RewritePatternSet &patterns, PatternBenefit benefit) {
-  patterns.add<ReorderElementwiseOpsOnTranspose>(patterns.getContext(),
+void mlir::vector::populateSinkVectorOpsPatterns(RewritePatternSet &patterns,
+                                                 PatternBenefit benefit) {
+  patterns.add<ReorderElementwiseOpsOnTranspose, ReorderCastOpsOnBroadcast,
+               ReorderElementwiseOpsOnBroadcast>(patterns.getContext(),
                                                  benefit);
 }
 
diff --git a/mlir/test/Dialect/Vector/vector-reorder.mlir b/mlir/test/Dialect/Vector/vector-reorder.mlir
index d7669ec2b54037..71b1a6ace75be2 100644
--- a/mlir/test/Dialect/Vector/vector-reorder.mlir
+++ b/mlir/test/Dialect/Vector/vector-reorder.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -test-vector-reorder-patterns -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -test-vector-sink-patterns -split-input-file | FileCheck %s
 
 //-----------------------------------------------------------------------------
 // [Pattern: ReorderElementwiseOpsOnBroadcast]
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index f468873aff859d..72aaa7dc4f8973 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -374,19 +374,18 @@ struct TestVectorTransferCollapseInnerMostContiguousDims
   }
 };
 
-struct TestVectorReorderPatterns
-    : public PassWrapper<TestVectorReorderPatterns,
-                         OperationPass<func::FuncOp>> {
-  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorReorderPatterns)
+struct TestVectorSinkPatterns
+    : public PassWrapper<TestVectorSinkPatterns, OperationPass<func::FuncOp>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorSinkPatterns)
 
-  TestVectorReorderPatterns() = default;
-  TestVectorReorderPatterns(const TestVectorReorderPatterns &pass) = default;
+  TestVectorSinkPatterns() = default;
+  TestVectorSinkPatterns(const TestVectorSinkPatterns &pass) = default;
 
   void getDependentDialects(DialectRegistry &registry) const override {
     registry.insert<memref::MemRefDialect, affine::AffineDialect>();
   }
 
-  StringRef getArgument() const final { return "test-vector-reorder-patterns"; }
+  StringRef getArgument() const final { return "test-vector-sink-patterns"; }
 
   StringRef getDescription() const final {
     return "Test lowering patterns that eliminate redundant brodacast "
@@ -395,8 +394,7 @@ struct TestVectorReorderPatterns
 
   void runOnOperation() override {
     RewritePatternSet patterns(&getContext());
-    populateSinkVectorBroadcastPatterns(patterns);
-    populateReoderVectorTransposePatterns(patterns);
+    populateSinkVectorOpsPatterns(patterns);
     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
   }
 };
@@ -921,7 +919,7 @@ void registerTestVectorLowerings() {
 
   PassRegistration<TestVectorTransferCollapseInnerMostContiguousDims>();
 
-  PassRegistration<TestVectorReorderPatterns>();
+  PassRegistration<TestVectorSinkPatterns>();
 
   PassRegistration<TestVectorReduceToContractPatternsPatterns>();
 



More information about the Mlir-commits mailing list