[Mlir-commits] [mlir] [mlir][vector] Add elementwise unrolling for vector.multi_reduction (PR #176036)
Erick Ochoa Lopez
llvmlistbot at llvm.org
Thu Jan 15 13:51:17 PST 2026
https://github.com/amd-eochoalo updated https://github.com/llvm/llvm-project/pull/176036
>From 55ea1bf618d9d57b83b8afc1760a94776d54305b Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Wed, 14 Jan 2026 11:36:25 -0500
Subject: [PATCH 1/6] [mlir][vector] Add
populateVectorInnerOuterDimReductionConversionPatterns
This method will be used for testing and for unrolling
vector.multi_reduction in a rank-reducing way that does not involve
flattening.
---
.../mlir/Dialect/Vector/Transforms/LoweringPatterns.h | 3 +++
.../Vector/Transforms/LowerVectorMultiReduction.cpp | 7 +++++++
2 files changed, 10 insertions(+)
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index 7bd96c8a6d1a1..658365b97d721 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -328,6 +328,9 @@ void populateVectorContractToMatrixMultiply(RewritePatternSet &patterns,
void populateVectorTransposeToFlatTranspose(RewritePatternSet &patterns,
PatternBenefit benefit = 100);
+void populateVectorInnerOuterDimReductionConversionPatterns(
+ RewritePatternSet &patterns, VectorMultiReductionLowering options,
+ PatternBenefit benefit = 1);
} // namespace vector
} // namespace mlir
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index e86e2a97038db..b5660efc8f4cf 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -511,6 +511,13 @@ struct LowerVectorMultiReductionPass
} // namespace
+void mlir::vector::populateVectorInnerOuterDimReductionConversionPatterns(
+ RewritePatternSet &patterns, VectorMultiReductionLowering options,
+ PatternBenefit benefit) {
+ patterns.add<InnerOuterDimReductionConversion>(patterns.getContext(), options,
+ benefit);
+}
+
void mlir::vector::populateVectorMultiReductionLoweringPatterns(
RewritePatternSet &patterns, VectorMultiReductionLowering options,
PatternBenefit benefit) {
>From 933c1a2f130f038b2c7f912eef7369cb99a9b304 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Wed, 14 Jan 2026 11:46:50 -0500
Subject: [PATCH 2/6] [mlir][vector] use TD Op for testing multi_reduction
* Test vector.multi_reduction's patterns using the transform dialect.
This will be useful as we will be creating new patterns to unroll
vector.multi_reduction without flattening.
---
.../Vector/TransformOps/VectorTransformOps.td | 15 +++++++++++++++
.../TransformOps/VectorTransformOps.cpp | 8 ++++++++
.../Vector/td/inner-outer-dim-conversion.mlir | 14 ++++++++++++++
...-inner-outer-dim-reduction-conversion.mlir | 19 +++++++++++++++++++
4 files changed, 56 insertions(+)
create mode 100644 mlir/test/Dialect/Vector/td/inner-outer-dim-conversion.mlir
create mode 100644 mlir/test/Dialect/Vector/vector-inner-outer-dim-reduction-conversion.mlir
diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index 03d25505dc65c..e2423593075f6 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -539,4 +539,19 @@ def ApplySinkVectorMemPatternsOp : Op<Transform_Dialect,
let assemblyFormat = "attr-dict";
}
+def ApplyInnerOuterDimReductionConversionPatternsOp : Op<Transform_Dialect,
+ "apply_patterns.vector.inner_outer_dim_reduction_conversion",
+ [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+ let description = [{
+ Converts vector.multi_reduction into inner-most/outer-most reduction form
+ by using vector.transpose.
+ }];
+
+ let arguments = (ins DefaultValuedAttr<VectorMultiReductionLoweringAttr,
+ "vector::VectorMultiReductionLowering::InnerParallel">:$lowering_strategy
+ );
+
+ let assemblyFormat = "attr-dict";
+}
+
#endif // VECTOR_TRANSFORM_OPS
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index 7faa222a9e574..53284c5873841 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -227,6 +227,14 @@ void transform::ApplySinkVectorMemPatternsOp::populatePatterns(
vector::populateSinkVectorMemOpsPatterns(patterns);
}
+void transform::ApplyInnerOuterDimReductionConversionPatternsOp::
+ populatePatterns(RewritePatternSet &patterns) {
+ vector::VectorTransformsOptions vectorTransformOptions;
+ vectorTransformOptions.setVectorMultiReductionLowering(getLoweringStrategy());
+ vector::populateVectorInnerOuterDimReductionConversionPatterns(
+ patterns, vectorTransformOptions.vectorMultiReductionLowering);
+}
+
//===----------------------------------------------------------------------===//
// Transform op registration
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/td/inner-outer-dim-conversion.mlir b/mlir/test/Dialect/Vector/td/inner-outer-dim-conversion.mlir
new file mode 100644
index 0000000000000..e965e414ddeae
--- /dev/null
+++ b/mlir/test/Dialect/Vector/td/inner-outer-dim-conversion.mlir
@@ -0,0 +1,14 @@
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @inner_outer_dim_reduction_conversion(%module_op: !transform.any_op {transform.readonly}) {
+
+ %func_op = transform.structured.match ops{["func.func"]} in %module_op
+ : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func_op {
+ // Test patterns
+ transform.apply_patterns.vector.inner_outer_dim_reduction_conversion
+ } : !transform.any_op
+
+ transform.yield
+ }
+}
+
diff --git a/mlir/test/Dialect/Vector/vector-inner-outer-dim-reduction-conversion.mlir b/mlir/test/Dialect/Vector/vector-inner-outer-dim-reduction-conversion.mlir
new file mode 100644
index 0000000000000..bc99e64e2f909
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-inner-outer-dim-reduction-conversion.mlir
@@ -0,0 +1,19 @@
+// RUN: mlir-opt %s -transform-preload-library='transform-library-paths=%p/td/inner-outer-dim-conversion.mlir' \
+// RUN: -transform-interpreter=entry-point=inner_outer_dim_reduction_conversion | FileCheck %s
+
+//===----------------------------------------------------------------------===//
+// Test InnerOuterDimReductionConversion
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: func @inner_outer_dim_reduction_conversion(
+// CHECK-SAME: %[[ARG0:.+]]: vector<2x3x5x7xf32>,
+// CHECK-SAME: %[[ACC:.+]]: vector<3x5xf32>
+func.func @inner_outer_dim_reduction_conversion(%arg0: vector<2x3x5x7xf32>, %acc: vector<3x5xf32>) -> (vector<3x5xf32>) {
+ // CHECK: %[[TRANSPOSE:.+]] = vector.transpose %[[ARG0]], [0, 3, 1, 2] : vector<2x3x5x7xf32> to vector<2x7x3x5xf32>
+ // CHECK: %[[RES:.+]] = vector.multi_reduction <add>, %[[TRANSPOSE]], %[[ACC]] [0, 1] : vector<2x7x3x5xf32> to vector<3x5xf32>
+ %1 = vector.multi_reduction <add>, %arg0, %acc [0, 3] : vector<2x3x5x7xf32> to vector<3x5xf32>
+
+ // CHECK: return %[[RES]]
+ return %1 : vector<3x5xf32>
+}
+
>From 29e4f621c127f87400cba3d1f7076eeb63dd56a4 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Wed, 14 Jan 2026 14:53:25 -0500
Subject: [PATCH 3/6] [mlir][vector] Add unrolling pattern for
vector.multi_reduction
---
.../Vector/Transforms/LoweringPatterns.h | 4 +
.../Transforms/LowerVectorMultiReduction.cpp | 92 +++++++++++++++++++
...-inner-outer-dim-reduction-conversion.mlir | 1 -
3 files changed, 96 insertions(+), 1 deletion(-)
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index 658365b97d721..f1a57cbd43360 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -331,6 +331,10 @@ void populateVectorTransposeToFlatTranspose(RewritePatternSet &patterns,
void populateVectorInnerOuterDimReductionConversionPatterns(
RewritePatternSet &patterns, VectorMultiReductionLowering options,
PatternBenefit benefit = 1);
+
+void populateVectorMultiDimMultiReductionToElementWisePatterns(
+ RewritePatternSet &patterns, PatternBenefit benefit = 1);
+
} // namespace vector
} // namespace mlir
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index b5660efc8f4cf..4644f8973c481 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -300,6 +300,93 @@ class ReduceMultiDimReductionRank
const bool useInnerDimsForReduction;
};
+/// Unrolls vector.multi_reduction with outermost reductions
+/// and combines results
+struct MultiDimMultiReductionToElementWise
+ : public OpRewritePattern<vector::MultiDimReductionOp> {
+ using Base::Base;
+
+ LogicalResult match(const ArrayRef<int64_t> reductionDims,
+ uint64_t rank) const {
+ // ["parallel"*, "reduce"+] or bail.
+ // That means we should match the following template:
+ //
+ // ```mlir
+ // %0 = vector.multi_reduction <add>, %source, %acc [0, ..., N] :
+ // vector<z... x o x n x...x b x a xT> to vector<m...xbxaxT>
+ // ```
+ //
+ // This means that we can compare the set of integers from 0 to N
+ // with the reduction dimensions to see if the operation should match.
+ std::set<int64_t> reductionDimsSet(std::begin(reductionDims),
+ std::end(reductionDims));
+
+ SmallVector<int64_t> expectedReductionDims(reductionDims.size());
+ std::iota(std::begin(expectedReductionDims),
+ std::end(expectedReductionDims), 0);
+
+ std::set<int64_t> expectedReductionDimsSet(
+ std::begin(expectedReductionDims), std::end(expectedReductionDims));
+ bool equal = reductionDimsSet == expectedReductionDimsSet;
+ return equal ? success() : failure();
+ }
+
+ LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
+ PatternRewriter &rewriter) const override {
+
+ Type elementType = getElementTypeOrSelf(multiReductionOp.getDestType());
+ if (!elementType.isIntOrIndexOrFloat())
+ return rewriter.notifyMatchFailure(
+ multiReductionOp, "expected element type to be integer or float.");
+
+ ArrayRef<int64_t> reductionDims = multiReductionOp.getReductionDims();
+ uint64_t srcRank = multiReductionOp.getSourceVectorType().getRank();
+ if (failed(match(reductionDims, srcRank)))
+ return rewriter.notifyMatchFailure(
+ multiReductionOp, "expected outermost dimensions to be reduced.");
+
+ // ```mlir
+ // %0 = vector.multi_reduction <add>, %source, %acc [0, ..., N] :
+ // vector<z... x o x n x...x b x a xT> to vector<m...xbxaxT>
+ // ```
+ // to
+ // ```mlir
+ // %tmp0 = vector.extract %source[...] : vector<m...xbxaxT> from
+ // vector<zx...axT>
+ // ```
+ Value source = multiReductionOp.getSource();
+
+ // srcShape = zx...xa
+ // targetShape = mx...xa
+ ArrayRef<int64_t> srcShape =
+ multiReductionOp.getSourceVectorType().getShape();
+ ArrayRef<int64_t> targetShape = srcShape.drop_front(reductionDims.size());
+
+ Location loc = multiReductionOp.getLoc();
+ SmallVector<Value> innermostVectors;
+ for (SmallVector<int64_t> offsets :
+ StaticTileOffsetRange(srcShape, targetShape)) {
+ // Offsets are valid for z...n.
+ // But trailing zeros that correspond to dimensions m..a must be stripped.
+ int64_t toDrop = srcRank - reductionDims.size();
+ ArrayRef<int64_t> validOffsets =
+ ArrayRef<int64_t>(offsets).drop_back(toDrop);
+ innermostVectors.push_back(
+ vector::ExtractOp::create(rewriter, loc, source, validOffsets));
+ }
+
+ Value result = multiReductionOp.getAcc();
+ for (Value innermostVector : innermostVectors) {
+ Value extractMask = nullptr;
+ result = makeArithReduction(rewriter, loc, multiReductionOp.getKind(),
+ innermostVector, result, /*fastmath=*/nullptr,
+ extractMask);
+ }
+ rewriter.replaceOp(multiReductionOp, result);
+ return success();
+ }
+};
+
/// Unrolls vector.multi_reduction with outermost reductions
/// and combines results
struct TwoDimMultiReductionToElementWise
@@ -518,6 +605,11 @@ void mlir::vector::populateVectorInnerOuterDimReductionConversionPatterns(
benefit);
}
+void mlir::vector::populateVectorMultiDimMultiReductionToElementWisePatterns(
+ RewritePatternSet &patterns, PatternBenefit benefit) {
+ patterns.add<MultiDimMultiReductionToElementWise>(patterns.getContext());
+}
+
void mlir::vector::populateVectorMultiReductionLoweringPatterns(
RewritePatternSet &patterns, VectorMultiReductionLowering options,
PatternBenefit benefit) {
diff --git a/mlir/test/Dialect/Vector/vector-inner-outer-dim-reduction-conversion.mlir b/mlir/test/Dialect/Vector/vector-inner-outer-dim-reduction-conversion.mlir
index bc99e64e2f909..9c508eb3d53c0 100644
--- a/mlir/test/Dialect/Vector/vector-inner-outer-dim-reduction-conversion.mlir
+++ b/mlir/test/Dialect/Vector/vector-inner-outer-dim-reduction-conversion.mlir
@@ -16,4 +16,3 @@ func.func @inner_outer_dim_reduction_conversion(%arg0: vector<2x3x5x7xf32>, %acc
// CHECK: return %[[RES]]
return %1 : vector<3x5xf32>
}
-
>From 20bfad28524cf9701a3685a0e7a16f4c7c17ebef Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Wed, 14 Jan 2026 15:14:16 -0500
Subject: [PATCH 4/6] [mlir][vector] Test for multi-dim multi_reduction.
---
.../Vector/TransformOps/VectorTransformOps.td | 11 ++++++++
.../TransformOps/VectorTransformOps.cpp | 5 ++++
...multi-dim-multi-reduction-elementwise.mlir | 14 ++++++++++
...ti-dim-multi-reduction-to-elementwise.mlir | 27 +++++++++++++++++++
4 files changed, 57 insertions(+)
create mode 100644 mlir/test/Dialect/Vector/td/multi-dim-multi-reduction-elementwise.mlir
create mode 100644 mlir/test/Dialect/Vector/vector-multi-dim-multi-reduction-to-elementwise.mlir
diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index e2423593075f6..946638804b154 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -554,4 +554,15 @@ def ApplyInnerOuterDimReductionConversionPatternsOp : Op<Transform_Dialect,
let assemblyFormat = "attr-dict";
}
+def ApplyMultiDimMultiReductionToElementWisePatternsOp : Op<Transform_Dialect,
+ "apply_patterns.vector.multi_dim_multi_reduction_to_elementwise",
+ [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+ let description = [{
+ Unrolls vector.multi_reduction with outer-most reduction to elementwise
+ operations.
+ }];
+
+ let assemblyFormat = "attr-dict";
+}
+
#endif // VECTOR_TRANSFORM_OPS
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index 53284c5873841..b8856fc199bbd 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -235,6 +235,11 @@ void transform::ApplyInnerOuterDimReductionConversionPatternsOp::
patterns, vectorTransformOptions.vectorMultiReductionLowering);
}
+void transform::ApplyMultiDimMultiReductionToElementWisePatternsOp::
+ populatePatterns(RewritePatternSet &patterns) {
+ vector::populateVectorMultiDimMultiReductionToElementWisePatterns(patterns);
+}
+
//===----------------------------------------------------------------------===//
// Transform op registration
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/td/multi-dim-multi-reduction-elementwise.mlir b/mlir/test/Dialect/Vector/td/multi-dim-multi-reduction-elementwise.mlir
new file mode 100644
index 0000000000000..1d6c9768ce4ee
--- /dev/null
+++ b/mlir/test/Dialect/Vector/td/multi-dim-multi-reduction-elementwise.mlir
@@ -0,0 +1,14 @@
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @multi_dim_multi_reduction_to_elementwise(%module_op: !transform.any_op {transform.readonly}) {
+
+ %func_op = transform.structured.match ops{["func.func"]} in %module_op
+ : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func_op {
+ // Test patterns
+ transform.apply_patterns.vector.multi_dim_multi_reduction_to_elementwise
+ } : !transform.any_op
+
+ transform.yield
+ }
+}
+
diff --git a/mlir/test/Dialect/Vector/vector-multi-dim-multi-reduction-to-elementwise.mlir b/mlir/test/Dialect/Vector/vector-multi-dim-multi-reduction-to-elementwise.mlir
new file mode 100644
index 0000000000000..923cb7b0e4abe
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-multi-dim-multi-reduction-to-elementwise.mlir
@@ -0,0 +1,27 @@
+// RUN: mlir-opt %s -transform-preload-library='transform-library-paths=%p/td/multi-dim-multi-reduction-elementwise.mlir' \
+// RUN: -transform-interpreter=entry-point=multi_dim_multi_reduction_to_elementwise | FileCheck %s
+
+//===----------------------------------------------------------------------===//
+// Test MultiDimMultiReductionToElementWise
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: func @multi_reduction_multi_dimension_elementwise_unroll(
+// CHECK: %[[ARG:.+]]: vector<2x3x5xf32>,
+// CHECK: %[[ACC:.+]]: vector<5xf32>
+func.func @multi_reduction_multi_dimension_elementwise_unroll(%arg0: vector<2x3x5xf32>, %acc: vector<5xf32>) -> (vector<5xf32>) {
+ // CHECK: %[[VEC_0_0:.+]] = vector.extract %[[ARG]][0, 0] : vector<5xf32> from vector<2x3x5xf32>
+ // CHECK: %[[VEC_0_1:.+]] = vector.extract %[[ARG]][0, 1] : vector<5xf32> from vector<2x3x5xf32>
+ // CHECK: %[[VEC_0_2:.+]] = vector.extract %[[ARG]][0, 2] : vector<5xf32> from vector<2x3x5xf32>
+ // CHECK: %[[VEC_1_0:.+]] = vector.extract %[[ARG]][1, 0] : vector<5xf32> from vector<2x3x5xf32>
+ // CHECK: %[[VEC_1_1:.+]] = vector.extract %[[ARG]][1, 1] : vector<5xf32> from vector<2x3x5xf32>
+ // CHECK: %[[VEC_1_2:.+]] = vector.extract %[[ARG]][1, 2] : vector<5xf32> from vector<2x3x5xf32>
+ // CHECK: %[[RES_0:.+]] = arith.addf %[[VEC_0_0]], %[[ACC]] : vector<5xf32>
+ // CHECK: %[[RES_1:.+]] = arith.addf %[[VEC_0_1]], %[[RES_0]] : vector<5xf32>
+ // CHECK: %[[RES_2:.+]] = arith.addf %[[VEC_0_2]], %[[RES_1]] : vector<5xf32>
+ // CHECK: %[[RES_3:.+]] = arith.addf %[[VEC_1_0]], %[[RES_2]] : vector<5xf32>
+ // CHECK: %[[RES_4:.+]] = arith.addf %[[VEC_1_1]], %[[RES_3]] : vector<5xf32>
+ // CHECK: %[[RES_5:.+]] = arith.addf %[[VEC_1_2]], %[[RES_4]] : vector<5xf32>
+ %1 = vector.multi_reduction <add>, %arg0, %acc [0, 1] : vector<2x3x5xf32> to vector<5xf32>
+ // CHECK: return %[[RES_5]]
+ return %1 : vector<5xf32>
+}
>From 5948682205ba71e34da42a90e47acc65247235f5 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Wed, 14 Jan 2026 16:04:03 -0500
Subject: [PATCH 5/6] Add support for masks
---
.../Transforms/LowerVectorMultiReduction.cpp | 26 ++++++++--
...ti-dim-multi-reduction-to-elementwise.mlir | 49 +++++++++++++++++--
2 files changed, 69 insertions(+), 6 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index 4644f8973c481..61f1b5bd7f35f 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -363,7 +363,23 @@ struct MultiDimMultiReductionToElementWise
ArrayRef<int64_t> targetShape = srcShape.drop_front(reductionDims.size());
Location loc = multiReductionOp.getLoc();
+ OpBuilder::InsertionGuard guard(rewriter);
+ auto maskableOp =
+ cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
+
+ Value mask;
+ Operation *rootOp;
+ bool isMasked = maskableOp.isMasked();
+ if (isMasked) {
+ rewriter.setInsertionPoint(maskableOp.getMaskingOp());
+ rootOp = maskableOp.getMaskingOp();
+ mask = maskableOp.getMaskingOp().getMask();
+ } else {
+ rootOp = multiReductionOp;
+ }
+
SmallVector<Value> innermostVectors;
+ SmallVector<Value> masks;
for (SmallVector<int64_t> offsets :
StaticTileOffsetRange(srcShape, targetShape)) {
// Offsets are valid for z...n.
@@ -373,16 +389,20 @@ struct MultiDimMultiReductionToElementWise
ArrayRef<int64_t>(offsets).drop_back(toDrop);
innermostVectors.push_back(
vector::ExtractOp::create(rewriter, loc, source, validOffsets));
+
+ if (isMasked)
+ masks.push_back(
+ vector::ExtractOp::create(rewriter, loc, mask, validOffsets));
}
Value result = multiReductionOp.getAcc();
- for (Value innermostVector : innermostVectors) {
- Value extractMask = nullptr;
+ for (auto [idx, innermostVector] : llvm::enumerate(innermostVectors)) {
+ Value extractMask = isMasked ? masks[idx] : nullptr;
result = makeArithReduction(rewriter, loc, multiReductionOp.getKind(),
innermostVector, result, /*fastmath=*/nullptr,
extractMask);
}
- rewriter.replaceOp(multiReductionOp, result);
+ rewriter.replaceOp(rootOp, result);
return success();
}
};
diff --git a/mlir/test/Dialect/Vector/vector-multi-dim-multi-reduction-to-elementwise.mlir b/mlir/test/Dialect/Vector/vector-multi-dim-multi-reduction-to-elementwise.mlir
index 923cb7b0e4abe..33598c8709c50 100644
--- a/mlir/test/Dialect/Vector/vector-multi-dim-multi-reduction-to-elementwise.mlir
+++ b/mlir/test/Dialect/Vector/vector-multi-dim-multi-reduction-to-elementwise.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -transform-preload-library='transform-library-paths=%p/td/multi-dim-multi-reduction-elementwise.mlir' \
+// RUN: mlir-opt --split-input-file %s -transform-preload-library='transform-library-paths=%p/td/multi-dim-multi-reduction-elementwise.mlir' \
// RUN: -transform-interpreter=entry-point=multi_dim_multi_reduction_to_elementwise | FileCheck %s
//===----------------------------------------------------------------------===//
@@ -6,8 +6,8 @@
//===----------------------------------------------------------------------===//
// CHECK-LABEL: func @multi_reduction_multi_dimension_elementwise_unroll(
-// CHECK: %[[ARG:.+]]: vector<2x3x5xf32>,
-// CHECK: %[[ACC:.+]]: vector<5xf32>
+// CHECK-SAME: %[[ARG:.+]]: vector<2x3x5xf32>,
+// CHECK-SAME: %[[ACC:.+]]: vector<5xf32>
func.func @multi_reduction_multi_dimension_elementwise_unroll(%arg0: vector<2x3x5xf32>, %acc: vector<5xf32>) -> (vector<5xf32>) {
// CHECK: %[[VEC_0_0:.+]] = vector.extract %[[ARG]][0, 0] : vector<5xf32> from vector<2x3x5xf32>
// CHECK: %[[VEC_0_1:.+]] = vector.extract %[[ARG]][0, 1] : vector<5xf32> from vector<2x3x5xf32>
@@ -25,3 +25,46 @@ func.func @multi_reduction_multi_dimension_elementwise_unroll(%arg0: vector<2x3x
// CHECK: return %[[RES_5]]
return %1 : vector<5xf32>
}
+
+// -----
+
+// CHECK-LABEL: func @multi_reduction_masked_multi_dim(
+// CHECK-SAME: %[[ARG:.+]]: vector<2x3x5xf32>,
+// CHECK-SAME: %[[MASK:.+]]: vector<2x3x5xi1>,
+// CHECK-SAME: %[[ACC:.+]]: vector<5xf32>
+func.func @multi_reduction_masked_multi_dim(%arg0: vector<2x3x5xf32>, %mask: vector<2x3x5xi1>, %acc: vector<5xf32>) -> (vector<5xf32>) {
+
+ // CHECK-DAG: %[[VEC_0_0:.+]] = vector.extract %[[ARG]][0, 0] : vector<5xf32> from vector<2x3x5xf32>
+ // CHECK-DAG: %[[VEC_0_1:.+]] = vector.extract %[[ARG]][0, 1] : vector<5xf32> from vector<2x3x5xf32>
+ // CHECK-DAG: %[[VEC_0_2:.+]] = vector.extract %[[ARG]][0, 2] : vector<5xf32> from vector<2x3x5xf32>
+ // CHECK-DAG: %[[VEC_1_0:.+]] = vector.extract %[[ARG]][1, 0] : vector<5xf32> from vector<2x3x5xf32>
+ // CHECK-DAG: %[[VEC_1_1:.+]] = vector.extract %[[ARG]][1, 1] : vector<5xf32> from vector<2x3x5xf32>
+ // CHECK-DAG: %[[VEC_1_2:.+]] = vector.extract %[[ARG]][1, 2] : vector<5xf32> from vector<2x3x5xf32>
+
+ // CHECK-DAG: %[[MASK_0_0:.+]] = vector.extract %[[MASK]][0, 0] : vector<5xi1> from vector<2x3x5xi1>
+ // CHECK-DAG: %[[MASK_0_1:.+]] = vector.extract %[[MASK]][0, 1] : vector<5xi1> from vector<2x3x5xi1>
+ // CHECK-DAG: %[[MASK_0_2:.+]] = vector.extract %[[MASK]][0, 2] : vector<5xi1> from vector<2x3x5xi1>
+ // CHECK-DAG: %[[MASK_1_0:.+]] = vector.extract %[[MASK]][1, 0] : vector<5xi1> from vector<2x3x5xi1>
+ // CHECK-DAG: %[[MASK_1_1:.+]] = vector.extract %[[MASK]][1, 1] : vector<5xi1> from vector<2x3x5xi1>
+ // CHECK-DAG: %[[MASK_1_2:.+]] = vector.extract %[[MASK]][1, 2] : vector<5xi1> from vector<2x3x5xi1>
+
+ // CHECK: %[[RES_0:.+]] = arith.addf %[[VEC_0_0]], %[[ACC]] : vector<5xf32>
+ // CHECK: %[[MASKED_RES_0:.+]] = arith.select %[[MASK_0_0]], %[[RES_0]], %[[ACC]]
+ // CHECK: %[[RES_1:.+]] = arith.addf %[[VEC_0_1]], %[[MASKED_RES_0]] : vector<5xf32>
+ // CHECK: %[[MASKED_RES_1:.+]] = arith.select %[[MASK_0_1]], %[[RES_1]], %[[MASKED_RES_0]]
+ // CHECK: %[[RES_2:.+]] = arith.addf %[[VEC_0_2]], %[[MASKED_RES_1]] : vector<5xf32>
+ // CHECK: %[[MASKED_RES_2:.+]] = arith.select %[[MASK_0_2]], %[[RES_2]], %[[MASKED_RES_1]]
+ // CHECK: %[[RES_3:.+]] = arith.addf %[[VEC_1_0]], %[[MASKED_RES_2]] : vector<5xf32>
+ // CHECK: %[[MASKED_RES_3:.+]] = arith.select %[[MASK_1_0]], %[[RES_3]], %[[MASKED_RES_2]]
+ // CHECK: %[[RES_4:.+]] = arith.addf %[[VEC_1_1]], %[[MASKED_RES_3]] : vector<5xf32>
+ // CHECK: %[[MASKED_RES_4:.+]] = arith.select %[[MASK_1_1]], %[[RES_4]], %[[MASKED_RES_3]]
+ // CHECK: %[[RES_5:.+]] = arith.addf %[[VEC_1_2]], %[[MASKED_RES_4]] : vector<5xf32>
+ // CHECK: %[[MASKED_RES_5:.+]] = arith.select %[[MASK_1_2]], %[[RES_5]], %[[MASKED_RES_4]]
+
+ %0 = vector.mask %mask {
+ %1 = vector.multi_reduction <add>, %arg0, %acc [0, 1] : vector<2x3x5xf32> to vector<5xf32>
+ } : vector<2x3x5xi1> -> vector<5xf32>
+
+ // CHECK: return %[[MASKED_RES_5]]
+ return %0 : vector<5xf32>
+}
>From f7551e478d921bb908f355da964786da7aeb31b2 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Wed, 14 Jan 2026 16:16:03 -0500
Subject: [PATCH 6/6] Remove unnecessary comment
---
.../Vector/Transforms/LowerVectorMultiReduction.cpp | 9 ---------
1 file changed, 9 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index 61f1b5bd7f35f..a43838cce4398 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -345,15 +345,6 @@ struct MultiDimMultiReductionToElementWise
return rewriter.notifyMatchFailure(
multiReductionOp, "expected outermost dimensions to be reduced.");
- // ```mlir
- // %0 = vector.multi_reduction <add>, %source, %acc [0, ..., N] :
- // vector<z... x o x n x...x b x a xT> to vector<m...xbxaxT>
- // ```
- // to
- // ```mlir
- // %tmp0 = vector.extract %source[...] : vector<m...xbxaxT> from
- // vector<zx...axT>
- // ```
Value source = multiReductionOp.getSource();
// srcShape = zx...xa
More information about the Mlir-commits
mailing list