[Mlir-commits] [mlir] [mlir][vector] Add lower_multi_reduction_flattening (PR #181244)
Erick Ochoa Lopez
llvmlistbot at llvm.org
Wed Feb 18 07:26:11 PST 2026
https://github.com/amd-eochoalo updated https://github.com/llvm/llvm-project/pull/181244
>From 5ef8be600b9903df65fd8495a1c796ac7f1d42b6 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Wed, 11 Feb 2026 11:05:56 -0500
Subject: [PATCH 1/6] [mlir][vector] Add
ApplyLowerMultiReductionFlatteningPatternsOp.
---
.../Dialect/Vector/TransformOps/VectorTransformOps.cpp | 8 ++++++++
1 file changed, 8 insertions(+)
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index 4e2b97aa07084..5247140a48e24 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -146,6 +146,14 @@ void transform::ApplyReorderAndExpandMultiReductionPatternsOp::populatePatterns(
patterns, vectorTransformOptions.vectorMultiReductionLowering);
}
+void transform::ApplyLowerMultiReductionFlatteningPatternsOp::populatePatterns(
+ RewritePatternSet &patterns) {
+ vector::VectorTransformsOptions vectorTransformOptions;
+ vectorTransformOptions.setVectorMultiReductionLowering(getLoweringStrategy());
+ vector::populateVectorMultiReductionFlatteningPatterns(
+ patterns, vectorTransformOptions.vectorMultiReductionLowering);
+}
+
void transform::ApplyLowerOuterProductPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
populateVectorOuterProductLoweringPatterns(patterns);
>From e930a20206d2ded1251903a3d7bf63230e5f78b2 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Wed, 11 Feb 2026 11:06:49 -0500
Subject: [PATCH 2/6] Add test
---
...ti-reduction-flattening-innerparallel.mlir | 35 +++++++++++++++++++
.../vector-multi-reduction-flattening.mlir | 35 +++++++++++++++++++
.../vector-multi-reduction-lowering.mlir | 12 -------
3 files changed, 70 insertions(+), 12 deletions(-)
create mode 100644 mlir/test/Dialect/Vector/vector-multi-reduction-flattening-innerparallel.mlir
create mode 100644 mlir/test/Dialect/Vector/vector-multi-reduction-flattening.mlir
diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-flattening-innerparallel.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-flattening-innerparallel.mlir
new file mode 100644
index 0000000000000..ad8d0cc046d62
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-multi-reduction-flattening-innerparallel.mlir
@@ -0,0 +1,35 @@
+// RUN: mlir-opt %s --transform-interpreter | FileCheck %s
+
+// CHECK-LABEL: func @vector_multi_reduction_to_scalar
+// CHECK-SAME: %[[INPUT:.+]]: vector<2x3xf32>
+// CHECK-SAME: %[[ACC:.+]]: f32
+func.func @vector_multi_reduction_to_scalar(%arg0: vector<2x3xf32>, %acc: f32) -> f32 {
+ // CHECK: %[[CASTED:.+]] = vector.shape_cast %[[INPUT]] : vector<2x3xf32> to vector<6xf32>
+ // CHECK: %[[RESULT:.+]] = vector.multi_reduction <add>, %[[CASTED]], %[[ACC]] [0]
+ %0 = vector.multi_reduction <add>, %arg0, %acc [0, 1] : vector<2x3xf32> to f32
+ // CHECK: return %[[RESULT]]
+ return %0 : f32
+}
+
+// Test with parallel dimension: reduction dims [0,1] are already outermost,
+// parallel dim [2] is innermost - this can be flattened.
+// CHECK-LABEL: func @vector_multi_reduction_parallel_dim
+// CHECK-SAME: %[[INPUT:.+]]: vector<3x4x2xi32>
+// CHECK-SAME: %[[ACC:.+]]: vector<2xi32>
+func.func @vector_multi_reduction_parallel_dim(%arg0: vector<3x4x2xi32>, %acc: vector<2xi32>) -> vector<2xi32> {
+ // CHECK: %[[CASTED:.+]] = vector.shape_cast %[[INPUT]] : vector<3x4x2xi32> to vector<12x2xi32>
+ // CHECK: %[[RESULT:.+]] = vector.multi_reduction <mul>, %[[CASTED]], %[[ACC]] [0]
+ %0 = vector.multi_reduction <mul>, %arg0, %acc [0, 1] : vector<3x4x2xi32> to vector<2xi32>
+ // CHECK: return %[[RESULT]]
+ return %0 : vector<2xi32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%root : !transform.any_op {transform.readonly}) {
+ %func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func">
+ transform.apply_patterns to %func_op {
+ transform.apply_patterns.vector.lower_multi_reduction_flattening lowering_strategy = "innerparallel"
+ } : !transform.op<"func.func">
+ transform.yield
+ }
+}
diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-flattening.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-flattening.mlir
new file mode 100644
index 0000000000000..94830366c4f4d
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-multi-reduction-flattening.mlir
@@ -0,0 +1,35 @@
+// RUN: mlir-opt %s --transform-interpreter | FileCheck %s
+
+// Patterns applied:
+// * ReduceMultiDimReductionRank from populateVectorMultiReductionFlatteningPatterns
+func.func @vector_multi_reduction_flattening(%arg0: vector<2x4xf32>, %acc: f32) -> f32 {
+ %0 = vector.multi_reduction <mul>, %arg0, %acc [0, 1] : vector<2x4xf32> to f32
+ return %0 : f32
+}
+
+// CHECK-LABEL: func @vector_multi_reduction_flattening
+// CHECK-SAME: %[[INPUT:.+]]: vector<2x4xf32>, %[[ACC:.*]]: f32)
+// CHECK: %[[CASTED:.*]] = vector.shape_cast %[[INPUT]] : vector<2x4xf32> to vector<8xf32>
+// CHECK: %[[RESULT:.+]] = vector.multi_reduction <mul>, %[[CASTED]], %[[ACC]] [0]
+// CHECK: return %[[RESULT]]
+
+// CHECK-LABEL: func @vector_multi_reduction_parallel_dim
+// CHECK-SAME: %[[INPUT:.+]]: vector<2x3x4xi32>
+// CHECK-SAME: %[[ACC:.+]]: vector<2xi32>
+func.func @vector_multi_reduction_parallel_dim(%arg0: vector<2x3x4xi32>, %acc: vector<2xi32>) -> vector<2xi32> {
+ // CHECK: %[[CASTED:.+]] = vector.shape_cast %[[INPUT]] : vector<2x3x4xi32> to vector<2x12xi32>
+ // CHECK: %[[RESULT:.+]] = vector.multi_reduction <add>, %[[CASTED]], %[[ACC]] [1]
+ %0 = vector.multi_reduction <add>, %arg0, %acc [1, 2] : vector<2x3x4xi32> to vector<2xi32>
+ // CHECK: return %[[RESULT]]
+ return %0 : vector<2xi32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%root : !transform.any_op {transform.readonly}) {
+ %func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func">
+ transform.apply_patterns to %func_op {
+ transform.apply_patterns.vector.lower_multi_reduction_flattening lowering_strategy = "innerreduction"
+ } : !transform.op<"func.func">
+ transform.yield
+ }
+}
diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir
index 3ce9f57edf9d6..6b79a78e6a42a 100644
--- a/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir
@@ -19,18 +19,6 @@ func.func @vector_multi_reduction(%arg0: vector<2x4xf32>, %acc: vector<2xf32>) -
// CHECK: %[[RESULT_VEC:.+]] = vector.insert %[[RV1:.+]], %[[RESULT_VEC_1]] [1] : f32 into vector<2xf32>
// CHECK: return %[[RESULT_VEC]]
-// Patterns applied:
-// * ReduceMultiDimReductionRank from populateVectorMultiReductionFlatteningPatterns
-func.func @vector_multi_reduction_to_scalar(%arg0: vector<2x4xf32>, %acc: f32) -> f32 {
- %0 = vector.multi_reduction <mul>, %arg0, %acc [0, 1] : vector<2x4xf32> to f32
- return %0 : f32
-}
-// CHECK-LABEL: func @vector_multi_reduction_to_scalar
-// CHECK-SAME: %[[INPUT:.+]]: vector<2x4xf32>, %[[ACC:.*]]: f32)
-// CHECK: %[[CASTED:.*]] = vector.shape_cast %[[INPUT]] : vector<2x4xf32> to vector<8xf32>
-// CHECK: %[[REDUCED:.*]] = vector.reduction <mul>, %[[CASTED]], %[[ACC]] : vector<8xf32> into f32
-// CHECK: return %[[REDUCED]]
-
// Patterns applied:
// * ReduceMultiDimReductionRank from populateVectorMultiReductionFlatteningPatterns
// * TwoDimMultiReductionToReduction from populateVectorMultiReductionUnrollingPatterns
>From 4394e734532361d4771ce6358b09138630463ec3 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Wed, 11 Feb 2026 11:07:42 -0500
Subject: [PATCH 3/6] Add tests for python bindings
---
mlir/test/python/dialects/transform_vector_ext.py | 8 ++++++++
1 file changed, 8 insertions(+)
diff --git a/mlir/test/python/dialects/transform_vector_ext.py b/mlir/test/python/dialects/transform_vector_ext.py
index 0f9aab29ca5ca..4c64a4d08a57a 100644
--- a/mlir/test/python/dialects/transform_vector_ext.py
+++ b/mlir/test/python/dialects/transform_vector_ext.py
@@ -108,6 +108,14 @@ def enum_configurable_patterns():
lowering_strategy=vector.VectorMultiReductionLowering.InnerReduction
)
+ # CHECK: transform.apply_patterns.vector.lower_multi_reduction_flattening
+ vector.ApplyLowerMultiReductionFlatteningPatternsOp()
+ # CHECK: transform.apply_patterns.vector.lower_multi_reduction_flattening
+ # CHECK-SAME: lowering_strategy = innerreduction
+ vector.ApplyLowerMultiReductionFlatteningPatternsOp(
+ lowering_strategy=vector.VectorMultiReductionLowering.InnerReduction
+ )
+
# CHECK: transform.apply_patterns.vector.lower_transpose
vector.ApplyLowerTransposePatternsOp()
# CHECK: transform.apply_patterns.vector.lower_transpose
>From 982509ba52821807a4d5d90f3ffb701201ae1bb6 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Thu, 12 Feb 2026 16:31:03 -0500
Subject: [PATCH 4/6] Merge test file
---
.../Transforms/LowerVectorMultiReduction.cpp | 2 +-
...ti-reduction-flattening-innerparallel.mlir | 35 -----
.../vector-multi-reduction-flattening.mlir | 121 +++++++++++++++---
3 files changed, 105 insertions(+), 53 deletions(-)
delete mode 100644 mlir/test/Dialect/Vector/vector-multi-reduction-flattening-innerparallel.mlir
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index 2d6a49bad27bc..f9960da889788 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -287,7 +287,7 @@ class ReduceMultiDimReductionRank
return success();
}
- // 8. Creates shape cast for the output n-D -> 2-D.
+ // 8. Shape cast the flattened result back to the original n-D parallel shape.
VectorType outputCastedType = VectorType::get(
parallelShapes, multiReductionOp.getSourceVectorType().getElementType(),
parallelScalableDims);
diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-flattening-innerparallel.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-flattening-innerparallel.mlir
deleted file mode 100644
index ad8d0cc046d62..0000000000000
--- a/mlir/test/Dialect/Vector/vector-multi-reduction-flattening-innerparallel.mlir
+++ /dev/null
@@ -1,35 +0,0 @@
-// RUN: mlir-opt %s --transform-interpreter | FileCheck %s
-
-// CHECK-LABEL: func @vector_multi_reduction_to_scalar
-// CHECK-SAME: %[[INPUT:.+]]: vector<2x3xf32>
-// CHECK-SAME: %[[ACC:.+]]: f32
-func.func @vector_multi_reduction_to_scalar(%arg0: vector<2x3xf32>, %acc: f32) -> f32 {
- // CHECK: %[[CASTED:.+]] = vector.shape_cast %[[INPUT]] : vector<2x3xf32> to vector<6xf32>
- // CHECK: %[[RESULT:.+]] = vector.multi_reduction <add>, %[[CASTED]], %[[ACC]] [0]
- %0 = vector.multi_reduction <add>, %arg0, %acc [0, 1] : vector<2x3xf32> to f32
- // CHECK: return %[[RESULT]]
- return %0 : f32
-}
-
-// Test with parallel dimension: reduction dims [0,1] are already outermost,
-// parallel dim [2] is innermost - this can be flattened.
-// CHECK-LABEL: func @vector_multi_reduction_parallel_dim
-// CHECK-SAME: %[[INPUT:.+]]: vector<3x4x2xi32>
-// CHECK-SAME: %[[ACC:.+]]: vector<2xi32>
-func.func @vector_multi_reduction_parallel_dim(%arg0: vector<3x4x2xi32>, %acc: vector<2xi32>) -> vector<2xi32> {
- // CHECK: %[[CASTED:.+]] = vector.shape_cast %[[INPUT]] : vector<3x4x2xi32> to vector<12x2xi32>
- // CHECK: %[[RESULT:.+]] = vector.multi_reduction <mul>, %[[CASTED]], %[[ACC]] [0]
- %0 = vector.multi_reduction <mul>, %arg0, %acc [0, 1] : vector<3x4x2xi32> to vector<2xi32>
- // CHECK: return %[[RESULT]]
- return %0 : vector<2xi32>
-}
-
-module attributes {transform.with_named_sequence} {
- transform.named_sequence @__transform_main(%root : !transform.any_op {transform.readonly}) {
- %func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func">
- transform.apply_patterns to %func_op {
- transform.apply_patterns.vector.lower_multi_reduction_flattening lowering_strategy = "innerparallel"
- } : !transform.op<"func.func">
- transform.yield
- }
-}
diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-flattening.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-flattening.mlir
index 94830366c4f4d..55914cff94966 100644
--- a/mlir/test/Dialect/Vector/vector-multi-reduction-flattening.mlir
+++ b/mlir/test/Dialect/Vector/vector-multi-reduction-flattening.mlir
@@ -1,35 +1,122 @@
-// RUN: mlir-opt %s --transform-interpreter | FileCheck %s
+// RUN: mlir-opt %s --transform-interpreter='entry-point=innerreduction' | FileCheck %s --check-prefix=INNER_REDUCTION,ALL
+// RUN: mlir-opt %s --transform-interpreter='entry-point=innerparallel' | FileCheck %s --check-prefix=INNER_PARALLEL,ALL
-// Patterns applied:
-// * ReduceMultiDimReductionRank from populateVectorMultiReductionFlatteningPatterns
+// ALL-LABEL: func @negative_flattening_cases
+func.func @negative_flattening_cases(
+ %v1d: vector<8xf32>,
+ %v2d: vector<4x8xf32>,
+ %v_scalable: vector<[2]x[4]x8xf32>,
+ %v_non_contig: vector<2x3x4x5xi32>,
+ %acc_scalar: f32,
+ %acc_1d: vector<8xf32>,
+ %acc_2d: vector<2x4xi32>) -> (f32, vector<8xf32>, vector<8xf32>, vector<2x4xi32>) {
+
+ // Test 1: Less than 2 dimensions
+ // ALL: %[[R1:.+]] = vector.multi_reduction <add>, %{{.+}}, %{{.+}} [0] : vector<8xf32> to f32
+ %r1 = vector.multi_reduction <add>, %v1d, %acc_scalar [0] : vector<8xf32> to f32
+
+ // Test 2: More than one scalable dimensions
+ // ALL: %[[R2:.+]] = vector.multi_reduction <mul>, %{{.+}}, %{{.+}} [0, 1] : vector<[2]x[4]x8xf32> to vector<8xf32>
+ %r2 = vector.multi_reduction <mul>, %v_scalable, %acc_1d [0, 1] : vector<[2]x[4]x8xf32> to vector<8xf32>
+
+ // Test 3: Already 2D with reduction on single dim
+ // ALL: %[[R3:.+]] = vector.multi_reduction <add>, %{{.+}}, %{{.+}} [0] : vector<4x8xf32> to vector<8xf32>
+ %r3 = vector.multi_reduction <add>, %v2d, %acc_1d [0] : vector<4x8xf32> to vector<8xf32>
+
+ // Test 4: Non-contiguous parallel dimensions
+ // ALL: %[[R4:.+]] = vector.multi_reduction <add>, %{{.+}}, %{{.+}} [1, 3] : vector<2x3x4x5xi32> to vector<2x4xi32>
+ %r4 = vector.multi_reduction <add>, %v_non_contig, %acc_2d [1, 3] : vector<2x3x4x5xi32> to vector<2x4xi32>
+
+ // ALL: return %[[R1]], %[[R2]], %[[R3]], %[[R4]]
+ return %r1, %r2, %r3, %r4 : f32, vector<8xf32>, vector<8xf32>, vector<2x4xi32>
+}
+
+// ALL-LABEL: func @vector_multi_reduction_flattening
+// ALL-SAME: %[[INPUT:.+]]: vector<2x4xf32>, %[[ACC:.*]]: f32)
func.func @vector_multi_reduction_flattening(%arg0: vector<2x4xf32>, %acc: f32) -> f32 {
+ // ALL: %[[CASTED:.*]] = vector.shape_cast %[[INPUT]] : vector<2x4xf32> to vector<8xf32>
+ // ALL: %[[RESULT:.+]] = vector.multi_reduction <mul>, %[[CASTED]], %[[ACC]] [0]
%0 = vector.multi_reduction <mul>, %arg0, %acc [0, 1] : vector<2x4xf32> to f32
+ // ALL: return %[[RESULT]]
return %0 : f32
}
-// CHECK-LABEL: func @vector_multi_reduction_flattening
-// CHECK-SAME: %[[INPUT:.+]]: vector<2x4xf32>, %[[ACC:.*]]: f32)
-// CHECK: %[[CASTED:.*]] = vector.shape_cast %[[INPUT]] : vector<2x4xf32> to vector<8xf32>
-// CHECK: %[[RESULT:.+]] = vector.multi_reduction <mul>, %[[CASTED]], %[[ACC]] [0]
-// CHECK: return %[[RESULT]]
-
-// CHECK-LABEL: func @vector_multi_reduction_parallel_dim
-// CHECK-SAME: %[[INPUT:.+]]: vector<2x3x4xi32>
-// CHECK-SAME: %[[ACC:.+]]: vector<2xi32>
-func.func @vector_multi_reduction_parallel_dim(%arg0: vector<2x3x4xi32>, %acc: vector<2xi32>) -> vector<2xi32> {
- // CHECK: %[[CASTED:.+]] = vector.shape_cast %[[INPUT]] : vector<2x3x4xi32> to vector<2x12xi32>
- // CHECK: %[[RESULT:.+]] = vector.multi_reduction <add>, %[[CASTED]], %[[ACC]] [1]
+// INNER_REDUCTION-LABEL: func @vector_multi_reduction_parallel_dim_innerreduction
+// INNER_REDUCTION-SAME: %[[INPUT:.+]]: vector<2x3x4xi32>
+// INNER_REDUCTION-SAME: %[[ACC:.+]]: vector<2xi32>
+func.func @vector_multi_reduction_parallel_dim_innerreduction(%arg0: vector<2x3x4xi32>, %acc: vector<2xi32>) -> vector<2xi32> {
+ // INNER_REDUCTION: %[[CASTED:.+]] = vector.shape_cast %[[INPUT]] : vector<2x3x4xi32> to vector<2x12xi32>
+ // INNER_REDUCTION: %[[RESULT:.+]] = vector.multi_reduction <add>, %[[CASTED]], %[[ACC]] [1]
%0 = vector.multi_reduction <add>, %arg0, %acc [1, 2] : vector<2x3x4xi32> to vector<2xi32>
- // CHECK: return %[[RESULT]]
+ // INNER_REDUCTION: return %[[RESULT]]
+ return %0 : vector<2xi32>
+}
+
+// INNER_REDUCTION-LABEL: func @output_shapecast_multiple_parallel
+// INNER_REDUCTION-SAME: %[[INPUT:.+]]: vector<2x3x4x5x6xi32>
+// INNER_REDUCTION-SAME: %[[ACC:.+]]: vector<2x3x4xi32>
+func.func @output_shapecast_multiple_parallel(%arg0: vector<2x3x4x5x6xi32>, %acc: vector<2x3x4xi32>) -> vector<2x3x4xi32> {
+ // INNER_REDUCTION: %[[INPUT_CAST:.+]] = vector.shape_cast %[[INPUT]] : vector<2x3x4x5x6xi32> to vector<24x30xi32>
+ // INNER_REDUCTION: %[[ACC_CAST:.+]] = vector.shape_cast %[[ACC]] : vector<2x3x4xi32> to vector<24xi32>
+ // INNER_REDUCTION: %[[RESULT_FLAT:.+]] = vector.multi_reduction <mul>, %[[INPUT_CAST]], %[[ACC_CAST]] [1]
+ // INNER_REDUCTION: %[[RESULT:.+]] = vector.shape_cast %[[RESULT_FLAT]] : vector<24xi32> to vector<2x3x4xi32>
+ %0 = vector.multi_reduction <mul>, %arg0, %acc [3, 4] : vector<2x3x4x5x6xi32> to vector<2x3x4xi32>
+ // INNER_REDUCTION: return %[[RESULT]]
+ return %0 : vector<2x3x4xi32>
+}
+
+// INNER_PARALLEL-LABEL: func @vector_multi_reduction_parallel_dim_innerparallel
+// INNER_PARALLEL-SAME: %[[INPUT:.+]]: vector<3x4x2xi32>
+// INNER_PARALLEL-SAME: %[[ACC:.+]]: vector<2xi32>
+func.func @vector_multi_reduction_parallel_dim_innerparallel(%arg0: vector<3x4x2xi32>, %acc: vector<2xi32>) -> vector<2xi32> {
+ // INNER_PARALLEL: %[[CASTED:.+]] = vector.shape_cast %[[INPUT]] : vector<3x4x2xi32> to vector<12x2xi32>
+ // INNER_PARALLEL: %[[RESULT:.+]] = vector.multi_reduction <mul>, %[[CASTED]], %[[ACC]] [0]
+ %0 = vector.multi_reduction <mul>, %arg0, %acc [0, 1] : vector<3x4x2xi32> to vector<2xi32>
+ // INNER_PARALLEL: return %[[RESULT]]
return %0 : vector<2xi32>
}
+// ALL-LABEL: func @single_scalable_dim
+// ALL-SAME: %[[INPUT:.+]]: vector<4x[8]xf32>
+// ALL-SAME: %[[ACC:.+]]: f32
+func.func @single_scalable_dim(%arg0: vector<4x[8]xf32>, %acc: f32) -> f32 {
+ // ALL: %[[CASTED:.+]] = vector.shape_cast %[[INPUT]] : vector<4x[8]xf32> to vector<[32]xf32>
+ // ALL: %[[RESULT:.+]] = vector.multi_reduction <add>, %[[CASTED]], %[[ACC]] [0]
+ %0 = vector.multi_reduction <add>, %arg0, %acc [0, 1] : vector<4x[8]xf32> to f32
+ // ALL: return %[[RESULT]]
+ return %0 : f32
+}
+
+// ALL-LABEL: func @masked_multi_reduction
+// ALL-SAME: %[[INPUT:.+]]: vector<2x4xf32>
+// ALL-SAME: %[[ACC:.+]]: f32
+// ALL-SAME: %[[MASK:.+]]: vector<2x4xi1>
+func.func @masked_multi_reduction(%arg0: vector<2x4xf32>, %acc: f32, %mask: vector<2x4xi1>) -> f32 {
+ // ALL: %[[CASTED_MASK:.+]] = vector.shape_cast %[[MASK]] : vector<2x4xi1> to vector<8xi1>
+ // ALL: %[[CASTED_INPUT:.+]] = vector.shape_cast %[[INPUT]] : vector<2x4xf32> to vector<8xf32>
+ // ALL: %[[RESULT:.+]] = vector.mask %[[CASTED_MASK]]
+ // ALL: vector.multi_reduction <mul>, %[[CASTED_INPUT]], %[[ACC]] [0]
+ %0 = vector.mask %mask {
+ vector.multi_reduction <mul>, %arg0, %acc [0, 1] : vector<2x4xf32> to f32
+ } : vector<2x4xi1> -> f32
+ // ALL: return %[[RESULT]]
+ return %0 : f32
+}
+
module attributes {transform.with_named_sequence} {
- transform.named_sequence @__transform_main(%root : !transform.any_op {transform.readonly}) {
+ transform.named_sequence @innerreduction(%root : !transform.any_op {transform.readonly}) {
%func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func">
transform.apply_patterns to %func_op {
transform.apply_patterns.vector.lower_multi_reduction_flattening lowering_strategy = "innerreduction"
} : !transform.op<"func.func">
transform.yield
}
+
+ transform.named_sequence @innerparallel(%root : !transform.any_op {transform.readonly}) {
+ %func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func">
+ transform.apply_patterns to %func_op {
+ transform.apply_patterns.vector.lower_multi_reduction_flattening lowering_strategy = "innerparallel"
+ } : !transform.op<"func.func">
+ transform.yield
+ }
}
>From e943cec57e6683070146638627fc721c280fbef4 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Wed, 18 Feb 2026 10:16:06 -0500
Subject: [PATCH 5/6] Fix rebase
---
.../Vector/TransformOps/VectorTransformOps.td | 17 +++++++++++++++++
.../Vector/TransformOps/VectorTransformOps.cpp | 2 +-
.../vector-multi-reduction-flattening.mlir | 4 ++--
.../python/dialects/transform_vector_ext.py | 8 ++++----
4 files changed, 24 insertions(+), 7 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index 462c61df72108..6eb96e2a8fdab 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -266,6 +266,23 @@ def ApplyReorderAndExpandMultiReductionPatternsOp: Op<Transform_Dialect,
}];
}
+def ApplyMultiReductionFlatteningPatternsOp: Op<Transform_Dialect,
+ "apply_patterns.vector.multi_reduction_flattening",
+ [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+ let description = [{
+ Indicates that vector multi_reduction operations should be flattened from
+ more than 2-D to 2-D.
+ }];
+
+ let arguments = (ins DefaultValuedAttr<VectorMultiReductionLoweringAttr,
+ "vector::VectorMultiReductionLowering::InnerParallel">:$lowering_strategy
+ );
+
+ let assemblyFormat = [{
+ (`lowering_strategy` `=` $lowering_strategy^)? attr-dict
+ }];
+}
+
def ApplyLowerOuterProductPatternsOp : Op<Transform_Dialect,
"apply_patterns.vector.lower_outerproduct",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index 5247140a48e24..f3529ac26523f 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -146,7 +146,7 @@ void transform::ApplyReorderAndExpandMultiReductionPatternsOp::populatePatterns(
patterns, vectorTransformOptions.vectorMultiReductionLowering);
}
-void transform::ApplyLowerMultiReductionFlatteningPatternsOp::populatePatterns(
+void transform::ApplyMultiReductionFlatteningPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
vector::VectorTransformsOptions vectorTransformOptions;
vectorTransformOptions.setVectorMultiReductionLowering(getLoweringStrategy());
diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-flattening.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-flattening.mlir
index 55914cff94966..b8f970912909b 100644
--- a/mlir/test/Dialect/Vector/vector-multi-reduction-flattening.mlir
+++ b/mlir/test/Dialect/Vector/vector-multi-reduction-flattening.mlir
@@ -107,7 +107,7 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @innerreduction(%root : !transform.any_op {transform.readonly}) {
%func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func">
transform.apply_patterns to %func_op {
- transform.apply_patterns.vector.lower_multi_reduction_flattening lowering_strategy = "innerreduction"
+ transform.apply_patterns.vector.multi_reduction_flattening lowering_strategy = "innerreduction"
} : !transform.op<"func.func">
transform.yield
}
@@ -115,7 +115,7 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @innerparallel(%root : !transform.any_op {transform.readonly}) {
%func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func">
transform.apply_patterns to %func_op {
- transform.apply_patterns.vector.lower_multi_reduction_flattening lowering_strategy = "innerparallel"
+ transform.apply_patterns.vector.multi_reduction_flattening lowering_strategy = "innerparallel"
} : !transform.op<"func.func">
transform.yield
}
diff --git a/mlir/test/python/dialects/transform_vector_ext.py b/mlir/test/python/dialects/transform_vector_ext.py
index 4c64a4d08a57a..29ce8ba63cd53 100644
--- a/mlir/test/python/dialects/transform_vector_ext.py
+++ b/mlir/test/python/dialects/transform_vector_ext.py
@@ -108,11 +108,11 @@ def enum_configurable_patterns():
lowering_strategy=vector.VectorMultiReductionLowering.InnerReduction
)
- # CHECK: transform.apply_patterns.vector.lower_multi_reduction_flattening
- vector.ApplyLowerMultiReductionFlatteningPatternsOp()
- # CHECK: transform.apply_patterns.vector.lower_multi_reduction_flattening
+ # CHECK: transform.apply_patterns.vector.multi_reduction_flattening
+ vector.ApplyMultiReductionFlatteningPatternsOp()
+ # CHECK: transform.apply_patterns.vector.multi_reduction_flattening
# CHECK-SAME: lowering_strategy = innerreduction
- vector.ApplyLowerMultiReductionFlatteningPatternsOp(
+ vector.ApplyMultiReductionFlatteningPatternsOp(
lowering_strategy=vector.VectorMultiReductionLowering.InnerReduction
)
>From 4f92de4fcd2900ebef20dc6c77dcfaf2149f89d9 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Wed, 18 Feb 2026 10:25:23 -0500
Subject: [PATCH 6/6] Style
---
.../Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index f9960da889788..fec04c967c9e1 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -287,7 +287,8 @@ class ReduceMultiDimReductionRank
return success();
}
- // 8. Shape cast the flattened result back to the original n-D parallel shape.
+ // 8. Shape cast the flattened result back to the original n-D parallel
+ // shape.
VectorType outputCastedType = VectorType::get(
parallelShapes, multiReductionOp.getSourceVectorType().getElementType(),
parallelScalableDims);
More information about the Mlir-commits
mailing list