[Mlir-commits] [mlir] [mlir][vector] Add lower_multi_reduction_flattening (PR #181244)

Erick Ochoa Lopez llvmlistbot at llvm.org
Wed Feb 18 07:26:11 PST 2026


https://github.com/amd-eochoalo updated https://github.com/llvm/llvm-project/pull/181244

>From 5ef8be600b9903df65fd8495a1c796ac7f1d42b6 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Wed, 11 Feb 2026 11:05:56 -0500
Subject: [PATCH 1/6] [mlir][vector] Add
 ApplyLowerMultiReductionFlatteningPatternsOp.

---
 .../Dialect/Vector/TransformOps/VectorTransformOps.cpp    | 8 ++++++++
 1 file changed, 8 insertions(+)

diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index 4e2b97aa07084..5247140a48e24 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -146,6 +146,14 @@ void transform::ApplyReorderAndExpandMultiReductionPatternsOp::populatePatterns(
       patterns, vectorTransformOptions.vectorMultiReductionLowering);
 }
 
+void transform::ApplyLowerMultiReductionFlatteningPatternsOp::populatePatterns(
+    RewritePatternSet &patterns) {
+  vector::VectorTransformsOptions vectorTransformOptions;
+  vectorTransformOptions.setVectorMultiReductionLowering(getLoweringStrategy());
+  vector::populateVectorMultiReductionFlatteningPatterns(
+      patterns, vectorTransformOptions.vectorMultiReductionLowering);
+}
+
 void transform::ApplyLowerOuterProductPatternsOp::populatePatterns(
     RewritePatternSet &patterns) {
   populateVectorOuterProductLoweringPatterns(patterns);

>From e930a20206d2ded1251903a3d7bf63230e5f78b2 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Wed, 11 Feb 2026 11:06:49 -0500
Subject: [PATCH 2/6] Add test

---
 ...ti-reduction-flattening-innerparallel.mlir | 35 +++++++++++++++++++
 .../vector-multi-reduction-flattening.mlir    | 35 +++++++++++++++++++
 .../vector-multi-reduction-lowering.mlir      | 12 -------
 3 files changed, 70 insertions(+), 12 deletions(-)
 create mode 100644 mlir/test/Dialect/Vector/vector-multi-reduction-flattening-innerparallel.mlir
 create mode 100644 mlir/test/Dialect/Vector/vector-multi-reduction-flattening.mlir

diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-flattening-innerparallel.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-flattening-innerparallel.mlir
new file mode 100644
index 0000000000000..ad8d0cc046d62
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-multi-reduction-flattening-innerparallel.mlir
@@ -0,0 +1,35 @@
+// RUN: mlir-opt %s --transform-interpreter | FileCheck %s
+
+// CHECK-LABEL: func @vector_multi_reduction_to_scalar
+// CHECK-SAME:    %[[INPUT:.+]]: vector<2x3xf32>
+// CHECK-SAME:    %[[ACC:.+]]: f32
+func.func @vector_multi_reduction_to_scalar(%arg0: vector<2x3xf32>, %acc: f32) -> f32 {
+    // CHECK: %[[CASTED:.+]] = vector.shape_cast %[[INPUT]] : vector<2x3xf32> to vector<6xf32>
+    // CHECK: %[[RESULT:.+]] = vector.multi_reduction <add>, %[[CASTED]], %[[ACC]] [0]
+    %0 = vector.multi_reduction <add>, %arg0, %acc [0, 1] : vector<2x3xf32> to f32
+    // CHECK: return %[[RESULT]]
+    return %0 : f32
+}
+
+// Test with parallel dimension: reduction dims [0,1] are already outermost,
+// parallel dim [2] is innermost - this can be flattened.
+// CHECK-LABEL: func @vector_multi_reduction_parallel_dim
+// CHECK-SAME:    %[[INPUT:.+]]: vector<3x4x2xi32>
+// CHECK-SAME:    %[[ACC:.+]]: vector<2xi32>
+func.func @vector_multi_reduction_parallel_dim(%arg0: vector<3x4x2xi32>, %acc: vector<2xi32>) -> vector<2xi32> {
+    // CHECK: %[[CASTED:.+]] = vector.shape_cast %[[INPUT]] : vector<3x4x2xi32> to vector<12x2xi32>
+    // CHECK: %[[RESULT:.+]] = vector.multi_reduction <mul>, %[[CASTED]], %[[ACC]] [0]
+    %0 = vector.multi_reduction <mul>, %arg0, %acc [0, 1] : vector<3x4x2xi32> to vector<2xi32>
+    // CHECK: return %[[RESULT]]
+    return %0 : vector<2xi32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%root : !transform.any_op {transform.readonly}) {
+    %func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func">
+    transform.apply_patterns to %func_op {
+      transform.apply_patterns.vector.lower_multi_reduction_flattening lowering_strategy = "innerparallel"
+    } : !transform.op<"func.func">
+    transform.yield
+  }
+}
diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-flattening.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-flattening.mlir
new file mode 100644
index 0000000000000..94830366c4f4d
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-multi-reduction-flattening.mlir
@@ -0,0 +1,35 @@
+// RUN: mlir-opt %s --transform-interpreter | FileCheck %s
+
+// Patterns applied:
+// * ReduceMultiDimReductionRank from populateVectorMultiReductionFlatteningPatterns
+func.func @vector_multi_reduction_flattening(%arg0: vector<2x4xf32>, %acc: f32) -> f32 {
+    %0 = vector.multi_reduction <mul>, %arg0, %acc [0, 1] : vector<2x4xf32> to f32
+    return %0 : f32
+}
+
+// CHECK-LABEL: func @vector_multi_reduction_flattening
+//  CHECK-SAME:   %[[INPUT:.+]]: vector<2x4xf32>, %[[ACC:.*]]: f32)
+//       CHECK:   %[[CASTED:.*]] = vector.shape_cast %[[INPUT]] : vector<2x4xf32> to vector<8xf32>
+//       CHECK:   %[[RESULT:.+]] = vector.multi_reduction <mul>, %[[CASTED]], %[[ACC]] [0]
+//       CHECK:   return %[[RESULT]]
+
+// CHECK-LABEL: func @vector_multi_reduction_parallel_dim
+// CHECK-SAME:    %[[INPUT:.+]]: vector<2x3x4xi32>
+// CHECK-SAME:    %[[ACC:.+]]: vector<2xi32>
+func.func @vector_multi_reduction_parallel_dim(%arg0: vector<2x3x4xi32>, %acc: vector<2xi32>) -> vector<2xi32> {
+    // CHECK: %[[CASTED:.+]] = vector.shape_cast %[[INPUT]] : vector<2x3x4xi32> to vector<2x12xi32>
+    // CHECK: %[[RESULT:.+]] = vector.multi_reduction <add>, %[[CASTED]], %[[ACC]] [1]
+    %0 = vector.multi_reduction <add>, %arg0, %acc [1, 2] : vector<2x3x4xi32> to vector<2xi32>
+    // CHECK: return %[[RESULT]]
+    return %0 : vector<2xi32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%root : !transform.any_op {transform.readonly}) {
+    %func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func">
+    transform.apply_patterns to %func_op {
+      transform.apply_patterns.vector.lower_multi_reduction_flattening lowering_strategy = "innerreduction"
+    } : !transform.op<"func.func">
+    transform.yield
+  }
+}
diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir
index 3ce9f57edf9d6..6b79a78e6a42a 100644
--- a/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir
@@ -19,18 +19,6 @@ func.func @vector_multi_reduction(%arg0: vector<2x4xf32>, %acc: vector<2xf32>) -
 //       CHECK:       %[[RESULT_VEC:.+]] = vector.insert %[[RV1:.+]], %[[RESULT_VEC_1]] [1] : f32 into vector<2xf32>
 //       CHECK:       return %[[RESULT_VEC]]
 
-// Patterns applied:
-// * ReduceMultiDimReductionRank from populateVectorMultiReductionFlatteningPatterns
-func.func @vector_multi_reduction_to_scalar(%arg0: vector<2x4xf32>, %acc: f32) -> f32 {
-    %0 = vector.multi_reduction <mul>, %arg0, %acc [0, 1] : vector<2x4xf32> to f32
-    return %0 : f32
-}
-// CHECK-LABEL: func @vector_multi_reduction_to_scalar
-//  CHECK-SAME:   %[[INPUT:.+]]: vector<2x4xf32>, %[[ACC:.*]]: f32)
-//       CHECK:   %[[CASTED:.*]] = vector.shape_cast %[[INPUT]] : vector<2x4xf32> to vector<8xf32>
-//       CHECK:   %[[REDUCED:.*]] = vector.reduction <mul>, %[[CASTED]], %[[ACC]] : vector<8xf32> into f32
-//       CHECK:   return %[[REDUCED]]
-
 // Patterns applied:
 // * ReduceMultiDimReductionRank from populateVectorMultiReductionFlatteningPatterns
 // * TwoDimMultiReductionToReduction from populateVectorMultiReductionUnrollingPatterns

>From 4394e734532361d4771ce6358b09138630463ec3 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Wed, 11 Feb 2026 11:07:42 -0500
Subject: [PATCH 3/6] Add tests for python bindings

---
 mlir/test/python/dialects/transform_vector_ext.py | 8 ++++++++
 1 file changed, 8 insertions(+)

diff --git a/mlir/test/python/dialects/transform_vector_ext.py b/mlir/test/python/dialects/transform_vector_ext.py
index 0f9aab29ca5ca..4c64a4d08a57a 100644
--- a/mlir/test/python/dialects/transform_vector_ext.py
+++ b/mlir/test/python/dialects/transform_vector_ext.py
@@ -108,6 +108,14 @@ def enum_configurable_patterns():
         lowering_strategy=vector.VectorMultiReductionLowering.InnerReduction
     )
 
+    # CHECK: transform.apply_patterns.vector.lower_multi_reduction_flattening
+    vector.ApplyLowerMultiReductionFlatteningPatternsOp()
+    # CHECK: transform.apply_patterns.vector.lower_multi_reduction_flattening
+    # CHECK-SAME: lowering_strategy = innerreduction
+    vector.ApplyLowerMultiReductionFlatteningPatternsOp(
+        lowering_strategy=vector.VectorMultiReductionLowering.InnerReduction
+    )
+
     # CHECK: transform.apply_patterns.vector.lower_transpose
     vector.ApplyLowerTransposePatternsOp()
     # CHECK: transform.apply_patterns.vector.lower_transpose

>From 982509ba52821807a4d5d90f3ffb701201ae1bb6 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Thu, 12 Feb 2026 16:31:03 -0500
Subject: [PATCH 4/6] Merge test file

---
 .../Transforms/LowerVectorMultiReduction.cpp  |   2 +-
 ...ti-reduction-flattening-innerparallel.mlir |  35 -----
 .../vector-multi-reduction-flattening.mlir    | 121 +++++++++++++++---
 3 files changed, 105 insertions(+), 53 deletions(-)
 delete mode 100644 mlir/test/Dialect/Vector/vector-multi-reduction-flattening-innerparallel.mlir

diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index 2d6a49bad27bc..f9960da889788 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -287,7 +287,7 @@ class ReduceMultiDimReductionRank
       return success();
     }
 
-    // 8. Creates shape cast for the output n-D -> 2-D.
+    // 8. Shape cast the flattened result back to the original n-D parallel shape.
     VectorType outputCastedType = VectorType::get(
         parallelShapes, multiReductionOp.getSourceVectorType().getElementType(),
         parallelScalableDims);
diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-flattening-innerparallel.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-flattening-innerparallel.mlir
deleted file mode 100644
index ad8d0cc046d62..0000000000000
--- a/mlir/test/Dialect/Vector/vector-multi-reduction-flattening-innerparallel.mlir
+++ /dev/null
@@ -1,35 +0,0 @@
-// RUN: mlir-opt %s --transform-interpreter | FileCheck %s
-
-// CHECK-LABEL: func @vector_multi_reduction_to_scalar
-// CHECK-SAME:    %[[INPUT:.+]]: vector<2x3xf32>
-// CHECK-SAME:    %[[ACC:.+]]: f32
-func.func @vector_multi_reduction_to_scalar(%arg0: vector<2x3xf32>, %acc: f32) -> f32 {
-    // CHECK: %[[CASTED:.+]] = vector.shape_cast %[[INPUT]] : vector<2x3xf32> to vector<6xf32>
-    // CHECK: %[[RESULT:.+]] = vector.multi_reduction <add>, %[[CASTED]], %[[ACC]] [0]
-    %0 = vector.multi_reduction <add>, %arg0, %acc [0, 1] : vector<2x3xf32> to f32
-    // CHECK: return %[[RESULT]]
-    return %0 : f32
-}
-
-// Test with parallel dimension: reduction dims [0,1] are already outermost,
-// parallel dim [2] is innermost - this can be flattened.
-// CHECK-LABEL: func @vector_multi_reduction_parallel_dim
-// CHECK-SAME:    %[[INPUT:.+]]: vector<3x4x2xi32>
-// CHECK-SAME:    %[[ACC:.+]]: vector<2xi32>
-func.func @vector_multi_reduction_parallel_dim(%arg0: vector<3x4x2xi32>, %acc: vector<2xi32>) -> vector<2xi32> {
-    // CHECK: %[[CASTED:.+]] = vector.shape_cast %[[INPUT]] : vector<3x4x2xi32> to vector<12x2xi32>
-    // CHECK: %[[RESULT:.+]] = vector.multi_reduction <mul>, %[[CASTED]], %[[ACC]] [0]
-    %0 = vector.multi_reduction <mul>, %arg0, %acc [0, 1] : vector<3x4x2xi32> to vector<2xi32>
-    // CHECK: return %[[RESULT]]
-    return %0 : vector<2xi32>
-}
-
-module attributes {transform.with_named_sequence} {
-  transform.named_sequence @__transform_main(%root : !transform.any_op {transform.readonly}) {
-    %func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func">
-    transform.apply_patterns to %func_op {
-      transform.apply_patterns.vector.lower_multi_reduction_flattening lowering_strategy = "innerparallel"
-    } : !transform.op<"func.func">
-    transform.yield
-  }
-}
diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-flattening.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-flattening.mlir
index 94830366c4f4d..55914cff94966 100644
--- a/mlir/test/Dialect/Vector/vector-multi-reduction-flattening.mlir
+++ b/mlir/test/Dialect/Vector/vector-multi-reduction-flattening.mlir
@@ -1,35 +1,122 @@
-// RUN: mlir-opt %s --transform-interpreter | FileCheck %s
+// RUN: mlir-opt %s --transform-interpreter='entry-point=innerreduction' | FileCheck %s --check-prefix=INNER_REDUCTION,ALL
+// RUN: mlir-opt %s --transform-interpreter='entry-point=innerparallel' | FileCheck %s --check-prefix=INNER_PARALLEL,ALL
 
-// Patterns applied:
-// * ReduceMultiDimReductionRank from populateVectorMultiReductionFlatteningPatterns
+// ALL-LABEL: func @negative_flattening_cases
+func.func @negative_flattening_cases(
+    %v1d: vector<8xf32>,
+    %v2d: vector<4x8xf32>,
+    %v_scalable: vector<[2]x[4]x8xf32>,
+    %v_non_contig: vector<2x3x4x5xi32>,
+    %acc_scalar: f32,
+    %acc_1d: vector<8xf32>,
+    %acc_2d: vector<2x4xi32>) -> (f32, vector<8xf32>, vector<8xf32>, vector<2x4xi32>) {
+
+    // Test 1: Less than 2 dimensions
+    // ALL: %[[R1:.+]] = vector.multi_reduction <add>, %{{.+}}, %{{.+}} [0] : vector<8xf32> to f32
+    %r1 = vector.multi_reduction <add>, %v1d, %acc_scalar [0] : vector<8xf32> to f32
+
+    // Test 2: More than one scalable dimensions
+    // ALL: %[[R2:.+]] = vector.multi_reduction <mul>, %{{.+}}, %{{.+}} [0, 1] : vector<[2]x[4]x8xf32> to vector<8xf32>
+    %r2 = vector.multi_reduction <mul>, %v_scalable, %acc_1d [0, 1] : vector<[2]x[4]x8xf32> to vector<8xf32>
+
+    // Test 3: Already 2D with reduction on single dim
+    // ALL: %[[R3:.+]] = vector.multi_reduction <add>, %{{.+}}, %{{.+}} [0] : vector<4x8xf32> to vector<8xf32>
+    %r3 = vector.multi_reduction <add>, %v2d, %acc_1d [0] : vector<4x8xf32> to vector<8xf32>
+
+    // Test 4: Non-contiguous parallel dimensions
+    // ALL: %[[R4:.+]] = vector.multi_reduction <add>, %{{.+}}, %{{.+}} [1, 3] : vector<2x3x4x5xi32> to vector<2x4xi32>
+    %r4 = vector.multi_reduction <add>, %v_non_contig, %acc_2d [1, 3] : vector<2x3x4x5xi32> to vector<2x4xi32>
+
+    // ALL: return %[[R1]], %[[R2]], %[[R3]], %[[R4]]
+    return %r1, %r2, %r3, %r4 : f32, vector<8xf32>, vector<8xf32>, vector<2x4xi32>
+}
+
+// ALL-LABEL: func @vector_multi_reduction_flattening
+// ALL-SAME:   %[[INPUT:.+]]: vector<2x4xf32>, %[[ACC:.*]]: f32)
 func.func @vector_multi_reduction_flattening(%arg0: vector<2x4xf32>, %acc: f32) -> f32 {
+    // ALL: %[[CASTED:.*]] = vector.shape_cast %[[INPUT]] : vector<2x4xf32> to vector<8xf32>
+    // ALL: %[[RESULT:.+]] = vector.multi_reduction <mul>, %[[CASTED]], %[[ACC]] [0]
     %0 = vector.multi_reduction <mul>, %arg0, %acc [0, 1] : vector<2x4xf32> to f32
+    // ALL: return %[[RESULT]]
     return %0 : f32
 }
 
-// CHECK-LABEL: func @vector_multi_reduction_flattening
-//  CHECK-SAME:   %[[INPUT:.+]]: vector<2x4xf32>, %[[ACC:.*]]: f32)
-//       CHECK:   %[[CASTED:.*]] = vector.shape_cast %[[INPUT]] : vector<2x4xf32> to vector<8xf32>
-//       CHECK:   %[[RESULT:.+]] = vector.multi_reduction <mul>, %[[CASTED]], %[[ACC]] [0]
-//       CHECK:   return %[[RESULT]]
-
-// CHECK-LABEL: func @vector_multi_reduction_parallel_dim
-// CHECK-SAME:    %[[INPUT:.+]]: vector<2x3x4xi32>
-// CHECK-SAME:    %[[ACC:.+]]: vector<2xi32>
-func.func @vector_multi_reduction_parallel_dim(%arg0: vector<2x3x4xi32>, %acc: vector<2xi32>) -> vector<2xi32> {
-    // CHECK: %[[CASTED:.+]] = vector.shape_cast %[[INPUT]] : vector<2x3x4xi32> to vector<2x12xi32>
-    // CHECK: %[[RESULT:.+]] = vector.multi_reduction <add>, %[[CASTED]], %[[ACC]] [1]
+// INNER_REDUCTION-LABEL: func @vector_multi_reduction_parallel_dim_innerreduction
+// INNER_REDUCTION-SAME:    %[[INPUT:.+]]: vector<2x3x4xi32>
+// INNER_REDUCTION-SAME:    %[[ACC:.+]]: vector<2xi32>
+func.func @vector_multi_reduction_parallel_dim_innerreduction(%arg0: vector<2x3x4xi32>, %acc: vector<2xi32>) -> vector<2xi32> {
+    // INNER_REDUCTION: %[[CASTED:.+]] = vector.shape_cast %[[INPUT]] : vector<2x3x4xi32> to vector<2x12xi32>
+    // INNER_REDUCTION: %[[RESULT:.+]] = vector.multi_reduction <add>, %[[CASTED]], %[[ACC]] [1]
     %0 = vector.multi_reduction <add>, %arg0, %acc [1, 2] : vector<2x3x4xi32> to vector<2xi32>
-    // CHECK: return %[[RESULT]]
+    // INNER_REDUCTION: return %[[RESULT]]
+    return %0 : vector<2xi32>
+}
+
+// INNER_REDUCTION-LABEL: func @output_shapecast_multiple_parallel
+// INNER_REDUCTION-SAME:    %[[INPUT:.+]]: vector<2x3x4x5x6xi32>
+// INNER_REDUCTION-SAME:    %[[ACC:.+]]: vector<2x3x4xi32>
+func.func @output_shapecast_multiple_parallel(%arg0: vector<2x3x4x5x6xi32>, %acc: vector<2x3x4xi32>) -> vector<2x3x4xi32> {
+    // INNER_REDUCTION: %[[INPUT_CAST:.+]] = vector.shape_cast %[[INPUT]] : vector<2x3x4x5x6xi32> to vector<24x30xi32>
+    // INNER_REDUCTION: %[[ACC_CAST:.+]] = vector.shape_cast %[[ACC]] : vector<2x3x4xi32> to vector<24xi32>
+    // INNER_REDUCTION: %[[RESULT_FLAT:.+]] = vector.multi_reduction <mul>, %[[INPUT_CAST]], %[[ACC_CAST]] [1]
+    // INNER_REDUCTION: %[[RESULT:.+]] = vector.shape_cast %[[RESULT_FLAT]] : vector<24xi32> to vector<2x3x4xi32>
+    %0 = vector.multi_reduction <mul>, %arg0, %acc [3, 4] : vector<2x3x4x5x6xi32> to vector<2x3x4xi32>
+    // INNER_REDUCTION: return %[[RESULT]]
+    return %0 : vector<2x3x4xi32>
+}
+
+// INNER_PARALLEL-LABEL: func @vector_multi_reduction_parallel_dim_innerparallel
+// INNER_PARALLEL-SAME:    %[[INPUT:.+]]: vector<3x4x2xi32>
+// INNER_PARALLEL-SAME:    %[[ACC:.+]]: vector<2xi32>
+func.func @vector_multi_reduction_parallel_dim_innerparallel(%arg0: vector<3x4x2xi32>, %acc: vector<2xi32>) -> vector<2xi32> {
+    // INNER_PARALLEL: %[[CASTED:.+]] = vector.shape_cast %[[INPUT]] : vector<3x4x2xi32> to vector<12x2xi32>
+    // INNER_PARALLEL: %[[RESULT:.+]] = vector.multi_reduction <mul>, %[[CASTED]], %[[ACC]] [0]
+    %0 = vector.multi_reduction <mul>, %arg0, %acc [0, 1] : vector<3x4x2xi32> to vector<2xi32>
+    // INNER_PARALLEL: return %[[RESULT]]
     return %0 : vector<2xi32>
 }
 
+// ALL-LABEL: func @single_scalable_dim
+// ALL-SAME:    %[[INPUT:.+]]: vector<4x[8]xf32>
+// ALL-SAME:    %[[ACC:.+]]: f32
+func.func @single_scalable_dim(%arg0: vector<4x[8]xf32>, %acc: f32) -> f32 {
+    // ALL: %[[CASTED:.+]] = vector.shape_cast %[[INPUT]] : vector<4x[8]xf32> to vector<[32]xf32>
+    // ALL: %[[RESULT:.+]] = vector.multi_reduction <add>, %[[CASTED]], %[[ACC]] [0]
+    %0 = vector.multi_reduction <add>, %arg0, %acc [0, 1] : vector<4x[8]xf32> to f32
+    // ALL: return %[[RESULT]]
+    return %0 : f32
+}
+
+// ALL-LABEL: func @masked_multi_reduction
+// ALL-SAME:    %[[INPUT:.+]]: vector<2x4xf32>
+// ALL-SAME:    %[[ACC:.+]]: f32
+// ALL-SAME:    %[[MASK:.+]]: vector<2x4xi1>
+func.func @masked_multi_reduction(%arg0: vector<2x4xf32>, %acc: f32, %mask: vector<2x4xi1>) -> f32 {
+    // ALL: %[[CASTED_MASK:.+]] = vector.shape_cast %[[MASK]] : vector<2x4xi1> to vector<8xi1>
+    // ALL: %[[CASTED_INPUT:.+]] = vector.shape_cast %[[INPUT]] : vector<2x4xf32> to vector<8xf32>
+    // ALL: %[[RESULT:.+]] = vector.mask %[[CASTED_MASK]]
+    // ALL:   vector.multi_reduction <mul>, %[[CASTED_INPUT]], %[[ACC]] [0]
+    %0 = vector.mask %mask {
+      vector.multi_reduction <mul>, %arg0, %acc [0, 1] : vector<2x4xf32> to f32
+    } : vector<2x4xi1> -> f32
+    // ALL: return %[[RESULT]]
+    return %0 : f32
+}
+
 module attributes {transform.with_named_sequence} {
-  transform.named_sequence @__transform_main(%root : !transform.any_op {transform.readonly}) {
+  transform.named_sequence @innerreduction(%root : !transform.any_op {transform.readonly}) {
     %func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func">
     transform.apply_patterns to %func_op {
       transform.apply_patterns.vector.lower_multi_reduction_flattening lowering_strategy = "innerreduction"
     } : !transform.op<"func.func">
     transform.yield
   }
+
+  transform.named_sequence @innerparallel(%root : !transform.any_op {transform.readonly}) {
+    %func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func">
+    transform.apply_patterns to %func_op {
+      transform.apply_patterns.vector.lower_multi_reduction_flattening lowering_strategy = "innerparallel"
+    } : !transform.op<"func.func">
+    transform.yield
+  }
 }

>From e943cec57e6683070146638627fc721c280fbef4 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Wed, 18 Feb 2026 10:16:06 -0500
Subject: [PATCH 5/6] Fix rebase

---
 .../Vector/TransformOps/VectorTransformOps.td   | 17 +++++++++++++++++
 .../Vector/TransformOps/VectorTransformOps.cpp  |  2 +-
 .../vector-multi-reduction-flattening.mlir      |  4 ++--
 .../python/dialects/transform_vector_ext.py     |  8 ++++----
 4 files changed, 24 insertions(+), 7 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index 462c61df72108..6eb96e2a8fdab 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -266,6 +266,23 @@ def ApplyReorderAndExpandMultiReductionPatternsOp: Op<Transform_Dialect,
   }];
 }
 
+def ApplyMultiReductionFlatteningPatternsOp: Op<Transform_Dialect,
+    "apply_patterns.vector.multi_reduction_flattening",
+    [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+  let description = [{
+    Indicates that vector multi_reduction operations should be flattened from
+    more than 2-D to 2-D.
+  }];
+
+  let arguments = (ins DefaultValuedAttr<VectorMultiReductionLoweringAttr,
+      "vector::VectorMultiReductionLowering::InnerParallel">:$lowering_strategy
+  );
+
+  let assemblyFormat = [{
+    (`lowering_strategy` `=` $lowering_strategy^)? attr-dict
+  }];
+}
+
 def ApplyLowerOuterProductPatternsOp : Op<Transform_Dialect,
     "apply_patterns.vector.lower_outerproduct",
     [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index 5247140a48e24..f3529ac26523f 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -146,7 +146,7 @@ void transform::ApplyReorderAndExpandMultiReductionPatternsOp::populatePatterns(
       patterns, vectorTransformOptions.vectorMultiReductionLowering);
 }
 
-void transform::ApplyLowerMultiReductionFlatteningPatternsOp::populatePatterns(
+void transform::ApplyMultiReductionFlatteningPatternsOp::populatePatterns(
     RewritePatternSet &patterns) {
   vector::VectorTransformsOptions vectorTransformOptions;
   vectorTransformOptions.setVectorMultiReductionLowering(getLoweringStrategy());
diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-flattening.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-flattening.mlir
index 55914cff94966..b8f970912909b 100644
--- a/mlir/test/Dialect/Vector/vector-multi-reduction-flattening.mlir
+++ b/mlir/test/Dialect/Vector/vector-multi-reduction-flattening.mlir
@@ -107,7 +107,7 @@ module attributes {transform.with_named_sequence} {
   transform.named_sequence @innerreduction(%root : !transform.any_op {transform.readonly}) {
     %func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func">
     transform.apply_patterns to %func_op {
-      transform.apply_patterns.vector.lower_multi_reduction_flattening lowering_strategy = "innerreduction"
+      transform.apply_patterns.vector.multi_reduction_flattening lowering_strategy = "innerreduction"
     } : !transform.op<"func.func">
     transform.yield
   }
@@ -115,7 +115,7 @@ module attributes {transform.with_named_sequence} {
   transform.named_sequence @innerparallel(%root : !transform.any_op {transform.readonly}) {
     %func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func">
     transform.apply_patterns to %func_op {
-      transform.apply_patterns.vector.lower_multi_reduction_flattening lowering_strategy = "innerparallel"
+      transform.apply_patterns.vector.multi_reduction_flattening lowering_strategy = "innerparallel"
     } : !transform.op<"func.func">
     transform.yield
   }
diff --git a/mlir/test/python/dialects/transform_vector_ext.py b/mlir/test/python/dialects/transform_vector_ext.py
index 4c64a4d08a57a..29ce8ba63cd53 100644
--- a/mlir/test/python/dialects/transform_vector_ext.py
+++ b/mlir/test/python/dialects/transform_vector_ext.py
@@ -108,11 +108,11 @@ def enum_configurable_patterns():
         lowering_strategy=vector.VectorMultiReductionLowering.InnerReduction
     )
 
-    # CHECK: transform.apply_patterns.vector.lower_multi_reduction_flattening
-    vector.ApplyLowerMultiReductionFlatteningPatternsOp()
-    # CHECK: transform.apply_patterns.vector.lower_multi_reduction_flattening
+    # CHECK: transform.apply_patterns.vector.multi_reduction_flattening
+    vector.ApplyMultiReductionFlatteningPatternsOp()
+    # CHECK: transform.apply_patterns.vector.multi_reduction_flattening
     # CHECK-SAME: lowering_strategy = innerreduction
-    vector.ApplyLowerMultiReductionFlatteningPatternsOp(
+    vector.ApplyMultiReductionFlatteningPatternsOp(
         lowering_strategy=vector.VectorMultiReductionLowering.InnerReduction
     )
 

>From 4f92de4fcd2900ebef20dc6c77dcfaf2149f89d9 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Wed, 18 Feb 2026 10:25:23 -0500
Subject: [PATCH 6/6] Style

---
 .../Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp    | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index f9960da889788..fec04c967c9e1 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -287,7 +287,8 @@ class ReduceMultiDimReductionRank
       return success();
     }
 
-    // 8. Shape cast the flattened result back to the original n-D parallel shape.
+    // 8. Shape cast the flattened result back to the original n-D parallel
+    // shape.
     VectorType outputCastedType = VectorType::get(
         parallelShapes, multiReductionOp.getSourceVectorType().getElementType(),
         parallelScalableDims);



More information about the Mlir-commits mailing list