[Mlir-commits] [mlir] 4d339ec - [mlir][Vector] Add pattern to reorder elementwise and broadcast ops

Andrzej Warzynski llvmlistbot at llvm.org
Thu Jun 15 02:13:48 PDT 2023


Author: Andrzej Warzynski
Date: 2023-06-15T10:13:41+01:00
New Revision: 4d339ec91e81ae33b0f3ea0f8a3596d99645a0e9

URL: https://github.com/llvm/llvm-project/commit/4d339ec91e81ae33b0f3ea0f8a3596d99645a0e9
DIFF: https://github.com/llvm/llvm-project/commit/4d339ec91e81ae33b0f3ea0f8a3596d99645a0e9.diff

LOG: [mlir][Vector] Add pattern to reorder elementwise and broadcast ops

The new pattern will replace elementwise(broadcast) with
broadcast(elementwise) when safe.

This change affects tests for vectorising nD-extract. In one case
("vectorize_nd_tensor_extract_with_tensor_extract") I just trimmed the
test and only preserved the key parts (scalar and contiguous load from
the original Op). We could do the same with some other tests if that
helps maintainability.

Differential Revision: https://reviews.llvm.org/D152812

Added: 
    mlir/test/Dialect/Vector/sink-vector-broadcast.mlir

Modified: 
    mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
    mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
    mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
    mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
    mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index 10e70b96b835f..55fd2fcd34b68 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -137,6 +137,10 @@ void populateVectorTransferFullPartialPatterns(
 void populateVectorTransferCollapseInnerMostContiguousDimsPatterns(
     RewritePatternSet &patterns, PatternBenefit benefit = 1);
 
+/// Patterns that remove redundant vector broadcasts.
+void populateSinkVectorBroadcastPatterns(RewritePatternSet &patterns,
+                                          PatternBenefit benefit = 1);
+
 /// Populate `patterns` with the following patterns.
 ///
 /// [DecomposeDifferentRankInsertStridedSlice]

diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 58a1fa864d430..4a9fc8e51954c 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3066,6 +3066,8 @@ transform::VectorizeOp::applyToOne(Operation *target,
   if (!getDisableMultiReductionToContractPatterns())
     vector::populateVectorReductionToContractPatterns(patterns);
 
+  vector::populateSinkVectorBroadcastPatterns(patterns);
+
   patterns.add<linalg::LinalgCopyVTRForwardingPattern,
                linalg::LinalgCopyVTWForwardingPattern>(ctx,
                                                        /*benefit=*/2);

diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index f9e778d331ef0..ea42d57d2fb0a 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -885,6 +885,66 @@ struct BreakDownVectorBitCast : public OpRewritePattern<vector::BitCastOp> {
   std::function<bool(BitCastOp)> controlFn;
 };
 
+/// Reorders elementwise(broadcast) to broadcast(elementwise). Ex:
+/// ```
+/// %a = vector.broadcast %arg1 : index to vector<1x4xindex>
+/// %b = vector.broadcast %arg2 : index to vector<1x4xindex>
+/// %r = arith.addi %a, %b : vector<1x4xindex>
+/// ```
+/// Gets converted to:
+/// ```
+/// %r = arith.addi %arg0, %arg1 : index
+/// %b = vector.broadcast %r : index to vector<1x4xindex>
+/// ```
+struct ReorderElementwiseOpsOnBroadcast final
+    : public OpTraitRewritePattern<OpTrait::Elementwise> {
+  using OpTraitRewritePattern::OpTraitRewritePattern;
+  LogicalResult matchAndRewrite(Operation *op,
+                                PatternRewriter &rewriter) const override {
+    if (op->getNumResults() != 1)
+      return failure();
+    if (!llvm::isa<ShapedType>(op->getResults()[0].getType()))
+      return failure();
+    if (!OpTrait::hasElementwiseMappableTraits(op))
+      return failure();
+
+    // Get the type of the first operand
+    auto firstBcast = op->getOperand(0).getDefiningOp<vector::BroadcastOp>();
+    if (!firstBcast)
+      return failure();
+    auto firstOpType = firstBcast.getOperand().getType();
+
+    // Make sure that operands are "broadcast"ed from identical (scalar or
+    // vector) types. That indicates that it's safe to skip the broadcasting of
+    // operands.
+    if (!llvm::all_of(op->getOperands(), [&firstOpType](Value val) {
+          auto bcast = val.getDefiningOp<vector::BroadcastOp>();
+          return (bcast && (bcast.getOperand().getType() == firstOpType));
+        })) {
+      return failure();
+    }
+
+    // Collect the source values
+    SmallVector<Value> srcValues;
+    srcValues.reserve(op->getNumOperands());
+
+    for (Value operand : op->getOperands()) {
+      srcValues.push_back(
+          operand.getDefiningOp<vector::BroadcastOp>().getOperand());
+    }
+
+    Operation *elementwiseOp =
+        rewriter.create(op->getLoc(), op->getName().getIdentifier(), srcValues,
+                        firstOpType, op->getAttrs());
+
+    auto vectorType = op->getResultTypes()[0];
+    rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
+        op, vectorType, elementwiseOp->getResults());
+
+    return success();
+  }
+};
+
 // Helper that returns a vector comparison that constructs a mask:
 //     mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b]
 //
@@ -1311,6 +1371,12 @@ void mlir::vector::
   patterns.add<DropInnerMostUnitDims>(patterns.getContext(), benefit);
 }
 
+void mlir::vector::populateSinkVectorBroadcastPatterns(
+    RewritePatternSet &patterns, PatternBenefit benefit) {
+  patterns.add<ReorderElementwiseOpsOnBroadcast>(patterns.getContext(),
+                                                 benefit);
+}
+
 //===----------------------------------------------------------------------===//
 // TableGen'd enum attribute definitions
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
index 5047545f0b2cb..baf0894fd4853 100644
--- a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
+++ b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
@@ -130,27 +130,29 @@ func.func @vectorize_nd_tensor_extract_transfer_read_complex(%6: tensor<45x80x16
   return %25 : tensor<1x4xf32>
 }
 
-// CHECK-LABEL:   func.func @vectorize_nd_tensor_extract_transfer_read_complex
+
+// CHECK-LABEL:   func.func @vectorize_nd_tensor_extract_transfer_read_complex(
 // CHECK-SAME:      %[[VAL_0:.*]]: tensor<45x80x16xf32>,
-// CHECK-SAME:      {{.*}}: index,
+// CHECK-SAME:      %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: index, %[[VAL_4:.*]]: index,
 // CHECK-SAME:      %[[VAL_5:.*]]: tensor<1x4xf32>) -> tensor<1x4xf32> {
 // CHECK:           %[[VAL_6:.*]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex>
 // CHECK:           %[[VAL_7:.*]] = arith.constant 0 : i32
 // CHECK:           %[[VAL_8:.*]] = arith.constant 0.000000e+00 : f32
 // CHECK:           %[[VAL_9:.*]] = arith.constant 0 : index
 // CHECK:           %[[VAL_10:.*]] = arith.constant 79 : index
-// CHECK:           %[[VAL_11:.*]] = vector.broadcast %{{.*}} : index to vector<1x4xindex>
-// CHECK:           %[[VAL_12:.*]] = vector.broadcast %{{.*}} : index to vector<1x4xindex>
-// CHECK:           %[[VAL_13:.*]] = arith.addi %[[VAL_11]], %[[VAL_12]] : vector<1x4xindex>
-// CHECK:           %[[VAL_14:.*]] = vector.broadcast %{{.*}} : index to vector<4xindex>
-// CHECK:           %[[VAL_15:.*]] = arith.addi %[[VAL_14]], %[[VAL_6]] : vector<4xindex>
-// CHECK:           %[[VAL_16:.*]] = vector.broadcast %{{.*}} : index to vector<4xindex>
-// CHECK:           %[[VAL_17:.*]] = arith.addi %[[VAL_15]], %[[VAL_16]] : vector<4xindex>
-// CHECK:           %[[VAL_18:.*]] = vector.shape_cast %[[VAL_13]] : vector<1x4xindex> to vector<4xindex>
-// CHECK:           %[[VAL_19:.*]] = vector.extractelement %[[VAL_18]]{{\[}}%[[VAL_7]] : i32] : vector<4xindex>
-// CHECK:           %[[VAL_20:.*]] = vector.extractelement %[[VAL_17]]{{\[}}%[[VAL_7]] : i32] : vector<4xindex>
-// CHECK:           %[[VAL_21:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[VAL_19]], %[[VAL_10]], %[[VAL_20]]], %[[VAL_8]] {in_bounds = [true, true]} : tensor<45x80x16xf32>, vector<1x4xf32>
-// CHECK:           %[[VAL_22:.*]] = vector.transfer_write %[[VAL_21]], %[[VAL_5]]{{\[}}%[[VAL_9]], %[[VAL_9]]] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x4xf32>
+// CHECK:           %[[VAL_11:.*]] = arith.addi %[[VAL_1]], %[[VAL_2]] : index
+// CHECK:           %[[VAL_12:.*]] = vector.broadcast %[[VAL_11]] : index to vector<1x4xindex>
+// CHECK:           %[[VAL_13:.*]] = vector.broadcast %[[VAL_3]] : index to vector<4xindex>
+// CHECK:           %[[VAL_14:.*]] = arith.addi %[[VAL_13]], %[[VAL_6]] : vector<4xindex>
+// CHECK:           %[[VAL_15:.*]] = vector.broadcast %[[VAL_4]] : index to vector<4xindex>
+// CHECK:           %[[VAL_16:.*]] = arith.addi %[[VAL_14]], %[[VAL_15]] : vector<4xindex>
+// CHECK:           %[[VAL_17:.*]] = vector.shape_cast %[[VAL_12]] : vector<1x4xindex> to vector<4xindex>
+// CHECK:           %[[VAL_18:.*]] = vector.extractelement %[[VAL_17]]{{\[}}%[[VAL_7]] : i32] : vector<4xindex>
+// CHECK:           %[[VAL_19:.*]] = vector.extractelement %[[VAL_16]]{{\[}}%[[VAL_7]] : i32] : vector<4xindex>
+// CHECK:           %[[VAL_20:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[VAL_18]], %[[VAL_10]], %[[VAL_19]]], %[[VAL_8]] {in_bounds = [true, true]} : tensor<45x80x16xf32>, vector<1x4xf32>
+// CHECK:           %[[VAL_21:.*]] = vector.transfer_write %[[VAL_20]], %[[VAL_5]]{{\[}}%[[VAL_9]], %[[VAL_9]]] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x4xf32>
+// CHECK:           return %[[VAL_21]] : tensor<1x4xf32>
+// CHECK:         }
 
 transform.sequence failures(propagate) {
  ^bb1(%arg1: !transform.any_op):
@@ -317,43 +319,16 @@ func.func @vectorize_nd_tensor_extract_with_tensor_extract(%input_1: tensor<1x20
 }
 
 // CHECK-LABEL:   func.func @vectorize_nd_tensor_extract_with_tensor_extract(
-// CHECK-SAME:    %[[VAL_0:.*]]: tensor<1x20xi32>,
-// CHECK-SAME:    %[[VAL_1:.*]]: tensor<257x24xf32>,
-// CHECK-SAME:       %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: index, %[[VAL_4:.*]]: index, %[[VAL_5:.*]]: index) -> tensor<1x1x4xf32> {
-// CHECK-DAG:       %[[VAL_6:.*]] = arith.constant dense<0> : vector<1x1x4xindex>
-// CHECK-DAG:       %[[VAL_7:.*]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex>
-// CHECK-DAG:       %[[VAL_8:.*]] = arith.constant 0 : i32
-// CHECK-DAG:       %[[VAL_9:.*]] = arith.constant dense<256> : vector<1x1x4xindex>
-// CHECK-DAG:       %[[VAL_10:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK-DAG:       %[[VAL_11:.*]] = arith.constant 0 : index
-// CHECK:           %[[VAL_12:.*]] = tensor.empty() : tensor<1x1x4xf32>
-// CHECK:           %[[VAL_13:.*]] = vector.broadcast %[[VAL_2]] : index to vector<1x1x4xindex>
-// CHECK:           %[[VAL_14:.*]] = vector.broadcast %[[VAL_4]] : index to vector<1x1x4xindex>
-// CHECK:           %[[VAL_15:.*]] = arith.addi %[[VAL_13]], %[[VAL_14]] : vector<1x1x4xindex>
-// CHECK:           %[[VAL_16:.*]] = vector.broadcast %[[VAL_3]] : index to vector<1x1x4xindex>
-// CHECK:           %[[VAL_17:.*]] = vector.broadcast %[[VAL_7]] : vector<4xindex> to vector<1x1x4xindex>
-// CHECK:           %[[VAL_18:.*]] = arith.addi %[[VAL_16]], %[[VAL_17]] : vector<1x1x4xindex>
-// CHECK:           %[[VAL_19:.*]] = vector.broadcast %[[VAL_5]] : index to vector<1x1x4xindex>
-// CHECK:           %[[VAL_20:.*]] = arith.addi %[[VAL_18]], %[[VAL_19]] : vector<1x1x4xindex>
-// CHECK:           %[[VAL_21:.*]] = vector.shape_cast %[[VAL_15]] : vector<1x1x4xindex> to vector<4xindex>
-// CHECK:           %[[VAL_22:.*]] = vector.extractelement %[[VAL_21]][%[[VAL_8]] : i32] : vector<4xindex>
+// CHECK-SAME:      %[[INPUT_1:.*]]: tensor<1x20xi32>,
+// CHECK-SAME:      %[[INPUT_2:.*]]: tensor<257x24xf32>,
+// CHECK:           %[[EXTRACTED_0_IDX_0:.*]] = arith.constant 0 : index
+// CHECK:           %[[EXTRACTED_0_IDX_1:.*]] = vector.extractelement %{{.*}}[%{{.*}} : i32] : vector<4xindex>
 // First `tensor.extract` from the generic Op - loop invariant scalar load.
-// CHECK:           %[[VAL_23:.*]] = tensor.extract %[[VAL_0]][%[[VAL_11]], %[[VAL_22]]] : tensor<1x20xi32>
-// CHECK:           %[[VAL_24:.*]] = arith.index_cast %[[VAL_23]] : i32 to index
-// CHECK:           %[[VAL_25:.*]] = vector.broadcast %[[VAL_24]] : index to vector<1x1x4xindex>
-// CHECK:           %[[VAL_26:.*]] = arith.maxsi %[[VAL_25]], %[[VAL_6]] : vector<1x1x4xindex>
-// CHECK:           %[[VAL_27:.*]] = arith.minsi %[[VAL_26]], %[[VAL_9]] : vector<1x1x4xindex>
-// CHECK:           %[[VAL_28:.*]] = vector.shape_cast %[[VAL_27]] : vector<1x1x4xindex> to vector<4xindex>
-// CHECK:           %[[VAL_29:.*]] = vector.extractelement %[[VAL_28]][%[[VAL_8]] : i32] : vector<4xindex>
-// CHECK:           %[[VAL_30:.*]] = vector.shape_cast %[[VAL_20]] : vector<1x1x4xindex> to vector<4xindex>
-// CHECK:           %[[VAL_31:.*]] = vector.extractelement %[[VAL_30]][%[[VAL_8]] : i32] : vector<4xindex>
+// CHECK:           tensor.extract %[[INPUT_1]][%[[EXTRACTED_0_IDX_0]], %[[EXTRACTED_0_IDX_1]]] : tensor<1x20xi32>
 // The following `tensor.extract` from the generic Op s a contiguous load (all Ops used
 // for address calculation also satisfy the required conditions).
-// CHECK:           %[[VAL_32:.*]] = vector.transfer_read %[[VAL_1]][%[[VAL_29]], %[[VAL_31]]], %[[VAL_10]] {in_bounds = [true, true]} : tensor<257x24xf32>, vector<1x4xf32>
-// CHECK:           %[[VAL_33:.*]] = vector.broadcast %[[VAL_32]] : vector<1x4xf32> to vector<1x1x4xf32>
-// CHECK:           %[[VAL_34:.*]] = vector.transfer_write %[[VAL_33]], %[[VAL_12]][%[[VAL_11]], %[[VAL_11]], %[[VAL_11]]] {in_bounds = [true, true, true]} : vector<1x1x4xf32>, tensor<1x1x4xf32>
-// CHECK:           return %[[VAL_34]] : tensor<1x1x4xf32>
-// CHECK:         }
+// CHECK:           vector.transfer_read %[[INPUT_2]][%{{.*}}, %{{.*}}, %{{.*}} {in_bounds = [true, true]} : tensor<257x24xf32>, vector<1x4xf32>
+
 
 transform.sequence failures(propagate) {
  ^bb1(%arg1: !transform.any_op):

diff  --git a/mlir/test/Dialect/Vector/sink-vector-broadcast.mlir b/mlir/test/Dialect/Vector/sink-vector-broadcast.mlir
new file mode 100644
index 0000000000000..fcf9815f6f6f1
--- /dev/null
+++ b/mlir/test/Dialect/Vector/sink-vector-broadcast.mlir
@@ -0,0 +1,78 @@
+// RUN: mlir-opt %s -test-sink-vector-broadcast -split-input-file | FileCheck %s
+
+// CHECK-LABEL:   func.func @broadcast_scalar(
+// CHECK-SAME:     %[[ARG_0:.*]]: index, %[[ARG_1:.*]]: index) -> vector<1x4xindex> {
+// CHECK:           %[[ADD:.*]] = arith.addi %[[ARG_0]], %[[ARG_1]] : index
+// CHECK:           %[[BCAST:.*]] = vector.broadcast %[[ADD]] : index to vector<1x4xindex>
+// CHECK:           return %[[BCAST]] : vector<1x4xindex>
+// CHECK:         }
+
+func.func @broadcast_scalar( %arg1: index, %arg2: index) -> vector<1x4xindex> {
+  %0 = vector.broadcast %arg1 : index to vector<1x4xindex>
+  %1 = vector.broadcast %arg2 : index to vector<1x4xindex>
+  %2 = arith.addi %0, %1 : vector<1x4xindex>
+  return %2 : vector<1x4xindex>
+}
+
+// -----
+
+// CHECK-LABEL:   func.func @broadcast_vector(
+// CHECK-SAME:      %[[ARG_0:.*]]: vector<4xf32>,
+// CHECK-SAME:      %[[ARG_1:.*]]: vector<4xf32>) -> vector<3x4xf32> {
+// CHECK:           %[[ADDF:.*]] = arith.addf %[[ARG_0]], %[[ARG_1]] : vector<4xf32>
+// CHECK:           %[[BCAST:.*]] = vector.broadcast %[[ADDF]] : vector<4xf32> to vector<3x4xf32>
+// CHECK:           return %[[BCAST]] : vector<3x4xf32>
+// CHECK:         }
+
+func.func @broadcast_vector( %arg1: vector<4xf32>, %arg2: vector<4xf32>) -> vector<3x4xf32> {
+  %arg1_bcast = vector.broadcast %arg1 : vector<4xf32> to vector<3x4xf32>
+  %arg2_bcast = vector.broadcast %arg2 : vector<4xf32> to vector<3x4xf32>
+  %2 = arith.addf %arg1_bcast, %arg2_bcast : vector<3x4xf32>
+  return %2 : vector<3x4xf32>
+}
+// -----
+
+// CHECK-LABEL:   func.func @broadcast_vector_and_scalar(
+// CHECK-SAME:      %[[ARG_0:.*]]: i32,
+// CHECK-SAME:      %[[ARG_1:.*]]: vector<4xi32>) -> vector<4xi32> {
+// CHECK:           %[[BCAST:.*]] = vector.broadcast %[[ARG_0]] : i32 to vector<4xi32>
+// CHECK:           %[[ADD:.*]] = arith.addi %[[BCAST]], %[[ARG_1]] : vector<4xi32>
+// CHECK:           return %[[ADD]] : vector<4xi32>
+// CHECK:         }
+
+func.func @broadcast_vector_and_scalar( %arg1: i32, %arg2: vector<4xi32>) -> vector<4xi32> {
+  %arg1_bcast = vector.broadcast %arg1 : i32 to vector<4xi32>
+  %2 = arith.addi %arg1_bcast, %arg2 : vector<4xi32>
+  return %2 : vector<4xi32>
+}
+
+// -----
+
+#matmat_accesses = [
+  affine_map<(i, j, k) -> (i, k)>,
+  affine_map<(i, j, k) -> (k, j)>,
+  affine_map<(i, j, k) -> (i, j)>
+]
+#matmat_trait = {
+  indexing_maps = #matmat_accesses,
+  iterator_types = ["parallel", "parallel", "reduction"]
+}
+
+// CHECK-LABEL:   func.func @broadcast_not_elementwise() -> vector<2x2xf32> {
+// CHECK-DAG:       %[[VAL_0:.*]] = arith.constant dense<1.000000e+00> : vector<2x2xf32>
+// CHECK-DAG:       %[[VAL_1:.*]] = arith.constant dense<2.000000e+00> : vector<2x2xf32>
+// CHECK-DAG:       %[[VAL_2:.*]] = arith.constant dense<3.000000e+00> : vector<2x2xf32>
+// CHECK:           %[[VAL_3:.*]] = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[VAL_0]], %[[VAL_1]], %[[VAL_2]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
+func.func @broadcast_not_elementwise() -> vector<2x2xf32> {
+  %f1 = arith.constant 1.0: f32
+  %f2 = arith.constant 2.0: f32
+  %f3 = arith.constant 3.0: f32
+
+  %A = vector.broadcast %f1 : f32 to vector<2x2xf32>
+  %B = vector.broadcast %f2 : f32 to vector<2x2xf32>
+  %C = vector.broadcast %f3 : f32 to vector<2x2xf32>
+  %mm1 = vector.contract #matmat_trait %A, %B, %C
+    : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
+
+  return %mm1 : vector<2x2xf32>
+}

diff  --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index a5de1fd4de431..554a7b6db4729 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -374,6 +374,31 @@ struct TestVectorTransferCollapseInnerMostContiguousDims
   }
 };
 
+struct TestSinkVectorBroadcast
+    : public PassWrapper<TestSinkVectorBroadcast, OperationPass<func::FuncOp>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestSinkVectorBroadcast)
+
+  TestSinkVectorBroadcast() = default;
+  TestSinkVectorBroadcast(const TestSinkVectorBroadcast &pass) = default;
+
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<memref::MemRefDialect, affine::AffineDialect>();
+  }
+
+  StringRef getArgument() const final { return "test-sink-vector-broadcast"; }
+
+  StringRef getDescription() const final {
+    return "Test lowering patterns that eliminate redundant brodacast "
+           "operations.";
+  }
+
+  void runOnOperation() override {
+    RewritePatternSet patterns(&getContext());
+    populateSinkVectorBroadcastPatterns(patterns);
+    (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+  }
+};
+
 struct TestVectorReduceToContractPatternsPatterns
     : public PassWrapper<TestVectorReduceToContractPatternsPatterns,
                          OperationPass<func::FuncOp>> {
@@ -735,6 +760,8 @@ void registerTestVectorLowerings() {
 
   PassRegistration<TestVectorTransferCollapseInnerMostContiguousDims>();
 
+  PassRegistration<TestSinkVectorBroadcast>();
+
   PassRegistration<TestVectorReduceToContractPatternsPatterns>();
 
   PassRegistration<TestFlattenVectorTransferPatterns>();


        


More information about the Mlir-commits mailing list