[Mlir-commits] [mlir] [mlir][vector] Add elementwise unrolling for vector.multi_reduction (PR #176036)
Erick Ochoa Lopez
llvmlistbot at llvm.org
Thu Jan 15 14:08:47 PST 2026
https://github.com/amd-eochoalo updated https://github.com/llvm/llvm-project/pull/176036
>From 64119f03a8233db5345c9f355fc50d2e91ac401a Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Wed, 14 Jan 2026 11:36:25 -0500
Subject: [PATCH 1/3] [mlir][vector] Add
populateVectorInnerOuterDimReductionConversionPatterns
This method will be used for testing and for unrolling
vector.multi_reduction in a rank-reducing way that does not involve
flattening.
---
.../mlir/Dialect/Vector/Transforms/LoweringPatterns.h | 3 +++
.../Vector/Transforms/LowerVectorMultiReduction.cpp | 7 +++++++
2 files changed, 10 insertions(+)
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index 7bd96c8a6d1a1..658365b97d721 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -328,6 +328,9 @@ void populateVectorContractToMatrixMultiply(RewritePatternSet &patterns,
void populateVectorTransposeToFlatTranspose(RewritePatternSet &patterns,
PatternBenefit benefit = 100);
+void populateVectorInnerOuterDimReductionConversionPatterns(
+ RewritePatternSet &patterns, VectorMultiReductionLowering options,
+ PatternBenefit benefit = 1);
} // namespace vector
} // namespace mlir
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index e86e2a97038db..b5660efc8f4cf 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -511,6 +511,13 @@ struct LowerVectorMultiReductionPass
} // namespace
+void mlir::vector::populateVectorInnerOuterDimReductionConversionPatterns(
+ RewritePatternSet &patterns, VectorMultiReductionLowering options,
+ PatternBenefit benefit) {
+ patterns.add<InnerOuterDimReductionConversion>(patterns.getContext(), options,
+ benefit);
+}
+
void mlir::vector::populateVectorMultiReductionLoweringPatterns(
RewritePatternSet &patterns, VectorMultiReductionLowering options,
PatternBenefit benefit) {
>From faa27de4bb0ec4d9d66b97e4f83492f5e72912ce Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Wed, 14 Jan 2026 11:46:50 -0500
Subject: [PATCH 2/3] [mlir][vector] use TD Op for testing multi_reduction
* Test vector.multi_reduction's patterns using the transform dialect.
This will be useful as we will be creating new patterns to unroll
vector.multi_reduction without flattening.
---
.../Vector/TransformOps/VectorTransformOps.td | 15 +++++++++++++++
.../TransformOps/VectorTransformOps.cpp | 8 ++++++++
.../Vector/td/inner-outer-dim-conversion.mlir | 14 ++++++++++++++
...-inner-outer-dim-reduction-conversion.mlir | 19 +++++++++++++++++++
4 files changed, 56 insertions(+)
create mode 100644 mlir/test/Dialect/Vector/td/inner-outer-dim-conversion.mlir
create mode 100644 mlir/test/Dialect/Vector/vector-inner-outer-dim-reduction-conversion.mlir
diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index 03d25505dc65c..e2423593075f6 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -539,4 +539,19 @@ def ApplySinkVectorMemPatternsOp : Op<Transform_Dialect,
let assemblyFormat = "attr-dict";
}
+def ApplyInnerOuterDimReductionConversionPatternsOp : Op<Transform_Dialect,
+ "apply_patterns.vector.inner_outer_dim_reduction_conversion",
+ [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+ let description = [{
+ Converts vector.multi_reduction into inner-most/outer-most reduction form
+ by using vector.transpose.
+ }];
+
+ let arguments = (ins DefaultValuedAttr<VectorMultiReductionLoweringAttr,
+ "vector::VectorMultiReductionLowering::InnerParallel">:$lowering_strategy
+ );
+
+ let assemblyFormat = "attr-dict";
+}
+
#endif // VECTOR_TRANSFORM_OPS
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index 7faa222a9e574..53284c5873841 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -227,6 +227,14 @@ void transform::ApplySinkVectorMemPatternsOp::populatePatterns(
vector::populateSinkVectorMemOpsPatterns(patterns);
}
+void transform::ApplyInnerOuterDimReductionConversionPatternsOp::
+ populatePatterns(RewritePatternSet &patterns) {
+ vector::VectorTransformsOptions vectorTransformOptions;
+ vectorTransformOptions.setVectorMultiReductionLowering(getLoweringStrategy());
+ vector::populateVectorInnerOuterDimReductionConversionPatterns(
+ patterns, vectorTransformOptions.vectorMultiReductionLowering);
+}
+
//===----------------------------------------------------------------------===//
// Transform op registration
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/td/inner-outer-dim-conversion.mlir b/mlir/test/Dialect/Vector/td/inner-outer-dim-conversion.mlir
new file mode 100644
index 0000000000000..e965e414ddeae
--- /dev/null
+++ b/mlir/test/Dialect/Vector/td/inner-outer-dim-conversion.mlir
@@ -0,0 +1,14 @@
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @inner_outer_dim_reduction_conversion(%module_op: !transform.any_op {transform.readonly}) {
+
+ %func_op = transform.structured.match ops{["func.func"]} in %module_op
+ : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func_op {
+ // Test patterns
+ transform.apply_patterns.vector.inner_outer_dim_reduction_conversion
+ } : !transform.any_op
+
+ transform.yield
+ }
+}
+
diff --git a/mlir/test/Dialect/Vector/vector-inner-outer-dim-reduction-conversion.mlir b/mlir/test/Dialect/Vector/vector-inner-outer-dim-reduction-conversion.mlir
new file mode 100644
index 0000000000000..bc99e64e2f909
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-inner-outer-dim-reduction-conversion.mlir
@@ -0,0 +1,19 @@
+// RUN: mlir-opt %s -transform-preload-library='transform-library-paths=%p/td/inner-outer-dim-conversion.mlir' \
+// RUN: -transform-interpreter=entry-point=inner_outer_dim_reduction_conversion | FileCheck %s
+
+//===----------------------------------------------------------------------===//
+// Test InnerOuterDimReductionConversion
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: func @inner_outer_dim_reduction_conversion(
+// CHECK-SAME: %[[ARG0:.+]]: vector<2x3x5x7xf32>,
+// CHECK-SAME: %[[ACC:.+]]: vector<3x5xf32>
+func.func @inner_outer_dim_reduction_conversion(%arg0: vector<2x3x5x7xf32>, %acc: vector<3x5xf32>) -> (vector<3x5xf32>) {
+ // CHECK: %[[TRANSPOSE:.+]] = vector.transpose %[[ARG0]], [0, 3, 1, 2] : vector<2x3x5x7xf32> to vector<2x7x3x5xf32>
+ // CHECK: %[[RES:.+]] = vector.multi_reduction <add>, %[[TRANSPOSE]], %[[ACC]] [0, 1] : vector<2x7x3x5xf32> to vector<3x5xf32>
+ %1 = vector.multi_reduction <add>, %arg0, %acc [0, 3] : vector<2x3x5x7xf32> to vector<3x5xf32>
+
+ // CHECK: return %[[RES]]
+ return %1 : vector<3x5xf32>
+}
+
>From 8ccd08384ba16af8b8d0de02b2e31e8a7249fb78 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Wed, 14 Jan 2026 14:53:25 -0500
Subject: [PATCH 3/3] [mlir][vector] Add unrolling pattern for
vector.multi_reduction
---
.../Vector/TransformOps/VectorTransformOps.td | 10 +
.../Vector/Transforms/LoweringPatterns.h | 4 +
.../TransformOps/VectorTransformOps.cpp | 5 +
.../Transforms/LowerVectorMultiReduction.cpp | 194 ++++++++++++++++++
.../Vector/td/unroll-multi-reduction.mlir | 14 ++
.../Vector/unroll-vector-multi-reduction.mlir | 92 +++++++++
...-inner-outer-dim-reduction-conversion.mlir | 1 -
7 files changed, 319 insertions(+), 1 deletion(-)
create mode 100644 mlir/test/Dialect/Vector/td/unroll-multi-reduction.mlir
create mode 100644 mlir/test/Dialect/Vector/unroll-vector-multi-reduction.mlir
diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index e2423593075f6..52205f7b1f81c 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -554,4 +554,14 @@ def ApplyInnerOuterDimReductionConversionPatternsOp : Op<Transform_Dialect,
let assemblyFormat = "attr-dict";
}
+def ApplyUnrollMultiReductionPatternsOp : Op<Transform_Dialect,
+ "apply_patterns.vector.unroll_multi_reduction",
+ [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+ let description = [{
+ Unrolls vector.multi_reduction.
+ }];
+
+ let assemblyFormat = "attr-dict";
+}
+
#endif // VECTOR_TRANSFORM_OPS
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index 658365b97d721..f37e45c09ebd1 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -331,6 +331,10 @@ void populateVectorTransposeToFlatTranspose(RewritePatternSet &patterns,
void populateVectorInnerOuterDimReductionConversionPatterns(
RewritePatternSet &patterns, VectorMultiReductionLowering options,
PatternBenefit benefit = 1);
+
+void populateVectorUnrollMultiReduction(RewritePatternSet &patterns,
+ PatternBenefit benefit = 1);
+
} // namespace vector
} // namespace mlir
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index 53284c5873841..337f18c7d68c9 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -235,6 +235,11 @@ void transform::ApplyInnerOuterDimReductionConversionPatternsOp::
patterns, vectorTransformOptions.vectorMultiReductionLowering);
}
+void transform::ApplyUnrollMultiReductionPatternsOp::populatePatterns(
+ RewritePatternSet &patterns) {
+ vector::populateVectorUnrollMultiReduction(patterns);
+}
+
//===----------------------------------------------------------------------===//
// Transform op registration
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index b5660efc8f4cf..3b8fba0d08801 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -300,6 +300,194 @@ class ReduceMultiDimReductionRank
const bool useInnerDimsForReduction;
};
+/// Unrolls outermost dimension for vector.multi_reduction.
+/// This patterns matches operations which reduce the outermost dimension,
+/// it does not transform operations for which the outermost dimension is not
+/// a reduction dimension.
+///
+/// There are two cases to consider:
+/// 1. The base case is when the outermost dimension is the only reduction
+/// dimension.
+/// 2. The general case is when the outermost dimension is not the only
+/// reduction dimension.
+///
+/// The base case transformation:
+///
+/// ```mlir
+/// %res = vector.multi_reduction <add> %src, %acc [0] : vector<NxMx...xf32> to
+/// vector<Mx...xf32>
+/// ```
+///
+/// will extract N vectors from %src and then perform elementwise operations.
+///
+/// ```mlir
+/// %0 = vector.extract %src[0] : vector<Mx...xf32> from vector<NxMx...xf32>
+/// ...
+/// %Nminus1 = vector.extract %src[ [[N-1]] ] : vector<Mx...x.f32> from
+/// vector<NxMx...xf32>
+///
+/// %res0 = arith.addf %0, %acc : vector<Mx...xf32>
+/// ...
+/// %res = arith.addf %Nminus1, %resNminus2 : vector<Mx...xf32>
+/// ```
+///
+/// For the general case, we still extract N vectors, but produce N
+/// vector.multi_reduction instead of elementwise operations.
+///
+/// ```mlir
+/// %res = vector.multi_reduction <add> %src, %acc [0, [[REDUCTION_DIMS]] ] :
+/// vector<NxMx...xf32> to vector<Ix...xf32>
+///
+/// ```mlir
+/// %0 = vector.extract %src[0] : vector<Mx...xf32> from vector<NxMx...xf32>
+/// ...
+/// %Nminus1 = vector.extract %src[ [[N-1]] ] : vector<Mx...x.f32> from
+/// vector<NxMx...xf32>
+///
+/// %red0 = vector.multi_reduction %0, %acc [ [[REDUCTION_DIMS]] ] :
+/// vector<Mx...xf32> to vector<Ix...xf32>
+/// ...
+/// %res = vector.multi_reduction %Nminus1, %redNminus2 [ [[REDUCTION_DIMS]] ] :
+/// vector<Mx...xf32> to vector<Ix...xf32>
+/// ```
+struct UnrollMultiReductionBaseCase
+ : public OpRewritePattern<vector::MultiDimReductionOp> {
+ using Base::Base;
+
+ LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
+ PatternRewriter &rewriter) const override {
+ if (!multiReductionOp.isReducedDim(0))
+ return rewriter.notifyMatchFailure(
+ multiReductionOp, "expected outermost dimension to be reduced dimension.");
+
+ Type elementType = getElementTypeOrSelf(multiReductionOp.getDestType());
+ if (!elementType.isIntOrIndexOrFloat())
+ return rewriter.notifyMatchFailure(
+ multiReductionOp, "expected integer or float element type.");
+
+ ArrayRef<int64_t> reductionDims = multiReductionOp.getReductionDims();
+ if (reductionDims.size() > 1)
+ return rewriter.notifyMatchFailure(
+ multiReductionOp, "expected only one reduction dimension.");
+
+ Location loc = multiReductionOp.getLoc();
+ Value source = multiReductionOp.getSource();
+
+ ArrayRef<int64_t> srcShape =
+ multiReductionOp.getSourceVectorType().getShape();
+ int64_t numElementwiseOps = srcShape.front();
+
+ OpBuilder::InsertionGuard guard(rewriter);
+ auto maskableOp =
+ cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
+ bool isMasked = maskableOp.isMasked();
+ Operation *rootOp;
+ Value mask = nullptr;
+ if (isMasked) {
+ rewriter.setInsertionPoint(maskableOp.getMaskingOp());
+ rootOp = maskableOp.getMaskingOp();
+ mask = maskableOp.getMaskingOp().getMask();
+ } else {
+ rootOp = multiReductionOp;
+ }
+
+ SmallVector<Value> vectors;
+ for (int64_t i = 0; i < numElementwiseOps; ++i)
+ vectors.push_back(vector::ExtractOp::create(rewriter, loc, source, i));
+
+ SmallVector<Value> masks;
+ for (int64_t i = 0; i < numElementwiseOps; ++i)
+ if (isMasked)
+ masks.push_back(vector::ExtractOp::create(rewriter, loc, mask, i));
+ else
+ masks.push_back(nullptr);
+
+ Value result = multiReductionOp.getAcc();
+ for (auto [innerVector, innerMask] : llvm::zip(vectors, masks))
+ result = makeArithReduction(rewriter, loc, multiReductionOp.getKind(),
+ innerVector, result, /*fastmath=*/nullptr,
+ innerMask);
+
+ rewriter.replaceOp(rootOp, result);
+ return success();
+ }
+};
+
+struct UnrollMultiReductionGeneralCase
+ : public OpRewritePattern<vector::MultiDimReductionOp> {
+ using Base::Base;
+
+ LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
+ PatternRewriter &rewriter) const override {
+ if (!multiReductionOp.isReducedDim(0))
+ return rewriter.notifyMatchFailure(
+ multiReductionOp,
+ "expected outermost dimension to be reduced dimension.");
+
+ Type elementType = getElementTypeOrSelf(multiReductionOp.getDestType());
+ if (!elementType.isIntOrIndexOrFloat())
+ return rewriter.notifyMatchFailure(
+ multiReductionOp, "expected integer or float element type.");
+
+ ArrayRef<int64_t> reductionDims = multiReductionOp.getReductionDims();
+ if (reductionDims.size() <= 1)
+ return rewriter.notifyMatchFailure(
+ multiReductionOp, "expected more than one reduction dimension.");
+
+ Location loc = multiReductionOp.getLoc();
+ Value source = multiReductionOp.getSource();
+
+ ArrayRef<int64_t> srcShape =
+ multiReductionOp.getSourceVectorType().getShape();
+ int64_t numElementwiseOps = srcShape.front();
+
+ OpBuilder::InsertionGuard guard(rewriter);
+ auto maskableOp =
+ cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
+ bool isMasked = maskableOp.isMasked();
+ Operation *rootOp;
+ Value mask = nullptr;
+ if (isMasked) {
+ rewriter.setInsertionPoint(maskableOp.getMaskingOp());
+ rootOp = maskableOp.getMaskingOp();
+ mask = maskableOp.getMaskingOp().getMask();
+ } else {
+ rootOp = multiReductionOp;
+ }
+
+ SmallVector<Value> vectors;
+ for (int64_t i = 0; i < numElementwiseOps; ++i)
+ vectors.push_back(vector::ExtractOp::create(rewriter, loc, source, i));
+
+ SmallVector<Value> masks;
+ for (int64_t i = 0; i < numElementwiseOps; ++i)
+ if (isMasked)
+ masks.push_back(vector::ExtractOp::create(rewriter, loc, mask, i));
+ else
+ masks.push_back(nullptr);
+
+ ArrayRef<bool> reductionMask =
+ ArrayRef<bool>(multiReductionOp.getReductionMask()).drop_front();
+ Value result = multiReductionOp.getAcc();
+ for (auto [innerVector, innerMask] : llvm::zip(vectors, masks)) {
+
+ auto reductionOp = vector::MultiDimReductionOp::create(
+ rewriter, loc, innerVector, result, reductionMask,
+ multiReductionOp.getKind());
+
+ if (isMasked) {
+ auto maskOp = vector::maskOperation(rewriter, reductionOp, innerMask);
+ result = maskOp->getResult(0);
+ } else {
+ result = reductionOp.getResult();
+ }
+ }
+
+ rewriter.replaceOp(rootOp, result);
+ return success();
+ }
+};
+
/// Unrolls vector.multi_reduction with outermost reductions
/// and combines results
struct TwoDimMultiReductionToElementWise
@@ -518,6 +706,12 @@ void mlir::vector::populateVectorInnerOuterDimReductionConversionPatterns(
benefit);
}
+void mlir::vector::populateVectorUnrollMultiReduction(
+ RewritePatternSet &patterns, PatternBenefit benefit) {
+ patterns.add<UnrollMultiReductionBaseCase>(patterns.getContext());
+ patterns.add<UnrollMultiReductionGeneralCase>(patterns.getContext());
+}
+
void mlir::vector::populateVectorMultiReductionLoweringPatterns(
RewritePatternSet &patterns, VectorMultiReductionLowering options,
PatternBenefit benefit) {
diff --git a/mlir/test/Dialect/Vector/td/unroll-multi-reduction.mlir b/mlir/test/Dialect/Vector/td/unroll-multi-reduction.mlir
new file mode 100644
index 0000000000000..89d311f3a60e5
--- /dev/null
+++ b/mlir/test/Dialect/Vector/td/unroll-multi-reduction.mlir
@@ -0,0 +1,14 @@
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @unroll_multi_reduction(%module_op: !transform.any_op {transform.readonly}) {
+
+ %func_op = transform.structured.match ops{["func.func"]} in %module_op
+ : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func_op {
+ // Test patterns
+ transform.apply_patterns.vector.unroll_multi_reduction
+ } : !transform.any_op
+
+ transform.yield
+ }
+}
+
diff --git a/mlir/test/Dialect/Vector/unroll-vector-multi-reduction.mlir b/mlir/test/Dialect/Vector/unroll-vector-multi-reduction.mlir
new file mode 100644
index 0000000000000..79086e2b0b9ad
--- /dev/null
+++ b/mlir/test/Dialect/Vector/unroll-vector-multi-reduction.mlir
@@ -0,0 +1,92 @@
+// RUN: mlir-opt --split-input-file %s -transform-preload-library='transform-library-paths=%p/td/unroll-multi-reduction.mlir' \
+// RUN: -transform-interpreter=entry-point=unroll_multi_reduction | FileCheck %s
+
+//===----------------------------------------------------------------------===//
+// Test UnrollVectorMultiReduction
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: func @unroll_vector_multi_reduction(
+// CHECK-SAME: %[[SOURCE:.+]]: vector<2x3x5xf32>,
+// CHECK-SAME: %[[ACC:.+]]: vector<3x5xf32>
+func.func @unroll_vector_multi_reduction(%source: vector<2x3x5xf32>, %acc: vector<3x5xf32>) -> (vector<3x5xf32>) {
+ // CHECK-DAG: %[[VEC_0:.+]] = vector.extract %[[SOURCE]][0] : vector<3x5xf32> from vector<2x3x5xf32>
+ // CHECK-DAG: %[[VEC_1:.+]] = vector.extract %[[SOURCE]][1] : vector<3x5xf32> from vector<2x3x5xf32>
+
+ // CHECK: %[[RES_0:.+]] = arith.addf %[[VEC_0]], %[[ACC]] : vector<3x5xf32>
+ // CHECK: %[[RES_1:.+]] = arith.addf %[[VEC_1]], %[[RES_0]] : vector<3x5xf32>
+ %1 = vector.multi_reduction <add>, %source, %acc [0] : vector<2x3x5xf32> to vector<3x5xf32>
+
+ // CHECK: return %[[RES_1]]
+ return %1 : vector<3x5xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @unroll_vector_multi_reduction_masked(
+// CHECK-SAME: %[[SOURCE:.+]]: vector<2x3x5xf32>,
+// CHECK-SAME: %[[MASK:.+]]: vector<2x3x5xi1>,
+// CHECK-SAME: %[[ACC:.+]]: vector<3x5xf32>
+func.func @unroll_vector_multi_reduction_masked(%source: vector<2x3x5xf32>, %mask: vector<2x3x5xi1>, %acc: vector<3x5xf32>) -> (vector<3x5xf32>) {
+ // CHECK-DAG: %[[VEC_0:.+]] = vector.extract %[[SOURCE]][0] : vector<3x5xf32> from vector<2x3x5xf32>
+ // CHECK-DAG: %[[VEC_1:.+]] = vector.extract %[[SOURCE]][1] : vector<3x5xf32> from vector<2x3x5xf32>
+
+ // CHECK-DAG: %[[MASK_0:.+]] = vector.extract %[[MASK]][0] : vector<3x5xi1> from vector<2x3x5xi1>
+ // CHECK-DAG: %[[MASK_1:.+]] = vector.extract %[[MASK]][1] : vector<3x5xi1> from vector<2x3x5xi1>
+
+ // CHECK: %[[RES_0:.+]] = arith.addf %[[VEC_0]], %[[ACC]] : vector<3x5xf32>
+ // CHECK: %[[RES_MASKED_0:.+]] = arith.select %[[MASK_0]], %[[RES_0]], %[[ACC]] : vector<3x5xi1>, vector<3x5xf32>
+
+ // CHECK: %[[RES_1:.+]] = arith.addf %[[VEC_1]], %[[RES_MASKED_0]] : vector<3x5xf32>
+ // CHECK: %[[RES_MASKED_1:.+]] = arith.select %[[MASK_1]], %[[RES_1]], %[[RES_MASKED_0]] : vector<3x5xi1>, vector<3x5xf32>
+
+ %0 = vector.mask %mask {
+ %1 = vector.multi_reduction <add>, %source, %acc [0] : vector<2x3x5xf32> to vector<3x5xf32>
+ } : vector<2x3x5xi1> -> vector<3x5xf32>
+
+ // CHECK: return %[[RES_MASKED_1]]
+ return %0 : vector<3x5xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @unroll_vector_multi_reduction_general(
+// CHECK-SAME: %[[SOURCE:.+]]: vector<2x3x5xf32>,
+// CHECK-SAME: %[[ACC:.+]]: vector<3xf32>
+func.func @unroll_vector_multi_reduction_general(%source: vector<2x3x5xf32>, %acc: vector<3xf32>) -> (vector<3xf32>) {
+
+ // CHECK-DAG: %[[VEC_0:.+]] = vector.extract %[[SOURCE]][0] : vector<3x5xf32> from vector<2x3x5xf32>
+ // CHECK-DAG: %[[VEC_1:.+]] = vector.extract %[[SOURCE]][1] : vector<3x5xf32> from vector<2x3x5xf32>
+
+ // CHECK: %[[RES_0:.+]] = vector.multi_reduction <add>, %[[VEC_0]], %[[ACC]] [1] : vector<3x5xf32> to vector<3xf32>
+ // CHECK: %[[RES_1:.+]] = vector.multi_reduction <add>, %[[VEC_1]], %[[RES_0]] [1] : vector<3x5xf32> to vector<3xf32>
+
+ %1 = vector.multi_reduction <add>, %source, %acc [0, 2] : vector<2x3x5xf32> to vector<3xf32>
+
+ // CHECK: return %[[RES_1]]
+ return %1 : vector<3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @unroll_vector_multi_reduction_general_masked(
+// CHECK-SAME: %[[SOURCE:.+]]: vector<2x3x5xf32>,
+// CHECK-SAME: %[[MASK:.+]]: vector<2x3x5xi1>,
+// CHECK-SAME: %[[ACC:.+]]: vector<3xf32>
+func.func @unroll_vector_multi_reduction_general_masked(%source: vector<2x3x5xf32>, %mask: vector<2x3x5xi1>, %acc: vector<3xf32>) -> (vector<3xf32>) {
+
+ // CHECK-DAG: %[[VEC_0:.+]] = vector.extract %[[SOURCE]][0] : vector<3x5xf32> from vector<2x3x5xf32>
+ // CHECK-DAG: %[[VEC_1:.+]] = vector.extract %[[SOURCE]][1] : vector<3x5xf32> from vector<2x3x5xf32>
+
+ // CHECK-DAG: %[[MASK_0:.+]] = vector.extract %[[MASK]][0] : vector<3x5xi1> from vector<2x3x5xi1>
+ // CHECK-DAG: %[[MASK_1:.+]] = vector.extract %[[MASK]][1] : vector<3x5xi1> from vector<2x3x5xi1>
+
+ // CHECK: %[[RES_0:.+]] = vector.mask %[[MASK_0]] { vector.multi_reduction <add>, %[[VEC_0]], %[[ACC]] [1] : vector<3x5xf32> to vector<3xf32> } : vector<3x5xi1> -> vector<3xf32>
+ // CHECK: %[[RES_1:.+]] = vector.mask %[[MASK_1]] { vector.multi_reduction <add>, %[[VEC_1]], %[[RES_0]] [1] : vector<3x5xf32> to vector<3xf32> } : vector<3x5xi1> -> vector<3xf32>
+
+ %0 = vector.mask %mask {
+ %1 = vector.multi_reduction <add>, %source, %acc [0, 2] : vector<2x3x5xf32> to vector<3xf32>
+ } : vector<2x3x5xi1> -> vector<3xf32>
+
+ // CHECK: return %[[RES_1]]
+ return %0 : vector<3xf32>
+}
diff --git a/mlir/test/Dialect/Vector/vector-inner-outer-dim-reduction-conversion.mlir b/mlir/test/Dialect/Vector/vector-inner-outer-dim-reduction-conversion.mlir
index bc99e64e2f909..9c508eb3d53c0 100644
--- a/mlir/test/Dialect/Vector/vector-inner-outer-dim-reduction-conversion.mlir
+++ b/mlir/test/Dialect/Vector/vector-inner-outer-dim-reduction-conversion.mlir
@@ -16,4 +16,3 @@ func.func @inner_outer_dim_reduction_conversion(%arg0: vector<2x3x5x7xf32>, %acc
// CHECK: return %[[RES]]
return %1 : vector<3x5xf32>
}
-
More information about the Mlir-commits
mailing list