[Mlir-commits] [mlir] [mlir][vector] remove lower_multi_reduction (PR #182332)
Erick Ochoa Lopez
llvmlistbot at llvm.org
Thu Feb 19 10:06:17 PST 2026
https://github.com/amd-eochoalo created https://github.com/llvm/llvm-project/pull/182332
* Removes `ApplyLowerMultiReductionPatternsOp` (`apply_patterns.vector.lower_multi_reduction`)
* Updates uses of `apply_patterns.vector.lower_multi_reduction` in tests to use:
* reorder_and_expand_multi_reduction_dims
* multi_reduction_flattening
* multi_reduction_unrolling
* Removes `populateVectorMultiReductionLoweringPatterns` (unused)
Depends on https://github.com/llvm/llvm-project/pull/182113
>From 79202b30894b14891198c5d000b3a296cab44ff8 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Wed, 18 Feb 2026 11:55:25 -0500
Subject: [PATCH 1/2] Add tests
---
.../Vector/TransformOps/VectorTransformOps.td | 23 ++
.../TransformOps/VectorTransformOps.cpp | 8 +
.../vector-multi-reduction-lowering.mlir | 255 ------------------
...vector-multi-reduction-outer-lowering.mlir | 192 -------------
.../vector-multi-reduction-unrolling.mlir | 156 +++++++++++
.../python/dialects/transform_vector_ext.py | 8 +
6 files changed, 195 insertions(+), 447 deletions(-)
delete mode 100644 mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir
delete mode 100644 mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir
create mode 100644 mlir/test/Dialect/Vector/vector-multi-reduction-unrolling.mlir
diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index 6eb96e2a8fdab..685c88c17e556 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -283,6 +283,29 @@ def ApplyMultiReductionFlatteningPatternsOp: Op<Transform_Dialect,
}];
}
+def ApplyMultiReductionUnrollingPatternsOp: Op<Transform_Dialect,
+ "apply_patterns.vector.multi_reduction_unrolling",
+ [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+ let description = [{
+ Indicates that 2-D vector multi_reduction operations should be unrolled
+ into either a sequence of vector.reduction ops (innerreduction) or
+ element-wise arith ops (innerparallel).
+
+ This populates the patterns from
+ `populateVectorMultiReductionUnrollingPatterns`, i.e.:
+ * `TwoDimMultiReductionToReduction` (innerreduction)
+ * `TwoDimMultiReductionToElementWise` (innerparallel)
+ }];
+
+ let arguments = (ins DefaultValuedAttr<VectorMultiReductionLoweringAttr,
+ "vector::VectorMultiReductionLowering::InnerParallel">:$lowering_strategy
+ );
+
+ let assemblyFormat = [{
+ (`lowering_strategy` `=` $lowering_strategy^)? attr-dict
+ }];
+}
+
def ApplyLowerOuterProductPatternsOp : Op<Transform_Dialect,
"apply_patterns.vector.lower_outerproduct",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index f3529ac26523f..6c97c6501a23e 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -154,6 +154,14 @@ void transform::ApplyMultiReductionFlatteningPatternsOp::populatePatterns(
patterns, vectorTransformOptions.vectorMultiReductionLowering);
}
+void transform::ApplyMultiReductionUnrollingPatternsOp::populatePatterns(
+ RewritePatternSet &patterns) {
+ vector::VectorTransformsOptions vectorTransformOptions;
+ vectorTransformOptions.setVectorMultiReductionLowering(getLoweringStrategy());
+ vector::populateVectorMultiReductionUnrollingPatterns(
+ patterns, vectorTransformOptions.vectorMultiReductionLowering);
+}
+
void transform::ApplyLowerOuterProductPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
populateVectorOuterProductLoweringPatterns(patterns);
diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir
deleted file mode 100644
index 6b79a78e6a42a..0000000000000
--- a/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir
+++ /dev/null
@@ -1,255 +0,0 @@
-// RUN: mlir-opt %s --transform-interpreter | FileCheck %s
-
-// Patterns applied:
-// * TwoDimMultiReductionToReduction from populateVectorMultiReductionUnrollingPatterns
-func.func @vector_multi_reduction(%arg0: vector<2x4xf32>, %acc: vector<2xf32>) -> vector<2xf32> {
- %0 = vector.multi_reduction <mul>, %arg0, %acc [1] : vector<2x4xf32> to vector<2xf32>
- return %0 : vector<2xf32>
-}
-// CHECK-LABEL: func @vector_multi_reduction
-// CHECK-SAME: %[[INPUT:.+]]: vector<2x4xf32>, %[[ACC:.*]]: vector<2xf32>)
-// CHECK-DAG: %[[RESULT_VEC_0:.+]] = arith.constant dense<{{.*}}> : vector<2xf32>
-// CHECK: %[[V0:.+]] = vector.extract %[[INPUT]][0]
-// CHECK: %[[ACC0:.+]] = vector.extract %[[ACC]][0]
-// CHECK: %[[RV0:.+]] = vector.reduction <mul>, %[[V0]], %[[ACC0]] : vector<4xf32> into f32
-// CHECK: %[[RESULT_VEC_1:.+]] = vector.insert %[[RV0:.+]], %[[RESULT_VEC_0]] [0] : f32 into vector<2xf32>
-// CHECK: %[[V1:.+]] = vector.extract %[[INPUT]][1]
-// CHECK: %[[ACC1:.+]] = vector.extract %[[ACC]][1]
-// CHECK: %[[RV1:.+]] = vector.reduction <mul>, %[[V1]], %[[ACC1]] : vector<4xf32> into f32
-// CHECK: %[[RESULT_VEC:.+]] = vector.insert %[[RV1:.+]], %[[RESULT_VEC_1]] [1] : f32 into vector<2xf32>
-// CHECK: return %[[RESULT_VEC]]
-
-// Patterns applied:
-// * ReduceMultiDimReductionRank from populateVectorMultiReductionFlatteningPatterns
-// * TwoDimMultiReductionToReduction from populateVectorMultiReductionUnrollingPatterns
-func.func @vector_reduction_inner(%arg0: vector<2x3x4x5xi32>, %acc: vector<2x3xi32>) -> vector<2x3xi32> {
- %0 = vector.multi_reduction <add>, %arg0, %acc [2, 3] : vector<2x3x4x5xi32> to vector<2x3xi32>
- return %0 : vector<2x3xi32>
-}
-// CHECK-LABEL: func @vector_reduction_inner
-// CHECK-SAME: %[[INPUT:.+]]: vector<2x3x4x5xi32>, %[[ACC:.*]]: vector<2x3xi32>
-// CHECK-DAG: %[[FLAT_RESULT_VEC_0:.+]] = arith.constant dense<0> : vector<6xi32>
-// CHECK: %[[RESHAPED_INPUT:.+]] = vector.shape_cast %[[INPUT]] : vector<2x3x4x5xi32> to vector<6x20xi32>
-// CHECK: %[[V0:.+]] = vector.extract %[[RESHAPED_INPUT]][0] : vector<20xi32> from vector<6x20xi32>
-// CHECK: %[[ACC0:.+]] = vector.extract %[[ACC]][0, 0] : i32 from vector<2x3xi32>
-// CHECK: %[[V0R:.+]] = vector.reduction <add>, %[[V0]], %[[ACC0]] : vector<20xi32> into i32
-// CHECK: %[[FLAT_RESULT_VEC_1:.+]] = vector.insert %[[V0R]], %[[FLAT_RESULT_VEC_0]] [0] : i32 into vector<6xi32>
-// CHECK: %[[V1:.+]] = vector.extract %[[RESHAPED_INPUT]][1] : vector<20xi32> from vector<6x20xi32>
-// CHECK: %[[ACC1:.+]] = vector.extract %[[ACC]][0, 1] : i32 from vector<2x3xi32>
-// CHECK: %[[V1R:.+]] = vector.reduction <add>, %[[V1]], %[[ACC1]] : vector<20xi32> into i32
-// CHECK: %[[FLAT_RESULT_VEC_2:.+]] = vector.insert %[[V1R]], %[[FLAT_RESULT_VEC_1]] [1] : i32 into vector<6xi32>
-// CHECK: %[[V2:.+]] = vector.extract %[[RESHAPED_INPUT]][2] : vector<20xi32> from vector<6x20xi32>
-// CHECK: %[[ACC2:.+]] = vector.extract %[[ACC]][0, 2] : i32 from vector<2x3xi32>
-// CHECK: %[[V2R:.+]] = vector.reduction <add>, %[[V2]], %[[ACC2]] : vector<20xi32> into i32
-// CHECK: %[[FLAT_RESULT_VEC_3:.+]] = vector.insert %[[V2R]], %[[FLAT_RESULT_VEC_2]] [2] : i32 into vector<6xi32>
-// CHECK: %[[V3:.+]] = vector.extract %[[RESHAPED_INPUT]][3] : vector<20xi32> from vector<6x20xi32>
-// CHECK: %[[ACC3:.+]] = vector.extract %[[ACC]][1, 0] : i32 from vector<2x3xi32>
-// CHECK: %[[V3R:.+]] = vector.reduction <add>, %[[V3]], %[[ACC3]] : vector<20xi32> into i32
-// CHECK: %[[FLAT_RESULT_VEC_4:.+]] = vector.insert %[[V3R]], %[[FLAT_RESULT_VEC_3]] [3] : i32 into vector<6xi32>
-// CHECK: %[[V4:.+]] = vector.extract %[[RESHAPED_INPUT]][4] : vector<20xi32> from vector<6x20xi32>
-// CHECK: %[[ACC4:.+]] = vector.extract %[[ACC]][1, 1] : i32 from vector<2x3xi32>
-// CHECK: %[[V4R:.+]] = vector.reduction <add>, %[[V4]], %[[ACC4]] : vector<20xi32> into i32
-// CHECK: %[[FLAT_RESULT_VEC_5:.+]] = vector.insert %[[V4R]], %[[FLAT_RESULT_VEC_4]] [4] : i32 into vector<6xi32>
-// CHECK: %[[V5:.+]] = vector.extract %[[RESHAPED_INPUT]][5] : vector<20xi32> from vector<6x20xi32>
-// CHECK: %[[ACC5:.+]] = vector.extract %[[ACC]][1, 2] : i32 from vector<2x3xi32>
-// CHECK: %[[V5R:.+]] = vector.reduction <add>, %[[V5]], %[[ACC5]] : vector<20xi32> into i32
-// CHECK: %[[FLAT_RESULT_VEC:.+]] = vector.insert %[[V5R]], %[[FLAT_RESULT_VEC_5]] [5] : i32 into vector<6xi32>
-// CHECK: %[[RESULT:.+]] = vector.shape_cast %[[FLAT_RESULT_VEC]] : vector<6xi32> to vector<2x3xi32>
-// CHECK: return %[[RESULT]]
-
-// Patterns applied:
-// * TwoDimMultiReductionToReduction from populateVectorMultiReductionUnrollingPatterns
-func.func @vectorize_dynamic_reduction(%arg0: tensor<?x?xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
- %c0 = arith.constant 0 : index
- %dim = tensor.dim %arg0, %c0 : tensor<?x?xf32>
- %c1 = arith.constant 1 : index
- %dim_0 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
- %c0_1 = arith.constant 0 : index
- %cst = arith.constant 0.000000e+00 : f32
- %0 = vector.create_mask %dim, %dim_0 : vector<4x8xi1>
- %1 = vector.mask %0 { vector.transfer_read %arg0[%c0_1, %c0_1], %cst {in_bounds = [true, true]} : tensor<?x?xf32>, vector<4x8xf32> } : vector<4x8xi1> -> vector<4x8xf32>
- %cst_2 = arith.constant 0.000000e+00 : f32
- %2 = vector.create_mask %dim : vector<4xi1>
- %3 = vector.mask %2 { vector.transfer_read %arg1[%c0_1], %cst_2 {in_bounds = [true]} : tensor<?xf32>, vector<4xf32> } : vector<4xi1> -> vector<4xf32>
- %4 = vector.mask %0 { vector.multi_reduction <add>, %1, %3 [1] : vector<4x8xf32> to vector<4xf32> } : vector<4x8xi1> -> vector<4xf32>
- %c0_3 = arith.constant 0 : index
- %5 = vector.mask %2 { vector.transfer_write %4, %arg1[%c0_3] {in_bounds = [true]} : vector<4xf32>, tensor<?xf32> } : vector<4xi1> -> tensor<?xf32>
- return %5 : tensor<?xf32>
-}
-
-// Verify that the original 2-D mask is sliced and propagated properly to the
-// vector.reduction instances.
-
-// CHECK-LABEL: func.func @vectorize_dynamic_reduction
-// CHECK: %[[VAL_8:.*]] = tensor.dim
-// CHECK: %[[VAL_9:.*]] = tensor.dim
-// CHECK: %[[VAL_10:.*]] = vector.create_mask %[[VAL_8]], %[[VAL_9]] : vector<4x8xi1>
-
-// CHECK: %[[VAL_16:.*]] = vector.extract %[[VAL_10]][0] : vector<8xi1> from vector<4x8xi1>
-// CHECK: %[[VAL_17:.*]] = vector.mask %[[VAL_16]] { vector.reduction <add>, %{{.*}} : vector<8xf32> into f32 } : vector<8xi1> -> f32
-// CHECK: %[[VAL_18:.*]] = vector.insert
-
-// CHECK: %[[VAL_21:.*]] = vector.extract %[[VAL_10]][1] : vector<8xi1> from vector<4x8xi1>
-// CHECK: %[[VAL_22:.*]] = vector.mask %[[VAL_21]] { vector.reduction <add>, %{{.*}} : vector<8xf32> into f32 } : vector<8xi1> -> f32
-// CHECK: %[[VAL_23:.*]] = vector.insert
-
-// CHECK: %[[VAL_26:.*]] = vector.extract %[[VAL_10]][2] : vector<8xi1> from vector<4x8xi1>
-// CHECK: %[[VAL_27:.*]] = vector.mask %[[VAL_26]] { vector.reduction <add>, %{{.*}} : vector<8xf32> into f32 } : vector<8xi1> -> f32
-// CHECK: %[[VAL_28:.*]] = vector.insert
-
-// CHECK: %[[VAL_31:.*]] = vector.extract %[[VAL_10]][3] : vector<8xi1> from vector<4x8xi1>
-// CHECK: %[[VAL_32:.*]] = vector.mask %[[VAL_31]] { vector.reduction <add>, %{{.*}} : vector<8xf32> into f32 } : vector<8xi1> -> f32
-// CHECK: %[[VAL_33:.*]] = vector.insert
-
-// Patterns applied:
-// * OneDimMultiReductionToTwoDim from populateVectorMultiReductionTransformationPatterns
-// * TwoDimMultiReductionToReduction from populateVectorMultiReductionUnrollingPatterns
-func.func @vectorize_1d_dynamic_reduction(%arg0: tensor<?xf32>) -> f32 {
- %c0 = arith.constant 0 : index
- %dim = tensor.dim %arg0, %c0 : tensor<?xf32>
- %c0_1 = arith.constant 0 : index
- %cst = arith.constant 0.000000e+00 : f32
- %0 = vector.create_mask %dim : vector<8xi1>
- %1 = vector.mask %0 { vector.transfer_read %arg0[%c0_1], %cst {in_bounds = [true]} : tensor<?xf32>, vector<8xf32> } : vector<8xi1> -> vector<8xf32>
- %4 = vector.mask %0 { vector.multi_reduction <add>, %1, %cst [0] : vector<8xf32> to f32 } : vector<8xi1> -> f32
- return %4 : f32
-}
-
-// Verify that a 1-D vector.multi_reduction is transformed into a vector.reduction.
-// This transform expands 1-D vectors into 2-D.
-
-// CHECK-LABEL: func.func @vectorize_1d_dynamic_reduction(
-// CHECK: %[[VAL_5:.*]] = vector.create_mask {{.*}} : vector<8xi1>
-// CHECK: %[[VAL_7:.*]] = vector.mask %[[VAL_5]] { vector.reduction <add>, %{{.*}} : vector<8xf32> into f32 } : vector<8xi1> -> f32
-
-
-// Patterns applied:
-// * InnerOuterDimReductionConversion from populateVectorMultiReductionTransformationPatterns
-// * ReduceMultiDimReductionRank from populateVectorMultiReductionFlatteningPatterns
-// * TwoDimMultiReductionToReduction from populateVectorMultiReductionUnrollingPatterns
-func.func @vectorize_dynamic_transpose_reduction(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
- %c0 = arith.constant 0 : index
- %dim = tensor.dim %arg0, %c0 : tensor<?x?x?xf32>
- %c1 = arith.constant 1 : index
- %dim_0 = tensor.dim %arg0, %c1 : tensor<?x?x?xf32>
- %c2 = arith.constant 2 : index
- %dim_1 = tensor.dim %arg0, %c2 : tensor<?x?x?xf32>
- %c0_2 = arith.constant 0 : index
- %cst = arith.constant 0.000000e+00 : f32
- %0 = vector.create_mask %dim, %dim_0, %dim_1 : vector<4x8x16xi1>
- %1 = vector.mask %0 { vector.transfer_read %arg0[%c0_2, %c0_2, %c0_2], %cst {in_bounds = [true, true, true]} : tensor<?x?x?xf32>, vector<4x8x16xf32> } : vector<4x8x16xi1> -> vector<4x8x16xf32>
- %cst_3 = arith.constant 0.000000e+00 : f32
- %2 = vector.create_mask %dim_1, %dim_0 : vector<16x8xi1>
- %3 = vector.mask %2 { vector.transfer_read %arg1[%c0_2, %c0_2], %cst_3 {in_bounds = [true, true], permutation_map = affine_map<(d0, d1) -> (d1, d0)>} : tensor<?x?xf32>, vector<8x16xf32> } : vector<16x8xi1> -> vector<8x16xf32>
- %4 = vector.mask %0 { vector.multi_reduction <add>, %1, %3 [0] : vector<4x8x16xf32> to vector<8x16xf32> } : vector<4x8x16xi1> -> vector<8x16xf32>
- %c0_4 = arith.constant 0 : index
- %5 = vector.mask %2 { vector.transfer_write %4, %arg1[%c0_4, %c0_4] {in_bounds = [true, true], permutation_map = affine_map<(d0, d1) -> (d1, d0)>} : vector<8x16xf32>, tensor<?x?xf32> } : vector<16x8xi1> -> tensor<?x?xf32>
- return %5 : tensor<?x?xf32>
-}
-
-// CHECK-LABEL: func.func @vectorize_dynamic_transpose_reduction
-// CHECK: %[[VAL_6:.*]] = tensor.dim
-// CHECK: %[[VAL_7:.*]] = tensor.dim
-// CHECK: %[[VAL_8:.*]] = tensor.dim
-// CHECK: %[[VAL_135:.*]] = vector.create_mask %{{.*}}, %{{.*}}, %{{.*}} : vector<4x8x16xi1>
-// CHECK: %[[VAL_139:.*]] = vector.transpose %[[VAL_135]], [1, 2, 0] : vector<4x8x16xi1> to vector<8x16x4xi1>
-
-// Just checking a few instances to make sure the vector mask is properly propagated:
-
-// CHECK: %[[VAL_143:.*]] = vector.extract %[[VAL_139]][0, 0] : vector<4xi1> from vector<8x16x4xi1>
-// CHECK: %[[VAL_144:.*]] = vector.mask %[[VAL_143]] { vector.reduction <add>
-// CHECK: %[[VAL_145:.*]] = vector.insert %[[VAL_144]]
-
-// CHECK: %[[VAL_148:.*]] = vector.extract %[[VAL_139]][0, 1] : vector<4xi1> from vector<8x16x4xi1>
-// CHECK: %[[VAL_149:.*]] = vector.mask %[[VAL_148]] { vector.reduction <add>
-// CHECK: %[[VAL_150:.*]] = vector.insert %[[VAL_149]]
-
-// CHECK: %[[VAL_153:.*]] = vector.extract %[[VAL_139]][0, 2] : vector<4xi1> from vector<8x16x4xi1>
-// CHECK: %[[VAL_154:.*]] = vector.mask %[[VAL_153]] { vector.reduction <add>
-// CHECK: %[[VAL_155:.*]] = vector.insert %[[VAL_154]]
-
-// CHECK: %[[VAL_158:.*]] = vector.extract %[[VAL_139]][0, 3] : vector<4xi1> from vector<8x16x4xi1>
-// CHECK: %[[VAL_159:.*]] = vector.mask %[[VAL_158]] { vector.reduction <add>
-// CHECK: %[[VAL_160:.*]] = vector.insert %[[VAL_159]]
-
-// Patterns applied:
-// * ReduceMultiDimReductionRank from populateVectorMultiReductionFlatteningPatterns
-// * TwoDimMultiReductionToReduction from populateVectorMultiReductionUnrollingPatterns
-func.func private @vector_multi_reduction_non_scalable_dim(%A : vector<8x[4]x2xf32>, %B: vector<8x[4]xf32>) -> vector<8x[4]xf32> {
- %0 = vector.multi_reduction <add>, %A, %B [2] : vector<8x[4]x2xf32> to vector<8x[4]xf32>
- return %0 : vector<8x[4]xf32>
-}
-// CHECK-LABEL: func.func private @vector_multi_reduction_non_scalable_dim(
-// CHECK-SAME: %[[VAL_0:.*]]: vector<8x[4]x2xf32>,
-// CHECK-SAME: %[[VAL_1:.*]]: vector<8x[4]xf32>) -> vector<8x[4]xf32> {
-// CHECK-DAG: %[[VAL_2:.*]] = arith.constant dense<0.000000e+00> : vector<[32]xf32>
-
-// CHECK: %[[VAL_35:.*]] = vector.extract %[[VAL_0]][0, 0] : vector<2xf32> from vector<8x[4]x2xf32>
-// CHECK: %[[VAL_36:.*]] = vector.extract %[[VAL_1]][0, 0] : f32 from vector<8x[4]xf32>
-// CHECK: %[[VAL_37:.*]] = vector.reduction <add>, %[[VAL_35]], %[[VAL_36]] : vector<2xf32> into f32
-// CHECK: %[[VAL_38:.*]] = vector.insert %[[VAL_37]], %[[VAL_2]] [0] : f32 into vector<[32]xf32>
-
-// CHECK: %[[VAL_39:.*]] = vector.extract %[[VAL_0]][0, 1] : vector<2xf32> from vector<8x[4]x2xf32>
-// CHECK: %[[VAL_40:.*]] = vector.extract %[[VAL_1]][0, 1] : f32 from vector<8x[4]xf32>
-// CHECK: %[[VAL_41:.*]] = vector.reduction <add>, %[[VAL_39]], %[[VAL_40]] : vector<2xf32> into f32
-// CHECK: %[[VAL_42:.*]] = vector.insert %[[VAL_41]], %[[VAL_38]] [1] : f32 into vector<[32]xf32>
-
-// (...)
-
-// CHECK: %[[VAL_159:.*]] = vector.extract %[[VAL_0]][7, 3] : vector<2xf32> from vector<8x[4]x2xf32>
-// CHECK: %[[VAL_160:.*]] = vector.extract %[[VAL_1]][7, 3] : f32 from vector<8x[4]xf32>
-// CHECK: %[[VAL_161:.*]] = vector.reduction <add>, %[[VAL_159]], %[[VAL_160]] : vector<2xf32> into f32
-// CHECK: %[[VAL_162:.*]] = vector.insert %[[VAL_161]], %{{.*}} [31] : f32 into vector<[32]xf32>
-
-// CHECK: %[[VAL_163:.*]] = vector.shape_cast %[[VAL_162]] : vector<[32]xf32> to vector<8x[4]xf32>
-// CHECK: return %[[VAL_163]] : vector<8x[4]xf32>
-
-// Check that OneDimMultiReductionToTwoDim handles scalable dim
-// Patterns applied:
-// * OneDimMultiReductionToTwoDim from populateVectorMultiReductionTransformationPatterns
-// * TwoDimMultiReductionToReduction from populateVectorMultiReductionUnrollingPatterns
-func.func @vector_multi_reduction_scalable_dim_1d(%A: vector<[4]xf32>, %B: f32, %C: vector<[4]xi1>) -> f32 {
- %0 = vector.mask %C { vector.multi_reduction <add>, %A, %B [0] : vector<[4]xf32> to f32 } : vector<[4]xi1> -> f32
- return %0 : f32
-}
-
-// CHECK-LABEL: func.func @vector_multi_reduction_scalable_dim_1d(
-// CHECK-SAME: %[[ARG_0:.*]]: vector<[4]xf32>,
-// CHECK-SAME: %[[ARG_1:.*]]: f32,
-// CHECK-SAME: %[[ARG_2:.*]]: vector<[4]xi1>) -> f32 {
-// CHECK: %[[VAL_2:.*]] = vector.mask %[[ARG_2]] { vector.reduction <add>, %[[ARG_0]], %[[ARG_1]] : vector<[4]xf32> into f32 } : vector<[4]xi1> -> f32
-// CHECK: return %[[VAL_2]] : f32
-
-// Patterns applied:
-// * TwoDimMultiReductionToReduction from populateVectorMultiReductionUnrollingPatterns
-func.func @vector_multi_reduction_scalable_dim_2d(%A: vector<2x[4]xf32>, %B: vector<2xf32>, %C: vector<2x[4]xi1>) -> vector<2xf32> {
- %0 = vector.mask %C { vector.multi_reduction <add>, %A, %B [1] : vector<2x[4]xf32> to vector<2xf32> } : vector<2x[4]xi1> -> vector<2xf32>
- return %0 : vector<2xf32>
-}
-
-// CHECK-LABEL: func.func @vector_multi_reduction_scalable_dim_2d(
-// CHECK-SAME: %[[ARG_0:.*]]: vector<2x[4]xf32>,
-// CHECK-SAME: %[[ARG_1:.*]]: vector<2xf32>,
-// CHECK-SAME: %[[ARG_2:.*]]: vector<2x[4]xi1>) -> vector<2xf32> {
-// CHECK-DAG: %[[C0_2xf32:.*]] = arith.constant dense<0.000000e+00> : vector<2xf32>
-// CHECK: %[[ARG0_0:.*]] = vector.extract %[[ARG_0]][0] : vector<[4]xf32> from vector<2x[4]xf32>
-// CHECK: %[[ARG1_0:.*]] = vector.extract %[[ARG_1]][0] : f32 from vector<2xf32>
-// CHECK: %[[ARG2_0:.*]] = vector.extract %[[ARG_2]][0] : vector<[4]xi1> from vector<2x[4]xi1>
-// CHECK: %[[REDUCE_0:.*]] = vector.mask %[[ARG2_0]] { vector.reduction <add>, %[[ARG0_0]], %[[ARG1_0]] : vector<[4]xf32> into f32 } : vector<[4]xi1> -> f32
-// CHECK: %[[INSERT_0:.*]] = vector.insert %[[REDUCE_0]], %[[C0_2xf32]] [0] : f32 into vector<2xf32>
-// CHECK: %[[ARG0_1:.*]] = vector.extract %[[ARG_0]][1] : vector<[4]xf32> from vector<2x[4]xf32>
-// CHECK: %[[ARG1_1:.*]] = vector.extract %[[ARG_1]][1] : f32 from vector<2xf32>
-// CHECK: %[[ARG2_1:.*]] = vector.extract %[[ARG_2]][1] : vector<[4]xi1> from vector<2x[4]xi1>
-// CHECK: %[[REDUCE_1:.*]] = vector.mask %[[ARG2_1]] { vector.reduction <add>, %[[ARG0_1]], %[[ARG1_1]] : vector<[4]xf32> into f32 } : vector<[4]xi1> -> f32
-// CHECK: %[[INSERT_1:.*]] = vector.insert %[[REDUCE_1]], %[[INSERT_0]] [1] : f32 into vector<2xf32>
-// CHECK: return %[[INSERT_1]] : vector<2xf32>
-
-module attributes {transform.with_named_sequence} {
- transform.named_sequence @__transform_main(%root : !transform.any_op {transform.readonly}) {
- %func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func">
- transform.apply_patterns to %func_op {
- transform.apply_patterns.vector.lower_multi_reduction lowering_strategy = "innerreduction"
- } : !transform.op<"func.func">
- transform.yield
- }
-}
diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir
deleted file mode 100644
index d0ab71e3f400f..0000000000000
--- a/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir
+++ /dev/null
@@ -1,192 +0,0 @@
-// RUN: mlir-opt %s --transform-interpreter | FileCheck %s
-
-func.func @vector_multi_reduction(%arg0: vector<2x4xf32>, %acc: vector<2xf32>) -> vector<2xf32> {
- %0 = vector.multi_reduction <mul>, %arg0, %acc [1] : vector<2x4xf32> to vector<2xf32>
- return %0 : vector<2xf32>
-}
-
-// CHECK-LABEL: func @vector_multi_reduction
-// CHECK-SAME: %[[INPUT:.+]]: vector<2x4xf32>, %[[ACC:.*]]: vector<2xf32>
-// CHECK: %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xf32> to vector<4x2xf32>
-// CHECK: %[[V0:.+]] = vector.extract %[[TRANSPOSED]][0] : vector<2xf32> from vector<4x2xf32>
-// CHECK: %[[RV0:.+]] = arith.mulf %[[V0]], %[[ACC]] : vector<2xf32>
-// CHECK: %[[V1:.+]] = vector.extract %[[TRANSPOSED]][1] : vector<2xf32> from vector<4x2xf32>
-// CHECK: %[[RV01:.+]] = arith.mulf %[[V1]], %[[RV0]] : vector<2xf32>
-// CHECK: %[[V2:.+]] = vector.extract %[[TRANSPOSED]][2] : vector<2xf32> from vector<4x2xf32>
-// CHECK: %[[RV012:.+]] = arith.mulf %[[V2]], %[[RV01]] : vector<2xf32>
-// CHECK: %[[V3:.+]] = vector.extract %[[TRANSPOSED]][3] : vector<2xf32> from vector<4x2xf32>
-// CHECK: %[[RESULT_VEC:.+]] = arith.mulf %[[V3]], %[[RV012]] : vector<2xf32>
-// CHECK: return %[[RESULT_VEC]] : vector<2xf32>
-
-func.func @vector_multi_reduction_min(%arg0: vector<2x4xf32>, %acc: vector<2xf32>) -> vector<2xf32> {
- %0 = vector.multi_reduction <minnumf>, %arg0, %acc [1] : vector<2x4xf32> to vector<2xf32>
- return %0 : vector<2xf32>
-}
-
-// CHECK-LABEL: func @vector_multi_reduction_min
-// CHECK-SAME: %[[INPUT:.+]]: vector<2x4xf32>, %[[ACC:.*]]: vector<2xf32>
-// CHECK: %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xf32> to vector<4x2xf32>
-// CHECK: %[[V0:.+]] = vector.extract %[[TRANSPOSED]][0] : vector<2xf32> from vector<4x2xf32>
-// CHECK: %[[RV0:.+]] = arith.minnumf %[[V0]], %[[ACC]] : vector<2xf32>
-// CHECK: %[[V1:.+]] = vector.extract %[[TRANSPOSED]][1] : vector<2xf32> from vector<4x2xf32>
-// CHECK: %[[RV01:.+]] = arith.minnumf %[[V1]], %[[RV0]] : vector<2xf32>
-// CHECK: %[[V2:.+]] = vector.extract %[[TRANSPOSED]][2] : vector<2xf32> from vector<4x2xf32>
-// CHECK: %[[RV012:.+]] = arith.minnumf %[[V2]], %[[RV01]] : vector<2xf32>
-// CHECK: %[[V3:.+]] = vector.extract %[[TRANSPOSED]][3] : vector<2xf32> from vector<4x2xf32>
-// CHECK: %[[RESULT_VEC:.+]] = arith.minnumf %[[V3]], %[[RV012]] : vector<2xf32>
-// CHECK: return %[[RESULT_VEC]] : vector<2xf32>
-
-func.func @vector_multi_reduction_max(%arg0: vector<2x4xf32>, %acc: vector<2xf32>) -> vector<2xf32> {
- %0 = vector.multi_reduction <maxnumf>, %arg0, %acc [1] : vector<2x4xf32> to vector<2xf32>
- return %0 : vector<2xf32>
-}
-
-// CHECK-LABEL: func @vector_multi_reduction_max
-// CHECK-SAME: %[[INPUT:.+]]: vector<2x4xf32>, %[[ACC:.*]]: vector<2xf32>
-// CHECK: %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xf32> to vector<4x2xf32>
-// CHECK: %[[V0:.+]] = vector.extract %[[TRANSPOSED]][0] : vector<2xf32> from vector<4x2xf32>
-// CHECK: %[[RV0:.+]] = arith.maxnumf %[[V0]], %[[ACC]] : vector<2xf32>
-// CHECK: %[[V1:.+]] = vector.extract %[[TRANSPOSED]][1] : vector<2xf32> from vector<4x2xf32>
-// CHECK: %[[RV01:.+]] = arith.maxnumf %[[V1]], %[[RV0]] : vector<2xf32>
-// CHECK: %[[V2:.+]] = vector.extract %[[TRANSPOSED]][2] : vector<2xf32> from vector<4x2xf32>
-// CHECK: %[[RV012:.+]] = arith.maxnumf %[[V2]], %[[RV01]] : vector<2xf32>
-// CHECK: %[[V3:.+]] = vector.extract %[[TRANSPOSED]][3] : vector<2xf32> from vector<4x2xf32>
-// CHECK: %[[RESULT_VEC:.+]] = arith.maxnumf %[[V3]], %[[RV012]] : vector<2xf32>
-// CHECK: return %[[RESULT_VEC]] : vector<2xf32>
-
-func.func @vector_multi_reduction_and(%arg0: vector<2x4xi32>, %acc: vector<2xi32>) -> vector<2xi32> {
- %0 = vector.multi_reduction <and>, %arg0, %acc [1] : vector<2x4xi32> to vector<2xi32>
- return %0 : vector<2xi32>
-}
-
-// CHECK-LABEL: func @vector_multi_reduction_and
-// CHECK-SAME: %[[INPUT:.+]]: vector<2x4xi32>, %[[ACC:.*]]: vector<2xi32>
-// CHECK: %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xi32> to vector<4x2xi32>
-// CHECK: %[[V0:.+]] = vector.extract %[[TRANSPOSED]][0] : vector<2xi32> from vector<4x2xi32>
-// CHECK: %[[RV0:.+]] = arith.andi %[[V0]], %[[ACC]] : vector<2xi32>
-// CHECK: %[[V1:.+]] = vector.extract %[[TRANSPOSED]][1] : vector<2xi32> from vector<4x2xi32>
-// CHECK: %[[RV01:.+]] = arith.andi %[[V1]], %[[RV0]] : vector<2xi32>
-// CHECK: %[[V2:.+]] = vector.extract %[[TRANSPOSED]][2] : vector<2xi32> from vector<4x2xi32>
-// CHECK: %[[RV012:.+]] = arith.andi %[[V2]], %[[RV01]] : vector<2xi32>
-// CHECK: %[[V3:.+]] = vector.extract %[[TRANSPOSED]][3] : vector<2xi32> from vector<4x2xi32>
-// CHECK: %[[RESULT_VEC:.+]] = arith.andi %[[V3]], %[[RV012]] : vector<2xi32>
-// CHECK: return %[[RESULT_VEC]] : vector<2xi32>
-
-func.func @vector_multi_reduction_or(%arg0: vector<2x4xi32>, %acc: vector<2xi32>) -> vector<2xi32> {
- %0 = vector.multi_reduction <or>, %arg0, %acc [1] : vector<2x4xi32> to vector<2xi32>
- return %0 : vector<2xi32>
-}
-
-// CHECK-LABEL: func @vector_multi_reduction_or
-// CHECK-SAME: %[[INPUT:.+]]: vector<2x4xi32>, %[[ACC:.*]]: vector<2xi32>
-// CHECK: %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xi32> to vector<4x2xi32>
-// CHECK: %[[V0:.+]] = vector.extract %[[TRANSPOSED]][0] : vector<2xi32> from vector<4x2xi32>
-// CHECK: %[[RV0:.+]] = arith.ori %[[V0]], %[[ACC]] : vector<2xi32>
-// CHECK: %[[V1:.+]] = vector.extract %[[TRANSPOSED]][1] : vector<2xi32> from vector<4x2xi32>
-// CHECK: %[[RV01:.+]] = arith.ori %[[V1]], %[[RV0]] : vector<2xi32>
-// CHECK: %[[V2:.+]] = vector.extract %[[TRANSPOSED]][2] : vector<2xi32> from vector<4x2xi32>
-// CHECK: %[[RV012:.+]] = arith.ori %[[V2]], %[[RV01]] : vector<2xi32>
-// CHECK: %[[V3:.+]] = vector.extract %[[TRANSPOSED]][3] : vector<2xi32> from vector<4x2xi32>
-// CHECK: %[[RESULT_VEC:.+]] = arith.ori %[[V3]], %[[RV012]] : vector<2xi32>
-// CHECK: return %[[RESULT_VEC]] : vector<2xi32>
-
-func.func @vector_multi_reduction_xor(%arg0: vector<2x4xi32>, %acc: vector<2xi32>) -> vector<2xi32> {
- %0 = vector.multi_reduction <xor>, %arg0, %acc [1] : vector<2x4xi32> to vector<2xi32>
- return %0 : vector<2xi32>
-}
-
-// CHECK-LABEL: func @vector_multi_reduction_xor
-// CHECK-SAME: %[[INPUT:.+]]: vector<2x4xi32>, %[[ACC:.*]]: vector<2xi32>
-// CHECK: %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xi32> to vector<4x2xi32>
-// CHECK: %[[V0:.+]] = vector.extract %[[TRANSPOSED]][0] : vector<2xi32> from vector<4x2xi32>
-// CHECK: %[[RV0:.+]] = arith.xori %[[V0]], %[[ACC]] : vector<2xi32>
-// CHECK: %[[V1:.+]] = vector.extract %[[TRANSPOSED]][1] : vector<2xi32> from vector<4x2xi32>
-// CHECK: %[[RV01:.+]] = arith.xori %[[V1]], %[[RV0]] : vector<2xi32>
-// CHECK: %[[V2:.+]] = vector.extract %[[TRANSPOSED]][2] : vector<2xi32> from vector<4x2xi32>
-// CHECK: %[[RV012:.+]] = arith.xori %[[V2]], %[[RV01]] : vector<2xi32>
-// CHECK: %[[V3:.+]] = vector.extract %[[TRANSPOSED]][3] : vector<2xi32> from vector<4x2xi32>
-// CHECK: %[[RESULT_VEC:.+]] = arith.xori %[[V3]], %[[RV012]] : vector<2xi32>
-// CHECK: return %[[RESULT_VEC]] : vector<2xi32>
-
-
-func.func @vector_reduction_outer(%arg0: vector<2x3x4x5xi32>, %acc: vector<2x3xi32>) -> vector<2x3xi32> {
- %0 = vector.multi_reduction <add>, %arg0, %acc [2, 3] : vector<2x3x4x5xi32> to vector<2x3xi32>
- return %0 : vector<2x3xi32>
-}
-
-// CHECK-LABEL: func @vector_reduction_outer
-// CHECK-SAME: %[[INPUT:.+]]: vector<2x3x4x5xi32>, %[[ACC:.*]]: vector<2x3xi32>
-// CHECK: %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [2, 3, 0, 1] : vector<2x3x4x5xi32> to vector<4x5x2x3xi32>
-// CHECK: %[[RESHAPED:.+]] = vector.shape_cast %[[TRANSPOSED]] : vector<4x5x2x3xi32> to vector<20x6xi32>
-// CHECK: %[[FACC:.+]] = vector.shape_cast %[[ACC]] : vector<2x3xi32> to vector<6xi32>
-// CHECK: %[[V0:.+]] = vector.extract %[[RESHAPED]][0] : vector<6xi32> from vector<20x6xi32>
-// CHECK: %[[R:.+]] = arith.addi %[[V0]], %[[FACC]] : vector<6xi32>
-// CHECK: %[[V1:.+]] = vector.extract %[[RESHAPED]][1] : vector<6xi32> from vector<20x6xi32>
-// CHECK: %[[R0:.+]] = arith.addi %[[V1]], %[[R]] : vector<6xi32>
-// CHECK: %[[V2:.+]] = vector.extract %[[RESHAPED]][2] : vector<6xi32> from vector<20x6xi32>
-// CHECK: %[[R1:.+]] = arith.addi %[[V2]], %[[R0]] : vector<6xi32>
-// CHECK: %[[V3:.+]] = vector.extract %[[RESHAPED]][3] : vector<6xi32> from vector<20x6xi32>
-// CHECK: %[[R2:.+]] = arith.addi %[[V3]], %[[R1]] : vector<6xi32>
-// CHECK: %[[V4:.+]] = vector.extract %[[RESHAPED]][4] : vector<6xi32> from vector<20x6xi32>
-// CHECK: %[[R3:.+]] = arith.addi %[[V4]], %[[R2]] : vector<6xi32>
-// CHECK: %[[V5:.+]] = vector.extract %[[RESHAPED]][5] : vector<6xi32> from vector<20x6xi32>
-// CHECK: %[[R4:.+]] = arith.addi %[[V5]], %[[R3]] : vector<6xi32>
-// CHECK: %[[V6:.+]] = vector.extract %[[RESHAPED]][6] : vector<6xi32> from vector<20x6xi32>
-// CHECK: %[[R5:.+]] = arith.addi %[[V6]], %[[R4]] : vector<6xi32>
-// CHECK: %[[V7:.+]] = vector.extract %[[RESHAPED]][7] : vector<6xi32> from vector<20x6xi32>
-// CHECK: %[[R6:.+]] = arith.addi %[[V7]], %[[R5]] : vector<6xi32>
-// CHECK: %[[V8:.+]] = vector.extract %[[RESHAPED]][8] : vector<6xi32> from vector<20x6xi32>
-// CHECK: %[[R7:.+]] = arith.addi %[[V8]], %[[R6]] : vector<6xi32>
-// CHECK: %[[V9:.+]] = vector.extract %[[RESHAPED]][9] : vector<6xi32> from vector<20x6xi32>
-// CHECK: %[[R8:.+]] = arith.addi %[[V9]], %[[R7]] : vector<6xi32>
-// CHECK: %[[V10:.+]] = vector.extract %[[RESHAPED]][10] : vector<6xi32> from vector<20x6xi32>
-// CHECK: %[[R9:.+]] = arith.addi %[[V10]], %[[R8]] : vector<6xi32>
-// CHECK: %[[V11:.+]] = vector.extract %[[RESHAPED]][11] : vector<6xi32> from vector<20x6xi32>
-// CHECK: %[[R10:.+]] = arith.addi %[[V11]], %[[R9]] : vector<6xi32>
-// CHECK: %[[V12:.+]] = vector.extract %[[RESHAPED]][12] : vector<6xi32> from vector<20x6xi32>
-// CHECK: %[[R11:.+]] = arith.addi %[[V12]], %[[R10]] : vector<6xi32>
-// CHECK: %[[V13:.+]] = vector.extract %[[RESHAPED]][13] : vector<6xi32> from vector<20x6xi32>
-// CHECK: %[[R12:.+]] = arith.addi %[[V13]], %[[R11]] : vector<6xi32>
-// CHECK: %[[V14:.+]] = vector.extract %[[RESHAPED]][14] : vector<6xi32> from vector<20x6xi32>
-// CHECK: %[[R13:.+]] = arith.addi %[[V14]], %[[R12]] : vector<6xi32>
-// CHECK: %[[V15:.+]] = vector.extract %[[RESHAPED]][15] : vector<6xi32> from vector<20x6xi32>
-// CHECK: %[[R14:.+]] = arith.addi %[[V15]], %[[R13]] : vector<6xi32>
-// CHECK: %[[V16:.+]] = vector.extract %[[RESHAPED]][16] : vector<6xi32> from vector<20x6xi32>
-// CHECK: %[[R15:.+]] = arith.addi %[[V16]], %[[R14]] : vector<6xi32>
-// CHECK: %[[V17:.+]] = vector.extract %[[RESHAPED]][17] : vector<6xi32> from vector<20x6xi32>
-// CHECK: %[[R16:.+]] = arith.addi %[[V17]], %[[R15]] : vector<6xi32>
-// CHECK: %[[V18:.+]] = vector.extract %[[RESHAPED]][18] : vector<6xi32> from vector<20x6xi32>
-// CHECK: %[[R17:.+]] = arith.addi %[[V18]], %[[R16]] : vector<6xi32>
-// CHECK: %[[V19:.+]] = vector.extract %[[RESHAPED]][19] : vector<6xi32> from vector<20x6xi32>
-// CHECK: %[[R18:.+]] = arith.addi %[[V19]], %[[R17]] : vector<6xi32>
-// CHECK: %[[RESULT_VEC:.+]] = vector.shape_cast %[[R18]] : vector<6xi32> to vector<2x3xi32>
-// CHECK: return %[[RESULT_VEC]] : vector<2x3xi32>
-
-func.func @vector_multi_reduction_parallel_middle(%arg0: vector<3x4x5xf32>, %acc: vector<4xf32>) -> vector<4xf32> {
- %0 = vector.multi_reduction <add>, %arg0, %acc [0, 2] : vector<3x4x5xf32> to vector<4xf32>
- return %0 : vector<4xf32>
-}
-
-// CHECK-LABEL: func @vector_multi_reduction_parallel_middle
-// CHECK-SAME: %[[INPUT:.+]]: vector<3x4x5xf32>, %[[ACC:.+]]: vector<4xf32>
-// CHECK: vector.transpose %[[INPUT]], [0, 2, 1] : vector<3x4x5xf32> to vector<3x5x4xf32>
-
-// This test is mainly to catch a bug that running
-// `InnerOuterDimReductionConversion` on this function results in an
-// infinite loop. So just check that some value is returned.
-func.func @vector_reduction_1D(%arg0 : vector<2xf32>, %acc: f32) -> f32 {
- %0 = vector.multi_reduction #vector.kind<maxnumf>, %arg0, %acc [0] : vector<2xf32> to f32
- return %0 : f32
-}
-// CHECK-LABEL: func @vector_reduction_1D
-// CHECK: return %{{.+}}
-
-module attributes {transform.with_named_sequence} {
- transform.named_sequence @__transform_main(%root : !transform.any_op {transform.readonly}) {
- %func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func">
- transform.apply_patterns to %func_op {
- transform.apply_patterns.vector.lower_multi_reduction lowering_strategy = "innerparallel"
- } : !transform.op<"func.func">
- transform.yield
- }
-}
diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-unrolling.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-unrolling.mlir
new file mode 100644
index 0000000000000..d4fb79a1d4668
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-multi-reduction-unrolling.mlir
@@ -0,0 +1,156 @@
+// RUN: mlir-opt %s --transform-interpreter='entry-point=innerreduction' | FileCheck %s --check-prefixes=INNER_REDUCTION,ALL
+// RUN: mlir-opt %s --transform-interpreter='entry-point=innerparallel' | FileCheck %s --check-prefixes=INNER_PARALLEL,ALL
+
+// ALL-LABEL: func @negative_rank1_and_rank3
+func.func @negative_rank1_and_rank3(
+ %rank1: vector<8xf32>, %rank1_acc: f32,
+ %rank3: vector<2x3x4xf32>, %rank3_acc: vector<2x3xf32>) -> (f32, vector<2x3xf32>) {
+ // ALL: vector.multi_reduction <add>, {{.+}} [0] : vector<8xf32> to f32
+ %0 = vector.multi_reduction <add>, %rank1, %rank1_acc [0] : vector<8xf32> to f32
+ // ALL: vector.multi_reduction <add>, {{.+}} [2] : vector<2x3x4xf32> to vector<2x3xf32>
+ %1 = vector.multi_reduction <add>, %rank3, %rank3_acc [2] : vector<2x3x4xf32> to vector<2x3xf32>
+ return %0, %1 : f32, vector<2x3xf32>
+}
+
+// ALL-LABEL: func @inner_reduction_2d
+// ALL-SAME: %[[INPUT:.+]]: vector<2x4xf32>, %[[ACC:.+]]: vector<2xf32>
+func.func @inner_reduction_2d(%arg0: vector<2x4xf32>, %acc: vector<2xf32>) -> vector<2xf32> {
+ // INNER_REDUCTION: %[[RESULT_VEC_0:.+]] = arith.constant dense<{{.+}}> : vector<2xf32>
+ // INNER_REDUCTION: %[[V0:.+]] = vector.extract %[[INPUT]][0]
+ // INNER_REDUCTION: %[[ACC0:.+]] = vector.extract %[[ACC]][0]
+ // INNER_REDUCTION: %[[RV0:.+]] = vector.reduction <mul>, %[[V0]], %[[ACC0]] : vector<4xf32> into f32
+ // INNER_REDUCTION: %[[RESULT_VEC_1:.+]] = vector.insert %[[RV0]], %[[RESULT_VEC_0]] [0] : f32 into vector<2xf32>
+ // INNER_REDUCTION: %[[V1:.+]] = vector.extract %[[INPUT]][1]
+ // INNER_REDUCTION: %[[ACC1:.+]] = vector.extract %[[ACC]][1]
+ // INNER_REDUCTION: %[[RV1:.+]] = vector.reduction <mul>, %[[V1]], %[[ACC1]] : vector<4xf32> into f32
+ // INNER_REDUCTION: %[[RESULT:.+]] = vector.insert %[[RV1]], %[[RESULT_VEC_1]] [1] : f32 into vector<2xf32>
+
+ // INNER_PARALLEL: %[[RESULT:.+]] = vector.multi_reduction <mul>, %[[INPUT]], %[[ACC]] [1]
+ %0 = vector.multi_reduction <mul>, %arg0, %acc [1] : vector<2x4xf32> to vector<2xf32>
+ // ALL: return %[[RESULT]]
+ return %0 : vector<2xf32>
+}
+
+func.func @inner_reduction_2d_masked_dynamic(%arg0: tensor<?x?xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
+ %c0 = arith.constant 0 : index
+ %dim = tensor.dim %arg0, %c0 : tensor<?x?xf32>
+ %c1 = arith.constant 1 : index
+ %dim_0 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
+ %c0_1 = arith.constant 0 : index
+ %cst = arith.constant 0.000000e+00 : f32
+ %0 = vector.create_mask %dim, %dim_0 : vector<4x8xi1>
+ %1 = vector.mask %0 { vector.transfer_read %arg0[%c0_1, %c0_1], %cst {in_bounds = [true, true]} : tensor<?x?xf32>, vector<4x8xf32> } : vector<4x8xi1> -> vector<4x8xf32>
+ %cst_2 = arith.constant 0.000000e+00 : f32
+ %2 = vector.create_mask %dim : vector<4xi1>
+ %3 = vector.mask %2 { vector.transfer_read %arg1[%c0_1], %cst_2 {in_bounds = [true]} : tensor<?xf32>, vector<4xf32> } : vector<4xi1> -> vector<4xf32>
+ %4 = vector.mask %0 { vector.multi_reduction <add>, %1, %3 [1] : vector<4x8xf32> to vector<4xf32> } : vector<4x8xi1> -> vector<4xf32>
+ %c0_3 = arith.constant 0 : index
+ %5 = vector.mask %2 { vector.transfer_write %4, %arg1[%c0_3] {in_bounds = [true]} : vector<4xf32>, tensor<?xf32> } : vector<4xi1> -> tensor<?xf32>
+ return %5 : tensor<?xf32>
+}
+
+// ALL-LABEL: func @inner_reduction_2d_masked_dynamic
+// INNER_REDUCTION: %[[DIM_0:.+]] = tensor.dim
+// INNER_REDUCTION: %[[DIM_1:.+]] = tensor.dim
+// INNER_REDUCTION: %[[MASK_2D:.+]] = vector.create_mask %[[DIM_0]], %[[DIM_1]] : vector<4x8xi1>
+//
+// INNER_REDUCTION: %[[MASK_SLICE_0:.+]] = vector.extract %[[MASK_2D]][0] : vector<8xi1> from vector<4x8xi1>
+// INNER_REDUCTION: %[[REDUCE_0:.+]] = vector.mask %[[MASK_SLICE_0]] { vector.reduction <add>, %{{.+}} : vector<8xf32> into f32 } : vector<8xi1> -> f32
+// INNER_REDUCTION: %[[INSERT_0:.+]] = vector.insert
+//
+// INNER_REDUCTION: %[[MASK_SLICE_1:.+]] = vector.extract %[[MASK_2D]][1] : vector<8xi1> from vector<4x8xi1>
+// INNER_REDUCTION: %[[REDUCE_1:.+]] = vector.mask %[[MASK_SLICE_1]] { vector.reduction <add>, %{{.+}} : vector<8xf32> into f32 } : vector<8xi1> -> f32
+// INNER_REDUCTION: %[[INSERT_1:.+]] = vector.insert
+//
+// INNER_REDUCTION: %[[MASK_SLICE_2:.+]] = vector.extract %[[MASK_2D]][2] : vector<8xi1> from vector<4x8xi1>
+// INNER_REDUCTION: %[[REDUCE_2:.+]] = vector.mask %[[MASK_SLICE_2]] { vector.reduction <add>, %{{.+}} : vector<8xf32> into f32 } : vector<8xi1> -> f32
+// INNER_REDUCTION: %[[INSERT_2:.+]] = vector.insert
+//
+// INNER_REDUCTION: %[[MASK_SLICE_3:.+]] = vector.extract %[[MASK_2D]][3] : vector<8xi1> from vector<4x8xi1>
+// INNER_REDUCTION: %[[REDUCE_3:.+]] = vector.mask %[[MASK_SLICE_3]] { vector.reduction <add>, %{{.+}} : vector<8xf32> into f32 } : vector<8xi1> -> f32
+// INNER_REDUCTION: %[[INSERT_3:.+]] = vector.insert
+//
+// INNER_PARALLEL: vector.multi_reduction <add>
+
+// ALL-LABEL: func @inner_reduction_2d_scalable
+// ALL-SAME: %[[INPUT:.+]]: vector<2x[4]xf32>
+// ALL-SAME: %[[ACC:.+]]: vector<2xf32>
+// ALL-SAME: %[[MASK:.+]]: vector<2x[4]xi1>
+func.func @inner_reduction_2d_scalable(%input: vector<2x[4]xf32>, %acc: vector<2xf32>, %mask: vector<2x[4]xi1>) -> vector<2xf32> {
+ // INNER_REDUCTION: %[[INIT:.+]] = arith.constant dense<0.000000e+00> : vector<2xf32>
+ // INNER_REDUCTION: %[[INPUT_0:.+]] = vector.extract %[[INPUT]][0] : vector<[4]xf32> from vector<2x[4]xf32>
+ // INNER_REDUCTION: %[[ACC_0:.+]] = vector.extract %[[ACC]][0] : f32 from vector<2xf32>
+ // INNER_REDUCTION: %[[MASK_0:.+]] = vector.extract %[[MASK]][0] : vector<[4]xi1> from vector<2x[4]xi1>
+ // INNER_REDUCTION: %[[REDUCE_0:.+]] = vector.mask %[[MASK_0]] { vector.reduction <add>, %[[INPUT_0]], %[[ACC_0]] : vector<[4]xf32> into f32 } : vector<[4]xi1> -> f32
+ // INNER_REDUCTION: %[[INSERT_0:.+]] = vector.insert %[[REDUCE_0]], %[[INIT]] [0] : f32 into vector<2xf32>
+ // INNER_REDUCTION: %[[INPUT_1:.+]] = vector.extract %[[INPUT]][1] : vector<[4]xf32> from vector<2x[4]xf32>
+ // INNER_REDUCTION: %[[ACC_1:.+]] = vector.extract %[[ACC]][1] : f32 from vector<2xf32>
+ // INNER_REDUCTION: %[[MASK_1:.+]] = vector.extract %[[MASK]][1] : vector<[4]xi1> from vector<2x[4]xi1>
+ // INNER_REDUCTION: %[[REDUCE_1:.+]] = vector.mask %[[MASK_1]] { vector.reduction <add>, %[[INPUT_1]], %[[ACC_1]] : vector<[4]xf32> into f32 } : vector<[4]xi1> -> f32
+ // INNER_REDUCTION: %[[RESULT:.+]] = vector.insert %[[REDUCE_1]], %[[INSERT_0]] [1] : f32 into vector<2xf32>
+
+ // INNER_PARALLEL: %[[RESULT:.+]] = vector.mask %[[MASK]] { vector.multi_reduction <add>, %[[INPUT]], %[[ACC]] [1] {{.+}} } : vector<2x[4]xi1> -> vector<2xf32>
+ // ALL: return %[[RESULT]] : vector<2xf32>
+ %0 = vector.mask %mask { vector.multi_reduction <add>, %input, %acc [1] : vector<2x[4]xf32> to vector<2xf32> } : vector<2x[4]xi1> -> vector<2xf32>
+ return %0 : vector<2xf32>
+}
+
+// ALL-LABEL: func @inner_parallel_2d
+// ALL-SAME: %[[INPUT:.+]]: vector<4x2xf32>, %[[ACC:.+]]: vector<2xf32>
+func.func @inner_parallel_2d(%arg0: vector<4x2xf32>, %acc: vector<2xf32>) -> vector<2xf32> {
+ // INNER_PARALLEL: %[[V0:.+]] = vector.extract %[[INPUT]][0] : vector<2xf32> from vector<4x2xf32>
+ // INNER_PARALLEL: %[[RV0:.+]] = arith.mulf %[[V0]], %[[ACC]] : vector<2xf32>
+ // INNER_PARALLEL: %[[V1:.+]] = vector.extract %[[INPUT]][1] : vector<2xf32> from vector<4x2xf32>
+ // INNER_PARALLEL: %[[RV1:.+]] = arith.mulf %[[V1]], %[[RV0]] : vector<2xf32>
+ // INNER_PARALLEL: %[[V2:.+]] = vector.extract %[[INPUT]][2] : vector<2xf32> from vector<4x2xf32>
+ // INNER_PARALLEL: %[[RV2:.+]] = arith.mulf %[[V2]], %[[RV1]] : vector<2xf32>
+ // INNER_PARALLEL: %[[V3:.+]] = vector.extract %[[INPUT]][3] : vector<2xf32> from vector<4x2xf32>
+ // INNER_PARALLEL: %[[RESULT:.+]] = arith.mulf %[[V3]], %[[RV2]] : vector<2xf32>
+ // INNER_REDUCTION: %[[RESULT:.+]] = vector.multi_reduction <mul>, %[[INPUT]], %[[ACC]] [0]
+ // ALL: return %[[RESULT]] : vector<2xf32>
+ %0 = vector.multi_reduction <mul>, %arg0, %acc [0] : vector<4x2xf32> to vector<2xf32>
+ return %0 : vector<2xf32>
+}
+
+// ALL-LABEL: func @inner_parallel_2d_masked
+// ALL-SAME: %[[INPUT:.+]]: vector<4x2xf32>, %[[ACC:.+]]: vector<2xf32>, %[[MASK:.+]]: vector<4x2xi1>
+func.func @inner_parallel_2d_masked(%arg0: vector<4x2xf32>, %acc: vector<2xf32>, %mask: vector<4x2xi1>) -> vector<2xf32> {
+ // INNER_PARALLEL: %[[V0:.+]] = vector.extract %[[INPUT]][0] : vector<2xf32> from vector<4x2xf32>
+ // INNER_PARALLEL: %[[M0:.+]] = vector.extract %[[MASK]][0] : vector<2xi1> from vector<4x2xi1>
+ // INNER_PARALLEL: %[[RED0:.+]] = arith.mulf %[[V0]], %[[ACC]] : vector<2xf32>
+ // INNER_PARALLEL: %[[RV0:.+]] = arith.select %[[M0]], %[[RED0]], %[[ACC]] : vector<2xi1>, vector<2xf32>
+ // INNER_PARALLEL: %[[V1:.+]] = vector.extract %[[INPUT]][1] : vector<2xf32> from vector<4x2xf32>
+ // INNER_PARALLEL: %[[M1:.+]] = vector.extract %[[MASK]][1] : vector<2xi1> from vector<4x2xi1>
+ // INNER_PARALLEL: %[[RED1:.+]] = arith.mulf %[[V1]], %[[RV0]] : vector<2xf32>
+ // INNER_PARALLEL: %[[RV1:.+]] = arith.select %[[M1]], %[[RED1]], %[[RV0]] : vector<2xi1>, vector<2xf32>
+ // INNER_PARALLEL: %[[V2:.+]] = vector.extract %[[INPUT]][2] : vector<2xf32> from vector<4x2xf32>
+ // INNER_PARALLEL: %[[M2:.+]] = vector.extract %[[MASK]][2] : vector<2xi1> from vector<4x2xi1>
+ // INNER_PARALLEL: %[[RED2:.+]] = arith.mulf %[[V2]], %[[RV1]] : vector<2xf32>
+ // INNER_PARALLEL: %[[RV2:.+]] = arith.select %[[M2]], %[[RED2]], %[[RV1]] : vector<2xi1>, vector<2xf32>
+ // INNER_PARALLEL: %[[V3:.+]] = vector.extract %[[INPUT]][3] : vector<2xf32> from vector<4x2xf32>
+ // INNER_PARALLEL: %[[M3:.+]] = vector.extract %[[MASK]][3] : vector<2xi1> from vector<4x2xi1>
+ // INNER_PARALLEL: %[[RED3:.+]] = arith.mulf %[[V3]], %[[RV2]] : vector<2xf32>
+ // INNER_PARALLEL: %[[RESULT:.+]] = arith.select %[[M3]], %[[RED3]], %[[RV2]] : vector<2xi1>, vector<2xf32>
+ // INNER_REDUCTION: %[[RESULT:.+]] = vector.mask %[[MASK]] { vector.multi_reduction <mul>, %[[INPUT]], %[[ACC]] [0] {{.+}} } : vector<4x2xi1> -> vector<2xf32>
+ // ALL: return %[[RESULT]] : vector<2xf32>
+ %0 = vector.mask %mask { vector.multi_reduction <mul>, %arg0, %acc [0] : vector<4x2xf32> to vector<2xf32> } : vector<4x2xi1> -> vector<2xf32>
+ return %0 : vector<2xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @innerreduction(%root : !transform.any_op {transform.readonly}) {
+ %func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func">
+ transform.apply_patterns to %func_op {
+ transform.apply_patterns.vector.multi_reduction_unrolling lowering_strategy = "innerreduction"
+ } : !transform.op<"func.func">
+ transform.yield
+ }
+
+ transform.named_sequence @innerparallel(%root : !transform.any_op {transform.readonly}) {
+ %func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func">
+ transform.apply_patterns to %func_op {
+ transform.apply_patterns.vector.multi_reduction_unrolling lowering_strategy = "innerparallel"
+ } : !transform.op<"func.func">
+ transform.yield
+ }
+}
diff --git a/mlir/test/python/dialects/transform_vector_ext.py b/mlir/test/python/dialects/transform_vector_ext.py
index 29ce8ba63cd53..76e39864fea9f 100644
--- a/mlir/test/python/dialects/transform_vector_ext.py
+++ b/mlir/test/python/dialects/transform_vector_ext.py
@@ -116,6 +116,14 @@ def enum_configurable_patterns():
lowering_strategy=vector.VectorMultiReductionLowering.InnerReduction
)
+ # CHECK: transform.apply_patterns.vector.multi_reduction_unrolling
+ vector.ApplyMultiReductionUnrollingPatternsOp()
+ # CHECK: transform.apply_patterns.vector.multi_reduction_unrolling
+ # CHECK-SAME: lowering_strategy = innerreduction
+ vector.ApplyMultiReductionUnrollingPatternsOp(
+ lowering_strategy=vector.VectorMultiReductionLowering.InnerReduction
+ )
+
# CHECK: transform.apply_patterns.vector.lower_transpose
vector.ApplyLowerTransposePatternsOp()
# CHECK: transform.apply_patterns.vector.lower_transpose
>From 758591ae60e31aa133dac23088537721e613e2d2 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Thu, 19 Feb 2026 11:59:22 -0500
Subject: [PATCH 2/2] [mlir][vector] Remove lower_multi_reduction
---
.../Vector/TransformOps/VectorTransformOps.td | 20 -------------------
.../Vector/Transforms/LoweringPatterns.h | 15 --------------
.../TransformOps/VectorTransformOps.cpp | 12 -----------
.../Transforms/LowerVectorMultiReduction.cpp | 9 ---------
mlir/test/Dialect/LLVM/transform-e2e.mlir | 4 +++-
.../test/Dialect/Vector/transform-vector.mlir | 4 +++-
.../Dialect/Linalg/CPU/ArmSVE/reduce_1d.mlir | 4 +++-
.../Dialect/Linalg/CPU/ArmSVE/reduce_2d.mlir | 4 +++-
.../Linalg/CPU/test-matmul-masked-vec.mlir | 4 +++-
.../python/dialects/transform_vector_ext.py | 13 ------------
10 files changed, 15 insertions(+), 74 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index 685c88c17e556..9fec5804d0b3b 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -223,26 +223,6 @@ def ApplyMaterializeMasksPatternsOp : Op<Transform_Dialect,
let assemblyFormat = "attr-dict";
}
-def ApplyLowerMultiReductionPatternsOp : Op<Transform_Dialect,
- "apply_patterns.vector.lower_multi_reduction",
- [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
- let description = [{
- Indicates that vector multi_reduction-like operations should be lowered to
- finer-grained vector primitives.
-
- This is usually a late step that is run after bufferization as part of the
- process of lowering to e.g. LLVM or NVVM.
- }];
-
- let arguments = (ins DefaultValuedAttr<VectorMultiReductionLoweringAttr,
- "vector::VectorMultiReductionLowering::InnerParallel">:$lowering_strategy
- );
-
- let assemblyFormat = [{
- (`lowering_strategy` `=` $lowering_strategy^)? attr-dict
- }];
-}
-
def ApplyReorderAndExpandMultiReductionPatternsOp: Op<Transform_Dialect,
"apply_patterns.vector.reorder_and_expand_multi_reduction_dims",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index 33487a9d8d6e0..a933f68732a4d 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -103,21 +103,6 @@ void populateVectorMultiReductionUnrollingPatterns(
RewritePatternSet &patterns, VectorMultiReductionLowering options,
PatternBenefit benefit = 1);
-/// Collect a set of patterns to convert vector.multi_reduction op into
-/// a sequence of vector.reduction ops. These patterns are the ones
-/// populated by:
-///
-/// * populateVectorMultiReductionReorderAndExpandPatterns
-/// * populateVectorMultiReductionFlatteningPatterns
-/// * populateVectorMultiReductionUnrollingPatterns
-///
-/// This is just a convenience wrapper that we use in testing and is effectively
-/// deprecated.
-/// TODO: Delete.
-void populateVectorMultiReductionLoweringPatterns(
- RewritePatternSet &patterns, VectorMultiReductionLowering options,
- PatternBenefit benefit = 1);
-
/// Populate the pattern set with the following patterns:
///
/// [TransferReadToVectorLoadLowering]
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index 6c97c6501a23e..23118bf3e726a 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -126,18 +126,6 @@ void transform::ApplyMaterializeMasksPatternsOp::populatePatterns(
/*force32BitVectorIndices=*/false);
}
-void transform::ApplyLowerMultiReductionPatternsOp::populatePatterns(
- RewritePatternSet &patterns) {
- vector::VectorTransformsOptions vectorTransformOptions;
- vectorTransformOptions.setVectorMultiReductionLowering(getLoweringStrategy());
- vector::populateVectorMultiReductionReorderAndExpandPatterns(
- patterns, vectorTransformOptions.vectorMultiReductionLowering);
- vector::populateVectorMultiReductionFlatteningPatterns(
- patterns, vectorTransformOptions.vectorMultiReductionLowering);
- vector::populateVectorMultiReductionUnrollingPatterns(
- patterns, vectorTransformOptions.vectorMultiReductionLowering);
-}
-
void transform::ApplyReorderAndExpandMultiReductionPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
vector::VectorTransformsOptions vectorTransformOptions;
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index fec04c967c9e1..663c43a44781b 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -549,15 +549,6 @@ void mlir::vector::populateVectorMultiReductionUnrollingPatterns(
benefit);
}
-void mlir::vector::populateVectorMultiReductionLoweringPatterns(
- RewritePatternSet &patterns, VectorMultiReductionLowering options,
- PatternBenefit benefit) {
- populateVectorMultiReductionReorderAndExpandPatterns(patterns, options,
- benefit);
- populateVectorMultiReductionFlatteningPatterns(patterns, options, benefit);
- populateVectorMultiReductionUnrollingPatterns(patterns, options, benefit);
-}
-
std::unique_ptr<Pass> vector::createLowerVectorMultiReductionPass(
vector::VectorMultiReductionLowering option) {
return std::make_unique<LowerVectorMultiReductionPass>(option);
diff --git a/mlir/test/Dialect/LLVM/transform-e2e.mlir b/mlir/test/Dialect/LLVM/transform-e2e.mlir
index c00b47fb936e9..ab58dda91a914 100644
--- a/mlir/test/Dialect/LLVM/transform-e2e.mlir
+++ b/mlir/test/Dialect/LLVM/transform-e2e.mlir
@@ -30,7 +30,9 @@ module attributes {transform.with_named_sequence} {
transform.apply_patterns to %f {
transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct"
transform.apply_patterns.vector.transfer_permutation_patterns
- transform.apply_patterns.vector.lower_multi_reduction lowering_strategy = "innerparallel"
+ transform.apply_patterns.vector.reorder_and_expand_multi_reduction_dims lowering_strategy = "innerparallel"
+ transform.apply_patterns.vector.multi_reduction_flattening lowering_strategy = "innerparallel"
+ transform.apply_patterns.vector.multi_reduction_unrolling lowering_strategy = "innerparallel"
transform.apply_patterns.vector.split_transfer_full_partial split_transfer_strategy = "linalg-copy"
transform.apply_patterns.vector.transfer_to_scf max_transfer_rank = 1 full_unroll = true
transform.apply_patterns.vector.lower_transfer max_transfer_rank = 1
diff --git a/mlir/test/Dialect/Vector/transform-vector.mlir b/mlir/test/Dialect/Vector/transform-vector.mlir
index 524a4f429211b..a37105d573219 100644
--- a/mlir/test/Dialect/Vector/transform-vector.mlir
+++ b/mlir/test/Dialect/Vector/transform-vector.mlir
@@ -39,7 +39,9 @@ module attributes {transform.with_named_sequence} {
} : !transform.any_op
transform.apply_patterns to %f {
- transform.apply_patterns.vector.lower_multi_reduction lowering_strategy = "innerparallel"
+ transform.apply_patterns.vector.reorder_and_expand_multi_reduction_dims lowering_strategy = "innerparallel"
+ transform.apply_patterns.vector.multi_reduction_flattening lowering_strategy = "innerparallel"
+ transform.apply_patterns.vector.multi_reduction_unrolling lowering_strategy = "innerparallel"
} : !transform.any_op
transform.apply_patterns to %f {
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_1d.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_1d.mlir
index 3090e921553c6..25b65080339d5 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_1d.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_1d.mlir
@@ -150,7 +150,9 @@ module attributes {transform.with_named_sequence} {
// Step 3: Lower vector.multi_reduction
transform.apply_patterns to %func {
transform.apply_patterns.vector.lower_masked_transfers
- transform.apply_patterns.vector.lower_multi_reduction lowering_strategy = "innerreduction"
+ transform.apply_patterns.vector.reorder_and_expand_multi_reduction_dims lowering_strategy = "innerreduction"
+ transform.apply_patterns.vector.multi_reduction_flattening lowering_strategy = "innerreduction"
+ transform.apply_patterns.vector.multi_reduction_unrolling lowering_strategy = "innerreduction"
} : !transform.op<"func.func">
transform.yield
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_2d.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_2d.mlir
index fc0fd40b4d265..6072b44adf4fa 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_2d.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_2d.mlir
@@ -155,7 +155,9 @@ module attributes {transform.with_named_sequence} {
// Step 3: Lower vector.multi_reduction
transform.apply_patterns to %func {
transform.apply_patterns.vector.lower_masked_transfers
- transform.apply_patterns.vector.lower_multi_reduction lowering_strategy = "innerreduction"
+ transform.apply_patterns.vector.reorder_and_expand_multi_reduction_dims lowering_strategy = "innerreduction"
+ transform.apply_patterns.vector.multi_reduction_flattening lowering_strategy = "innerreduction"
+ transform.apply_patterns.vector.multi_reduction_unrolling lowering_strategy = "innerreduction"
} : !transform.op<"func.func">
transform.yield
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/test-matmul-masked-vec.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/test-matmul-masked-vec.mlir
index bbda8d4e99d04..3c4f10316d0f3 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/test-matmul-masked-vec.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/test-matmul-masked-vec.mlir
@@ -53,7 +53,9 @@ module attributes {transform.with_named_sequence} {
%func_op = transform.get_parent_op %0 : (!transform.any_op) -> !transform.op<"func.func">
transform.structured.vectorize %0 vector_sizes [4, 4, 2] : !transform.any_op
transform.apply_patterns to %func_op {
- transform.apply_patterns.vector.lower_multi_reduction lowering_strategy = "innerreduction"
+ transform.apply_patterns.vector.reorder_and_expand_multi_reduction_dims lowering_strategy = "innerreduction"
+ transform.apply_patterns.vector.multi_reduction_flattening lowering_strategy = "innerreduction"
+ transform.apply_patterns.vector.multi_reduction_unrolling lowering_strategy = "innerreduction"
} : !transform.op<"func.func">
transform.yield
}
diff --git a/mlir/test/python/dialects/transform_vector_ext.py b/mlir/test/python/dialects/transform_vector_ext.py
index 76e39864fea9f..8a3091d0b1b02 100644
--- a/mlir/test/python/dialects/transform_vector_ext.py
+++ b/mlir/test/python/dialects/transform_vector_ext.py
@@ -87,19 +87,6 @@ def enum_configurable_patterns():
lowering_strategy=vector.VectorContractLowering.ParallelArith
)
- # CHECK: transform.apply_patterns.vector.lower_multi_reduction
- vector.ApplyLowerMultiReductionPatternsOp()
- # CHECK: transform.apply_patterns.vector.lower_multi_reduction
- # This is the default mode, not printed.
- vector.ApplyLowerMultiReductionPatternsOp(
- lowering_strategy=vector.VectorMultiReductionLowering.InnerParallel
- )
- # CHECK: transform.apply_patterns.vector.lower_multi_reduction
- # CHECK-SAME: lowering_strategy = innerreduction
- vector.ApplyLowerMultiReductionPatternsOp(
- lowering_strategy=vector.VectorMultiReductionLowering.InnerReduction
- )
-
# CHECK: transform.apply_patterns.vector.reorder_and_expand_multi_reduction_dims
vector.ApplyReorderAndExpandMultiReductionPatternsOp()
# CHECK: transform.apply_patterns.vector.reorder_and_expand_multi_reduction_dims
More information about the Mlir-commits
mailing list