[Mlir-commits] [mlir] [mlir][vector] Add pattern to break down reductions into arith ops (PR #75727)

Jakub Kuderski llvmlistbot at llvm.org
Mon Dec 18 11:47:32 PST 2023


https://github.com/kuhar updated https://github.com/llvm/llvm-project/pull/75727

>From c1bdff52068a5e0f19ffebbda8aff73e4dc34ef7 Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Sat, 16 Dec 2023 21:05:48 -0500
Subject: [PATCH 1/2] [mlir][vector] Improve `makeArithReduction` expansion

Propagate fast math flags.
Distinguish `minf`/`maxf` and `minimumf`/`maximumf`.

Required for future patterns in
https://github.com/llvm/llvm-project/pull/75727.
---
 .../mlir/Dialect/Vector/IR/VectorOps.h        |  6 +++--
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp      | 25 +++++++++++++------
 .../Vector/Transforms/LowerVectorContract.cpp |  3 ++-
 .../VectorToLLVM/vector-to-llvm.mlir          |  4 +--
 mlir/test/Dialect/Vector/canonicalize.mlir    | 12 +++++++++
 ...vector-multi-reduction-outer-lowering.mlir | 16 ++++++------
 6 files changed, 46 insertions(+), 20 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
index 59d585a77b1e29..a28b27e4e15816 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
@@ -123,10 +123,12 @@ bool isDisjointTransferSet(VectorTransferOpInterface transferA,
                            VectorTransferOpInterface transferB,
                            bool testDynamicValueUsingBounds = false);
 
-/// Return the result value of reducing two scalar/vector values with the
+/// Returns the result value of reducing two scalar/vector values with the
 /// corresponding arith operation.
 Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind,
-                         Value v1, Value acc, Value mask = Value());
+                         Value v1, Value acc,
+                         arith::FastMathFlagsAttr fastmath = nullptr,
+                         Value mask = nullptr);
 
 /// Returns true if `attr` has "parallel" iterator type semantics.
 inline bool isParallelIterator(Attribute attr) {
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 540959b486db9c..9f3e13c90a624d 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -507,8 +507,9 @@ struct ElideUnitDimsInMultiDimReduction
                                                 zeroIdx);
     }
 
-    Value result = vector::makeArithReduction(
-        rewriter, loc, reductionOp.getKind(), acc, cast, mask);
+    Value result =
+        vector::makeArithReduction(rewriter, loc, reductionOp.getKind(), acc,
+                                   cast, /*fastmath=*/nullptr, mask);
     rewriter.replaceOp(rootOp, result);
     return success();
   }
@@ -650,7 +651,8 @@ struct ElideSingleElementReduction : public OpRewritePattern<ReductionOp> {
 
     if (Value acc = reductionOp.getAcc())
       result = vector::makeArithReduction(rewriter, loc, reductionOp.getKind(),
-                                          result, acc, mask);
+                                          result, acc,
+                                          reductionOp.getFastmathAttr(), mask);
 
     rewriter.replaceOp(rootOp, result);
     return success();
@@ -6212,6 +6214,7 @@ bool WarpExecuteOnLane0Op::areTypesCompatible(Type lhs, Type rhs) {
 
 Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc,
                                        CombiningKind kind, Value v1, Value acc,
+                                       arith::FastMathFlagsAttr fastmath,
                                        Value mask) {
   Type t1 = getElementTypeOrSelf(v1.getType());
   Type tAcc = getElementTypeOrSelf(acc.getType());
@@ -6222,7 +6225,7 @@ Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc,
     if (t1.isIntOrIndex() && tAcc.isIntOrIndex())
       result = b.createOrFold<arith::AddIOp>(loc, v1, acc);
     else if (llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc))
-      result = b.createOrFold<arith::AddFOp>(loc, v1, acc);
+      result = b.createOrFold<arith::AddFOp>(loc, v1, acc, fastmath);
     else
       llvm_unreachable("invalid value types for ADD reduction");
     break;
@@ -6231,16 +6234,24 @@ Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc,
     result = b.createOrFold<arith::AndIOp>(loc, v1, acc);
     break;
   case CombiningKind::MAXF:
+    assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
+           "expected float values");
+    result = b.createOrFold<arith::MaxNumFOp>(loc, v1, acc, fastmath);
+    break;
   case CombiningKind::MAXIMUMF:
     assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
            "expected float values");
-    result = b.createOrFold<arith::MaximumFOp>(loc, v1, acc);
+    result = b.createOrFold<arith::MaximumFOp>(loc, v1, acc, fastmath);
     break;
   case CombiningKind::MINF:
+    assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
+           "expected float values");
+    result = b.createOrFold<arith::MinNumFOp>(loc, v1, acc, fastmath);
+    break;
   case CombiningKind::MINIMUMF:
     assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
            "expected float values");
-    result = b.createOrFold<arith::MinimumFOp>(loc, v1, acc);
+    result = b.createOrFold<arith::MinimumFOp>(loc, v1, acc, fastmath);
     break;
   case CombiningKind::MAXSI:
     assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
@@ -6262,7 +6273,7 @@ Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc,
     if (t1.isIntOrIndex() && tAcc.isIntOrIndex())
       result = b.createOrFold<arith::MulIOp>(loc, v1, acc);
     else if (llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc))
-      result = b.createOrFold<arith::MulFOp>(loc, v1, acc);
+      result = b.createOrFold<arith::MulFOp>(loc, v1, acc, fastmath);
     else
       llvm_unreachable("invalid value types for MUL reduction");
     break;
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
index 6dbe36e605e9a7..41ff0c18fe6258 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
@@ -167,7 +167,8 @@ createContractArithOp(Location loc, Value x, Value y, Value acc,
   if (!acc)
     return std::optional<Value>(mul);
 
-  return makeArithReduction(rewriter, loc, kind, mul, acc, mask);
+  return makeArithReduction(rewriter, loc, kind, mul, acc,
+                            /*fastmath=*/nullptr, mask);
 }
 
 /// Return the positions of the reductions in the given map.
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 012d30d96799f2..7353d16d79cea0 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -450,7 +450,7 @@ func.func @masked_float_max_outerprod(%arg0: vector<2xf32>, %arg1: f32, %arg2: v
 // CHECK-LABEL:   func.func @masked_float_max_outerprod(
 // CHECK-SAME:                                          %[[VAL_0:.*]]: vector<2xf32>, %[[VAL_1:.*]]: f32, %[[VAL_2:.*]]: vector<2xf32>, %[[VAL_3:.*]]: vector<2xi1>) -> vector<2xf32> {
 // CHECK:           %[[VAL_8:.*]] = arith.mulf %[[VAL_0]], %{{.*}} : vector<2xf32>
-// CHECK:           %[[VAL_9:.*]] = arith.maximumf %[[VAL_8]], %[[VAL_2]] : vector<2xf32>
+// CHECK:           %[[VAL_9:.*]] = arith.maxnumf %[[VAL_8]], %[[VAL_2]] : vector<2xf32>
 // CHECK:           %[[VAL_10:.*]] = arith.select %[[VAL_3]], %[[VAL_9]], %[[VAL_2]] : vector<2xi1>, vector<2xf32>
 
 // -----
@@ -463,7 +463,7 @@ func.func @masked_float_min_outerprod(%arg0: vector<2xf32>, %arg1: f32, %arg2: v
 // CHECK-LABEL:   func.func @masked_float_min_outerprod(
 // CHECK-SAME:                                          %[[VAL_0:.*]]: vector<2xf32>, %[[VAL_1:.*]]: f32, %[[VAL_2:.*]]: vector<2xf32>, %[[VAL_3:.*]]: vector<2xi1>) -> vector<2xf32> {
 // CHECK:           %[[VAL_8:.*]] = arith.mulf %[[VAL_0]], %{{.*}} : vector<2xf32>
-// CHECK:           %[[VAL_9:.*]] = arith.minimumf %[[VAL_8]], %[[VAL_2]] : vector<2xf32>
+// CHECK:           %[[VAL_9:.*]] = arith.minnumf %[[VAL_8]], %[[VAL_2]] : vector<2xf32>
 // CHECK:           %[[VAL_10:.*]] = arith.select %[[VAL_3]], %[[VAL_9]], %[[VAL_2]] : vector<2xi1>, vector<2xf32>
 
 // -----
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 1021c73cc57d34..b5164b66817352 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -2172,6 +2172,18 @@ func.func @reduce_one_element_vector_addf(%a : vector<1xf32>, %b: f32) -> f32 {
 
 // -----
 
+// CHECK-LABEL: func @reduce_one_element_vector_addf_fastmath
+//  CHECK-SAME: (%[[V:.+]]: vector<1xf32>, %[[B:.+]]: f32)
+//       CHECK:   %[[A:.+]] = vector.extract %[[V]][0] : f32 from vector<1xf32>
+//       CHECK:   %[[S:.+]] = arith.addf %[[A]], %arg1 fastmath<nnan,ninf> : f32
+//       CHECK:   return %[[S]]
+func.func @reduce_one_element_vector_addf_fastmath(%a : vector<1xf32>, %b: f32) -> f32 {
+  %s = vector.reduction <add>, %a, %b fastmath<nnan,ninf> : vector<1xf32> into f32
+  return %s : f32
+}
+
+// -----
+
 // CHECK-LABEL: func @masked_reduce_one_element_vector_addf
 //  CHECK-SAME: %[[VAL_0:.*]]: vector<1xf32>, %[[VAL_1:.*]]: f32,
 //  CHECK-SAME: %[[VAL_2:.*]]: vector<1xi1>)
diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir
index 12ea87ffb1413f..614a97fe4d6777 100644
--- a/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir
@@ -27,13 +27,13 @@ func.func @vector_multi_reduction_min(%arg0: vector<2x4xf32>, %acc: vector<2xf32
 //  CHECK-SAME:   %[[INPUT:.+]]: vector<2x4xf32>, %[[ACC:.*]]: vector<2xf32>
 //       CHECK:   %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xf32> to vector<4x2xf32>
 //       CHECK:   %[[V0:.+]] = vector.extract %[[TRANSPOSED]][0] : vector<2xf32> from vector<4x2xf32>
-//       CHECK:   %[[RV0:.+]] = arith.minimumf %[[V0]], %[[ACC]] : vector<2xf32>
+//       CHECK:   %[[RV0:.+]] = arith.minnumf %[[V0]], %[[ACC]] : vector<2xf32>
 //       CHECK:   %[[V1:.+]] = vector.extract %[[TRANSPOSED]][1] : vector<2xf32> from vector<4x2xf32>
-//       CHECK:   %[[RV01:.+]] = arith.minimumf %[[V1]], %[[RV0]] : vector<2xf32>
+//       CHECK:   %[[RV01:.+]] = arith.minnumf %[[V1]], %[[RV0]] : vector<2xf32>
 //       CHECK:   %[[V2:.+]] = vector.extract %[[TRANSPOSED]][2] : vector<2xf32> from vector<4x2xf32>
-//       CHECK:   %[[RV012:.+]] = arith.minimumf %[[V2]], %[[RV01]] : vector<2xf32>
+//       CHECK:   %[[RV012:.+]] = arith.minnumf %[[V2]], %[[RV01]] : vector<2xf32>
 //       CHECK:   %[[V3:.+]] = vector.extract %[[TRANSPOSED]][3] : vector<2xf32> from vector<4x2xf32>
-//       CHECK:   %[[RESULT_VEC:.+]] = arith.minimumf %[[V3]], %[[RV012]] : vector<2xf32>
+//       CHECK:   %[[RESULT_VEC:.+]] = arith.minnumf %[[V3]], %[[RV012]] : vector<2xf32>
 //       CHECK:   return %[[RESULT_VEC]] : vector<2xf32>
 
 func.func @vector_multi_reduction_max(%arg0: vector<2x4xf32>, %acc: vector<2xf32>) -> vector<2xf32> {
@@ -45,13 +45,13 @@ func.func @vector_multi_reduction_max(%arg0: vector<2x4xf32>, %acc: vector<2xf32
 //  CHECK-SAME:   %[[INPUT:.+]]: vector<2x4xf32>, %[[ACC:.*]]: vector<2xf32>
 //       CHECK:   %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xf32> to vector<4x2xf32>
 //       CHECK:   %[[V0:.+]] = vector.extract %[[TRANSPOSED]][0] : vector<2xf32> from vector<4x2xf32>
-//       CHECK:   %[[RV0:.+]] = arith.maximumf %[[V0]], %[[ACC]] : vector<2xf32>
+//       CHECK:   %[[RV0:.+]] = arith.maxnumf %[[V0]], %[[ACC]] : vector<2xf32>
 //       CHECK:   %[[V1:.+]] = vector.extract %[[TRANSPOSED]][1] : vector<2xf32> from vector<4x2xf32>
-//       CHECK:   %[[RV01:.+]] = arith.maximumf %[[V1]], %[[RV0]] : vector<2xf32>
+//       CHECK:   %[[RV01:.+]] = arith.maxnumf %[[V1]], %[[RV0]] : vector<2xf32>
 //       CHECK:   %[[V2:.+]] = vector.extract %[[TRANSPOSED]][2] : vector<2xf32> from vector<4x2xf32>
-//       CHECK:   %[[RV012:.+]] = arith.maximumf %[[V2]], %[[RV01]] : vector<2xf32>
+//       CHECK:   %[[RV012:.+]] = arith.maxnumf %[[V2]], %[[RV01]] : vector<2xf32>
 //       CHECK:   %[[V3:.+]] = vector.extract %[[TRANSPOSED]][3] : vector<2xf32> from vector<4x2xf32>
-//       CHECK:   %[[RESULT_VEC:.+]] = arith.maximumf %[[V3]], %[[RV012]] : vector<2xf32>
+//       CHECK:   %[[RESULT_VEC:.+]] = arith.maxnumf %[[V3]], %[[RV012]] : vector<2xf32>
 //       CHECK:   return %[[RESULT_VEC]] : vector<2xf32>
 
 func.func @vector_multi_reduction_and(%arg0: vector<2x4xi32>, %acc: vector<2xi32>) -> vector<2xi32> {

>From 0c2a497481866bc5ea06c49372fc89ccd0e8cc0f Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Sat, 16 Dec 2023 21:05:48 -0500
Subject: [PATCH 2/2] [mlir][vector] Add pattern to break down small reductions
 into arith ops

The number of vector elements considered 'small' enough to extract is
parameterized.

This is to avoid going into specialized reduction lowering when a
single/couple of arith ops can do. Targets without dedicated reduction
intrinsics can use that as an emulation path too.

Depends on https://github.com/llvm/llvm-project/pull/75846.

 Please enter the commit message for your changes. Lines starting
---
 .../Vector/Transforms/VectorRewritePatterns.h |  19 +++
 .../Vector/Transforms/VectorTransforms.cpp    |  63 +++++++++
 .../Vector/break-down-vector-reduction.mlir   | 126 ++++++++++++++++++
 .../Dialect/Vector/TestVectorTransforms.cpp   |  23 ++++
 4 files changed, 231 insertions(+)
 create mode 100644 mlir/test/Dialect/Vector/break-down-vector-reduction.mlir

diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index 17173c01ab762a..49b74c0c466d2f 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -166,6 +166,25 @@ void populateSinkVectorBroadcastPatterns(RewritePatternSet &patterns,
 void populateChainedVectorReductionFoldingPatterns(RewritePatternSet &patterns,
                                                    PatternBenefit benefit = 1);
 
+/// Patterns to break down vector reductions into a series of arith reductions
+/// over vector elements. This is intended to be simplify code with reductions
+/// over small vector types and avoid more specialized reduction lowering when
+/// possible.
+///
+/// Example:
+/// ```
+/// %a = vector.reduction <add> %x : vector<2xf32> into f32
+/// ```
+/// is transformed into:
+/// ```
+/// %y = vector.extract %x[0] : f32 from vector<2xf32>
+/// %z = vector.extract %x[1] : f32 from vector<2xf32>
+/// %a = arith.addf %y, %z : f32
+/// ```
+void populateBreakDownVectorReductionPatterns(
+    RewritePatternSet &patterns, unsigned maxNumElementsToExtract = 2,
+    PatternBenefit benefit = 1);
+
 /// Populate `patterns` with the following patterns.
 ///
 /// [DecomposeDifferentRankInsertStridedSlice]
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 5936b0b54af4e3..661674dd74c0cd 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -12,6 +12,7 @@
 
 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
 
+#include <cassert>
 #include <cstdint>
 #include <functional>
 #include <optional>
@@ -44,6 +45,7 @@
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Debug.h"
+#include "llvm/Support/FormatVariadic.h"
 #include "llvm/Support/raw_ostream.h"
 
 #define DEBUG_TYPE "vector-to-vector"
@@ -1578,6 +1580,60 @@ struct ReduceRedundantZero final : OpRewritePattern<vector::ReductionOp> {
   }
 };
 
+/// Example:
+/// ```
+/// %a = vector.reduction <add> %x : vector<2xf32> into f32
+/// ```
+/// is transformed into:
+/// ```
+/// %y = vector.extract %x[0] : f32 from vector<2xf32>
+/// %z = vector.extract %x[1] : f32 from vector<2xf32>
+/// %a = arith.addf %y, %z : f32
+/// ```
+struct BreakDownVectorReduction final : OpRewritePattern<vector::ReductionOp> {
+  BreakDownVectorReduction(MLIRContext *context,
+                           unsigned maxNumElementsToExtract,
+                           PatternBenefit benefit)
+      : OpRewritePattern(context, benefit),
+        maxNumElementsToExtract(maxNumElementsToExtract) {}
+
+  LogicalResult matchAndRewrite(vector::ReductionOp op,
+                                PatternRewriter &rewriter) const override {
+    VectorType type = op.getSourceVectorType();
+    if (type.isScalable() || op.isMasked())
+      return failure();
+    assert(type.getRank() == 1 && "Expected a 1-d vector");
+
+    int64_t numElems = type.getNumElements();
+    if (numElems > maxNumElementsToExtract) {
+      return rewriter.notifyMatchFailure(
+          op, llvm::formatv("has too many vector elements ({0}) to break down "
+                            "(max allowed: {1})",
+                            numElems, maxNumElementsToExtract));
+    }
+
+    Location loc = op.getLoc();
+    SmallVector<Value> extracted(numElems, nullptr);
+    for (auto [idx, extractedElem] : llvm::enumerate(extracted))
+      extractedElem = rewriter.create<vector::ExtractOp>(
+          loc, op.getVector(), static_cast<int64_t>(idx));
+
+    Value res = extracted.front();
+    for (auto extractedElem : llvm::drop_begin(extracted))
+      res = vector::makeArithReduction(rewriter, loc, op.getKind(), res,
+                                       extractedElem, op.getFastmathAttr());
+    if (Value acc = op.getAcc())
+      res = vector::makeArithReduction(rewriter, loc, op.getKind(), res, acc,
+                                       op.getFastmathAttr());
+
+    rewriter.replaceOp(op, res);
+    return success();
+  }
+
+private:
+  unsigned maxNumElementsToExtract = 0;
+};
+
 } // namespace
 
 void mlir::vector::populateFoldArithExtensionPatterns(
@@ -1656,6 +1712,13 @@ void mlir::vector::populateChainedVectorReductionFoldingPatterns(
                                     PatternBenefit(benefit.getBenefit() + 1));
 }
 
+void mlir::vector::populateBreakDownVectorReductionPatterns(
+    RewritePatternSet &patterns, unsigned maxNumElementsToExtract,
+    PatternBenefit benefit) {
+  patterns.add<BreakDownVectorReduction>(patterns.getContext(),
+                                         maxNumElementsToExtract, benefit);
+}
+
 //===----------------------------------------------------------------------===//
 // TableGen'd enum attribute definitions
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/break-down-vector-reduction.mlir b/mlir/test/Dialect/Vector/break-down-vector-reduction.mlir
new file mode 100644
index 00000000000000..34234591b79cab
--- /dev/null
+++ b/mlir/test/Dialect/Vector/break-down-vector-reduction.mlir
@@ -0,0 +1,126 @@
+// RUN: mlir-opt %s --test-vector-break-down-reduction-patterns --cse | FileCheck %s
+
+// NOTE: This test pass is set break down vector reductions of size 2 or fewer.
+
+// CHECK-LABEL:   func.func @reduce_2x_f32(
+// CHECK-SAME:     %[[ARG0:.+]]: vector<2xf32>) -> (f32, f32, f32, f32, f32, f32) {
+// CHECK-DAG:      %[[E0:.+]] = vector.extract %[[ARG0]][0] : f32 from vector<2xf32>
+// CHECK-DAG:      %[[E1:.+]] = vector.extract %[[ARG0]][1] : f32 from vector<2xf32>
+// CHECK-DAG:      %[[R0:.+]] = arith.addf %[[E0]], %[[E1]] : f32
+// CHECK-DAG:      %[[R1:.+]] = arith.mulf %[[E0]], %[[E1]] : f32
+// CHECK-DAG:      %[[R2:.+]] = arith.minnumf %[[E0]], %[[E1]] : f32
+// CHECK-DAG:      %[[R3:.+]] = arith.maxnumf %[[E0]], %[[E1]] : f32
+// CHECK-DAG:      %[[R4:.+]] = arith.minimumf %[[E0]], %[[E1]] : f32
+// CHECK-DAG:      %[[R5:.+]] = arith.maximumf %[[E0]], %[[E1]] : f32
+// CHECK:          return %[[R0]], %[[R1]], %[[R2]], %[[R3]], %[[R4]], %[[R5]]
+func.func @reduce_2x_f32(%arg0: vector<2xf32>) -> (f32, f32, f32, f32, f32, f32) {
+  %0 = vector.reduction <add>, %arg0 : vector<2xf32> into f32
+  %1 = vector.reduction <mul>, %arg0 : vector<2xf32> into f32
+  %2 = vector.reduction <minf>, %arg0 : vector<2xf32> into f32
+  %3 = vector.reduction <maxf>, %arg0 : vector<2xf32> into f32
+  %4 = vector.reduction <minimumf>, %arg0 : vector<2xf32> into f32
+  %5 = vector.reduction <maximumf>, %arg0 : vector<2xf32> into f32
+  return %0, %1, %2, %3, %4, %5 : f32, f32, f32, f32, f32, f32
+}
+
+// CHECK-LABEL:   func.func @reduce_2x_i32(
+// CHECK-SAME:     %[[ARG0:.+]]: vector<2xi32>) -> (i32, i32, i32, i32, i32, i32, i32, i32, i32) {
+// CHECK-DAG:      %[[E0:.+]] = vector.extract %[[ARG0]][0] : i32 from vector<2xi32>
+// CHECK-DAG:      %[[E1:.+]] = vector.extract %[[ARG0]][1] : i32 from vector<2xi32>
+// CHECK-DAG:      %[[R0:.+]] = arith.addi %[[E0]], %[[E1]] : i32
+// CHECK-DAG:      %[[R1:.+]] = arith.muli %[[E0]], %[[E1]] : i32
+// CHECK-DAG:      %[[R2:.+]] = arith.minsi %[[E0]], %[[E1]] : i32
+// CHECK-DAG:      %[[R3:.+]] = arith.maxsi %[[E0]], %[[E1]] : i32
+// CHECK-DAG:      %[[R4:.+]] = arith.minui %[[E0]], %[[E1]] : i32
+// CHECK-DAG:      %[[R5:.+]] = arith.maxui %[[E0]], %[[E1]] : i32
+// CHECK-DAG:      %[[R6:.+]] = arith.andi %[[E0]], %[[E1]] : i32
+// CHECK-DAG:      %[[R7:.+]] = arith.ori %[[E0]], %[[E1]] : i32
+// CHECK-DAG:      %[[R8:.+]] = arith.xori %[[E0]], %[[E1]] : i32
+// CHECK:          return %[[R0]], %[[R1]], %[[R2]], %[[R3]], %[[R4]], %[[R5]], %[[R6]], %[[R7]], %[[R8]]
+func.func @reduce_2x_i32(%arg0: vector<2xi32>) -> (i32, i32, i32, i32, i32, i32, i32, i32, i32) {
+  %0 = vector.reduction <add>, %arg0 : vector<2xi32> into i32
+  %1 = vector.reduction <mul>, %arg0 : vector<2xi32> into i32
+  %2 = vector.reduction <minsi>, %arg0 : vector<2xi32> into i32
+  %3 = vector.reduction <maxsi>, %arg0 : vector<2xi32> into i32
+  %4 = vector.reduction <minui>, %arg0 : vector<2xi32> into i32
+  %5 = vector.reduction <maxui>, %arg0 : vector<2xi32> into i32
+  %6 = vector.reduction <and>, %arg0 : vector<2xi32> into i32
+  %7 = vector.reduction <or>, %arg0 : vector<2xi32> into i32
+  %8 = vector.reduction <xor>, %arg0 : vector<2xi32> into i32
+  return %0, %1, %2, %3, %4, %5, %6, %7, %8 : i32, i32, i32, i32, i32, i32, i32, i32, i32
+}
+
+// CHECK-LABEL:   func.func @reduce_1x_f32(
+// CHECK-SAME:     %[[ARG0:.+]]: vector<1xf32>) -> f32 {
+// CHECK-NEXT:     %[[E0:.+]] = vector.extract %[[ARG0]][0] : f32 from vector<1xf32>
+// CHECK-NEXT:     return %[[E0]] : f32
+func.func @reduce_1x_f32(%arg0: vector<1xf32>) -> f32 {
+  %0 = vector.reduction <add>, %arg0 : vector<1xf32> into f32
+  return %0 : f32
+}
+
+// CHECK-LABEL:   func.func @reduce_1x_acc_f32(
+// CHECK-SAME:     %[[ARG0:.+]]: vector<1xf32>, %[[ARG1:.+]]: f32) -> f32 {
+// CHECK-NEXT:     %[[E0:.+]] = vector.extract %[[ARG0]][0] : f32 from vector<1xf32>
+// CHECK-NEXT:     %[[R0:.+]] = arith.addf %[[E0]], %[[ARG1]] : f32
+// CHECK-NEXT:     return %[[R0]] : f32
+func.func @reduce_1x_acc_f32(%arg0: vector<1xf32>, %arg1: f32) -> f32 {
+  %0 = vector.reduction <add>, %arg0, %arg1 : vector<1xf32> into f32
+  return %0 : f32
+}
+
+// CHECK-LABEL:   func.func @reduce_1x_acc_i32(
+// CHECK-SAME:     %[[ARG0:.+]]: vector<1xi32>, %[[ARG1:.+]]: i32) -> i32 {
+// CHECK-NEXT:     %[[E0:.+]] = vector.extract %[[ARG0]][0] : i32 from vector<1xi32>
+// CHECK-NEXT:     %[[R0:.+]] = arith.addi %[[E0]], %[[ARG1]] : i32
+// CHECK-NEXT:     return %[[R0]] : i32
+func.func @reduce_1x_acc_i32(%arg0: vector<1xi32>, %arg1: i32) -> i32 {
+  %0 = vector.reduction <add>, %arg0, %arg1 : vector<1xi32> into i32
+  return %0 : i32
+}
+
+// CHECK-LABEL:   func.func @reduce_2x_acc_f32(
+// CHECK-SAME:     %[[ARG0:.+]]: vector<2xf32>, %[[ARG1:.+]]: f32) -> (f32, f32) {
+// CHECK-DAG:      %[[E0:.+]] = vector.extract %[[ARG0]][0] : f32 from vector<2xf32>
+// CHECK-DAG:      %[[E1:.+]] = vector.extract %[[ARG0]][1] : f32 from vector<2xf32>
+// CHECK:          %[[A0:.+]] = arith.addf %[[E0]], %[[E1]] : f32
+// CHECK:          %[[R0:.+]] = arith.addf %[[A0]], %[[ARG1]] : f32
+// CHECK:          %[[M0:.+]] = arith.mulf %[[E0]], %[[E1]] fastmath<nnan> : f32
+// CHECK:          %[[R1:.+]] = arith.mulf %[[M0]], %[[ARG1]] fastmath<nnan> : f32
+// CHECK-NEXT:     return %[[R0]], %[[R1]] : f32, f32
+func.func @reduce_2x_acc_f32(%arg0: vector<2xf32>, %arg1: f32) -> (f32, f32) {
+  %0 = vector.reduction <add>, %arg0, %arg1 : vector<2xf32> into f32
+  %1 = vector.reduction <mul>, %arg0, %arg1 fastmath<nnan> : vector<2xf32> into f32
+  return %0, %1 : f32, f32
+}
+
+// CHECK-LABEL:   func.func @reduce_3x_f32(
+// CHECK-SAME:     %[[ARG0:.+]]: vector<3xf32>) -> f32 {
+// CHECK-NEXT:     %[[R0:.+]] = vector.reduction <add>, %[[ARG0]] : vector<3xf32> into f32
+// CHECK-NEXT:     return %[[R0]] : f32
+func.func @reduce_3x_f32(%arg0: vector<3xf32>) -> f32 {
+  %0 = vector.reduction <add>, %arg0 : vector<3xf32> into f32
+  return %0 : f32
+}
+
+// Masking is not handled yet.
+// CHECK-LABEL:   func.func @reduce_mask_3x_f32
+// CHECK-NEXT:     %[[M:.+]] = vector.create_mask
+// CHECK-NEXT:     %[[R:.+]] = vector.mask %[[M]]
+// CHECK-SAME:       vector.reduction <add>
+// CHECK-NEXT:     return %[[R]] : f32
+func.func @reduce_mask_3x_f32(%arg0: vector<3xf32>, %arg1: index) -> f32 {
+  %mask = vector.create_mask %arg1 : vector<3xi1>
+  %0 = vector.mask %mask { vector.reduction <add>, %arg0 : vector<3xf32> into f32 } : vector<3xi1> -> f32
+  return %0 : f32
+}
+
+// Scalable vectors are not supported.
+// CHECK-LABEL:   func.func @reduce_scalable_f32(
+// CHECK-SAME:     %[[ARG0:.+]]: vector<[1]xf32>) -> f32 {
+// CHECK-NEXT:     %[[R0:.+]] = vector.reduction <add>, %[[ARG0]] : vector<[1]xf32> into f32
+// CHECK-NEXT:     return %[[R0]] : f32
+func.func @reduce_scalable_f32(%arg0: vector<[1]xf32>) -> f32 {
+  %0 = vector.reduction <add>, %arg0 : vector<[1]xf32> into f32
+  return %0 : f32
+}
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 03ddebe82344d8..126d65b1b8487f 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -439,6 +439,27 @@ struct TestVectorChainedReductionFoldingPatterns
   }
 };
 
+struct TestVectorBreakDownReductionPatterns
+    : public PassWrapper<TestVectorBreakDownReductionPatterns,
+                         OperationPass<func::FuncOp>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
+      TestVectorBreakDownReductionPatterns)
+
+  StringRef getArgument() const final {
+    return "test-vector-break-down-reduction-patterns";
+  }
+  StringRef getDescription() const final {
+    return "Test patterns to break down vector reductions into arith "
+           "reductions";
+  }
+  void runOnOperation() override {
+    RewritePatternSet patterns(&getContext());
+    populateBreakDownVectorReductionPatterns(patterns,
+                                             /*maxNumElementsToExtract=*/2);
+    (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+  }
+};
+
 struct TestFlattenVectorTransferPatterns
     : public PassWrapper<TestFlattenVectorTransferPatterns,
                          OperationPass<func::FuncOp>> {
@@ -827,6 +848,8 @@ void registerTestVectorLowerings() {
 
   PassRegistration<TestVectorChainedReductionFoldingPatterns>();
 
+  PassRegistration<TestVectorBreakDownReductionPatterns>();
+
   PassRegistration<TestFlattenVectorTransferPatterns>();
 
   PassRegistration<TestVectorScanLowering>();



More information about the Mlir-commits mailing list