[Mlir-commits] [mlir] [mlir][vector] (PR #176036)

Erick Ochoa Lopez llvmlistbot at llvm.org
Wed Jan 14 13:12:08 PST 2026


https://github.com/amd-eochoalo created https://github.com/llvm/llvm-project/pull/176036

* Adds MultiDimMultiReductionToElementWise which matches vector.multi_reduction with outermost reduction dimensions. Transforms it to a series of extraction to the innermost (parallel) dimensions and elementwise operations. (This is different from the current unrolling for vector.multi_reduction which uses flattening in pattern ReduceMultiDimReductionRank.)

>From 2b08d3259b3f8010210499aecb7aede7e326f6e0 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Wed, 14 Jan 2026 11:36:25 -0500
Subject: [PATCH 1/5] [mlir][vector] Add
 populateVectorInnerOuterDimReductionConversionPatterns

This method will be used for testing and for unrolling
vector.multi_reduction in a rank-reducing way that does not involve
flattening.
---
 .../mlir/Dialect/Vector/Transforms/LoweringPatterns.h      | 3 +++
 .../Vector/Transforms/LowerVectorMultiReduction.cpp        | 7 +++++++
 2 files changed, 10 insertions(+)

diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index 7bd96c8a6d1a1..658365b97d721 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -328,6 +328,9 @@ void populateVectorContractToMatrixMultiply(RewritePatternSet &patterns,
 void populateVectorTransposeToFlatTranspose(RewritePatternSet &patterns,
                                             PatternBenefit benefit = 100);
 
+void populateVectorInnerOuterDimReductionConversionPatterns(
+    RewritePatternSet &patterns, VectorMultiReductionLowering options,
+    PatternBenefit benefit = 1);
 } // namespace vector
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index e86e2a97038db..b5660efc8f4cf 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -511,6 +511,13 @@ struct LowerVectorMultiReductionPass
 
 } // namespace
 
+void mlir::vector::populateVectorInnerOuterDimReductionConversionPatterns(
+    RewritePatternSet &patterns, VectorMultiReductionLowering options,
+    PatternBenefit benefit) {
+  patterns.add<InnerOuterDimReductionConversion>(patterns.getContext(), options,
+                                                 benefit);
+}
+
 void mlir::vector::populateVectorMultiReductionLoweringPatterns(
     RewritePatternSet &patterns, VectorMultiReductionLowering options,
     PatternBenefit benefit) {

>From d9035915bd6102d6fe2fd21897c37d21d30f31c8 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Wed, 14 Jan 2026 11:46:50 -0500
Subject: [PATCH 2/5] [mlir][vector] use TD Op for testing multi_reduction

* Test vector.multi_reduction's patterns using the transform dialect.
  This will be useful as we will be creating new patterns to unroll
  vector.multi_reduction without flattening.
---
 .../Vector/TransformOps/VectorTransformOps.td | 15 +++++++++++++++
 .../TransformOps/VectorTransformOps.cpp       |  8 ++++++++
 .../Vector/td/inner-outer-dim-conversion.mlir | 14 ++++++++++++++
 ...-inner-outer-dim-reduction-conversion.mlir | 19 +++++++++++++++++++
 4 files changed, 56 insertions(+)
 create mode 100644 mlir/test/Dialect/Vector/td/inner-outer-dim-conversion.mlir
 create mode 100644 mlir/test/Dialect/Vector/vector-inner-outer-dim-reduction-conversion.mlir

diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index 03d25505dc65c..e2423593075f6 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -539,4 +539,19 @@ def ApplySinkVectorMemPatternsOp : Op<Transform_Dialect,
   let assemblyFormat = "attr-dict";
 }
 
+def ApplyInnerOuterDimReductionConversionPatternsOp : Op<Transform_Dialect,
+    "apply_patterns.vector.inner_outer_dim_reduction_conversion",
+    [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+  let description = [{
+    Converts vector.multi_reduction into inner-most/outer-most reduction form
+    by using vector.transpose.
+  }];
+
+  let arguments = (ins DefaultValuedAttr<VectorMultiReductionLoweringAttr,
+      "vector::VectorMultiReductionLowering::InnerParallel">:$lowering_strategy
+  );
+
+  let assemblyFormat = "attr-dict";
+}
+
 #endif // VECTOR_TRANSFORM_OPS
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index 7faa222a9e574..53284c5873841 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -227,6 +227,14 @@ void transform::ApplySinkVectorMemPatternsOp::populatePatterns(
   vector::populateSinkVectorMemOpsPatterns(patterns);
 }
 
+void transform::ApplyInnerOuterDimReductionConversionPatternsOp::
+    populatePatterns(RewritePatternSet &patterns) {
+  vector::VectorTransformsOptions vectorTransformOptions;
+  vectorTransformOptions.setVectorMultiReductionLowering(getLoweringStrategy());
+  vector::populateVectorInnerOuterDimReductionConversionPatterns(
+      patterns, vectorTransformOptions.vectorMultiReductionLowering);
+}
+
 //===----------------------------------------------------------------------===//
 // Transform op registration
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/td/inner-outer-dim-conversion.mlir b/mlir/test/Dialect/Vector/td/inner-outer-dim-conversion.mlir
new file mode 100644
index 0000000000000..e965e414ddeae
--- /dev/null
+++ b/mlir/test/Dialect/Vector/td/inner-outer-dim-conversion.mlir
@@ -0,0 +1,14 @@
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @inner_outer_dim_reduction_conversion(%module_op: !transform.any_op {transform.readonly}) {
+
+    %func_op = transform.structured.match ops{["func.func"]} in %module_op
+      : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func_op {
+      // Test patterns
+      transform.apply_patterns.vector.inner_outer_dim_reduction_conversion
+    } : !transform.any_op
+
+    transform.yield
+  }
+}
+
diff --git a/mlir/test/Dialect/Vector/vector-inner-outer-dim-reduction-conversion.mlir b/mlir/test/Dialect/Vector/vector-inner-outer-dim-reduction-conversion.mlir
new file mode 100644
index 0000000000000..bc99e64e2f909
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-inner-outer-dim-reduction-conversion.mlir
@@ -0,0 +1,19 @@
+// RUN: mlir-opt %s -transform-preload-library='transform-library-paths=%p/td/inner-outer-dim-conversion.mlir' \
+// RUN: -transform-interpreter=entry-point=inner_outer_dim_reduction_conversion | FileCheck %s
+
+//===----------------------------------------------------------------------===//
+// Test InnerOuterDimReductionConversion
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: func @inner_outer_dim_reduction_conversion(
+// CHECK-SAME: %[[ARG0:.+]]: vector<2x3x5x7xf32>,
+// CHECK-SAME: %[[ACC:.+]]: vector<3x5xf32>
+func.func @inner_outer_dim_reduction_conversion(%arg0: vector<2x3x5x7xf32>, %acc: vector<3x5xf32>) -> (vector<3x5xf32>) {
+  // CHECK: %[[TRANSPOSE:.+]] = vector.transpose %[[ARG0]], [0, 3, 1, 2] : vector<2x3x5x7xf32> to vector<2x7x3x5xf32>
+  // CHECK: %[[RES:.+]] = vector.multi_reduction <add>, %[[TRANSPOSE]], %[[ACC]] [0, 1] : vector<2x7x3x5xf32> to vector<3x5xf32>
+  %1 = vector.multi_reduction <add>, %arg0, %acc [0, 3] : vector<2x3x5x7xf32> to vector<3x5xf32>
+
+  // CHECK: return %[[RES]]
+  return %1 : vector<3x5xf32>
+}
+

>From 46744358a4a746e4dd3e6ce2371d8c27b8f16f47 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Wed, 14 Jan 2026 14:53:25 -0500
Subject: [PATCH 3/5] [mlir][vector] Add unrolling pattern for
 vector.multi_reduction

---
 .../Vector/Transforms/LoweringPatterns.h      |  4 +
 .../Transforms/LowerVectorMultiReduction.cpp  | 92 +++++++++++++++++++
 ...-inner-outer-dim-reduction-conversion.mlir |  1 -
 3 files changed, 96 insertions(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index 658365b97d721..f1a57cbd43360 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -331,6 +331,10 @@ void populateVectorTransposeToFlatTranspose(RewritePatternSet &patterns,
 void populateVectorInnerOuterDimReductionConversionPatterns(
     RewritePatternSet &patterns, VectorMultiReductionLowering options,
     PatternBenefit benefit = 1);
+
+void populateVectorMultiDimMultiReductionToElementWisePatterns(
+    RewritePatternSet &patterns, PatternBenefit benefit = 1);
+
 } // namespace vector
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index b5660efc8f4cf..4644f8973c481 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -300,6 +300,93 @@ class ReduceMultiDimReductionRank
   const bool useInnerDimsForReduction;
 };
 
+/// Unrolls vector.multi_reduction with outermost reductions
+/// and combines results
+struct MultiDimMultiReductionToElementWise
+    : public OpRewritePattern<vector::MultiDimReductionOp> {
+  using Base::Base;
+
+  LogicalResult match(const ArrayRef<int64_t> reductionDims,
+                      uint64_t rank) const {
+    // ["parallel"*, "reduce"+] or bail.
+    // That means we should match the following template:
+    //
+    // ```mlir
+    // %0 = vector.multi_reduction <add>, %source, %acc [0, ..., N] :
+    //      vector<z... x o x n x...x b x a xT> to vector<m...xbxaxT>
+    // ```
+    //
+    // This means that we can compare the set of integers from 0 to N
+    // with the reduction dimensions to see if the operation should match.
+    std::set<int64_t> reductionDimsSet(std::begin(reductionDims),
+                                       std::end(reductionDims));
+
+    SmallVector<int64_t> expectedReductionDims(reductionDims.size());
+    std::iota(std::begin(expectedReductionDims),
+              std::end(expectedReductionDims), 0);
+
+    std::set<int64_t> expectedReductionDimsSet(
+        std::begin(expectedReductionDims), std::end(expectedReductionDims));
+    bool equal = reductionDimsSet == expectedReductionDimsSet;
+    return equal ? success() : failure();
+  }
+
+  LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
+                                PatternRewriter &rewriter) const override {
+
+    Type elementType = getElementTypeOrSelf(multiReductionOp.getDestType());
+    if (!elementType.isIntOrIndexOrFloat())
+      return rewriter.notifyMatchFailure(
+          multiReductionOp, "expected element type to be integer or float.");
+
+    ArrayRef<int64_t> reductionDims = multiReductionOp.getReductionDims();
+    uint64_t srcRank = multiReductionOp.getSourceVectorType().getRank();
+    if (failed(match(reductionDims, srcRank)))
+      return rewriter.notifyMatchFailure(
+          multiReductionOp, "expected outermost dimensions to be reduced.");
+
+    // ```mlir
+    // %0 = vector.multi_reduction <add>, %source, %acc [0, ..., N] :
+    //      vector<z... x o x n x...x b x a xT> to vector<m...xbxaxT>
+    // ```
+    // to
+    // ```mlir
+    // %tmp0 = vector.extract %source[...] : vector<m...xbxaxT> from
+    // vector<zx...axT>
+    // ```
+    Value source = multiReductionOp.getSource();
+
+    // srcShape = zx...xa
+    // targetShape = mx...xa
+    ArrayRef<int64_t> srcShape =
+        multiReductionOp.getSourceVectorType().getShape();
+    ArrayRef<int64_t> targetShape = srcShape.drop_front(reductionDims.size());
+
+    Location loc = multiReductionOp.getLoc();
+    SmallVector<Value> innermostVectors;
+    for (SmallVector<int64_t> offsets :
+         StaticTileOffsetRange(srcShape, targetShape)) {
+      // Offsets are valid for z...n.
+      // But trailing zeros that correspond to dimensions m..a must be stripped.
+      int64_t toDrop = srcRank - reductionDims.size();
+      ArrayRef<int64_t> validOffsets =
+          ArrayRef<int64_t>(offsets).drop_back(toDrop);
+      innermostVectors.push_back(
+          vector::ExtractOp::create(rewriter, loc, source, validOffsets));
+    }
+
+    Value result = multiReductionOp.getAcc();
+    for (Value innermostVector : innermostVectors) {
+      Value extractMask = nullptr;
+      result = makeArithReduction(rewriter, loc, multiReductionOp.getKind(),
+                                  innermostVector, result, /*fastmath=*/nullptr,
+                                  extractMask);
+    }
+    rewriter.replaceOp(multiReductionOp, result);
+    return success();
+  }
+};
+
 /// Unrolls vector.multi_reduction with outermost reductions
 /// and combines results
 struct TwoDimMultiReductionToElementWise
@@ -518,6 +605,11 @@ void mlir::vector::populateVectorInnerOuterDimReductionConversionPatterns(
                                                  benefit);
 }
 
+void mlir::vector::populateVectorMultiDimMultiReductionToElementWisePatterns(
+    RewritePatternSet &patterns, PatternBenefit benefit) {
+  patterns.add<MultiDimMultiReductionToElementWise>(patterns.getContext());
+}
+
 void mlir::vector::populateVectorMultiReductionLoweringPatterns(
     RewritePatternSet &patterns, VectorMultiReductionLowering options,
     PatternBenefit benefit) {
diff --git a/mlir/test/Dialect/Vector/vector-inner-outer-dim-reduction-conversion.mlir b/mlir/test/Dialect/Vector/vector-inner-outer-dim-reduction-conversion.mlir
index bc99e64e2f909..9c508eb3d53c0 100644
--- a/mlir/test/Dialect/Vector/vector-inner-outer-dim-reduction-conversion.mlir
+++ b/mlir/test/Dialect/Vector/vector-inner-outer-dim-reduction-conversion.mlir
@@ -16,4 +16,3 @@ func.func @inner_outer_dim_reduction_conversion(%arg0: vector<2x3x5x7xf32>, %acc
   // CHECK: return %[[RES]]
   return %1 : vector<3x5xf32>
 }
-

>From 652a397dd0be61c8a693fbf2c4a036b52bee4e79 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Wed, 14 Jan 2026 15:14:16 -0500
Subject: [PATCH 4/5] [mlir][vector] Test for multi-dim multi_reduction.

---
 .../Vector/TransformOps/VectorTransformOps.td | 11 ++++++++
 .../TransformOps/VectorTransformOps.cpp       |  5 ++++
 ...multi-dim-multi-reduction-elementwise.mlir | 14 ++++++++++
 ...ti-dim-multi-reduction-to-elementwise.mlir | 27 +++++++++++++++++++
 4 files changed, 57 insertions(+)
 create mode 100644 mlir/test/Dialect/Vector/td/multi-dim-multi-reduction-elementwise.mlir
 create mode 100644 mlir/test/Dialect/Vector/vector-multi-dim-multi-reduction-to-elementwise.mlir

diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index e2423593075f6..946638804b154 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -554,4 +554,15 @@ def ApplyInnerOuterDimReductionConversionPatternsOp : Op<Transform_Dialect,
   let assemblyFormat = "attr-dict";
 }
 
+def ApplyMultiDimMultiReductionToElementWisePatternsOp : Op<Transform_Dialect,
+    "apply_patterns.vector.multi_dim_multi_reduction_to_elementwise",
+    [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+  let description = [{
+    Unrolls vector.multi_reduction with outer-most reduction to elementwise
+    operations.
+  }];
+
+  let assemblyFormat = "attr-dict";
+}
+
 #endif // VECTOR_TRANSFORM_OPS
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index 53284c5873841..b8856fc199bbd 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -235,6 +235,11 @@ void transform::ApplyInnerOuterDimReductionConversionPatternsOp::
       patterns, vectorTransformOptions.vectorMultiReductionLowering);
 }
 
+void transform::ApplyMultiDimMultiReductionToElementWisePatternsOp::
+    populatePatterns(RewritePatternSet &patterns) {
+  vector::populateVectorMultiDimMultiReductionToElementWisePatterns(patterns);
+}
+
 //===----------------------------------------------------------------------===//
 // Transform op registration
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/td/multi-dim-multi-reduction-elementwise.mlir b/mlir/test/Dialect/Vector/td/multi-dim-multi-reduction-elementwise.mlir
new file mode 100644
index 0000000000000..1d6c9768ce4ee
--- /dev/null
+++ b/mlir/test/Dialect/Vector/td/multi-dim-multi-reduction-elementwise.mlir
@@ -0,0 +1,14 @@
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @multi_dim_multi_reduction_to_elementwise(%module_op: !transform.any_op {transform.readonly}) {
+
+    %func_op = transform.structured.match ops{["func.func"]} in %module_op
+      : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func_op {
+      // Test patterns
+      transform.apply_patterns.vector.multi_dim_multi_reduction_to_elementwise
+    } : !transform.any_op
+
+    transform.yield
+  }
+}
+
diff --git a/mlir/test/Dialect/Vector/vector-multi-dim-multi-reduction-to-elementwise.mlir b/mlir/test/Dialect/Vector/vector-multi-dim-multi-reduction-to-elementwise.mlir
new file mode 100644
index 0000000000000..923cb7b0e4abe
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-multi-dim-multi-reduction-to-elementwise.mlir
@@ -0,0 +1,27 @@
+// RUN: mlir-opt %s -transform-preload-library='transform-library-paths=%p/td/multi-dim-multi-reduction-elementwise.mlir' \
+// RUN: -transform-interpreter=entry-point=multi_dim_multi_reduction_to_elementwise | FileCheck %s
+
+//===----------------------------------------------------------------------===//
+// Test MultiDimMultiReductionToElementWise
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: func @multi_reduction_multi_dimension_elementwise_unroll(
+// CHECK: %[[ARG:.+]]: vector<2x3x5xf32>,
+// CHECK: %[[ACC:.+]]: vector<5xf32>
+func.func @multi_reduction_multi_dimension_elementwise_unroll(%arg0: vector<2x3x5xf32>, %acc: vector<5xf32>) -> (vector<5xf32>) {
+  // CHECK: %[[VEC_0_0:.+]] = vector.extract %[[ARG]][0, 0] : vector<5xf32> from vector<2x3x5xf32>
+  // CHECK: %[[VEC_0_1:.+]] = vector.extract %[[ARG]][0, 1] : vector<5xf32> from vector<2x3x5xf32>
+  // CHECK: %[[VEC_0_2:.+]] = vector.extract %[[ARG]][0, 2] : vector<5xf32> from vector<2x3x5xf32>
+  // CHECK: %[[VEC_1_0:.+]] = vector.extract %[[ARG]][1, 0] : vector<5xf32> from vector<2x3x5xf32>
+  // CHECK: %[[VEC_1_1:.+]] = vector.extract %[[ARG]][1, 1] : vector<5xf32> from vector<2x3x5xf32>
+  // CHECK: %[[VEC_1_2:.+]] = vector.extract %[[ARG]][1, 2] : vector<5xf32> from vector<2x3x5xf32>
+  // CHECK: %[[RES_0:.+]] = arith.addf %[[VEC_0_0]], %[[ACC]] : vector<5xf32>
+  // CHECK: %[[RES_1:.+]] = arith.addf %[[VEC_0_1]], %[[RES_0]] : vector<5xf32>
+  // CHECK: %[[RES_2:.+]] = arith.addf %[[VEC_0_2]], %[[RES_1]] : vector<5xf32>
+  // CHECK: %[[RES_3:.+]] = arith.addf %[[VEC_1_0]], %[[RES_2]] : vector<5xf32>
+  // CHECK: %[[RES_4:.+]] = arith.addf %[[VEC_1_1]], %[[RES_3]] : vector<5xf32>
+  // CHECK: %[[RES_5:.+]] = arith.addf %[[VEC_1_2]], %[[RES_4]] : vector<5xf32>
+  %1 = vector.multi_reduction <add>, %arg0, %acc [0, 1] : vector<2x3x5xf32> to vector<5xf32>
+  // CHECK: return %[[RES_5]]
+  return %1 : vector<5xf32>
+}

>From 3666c5a598d312a79dc316cab17b75680913df8d Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Wed, 14 Jan 2026 16:04:03 -0500
Subject: [PATCH 5/5] Add support for masks

---
 .../Transforms/LowerVectorMultiReduction.cpp  | 26 ++++++++--
 ...ti-dim-multi-reduction-to-elementwise.mlir | 49 +++++++++++++++++--
 2 files changed, 69 insertions(+), 6 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index 4644f8973c481..61f1b5bd7f35f 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -363,7 +363,23 @@ struct MultiDimMultiReductionToElementWise
     ArrayRef<int64_t> targetShape = srcShape.drop_front(reductionDims.size());
 
     Location loc = multiReductionOp.getLoc();
+    OpBuilder::InsertionGuard guard(rewriter);
+    auto maskableOp =
+        cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
+
+    Value mask;
+    Operation *rootOp;
+    bool isMasked = maskableOp.isMasked();
+    if (isMasked) {
+      rewriter.setInsertionPoint(maskableOp.getMaskingOp());
+      rootOp = maskableOp.getMaskingOp();
+      mask = maskableOp.getMaskingOp().getMask();
+    } else {
+      rootOp = multiReductionOp;
+    }
+
     SmallVector<Value> innermostVectors;
+    SmallVector<Value> masks;
     for (SmallVector<int64_t> offsets :
          StaticTileOffsetRange(srcShape, targetShape)) {
       // Offsets are valid for z...n.
@@ -373,16 +389,20 @@ struct MultiDimMultiReductionToElementWise
           ArrayRef<int64_t>(offsets).drop_back(toDrop);
       innermostVectors.push_back(
           vector::ExtractOp::create(rewriter, loc, source, validOffsets));
+
+      if (isMasked)
+        masks.push_back(
+            vector::ExtractOp::create(rewriter, loc, mask, validOffsets));
     }
 
     Value result = multiReductionOp.getAcc();
-    for (Value innermostVector : innermostVectors) {
-      Value extractMask = nullptr;
+    for (auto [idx, innermostVector] : llvm::enumerate(innermostVectors)) {
+      Value extractMask = isMasked ? masks[idx] : nullptr;
       result = makeArithReduction(rewriter, loc, multiReductionOp.getKind(),
                                   innermostVector, result, /*fastmath=*/nullptr,
                                   extractMask);
     }
-    rewriter.replaceOp(multiReductionOp, result);
+    rewriter.replaceOp(rootOp, result);
     return success();
   }
 };
diff --git a/mlir/test/Dialect/Vector/vector-multi-dim-multi-reduction-to-elementwise.mlir b/mlir/test/Dialect/Vector/vector-multi-dim-multi-reduction-to-elementwise.mlir
index 923cb7b0e4abe..33598c8709c50 100644
--- a/mlir/test/Dialect/Vector/vector-multi-dim-multi-reduction-to-elementwise.mlir
+++ b/mlir/test/Dialect/Vector/vector-multi-dim-multi-reduction-to-elementwise.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -transform-preload-library='transform-library-paths=%p/td/multi-dim-multi-reduction-elementwise.mlir' \
+// RUN: mlir-opt --split-input-file %s -transform-preload-library='transform-library-paths=%p/td/multi-dim-multi-reduction-elementwise.mlir' \
 // RUN: -transform-interpreter=entry-point=multi_dim_multi_reduction_to_elementwise | FileCheck %s
 
 //===----------------------------------------------------------------------===//
@@ -6,8 +6,8 @@
 //===----------------------------------------------------------------------===//
 
 // CHECK-LABEL: func @multi_reduction_multi_dimension_elementwise_unroll(
-// CHECK: %[[ARG:.+]]: vector<2x3x5xf32>,
-// CHECK: %[[ACC:.+]]: vector<5xf32>
+// CHECK-SAME: %[[ARG:.+]]: vector<2x3x5xf32>,
+// CHECK-SAME: %[[ACC:.+]]: vector<5xf32>
 func.func @multi_reduction_multi_dimension_elementwise_unroll(%arg0: vector<2x3x5xf32>, %acc: vector<5xf32>) -> (vector<5xf32>) {
   // CHECK: %[[VEC_0_0:.+]] = vector.extract %[[ARG]][0, 0] : vector<5xf32> from vector<2x3x5xf32>
   // CHECK: %[[VEC_0_1:.+]] = vector.extract %[[ARG]][0, 1] : vector<5xf32> from vector<2x3x5xf32>
@@ -25,3 +25,46 @@ func.func @multi_reduction_multi_dimension_elementwise_unroll(%arg0: vector<2x3x
   // CHECK: return %[[RES_5]]
   return %1 : vector<5xf32>
 }
+
+// -----
+
+// CHECK-LABEL: func @multi_reduction_masked_multi_dim(
+// CHECK-SAME: %[[ARG:.+]]: vector<2x3x5xf32>,
+// CHECK-SAME: %[[MASK:.+]]: vector<2x3x5xi1>,
+// CHECK-SAME: %[[ACC:.+]]: vector<5xf32>
+func.func @multi_reduction_masked_multi_dim(%arg0: vector<2x3x5xf32>, %mask: vector<2x3x5xi1>, %acc: vector<5xf32>) -> (vector<5xf32>) {
+
+  // CHECK-DAG: %[[VEC_0_0:.+]] = vector.extract %[[ARG]][0, 0] : vector<5xf32> from vector<2x3x5xf32>
+  // CHECK-DAG: %[[VEC_0_1:.+]] = vector.extract %[[ARG]][0, 1] : vector<5xf32> from vector<2x3x5xf32>
+  // CHECK-DAG: %[[VEC_0_2:.+]] = vector.extract %[[ARG]][0, 2] : vector<5xf32> from vector<2x3x5xf32>
+  // CHECK-DAG: %[[VEC_1_0:.+]] = vector.extract %[[ARG]][1, 0] : vector<5xf32> from vector<2x3x5xf32>
+  // CHECK-DAG: %[[VEC_1_1:.+]] = vector.extract %[[ARG]][1, 1] : vector<5xf32> from vector<2x3x5xf32>
+  // CHECK-DAG: %[[VEC_1_2:.+]] = vector.extract %[[ARG]][1, 2] : vector<5xf32> from vector<2x3x5xf32>
+
+  // CHECK-DAG: %[[MASK_0_0:.+]] = vector.extract %[[MASK]][0, 0] : vector<5xi1> from vector<2x3x5xi1>
+  // CHECK-DAG: %[[MASK_0_1:.+]] = vector.extract %[[MASK]][0, 1] : vector<5xi1> from vector<2x3x5xi1>
+  // CHECK-DAG: %[[MASK_0_2:.+]] = vector.extract %[[MASK]][0, 2] : vector<5xi1> from vector<2x3x5xi1>
+  // CHECK-DAG: %[[MASK_1_0:.+]] = vector.extract %[[MASK]][1, 0] : vector<5xi1> from vector<2x3x5xi1>
+  // CHECK-DAG: %[[MASK_1_1:.+]] = vector.extract %[[MASK]][1, 1] : vector<5xi1> from vector<2x3x5xi1>
+  // CHECK-DAG: %[[MASK_1_2:.+]] = vector.extract %[[MASK]][1, 2] : vector<5xi1> from vector<2x3x5xi1>
+
+  // CHECK: %[[RES_0:.+]] = arith.addf %[[VEC_0_0]], %[[ACC]] : vector<5xf32>
+  // CHECK: %[[MASKED_RES_0:.+]] = arith.select %[[MASK_0_0]], %[[RES_0]], %[[ACC]]
+  // CHECK: %[[RES_1:.+]] = arith.addf %[[VEC_0_1]], %[[MASKED_RES_0]] : vector<5xf32>
+  // CHECK: %[[MASKED_RES_1:.+]] = arith.select %[[MASK_0_1]], %[[RES_1]], %[[MASKED_RES_0]]
+  // CHECK: %[[RES_2:.+]] = arith.addf %[[VEC_0_2]], %[[MASKED_RES_1]] : vector<5xf32>
+  // CHECK: %[[MASKED_RES_2:.+]] = arith.select %[[MASK_0_2]], %[[RES_2]], %[[MASKED_RES_1]]
+  // CHECK: %[[RES_3:.+]] = arith.addf %[[VEC_1_0]], %[[MASKED_RES_2]] : vector<5xf32>
+  // CHECK: %[[MASKED_RES_3:.+]] = arith.select %[[MASK_1_0]], %[[RES_3]], %[[MASKED_RES_2]]
+  // CHECK: %[[RES_4:.+]] = arith.addf %[[VEC_1_1]], %[[MASKED_RES_3]] : vector<5xf32>
+  // CHECK: %[[MASKED_RES_4:.+]] = arith.select %[[MASK_1_1]], %[[RES_4]], %[[MASKED_RES_3]]
+  // CHECK: %[[RES_5:.+]] = arith.addf %[[VEC_1_2]], %[[MASKED_RES_4]] : vector<5xf32>
+  // CHECK: %[[MASKED_RES_5:.+]] = arith.select %[[MASK_1_2]], %[[RES_5]], %[[MASKED_RES_4]]
+
+  %0 = vector.mask %mask {
+    %1 = vector.multi_reduction <add>, %arg0, %acc [0, 1] : vector<2x3x5xf32> to vector<5xf32>
+  } : vector<2x3x5xi1> -> vector<5xf32>
+
+  // CHECK: return %[[MASKED_RES_5]]
+  return %0 : vector<5xf32>
+}



More information about the Mlir-commits mailing list