[Mlir-commits] [mlir] [mlir][vector] Add elementwise unrolling for vector.multi_reduction (PR #176036)

Erick Ochoa Lopez llvmlistbot at llvm.org
Thu Jan 15 14:08:47 PST 2026


https://github.com/amd-eochoalo updated https://github.com/llvm/llvm-project/pull/176036

>From 64119f03a8233db5345c9f355fc50d2e91ac401a Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Wed, 14 Jan 2026 11:36:25 -0500
Subject: [PATCH 1/3] [mlir][vector] Add
 populateVectorInnerOuterDimReductionConversionPatterns

This method will be used for testing and for unrolling
vector.multi_reduction in a rank-reducing way that does not involve
flattening.
---
 .../mlir/Dialect/Vector/Transforms/LoweringPatterns.h      | 3 +++
 .../Vector/Transforms/LowerVectorMultiReduction.cpp        | 7 +++++++
 2 files changed, 10 insertions(+)

diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index 7bd96c8a6d1a1..658365b97d721 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -328,6 +328,9 @@ void populateVectorContractToMatrixMultiply(RewritePatternSet &patterns,
 void populateVectorTransposeToFlatTranspose(RewritePatternSet &patterns,
                                             PatternBenefit benefit = 100);
 
+void populateVectorInnerOuterDimReductionConversionPatterns(
+    RewritePatternSet &patterns, VectorMultiReductionLowering options,
+    PatternBenefit benefit = 1);
 } // namespace vector
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index e86e2a97038db..b5660efc8f4cf 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -511,6 +511,13 @@ struct LowerVectorMultiReductionPass
 
 } // namespace
 
+void mlir::vector::populateVectorInnerOuterDimReductionConversionPatterns(
+    RewritePatternSet &patterns, VectorMultiReductionLowering options,
+    PatternBenefit benefit) {
+  patterns.add<InnerOuterDimReductionConversion>(patterns.getContext(), options,
+                                                 benefit);
+}
+
 void mlir::vector::populateVectorMultiReductionLoweringPatterns(
     RewritePatternSet &patterns, VectorMultiReductionLowering options,
     PatternBenefit benefit) {

>From faa27de4bb0ec4d9d66b97e4f83492f5e72912ce Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Wed, 14 Jan 2026 11:46:50 -0500
Subject: [PATCH 2/3] [mlir][vector] use TD Op for testing multi_reduction

* Test vector.multi_reduction's patterns using the transform dialect.
  This will be useful as we will be creating new patterns to unroll
  vector.multi_reduction without flattening.
---
 .../Vector/TransformOps/VectorTransformOps.td | 15 +++++++++++++++
 .../TransformOps/VectorTransformOps.cpp       |  8 ++++++++
 .../Vector/td/inner-outer-dim-conversion.mlir | 14 ++++++++++++++
 ...-inner-outer-dim-reduction-conversion.mlir | 19 +++++++++++++++++++
 4 files changed, 56 insertions(+)
 create mode 100644 mlir/test/Dialect/Vector/td/inner-outer-dim-conversion.mlir
 create mode 100644 mlir/test/Dialect/Vector/vector-inner-outer-dim-reduction-conversion.mlir

diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index 03d25505dc65c..e2423593075f6 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -539,4 +539,19 @@ def ApplySinkVectorMemPatternsOp : Op<Transform_Dialect,
   let assemblyFormat = "attr-dict";
 }
 
+def ApplyInnerOuterDimReductionConversionPatternsOp : Op<Transform_Dialect,
+    "apply_patterns.vector.inner_outer_dim_reduction_conversion",
+    [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+  let description = [{
+    Converts vector.multi_reduction into inner-most/outer-most reduction form
+    by using vector.transpose.
+  }];
+
+  let arguments = (ins DefaultValuedAttr<VectorMultiReductionLoweringAttr,
+      "vector::VectorMultiReductionLowering::InnerParallel">:$lowering_strategy
+  );
+
+  let assemblyFormat = "attr-dict";
+}
+
 #endif // VECTOR_TRANSFORM_OPS
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index 7faa222a9e574..53284c5873841 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -227,6 +227,14 @@ void transform::ApplySinkVectorMemPatternsOp::populatePatterns(
   vector::populateSinkVectorMemOpsPatterns(patterns);
 }
 
+void transform::ApplyInnerOuterDimReductionConversionPatternsOp::
+    populatePatterns(RewritePatternSet &patterns) {
+  vector::VectorTransformsOptions vectorTransformOptions;
+  vectorTransformOptions.setVectorMultiReductionLowering(getLoweringStrategy());
+  vector::populateVectorInnerOuterDimReductionConversionPatterns(
+      patterns, vectorTransformOptions.vectorMultiReductionLowering);
+}
+
 //===----------------------------------------------------------------------===//
 // Transform op registration
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/td/inner-outer-dim-conversion.mlir b/mlir/test/Dialect/Vector/td/inner-outer-dim-conversion.mlir
new file mode 100644
index 0000000000000..e965e414ddeae
--- /dev/null
+++ b/mlir/test/Dialect/Vector/td/inner-outer-dim-conversion.mlir
@@ -0,0 +1,14 @@
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @inner_outer_dim_reduction_conversion(%module_op: !transform.any_op {transform.readonly}) {
+
+    %func_op = transform.structured.match ops{["func.func"]} in %module_op
+      : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func_op {
+      // Test patterns
+      transform.apply_patterns.vector.inner_outer_dim_reduction_conversion
+    } : !transform.any_op
+
+    transform.yield
+  }
+}
+
diff --git a/mlir/test/Dialect/Vector/vector-inner-outer-dim-reduction-conversion.mlir b/mlir/test/Dialect/Vector/vector-inner-outer-dim-reduction-conversion.mlir
new file mode 100644
index 0000000000000..bc99e64e2f909
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-inner-outer-dim-reduction-conversion.mlir
@@ -0,0 +1,19 @@
+// RUN: mlir-opt %s -transform-preload-library='transform-library-paths=%p/td/inner-outer-dim-conversion.mlir' \
+// RUN: -transform-interpreter=entry-point=inner_outer_dim_reduction_conversion | FileCheck %s
+
+//===----------------------------------------------------------------------===//
+// Test InnerOuterDimReductionConversion
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: func @inner_outer_dim_reduction_conversion(
+// CHECK-SAME: %[[ARG0:.+]]: vector<2x3x5x7xf32>,
+// CHECK-SAME: %[[ACC:.+]]: vector<3x5xf32>
+func.func @inner_outer_dim_reduction_conversion(%arg0: vector<2x3x5x7xf32>, %acc: vector<3x5xf32>) -> (vector<3x5xf32>) {
+  // CHECK: %[[TRANSPOSE:.+]] = vector.transpose %[[ARG0]], [0, 3, 1, 2] : vector<2x3x5x7xf32> to vector<2x7x3x5xf32>
+  // CHECK: %[[RES:.+]] = vector.multi_reduction <add>, %[[TRANSPOSE]], %[[ACC]] [0, 1] : vector<2x7x3x5xf32> to vector<3x5xf32>
+  %1 = vector.multi_reduction <add>, %arg0, %acc [0, 3] : vector<2x3x5x7xf32> to vector<3x5xf32>
+
+  // CHECK: return %[[RES]]
+  return %1 : vector<3x5xf32>
+}
+

>From 8ccd08384ba16af8b8d0de02b2e31e8a7249fb78 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Wed, 14 Jan 2026 14:53:25 -0500
Subject: [PATCH 3/3] [mlir][vector] Add unrolling pattern for
 vector.multi_reduction

---
 .../Vector/TransformOps/VectorTransformOps.td |  10 +
 .../Vector/Transforms/LoweringPatterns.h      |   4 +
 .../TransformOps/VectorTransformOps.cpp       |   5 +
 .../Transforms/LowerVectorMultiReduction.cpp  | 194 ++++++++++++++++++
 .../Vector/td/unroll-multi-reduction.mlir     |  14 ++
 .../Vector/unroll-vector-multi-reduction.mlir |  92 +++++++++
 ...-inner-outer-dim-reduction-conversion.mlir |   1 -
 7 files changed, 319 insertions(+), 1 deletion(-)
 create mode 100644 mlir/test/Dialect/Vector/td/unroll-multi-reduction.mlir
 create mode 100644 mlir/test/Dialect/Vector/unroll-vector-multi-reduction.mlir

diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index e2423593075f6..52205f7b1f81c 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -554,4 +554,14 @@ def ApplyInnerOuterDimReductionConversionPatternsOp : Op<Transform_Dialect,
   let assemblyFormat = "attr-dict";
 }
 
+def ApplyUnrollMultiReductionPatternsOp : Op<Transform_Dialect,
+    "apply_patterns.vector.unroll_multi_reduction",
+    [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+  let description = [{
+    Unrolls vector.multi_reduction.
+  }];
+
+  let assemblyFormat = "attr-dict";
+}
+
 #endif // VECTOR_TRANSFORM_OPS
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index 658365b97d721..f37e45c09ebd1 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -331,6 +331,10 @@ void populateVectorTransposeToFlatTranspose(RewritePatternSet &patterns,
 void populateVectorInnerOuterDimReductionConversionPatterns(
     RewritePatternSet &patterns, VectorMultiReductionLowering options,
     PatternBenefit benefit = 1);
+
+void populateVectorUnrollMultiReduction(RewritePatternSet &patterns,
+                                        PatternBenefit benefit = 1);
+
 } // namespace vector
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index 53284c5873841..337f18c7d68c9 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -235,6 +235,11 @@ void transform::ApplyInnerOuterDimReductionConversionPatternsOp::
       patterns, vectorTransformOptions.vectorMultiReductionLowering);
 }
 
+void transform::ApplyUnrollMultiReductionPatternsOp::populatePatterns(
+    RewritePatternSet &patterns) {
+  vector::populateVectorUnrollMultiReduction(patterns);
+}
+
 //===----------------------------------------------------------------------===//
 // Transform op registration
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index b5660efc8f4cf..3b8fba0d08801 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -300,6 +300,194 @@ class ReduceMultiDimReductionRank
   const bool useInnerDimsForReduction;
 };
 
+/// Unrolls outermost dimension for vector.multi_reduction.
+/// This patterns matches operations which reduce the outermost dimension,
+/// it does not transform operations for which the outermost dimension is not
+/// a reduction dimension.
+///
+/// There are two cases to consider:
+/// 1. The base case is when the outermost dimension is the only reduction
+/// dimension.
+/// 2. The general case is when the outermost dimension is not the only
+/// reduction dimension.
+///
+/// The base case transformation:
+///
+/// ```mlir
+/// %res = vector.multi_reduction <add> %src, %acc [0] : vector<NxMx...xf32> to
+/// vector<Mx...xf32>
+/// ```
+///
+/// will extract N vectors from %src and then perform elementwise operations.
+///
+/// ```mlir
+/// %0 = vector.extract %src[0] : vector<Mx...xf32> from vector<NxMx...xf32>
+/// ...
+/// %Nminus1 = vector.extract %src[ [[N-1]] ] : vector<Mx...x.f32> from
+/// vector<NxMx...xf32>
+///
+/// %res0 = arith.addf %0, %acc : vector<Mx...xf32>
+/// ...
+/// %res = arith.addf %Nminus1, %resNminus2 : vector<Mx...xf32>
+/// ```
+///
+/// For the general case, we still extract N vectors, but produce N
+/// vector.multi_reduction instead of elementwise operations.
+///
+/// ```mlir
+/// %res = vector.multi_reduction <add> %src, %acc [0, [[REDUCTION_DIMS]] ] :
+/// vector<NxMx...xf32> to vector<Ix...xf32>
+///
+/// ```mlir
+/// %0 = vector.extract %src[0] : vector<Mx...xf32> from vector<NxMx...xf32>
+/// ...
+/// %Nminus1 = vector.extract %src[ [[N-1]] ] : vector<Mx...x.f32> from
+/// vector<NxMx...xf32>
+///
+/// %red0 = vector.multi_reduction %0, %acc [ [[REDUCTION_DIMS]] ] :
+/// vector<Mx...xf32> to vector<Ix...xf32>
+/// ...
+/// %res = vector.multi_reduction %Nminus1, %redNminus2 [ [[REDUCTION_DIMS]] ] :
+/// vector<Mx...xf32> to vector<Ix...xf32>
+/// ```
+struct UnrollMultiReductionBaseCase
+    : public OpRewritePattern<vector::MultiDimReductionOp> {
+  using Base::Base;
+
+  LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
+                                PatternRewriter &rewriter) const override {
+    if (!multiReductionOp.isReducedDim(0))
+      return rewriter.notifyMatchFailure(
+          multiReductionOp, "expected outermost dimension to be reduced dimension.");
+
+    Type elementType = getElementTypeOrSelf(multiReductionOp.getDestType());
+    if (!elementType.isIntOrIndexOrFloat())
+      return rewriter.notifyMatchFailure(
+          multiReductionOp, "expected integer or float element type.");
+
+    ArrayRef<int64_t> reductionDims = multiReductionOp.getReductionDims();
+    if (reductionDims.size() > 1)
+      return rewriter.notifyMatchFailure(
+          multiReductionOp, "expected only one reduction dimension.");
+
+    Location loc = multiReductionOp.getLoc();
+    Value source = multiReductionOp.getSource();
+
+    ArrayRef<int64_t> srcShape =
+        multiReductionOp.getSourceVectorType().getShape();
+    int64_t numElementwiseOps = srcShape.front();
+
+    OpBuilder::InsertionGuard guard(rewriter);
+    auto maskableOp =
+        cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
+    bool isMasked = maskableOp.isMasked();
+    Operation *rootOp;
+    Value mask = nullptr;
+    if (isMasked) {
+      rewriter.setInsertionPoint(maskableOp.getMaskingOp());
+      rootOp = maskableOp.getMaskingOp();
+      mask = maskableOp.getMaskingOp().getMask();
+    } else {
+      rootOp = multiReductionOp;
+    }
+
+    SmallVector<Value> vectors;
+    for (int64_t i = 0; i < numElementwiseOps; ++i)
+      vectors.push_back(vector::ExtractOp::create(rewriter, loc, source, i));
+
+    SmallVector<Value> masks;
+    for (int64_t i = 0; i < numElementwiseOps; ++i)
+      if (isMasked)
+        masks.push_back(vector::ExtractOp::create(rewriter, loc, mask, i));
+      else
+        masks.push_back(nullptr);
+
+    Value result = multiReductionOp.getAcc();
+    for (auto [innerVector, innerMask] : llvm::zip(vectors, masks))
+      result = makeArithReduction(rewriter, loc, multiReductionOp.getKind(),
+                                  innerVector, result, /*fastmath=*/nullptr,
+                                  innerMask);
+
+    rewriter.replaceOp(rootOp, result);
+    return success();
+  }
+};
+
+struct UnrollMultiReductionGeneralCase
+    : public OpRewritePattern<vector::MultiDimReductionOp> {
+  using Base::Base;
+
+  LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
+                                PatternRewriter &rewriter) const override {
+    if (!multiReductionOp.isReducedDim(0))
+      return rewriter.notifyMatchFailure(
+          multiReductionOp,
+          "expected outermost dimension to be reduced dimension.");
+
+    Type elementType = getElementTypeOrSelf(multiReductionOp.getDestType());
+    if (!elementType.isIntOrIndexOrFloat())
+      return rewriter.notifyMatchFailure(
+          multiReductionOp, "expected integer or float element type.");
+
+    ArrayRef<int64_t> reductionDims = multiReductionOp.getReductionDims();
+    if (reductionDims.size() <= 1)
+      return rewriter.notifyMatchFailure(
+          multiReductionOp, "expected more than one reduction dimension.");
+
+    Location loc = multiReductionOp.getLoc();
+    Value source = multiReductionOp.getSource();
+
+    ArrayRef<int64_t> srcShape =
+        multiReductionOp.getSourceVectorType().getShape();
+    int64_t numElementwiseOps = srcShape.front();
+
+    OpBuilder::InsertionGuard guard(rewriter);
+    auto maskableOp =
+        cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
+    bool isMasked = maskableOp.isMasked();
+    Operation *rootOp;
+    Value mask = nullptr;
+    if (isMasked) {
+      rewriter.setInsertionPoint(maskableOp.getMaskingOp());
+      rootOp = maskableOp.getMaskingOp();
+      mask = maskableOp.getMaskingOp().getMask();
+    } else {
+      rootOp = multiReductionOp;
+    }
+
+    SmallVector<Value> vectors;
+    for (int64_t i = 0; i < numElementwiseOps; ++i)
+      vectors.push_back(vector::ExtractOp::create(rewriter, loc, source, i));
+
+    SmallVector<Value> masks;
+    for (int64_t i = 0; i < numElementwiseOps; ++i)
+      if (isMasked)
+        masks.push_back(vector::ExtractOp::create(rewriter, loc, mask, i));
+      else
+        masks.push_back(nullptr);
+
+    ArrayRef<bool> reductionMask =
+        ArrayRef<bool>(multiReductionOp.getReductionMask()).drop_front();
+    Value result = multiReductionOp.getAcc();
+    for (auto [innerVector, innerMask] : llvm::zip(vectors, masks)) {
+
+      auto reductionOp = vector::MultiDimReductionOp::create(
+          rewriter, loc, innerVector, result, reductionMask,
+          multiReductionOp.getKind());
+
+      if (isMasked) {
+        auto maskOp = vector::maskOperation(rewriter, reductionOp, innerMask);
+        result = maskOp->getResult(0);
+      } else {
+        result = reductionOp.getResult();
+      }
+    }
+
+    rewriter.replaceOp(rootOp, result);
+    return success();
+  }
+};
+
 /// Unrolls vector.multi_reduction with outermost reductions
 /// and combines results
 struct TwoDimMultiReductionToElementWise
@@ -518,6 +706,12 @@ void mlir::vector::populateVectorInnerOuterDimReductionConversionPatterns(
                                                  benefit);
 }
 
+void mlir::vector::populateVectorUnrollMultiReduction(
+    RewritePatternSet &patterns, PatternBenefit benefit) {
+  patterns.add<UnrollMultiReductionBaseCase>(patterns.getContext());
+  patterns.add<UnrollMultiReductionGeneralCase>(patterns.getContext());
+}
+
 void mlir::vector::populateVectorMultiReductionLoweringPatterns(
     RewritePatternSet &patterns, VectorMultiReductionLowering options,
     PatternBenefit benefit) {
diff --git a/mlir/test/Dialect/Vector/td/unroll-multi-reduction.mlir b/mlir/test/Dialect/Vector/td/unroll-multi-reduction.mlir
new file mode 100644
index 0000000000000..89d311f3a60e5
--- /dev/null
+++ b/mlir/test/Dialect/Vector/td/unroll-multi-reduction.mlir
@@ -0,0 +1,14 @@
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @unroll_multi_reduction(%module_op: !transform.any_op {transform.readonly}) {
+
+    %func_op = transform.structured.match ops{["func.func"]} in %module_op
+      : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func_op {
+      // Test patterns
+      transform.apply_patterns.vector.unroll_multi_reduction
+    } : !transform.any_op
+
+    transform.yield
+  }
+}
+
diff --git a/mlir/test/Dialect/Vector/unroll-vector-multi-reduction.mlir b/mlir/test/Dialect/Vector/unroll-vector-multi-reduction.mlir
new file mode 100644
index 0000000000000..79086e2b0b9ad
--- /dev/null
+++ b/mlir/test/Dialect/Vector/unroll-vector-multi-reduction.mlir
@@ -0,0 +1,92 @@
+// RUN: mlir-opt --split-input-file %s -transform-preload-library='transform-library-paths=%p/td/unroll-multi-reduction.mlir' \
+// RUN: -transform-interpreter=entry-point=unroll_multi_reduction | FileCheck %s
+
+//===----------------------------------------------------------------------===//
+// Test UnrollVectorMultiReduction
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: func @unroll_vector_multi_reduction(
+// CHECK-SAME: %[[SOURCE:.+]]: vector<2x3x5xf32>,
+// CHECK-SAME: %[[ACC:.+]]: vector<3x5xf32>
+func.func @unroll_vector_multi_reduction(%source: vector<2x3x5xf32>, %acc: vector<3x5xf32>) -> (vector<3x5xf32>) {
+  // CHECK-DAG: %[[VEC_0:.+]] = vector.extract %[[SOURCE]][0] : vector<3x5xf32> from vector<2x3x5xf32>
+  // CHECK-DAG: %[[VEC_1:.+]] = vector.extract %[[SOURCE]][1] : vector<3x5xf32> from vector<2x3x5xf32>
+
+  // CHECK: %[[RES_0:.+]] = arith.addf %[[VEC_0]], %[[ACC]] : vector<3x5xf32>
+  // CHECK: %[[RES_1:.+]] = arith.addf %[[VEC_1]], %[[RES_0]] : vector<3x5xf32>
+  %1 = vector.multi_reduction <add>, %source, %acc [0] : vector<2x3x5xf32> to vector<3x5xf32>
+
+  // CHECK: return %[[RES_1]]
+  return %1 : vector<3x5xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @unroll_vector_multi_reduction_masked(
+// CHECK-SAME: %[[SOURCE:.+]]: vector<2x3x5xf32>,
+// CHECK-SAME: %[[MASK:.+]]: vector<2x3x5xi1>,
+// CHECK-SAME: %[[ACC:.+]]: vector<3x5xf32>
+func.func @unroll_vector_multi_reduction_masked(%source: vector<2x3x5xf32>, %mask: vector<2x3x5xi1>, %acc: vector<3x5xf32>) -> (vector<3x5xf32>) {
+  // CHECK-DAG: %[[VEC_0:.+]] = vector.extract %[[SOURCE]][0] : vector<3x5xf32> from vector<2x3x5xf32>
+  // CHECK-DAG: %[[VEC_1:.+]] = vector.extract %[[SOURCE]][1] : vector<3x5xf32> from vector<2x3x5xf32>
+
+  // CHECK-DAG: %[[MASK_0:.+]] = vector.extract %[[MASK]][0] : vector<3x5xi1> from vector<2x3x5xi1>
+  // CHECK-DAG: %[[MASK_1:.+]] = vector.extract %[[MASK]][1] : vector<3x5xi1> from vector<2x3x5xi1>
+
+  // CHECK: %[[RES_0:.+]] = arith.addf %[[VEC_0]], %[[ACC]] : vector<3x5xf32>
+  // CHECK: %[[RES_MASKED_0:.+]] = arith.select %[[MASK_0]], %[[RES_0]], %[[ACC]] : vector<3x5xi1>, vector<3x5xf32>
+
+  // CHECK: %[[RES_1:.+]] = arith.addf %[[VEC_1]], %[[RES_MASKED_0]] : vector<3x5xf32>
+  // CHECK: %[[RES_MASKED_1:.+]] = arith.select %[[MASK_1]], %[[RES_1]], %[[RES_MASKED_0]] : vector<3x5xi1>, vector<3x5xf32>
+
+  %0 = vector.mask %mask {
+    %1 = vector.multi_reduction <add>, %source, %acc [0] : vector<2x3x5xf32> to vector<3x5xf32>
+  } : vector<2x3x5xi1> -> vector<3x5xf32>
+
+  // CHECK: return %[[RES_MASKED_1]]
+  return %0 : vector<3x5xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @unroll_vector_multi_reduction_general(
+// CHECK-SAME: %[[SOURCE:.+]]: vector<2x3x5xf32>,
+// CHECK-SAME: %[[ACC:.+]]: vector<3xf32>
+func.func @unroll_vector_multi_reduction_general(%source: vector<2x3x5xf32>, %acc: vector<3xf32>) -> (vector<3xf32>) {
+
+  // CHECK-DAG: %[[VEC_0:.+]] = vector.extract %[[SOURCE]][0] : vector<3x5xf32> from vector<2x3x5xf32>
+  // CHECK-DAG: %[[VEC_1:.+]] = vector.extract %[[SOURCE]][1] : vector<3x5xf32> from vector<2x3x5xf32>
+
+  // CHECK: %[[RES_0:.+]] = vector.multi_reduction <add>, %[[VEC_0]], %[[ACC]] [1] : vector<3x5xf32> to vector<3xf32>
+  // CHECK: %[[RES_1:.+]] = vector.multi_reduction <add>, %[[VEC_1]], %[[RES_0]] [1] : vector<3x5xf32> to vector<3xf32>
+
+  %1 = vector.multi_reduction <add>, %source, %acc [0, 2] : vector<2x3x5xf32> to vector<3xf32>
+
+  // CHECK: return %[[RES_1]]
+  return %1 : vector<3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @unroll_vector_multi_reduction_general_masked(
+// CHECK-SAME: %[[SOURCE:.+]]: vector<2x3x5xf32>,
+// CHECK-SAME: %[[MASK:.+]]: vector<2x3x5xi1>,
+// CHECK-SAME: %[[ACC:.+]]: vector<3xf32>
+func.func @unroll_vector_multi_reduction_general_masked(%source: vector<2x3x5xf32>, %mask: vector<2x3x5xi1>, %acc: vector<3xf32>) -> (vector<3xf32>) {
+
+  // CHECK-DAG: %[[VEC_0:.+]] = vector.extract %[[SOURCE]][0] : vector<3x5xf32> from vector<2x3x5xf32>
+  // CHECK-DAG: %[[VEC_1:.+]] = vector.extract %[[SOURCE]][1] : vector<3x5xf32> from vector<2x3x5xf32>
+
+  // CHECK-DAG: %[[MASK_0:.+]] = vector.extract %[[MASK]][0] : vector<3x5xi1> from vector<2x3x5xi1>
+  // CHECK-DAG: %[[MASK_1:.+]] = vector.extract %[[MASK]][1] : vector<3x5xi1> from vector<2x3x5xi1>
+
+  // CHECK: %[[RES_0:.+]] = vector.mask %[[MASK_0]] { vector.multi_reduction <add>, %[[VEC_0]], %[[ACC]] [1] : vector<3x5xf32> to vector<3xf32> } : vector<3x5xi1> -> vector<3xf32>
+  // CHECK: %[[RES_1:.+]] = vector.mask %[[MASK_1]] { vector.multi_reduction <add>, %[[VEC_1]], %[[RES_0]] [1] : vector<3x5xf32> to vector<3xf32> } : vector<3x5xi1> -> vector<3xf32>
+
+  %0 = vector.mask %mask {
+    %1 = vector.multi_reduction <add>, %source, %acc [0, 2] : vector<2x3x5xf32> to vector<3xf32>
+  } : vector<2x3x5xi1> -> vector<3xf32>
+
+  // CHECK: return %[[RES_1]]
+  return %0 : vector<3xf32>
+}
diff --git a/mlir/test/Dialect/Vector/vector-inner-outer-dim-reduction-conversion.mlir b/mlir/test/Dialect/Vector/vector-inner-outer-dim-reduction-conversion.mlir
index bc99e64e2f909..9c508eb3d53c0 100644
--- a/mlir/test/Dialect/Vector/vector-inner-outer-dim-reduction-conversion.mlir
+++ b/mlir/test/Dialect/Vector/vector-inner-outer-dim-reduction-conversion.mlir
@@ -16,4 +16,3 @@ func.func @inner_outer_dim_reduction_conversion(%arg0: vector<2x3x5x7xf32>, %acc
   // CHECK: return %[[RES]]
   return %1 : vector<3x5xf32>
 }
-



More information about the Mlir-commits mailing list