[Mlir-commits] [mlir] [mlir][vector] Replace OneDimMultiReductionToTwoDim with OneDimMultiReductionToReduction (PR #184241)

Erick Ochoa Lopez llvmlistbot at llvm.org
Wed Mar 4 07:46:13 PST 2026


================
@@ -1,15 +1,32 @@
 // RUN: mlir-opt %s --transform-interpreter='entry-point=innerreduction' | FileCheck %s --check-prefixes=INNER_REDUCTION,ALL
 // RUN: mlir-opt %s --transform-interpreter='entry-point=innerparallel' | FileCheck %s --check-prefixes=INNER_PARALLEL,ALL
 
-// ALL-LABEL: func @negative_rank1_and_rank3
-func.func @negative_rank1_and_rank3(
-    %rank1: vector<8xf32>, %rank1_acc: f32,
-    %rank3: vector<2x3x4xf32>, %rank3_acc: vector<2x3xf32>) -> (f32, vector<2x3xf32>) {
-  // ALL: vector.multi_reduction <add>, {{.+}} [0] : vector<8xf32> to f32
-  %0 = vector.multi_reduction <add>, %rank1, %rank1_acc [0] : vector<8xf32> to f32
+// ALL-LABEL: func @one_dim_reduction
+// ALL-SAME:    %[[INPUT:.+]]: vector<8xf32>, %[[ACC:.+]]: f32
+func.func @one_dim_reduction(%arg0: vector<8xf32>, %acc: f32) -> f32 {
+  // ALL: %[[RESULT:.+]] = vector.reduction <add>, %[[INPUT]], %[[ACC]] : vector<8xf32> into f32
+  %0 = vector.multi_reduction <add>, %arg0, %acc [0] : vector<8xf32> to f32
+  // ALL: return %[[RESULT]]
+  return %0 : f32
+}
+
+// ALL-LABEL: func @one_dim_reduction_masked
+// ALL-SAME:    %[[INPUT:.+]]: vector<8xf32>, %[[ACC:.+]]: f32, %[[MASK:.+]]: vector<8xi1>
+func.func @one_dim_reduction_masked(%arg0: vector<8xf32>, %acc: f32, %mask: vector<8xi1>) -> f32 {
+  // ALL: %[[RESULT:.+]] = vector.mask %[[MASK]] { vector.reduction <add>, %[[INPUT]], %[[ACC]] : vector<8xf32> into f32 } : vector<8xi1> -> f32
+  %0 = vector.mask %mask {
+    vector.multi_reduction <add>, %arg0, %acc [0] : vector<8xf32> to f32
+  } : vector<8xi1> -> f32
+  // ALL: return %[[RESULT]]
+  return %0 : f32
+}
+
+// ALL-LABEL: func @negative_rank3
+func.func @negative_rank3(
+    %rank3: vector<2x3x4xf32>, %rank3_acc: vector<2x3xf32>) -> vector<2x3xf32> {
   // ALL: vector.multi_reduction <add>, {{.+}} [2] : vector<2x3x4xf32> to vector<2x3xf32>
-  %1 = vector.multi_reduction <add>, %rank3, %rank3_acc [2] : vector<2x3x4xf32> to vector<2x3xf32>
-  return %0, %1 : f32, vector<2x3xf32>
+  %0 = vector.multi_reduction <add>, %rank3, %rank3_acc [2] : vector<2x3x4xf32> to vector<2x3xf32>
+  return %0 : vector<2x3xf32>
----------------
amd-eochoalo wrote:

Yes, this will be actually unrolled in the next PR.

https://github.com/llvm/llvm-project/pull/184241


More information about the Mlir-commits mailing list