[Mlir-commits] [mlir] [mlir][vector] Add pattern to break down small reductions into arith ops (PR #75727)

Jakub Kuderski llvmlistbot at llvm.org
Sat Dec 16 18:12:32 PST 2023


https://github.com/kuhar created https://github.com/llvm/llvm-project/pull/75727

The number of vector elements considered 'small' is parameterized. This is to avoid going into specialized reduction lowering when a single/couple of arith ops can do.

Also update the `vector::makeArithReduction` helper to distinguish `minf`/`maxf` and `minimumf`/`maximumf`, and to propagate fast math flags.

>From 693ecfbb497e5c6bd0dd6276947e7f078c6bdb94 Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Sat, 16 Dec 2023 21:05:48 -0500
Subject: [PATCH] [mlir][vector] Add pattern to break down small reductions
 into arith ops

The number of vector elements considered 'small' is parameterized.
This is to avoid going into specialized reduction lowering when a
single/couple of arith ops can do.

Also update the `vector::makeArithReduction` helper to distinguish
`minf`/`maxf` and `minimumf`/`maximumf`, and to propagate fast math
flags.
---
 .../mlir/Dialect/Vector/IR/VectorOps.h        |   6 +-
 .../Vector/Transforms/VectorRewritePatterns.h |  19 ++++
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp      |  25 +++--
 .../Vector/Transforms/LowerVectorContract.cpp |   3 +-
 .../Vector/Transforms/VectorTransforms.cpp    |  53 +++++++++
 .../Vector/break-down-vector-reduction.mlir   | 104 ++++++++++++++++++
 mlir/test/Dialect/Vector/canonicalize.mlir    |  12 ++
 ...vector-multi-reduction-outer-lowering.mlir |  16 +--
 .../Dialect/Vector/TestVectorTransforms.cpp   |  23 ++++
 9 files changed, 243 insertions(+), 18 deletions(-)
 create mode 100644 mlir/test/Dialect/Vector/break-down-vector-reduction.mlir

diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
index 59d585a77b1e29..a28b27e4e15816 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
@@ -123,10 +123,12 @@ bool isDisjointTransferSet(VectorTransferOpInterface transferA,
                            VectorTransferOpInterface transferB,
                            bool testDynamicValueUsingBounds = false);
 
-/// Return the result value of reducing two scalar/vector values with the
+/// Returns the result value of reducing two scalar/vector values with the
 /// corresponding arith operation.
 Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind,
-                         Value v1, Value acc, Value mask = Value());
+                         Value v1, Value acc,
+                         arith::FastMathFlagsAttr fastmath = nullptr,
+                         Value mask = nullptr);
 
 /// Returns true if `attr` has "parallel" iterator type semantics.
 inline bool isParallelIterator(Attribute attr) {
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index 17173c01ab762a..49b74c0c466d2f 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -166,6 +166,25 @@ void populateSinkVectorBroadcastPatterns(RewritePatternSet &patterns,
 void populateChainedVectorReductionFoldingPatterns(RewritePatternSet &patterns,
                                                    PatternBenefit benefit = 1);
 
+/// Patterns to break down vector reductions into a series of arith reductions
+/// over vector elements. This is intended to be simplify code with reductions
+/// over small vector types and avoid more specialized reduction lowering when
+/// possible.
+///
+/// Example:
+/// ```
+/// %a = vector.reduction <add> %x : vector<2xf32> into f32
+/// ```
+/// is transformed into:
+/// ```
+/// %y = vector.extract %x[0] : f32 from vector<2xf32>
+/// %z = vector.extract %x[1] : f32 from vector<2xf32>
+/// %a = arith.addf %y, %z : f32
+/// ```
+void populateBreakDownVectorReductionPatterns(
+    RewritePatternSet &patterns, unsigned maxNumElementsToExtract = 2,
+    PatternBenefit benefit = 1);
+
 /// Populate `patterns` with the following patterns.
 ///
 /// [DecomposeDifferentRankInsertStridedSlice]
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 540959b486db9c..9f3e13c90a624d 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -507,8 +507,9 @@ struct ElideUnitDimsInMultiDimReduction
                                                 zeroIdx);
     }
 
-    Value result = vector::makeArithReduction(
-        rewriter, loc, reductionOp.getKind(), acc, cast, mask);
+    Value result =
+        vector::makeArithReduction(rewriter, loc, reductionOp.getKind(), acc,
+                                   cast, /*fastmath=*/nullptr, mask);
     rewriter.replaceOp(rootOp, result);
     return success();
   }
@@ -650,7 +651,8 @@ struct ElideSingleElementReduction : public OpRewritePattern<ReductionOp> {
 
     if (Value acc = reductionOp.getAcc())
       result = vector::makeArithReduction(rewriter, loc, reductionOp.getKind(),
-                                          result, acc, mask);
+                                          result, acc,
+                                          reductionOp.getFastmathAttr(), mask);
 
     rewriter.replaceOp(rootOp, result);
     return success();
@@ -6212,6 +6214,7 @@ bool WarpExecuteOnLane0Op::areTypesCompatible(Type lhs, Type rhs) {
 
 Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc,
                                        CombiningKind kind, Value v1, Value acc,
+                                       arith::FastMathFlagsAttr fastmath,
                                        Value mask) {
   Type t1 = getElementTypeOrSelf(v1.getType());
   Type tAcc = getElementTypeOrSelf(acc.getType());
@@ -6222,7 +6225,7 @@ Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc,
     if (t1.isIntOrIndex() && tAcc.isIntOrIndex())
       result = b.createOrFold<arith::AddIOp>(loc, v1, acc);
     else if (llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc))
-      result = b.createOrFold<arith::AddFOp>(loc, v1, acc);
+      result = b.createOrFold<arith::AddFOp>(loc, v1, acc, fastmath);
     else
       llvm_unreachable("invalid value types for ADD reduction");
     break;
@@ -6231,16 +6234,24 @@ Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc,
     result = b.createOrFold<arith::AndIOp>(loc, v1, acc);
     break;
   case CombiningKind::MAXF:
+    assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
+           "expected float values");
+    result = b.createOrFold<arith::MaxNumFOp>(loc, v1, acc, fastmath);
+    break;
   case CombiningKind::MAXIMUMF:
     assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
            "expected float values");
-    result = b.createOrFold<arith::MaximumFOp>(loc, v1, acc);
+    result = b.createOrFold<arith::MaximumFOp>(loc, v1, acc, fastmath);
     break;
   case CombiningKind::MINF:
+    assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
+           "expected float values");
+    result = b.createOrFold<arith::MinNumFOp>(loc, v1, acc, fastmath);
+    break;
   case CombiningKind::MINIMUMF:
     assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
            "expected float values");
-    result = b.createOrFold<arith::MinimumFOp>(loc, v1, acc);
+    result = b.createOrFold<arith::MinimumFOp>(loc, v1, acc, fastmath);
     break;
   case CombiningKind::MAXSI:
     assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
@@ -6262,7 +6273,7 @@ Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc,
     if (t1.isIntOrIndex() && tAcc.isIntOrIndex())
       result = b.createOrFold<arith::MulIOp>(loc, v1, acc);
     else if (llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc))
-      result = b.createOrFold<arith::MulFOp>(loc, v1, acc);
+      result = b.createOrFold<arith::MulFOp>(loc, v1, acc, fastmath);
     else
       llvm_unreachable("invalid value types for MUL reduction");
     break;
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
index 6dbe36e605e9a7..41ff0c18fe6258 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
@@ -167,7 +167,8 @@ createContractArithOp(Location loc, Value x, Value y, Value acc,
   if (!acc)
     return std::optional<Value>(mul);
 
-  return makeArithReduction(rewriter, loc, kind, mul, acc, mask);
+  return makeArithReduction(rewriter, loc, kind, mul, acc,
+                            /*fastmath=*/nullptr, mask);
 }
 
 /// Return the positions of the reductions in the given map.
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 5936b0b54af4e3..143360079916a3 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -12,6 +12,7 @@
 
 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
 
+#include <cassert>
 #include <cstdint>
 #include <functional>
 #include <optional>
@@ -44,6 +45,7 @@
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Debug.h"
+#include "llvm/Support/FormatVariadic.h"
 #include "llvm/Support/raw_ostream.h"
 
 #define DEBUG_TYPE "vector-to-vector"
@@ -1578,6 +1580,50 @@ struct ReduceRedundantZero final : OpRewritePattern<vector::ReductionOp> {
   }
 };
 
+struct BreakDownVectorReduction final : OpRewritePattern<vector::ReductionOp> {
+  BreakDownVectorReduction(MLIRContext *context,
+                           unsigned maxNumElementsToExtract,
+                           PatternBenefit benefit)
+      : OpRewritePattern(context, benefit),
+        maxNumElementsToExtract(maxNumElementsToExtract) {}
+
+  LogicalResult matchAndRewrite(vector::ReductionOp op,
+                                PatternRewriter &rewriter) const override {
+    VectorType type = op.getSourceVectorType();
+    if (type.isScalable() || op.isMasked())
+      return failure();
+    assert(type.getRank() == 1 && "Expected a 1-d vector");
+
+    int64_t numElems = type.getNumElements();
+    if (numElems > maxNumElementsToExtract) {
+      return rewriter.notifyMatchFailure(
+          op, llvm::formatv("has too many vector elements ({0}) to break down "
+                            "(max allowed: {1})",
+                            numElems, maxNumElementsToExtract));
+    }
+
+    Location loc = op.getLoc();
+    SmallVector<Value> extracted(numElems, nullptr);
+    for (auto [idx, extractedElem] : llvm::enumerate(extracted))
+      extractedElem = rewriter.create<vector::ExtractOp>(
+          loc, op.getVector(), static_cast<int64_t>(idx));
+
+    Value res = extracted.front();
+    for (auto extractedElem : llvm::drop_begin(extracted))
+      res = vector::makeArithReduction(rewriter, loc, op.getKind(), res,
+                                       extractedElem, op.getFastmathAttr());
+    if (Value acc = op.getAcc())
+      res = vector::makeArithReduction(rewriter, loc, op.getKind(), res, acc,
+                                       op.getFastmathAttr());
+
+    rewriter.replaceOp(op, res);
+    return success();
+  }
+
+private:
+  unsigned maxNumElementsToExtract = 0;
+};
+
 } // namespace
 
 void mlir::vector::populateFoldArithExtensionPatterns(
@@ -1656,6 +1702,13 @@ void mlir::vector::populateChainedVectorReductionFoldingPatterns(
                                     PatternBenefit(benefit.getBenefit() + 1));
 }
 
+void mlir::vector::populateBreakDownVectorReductionPatterns(
+    RewritePatternSet &patterns, unsigned maxNumElementsToExtract,
+    PatternBenefit benefit) {
+  patterns.add<BreakDownVectorReduction>(patterns.getContext(),
+                                         maxNumElementsToExtract, benefit);
+}
+
 //===----------------------------------------------------------------------===//
 // TableGen'd enum attribute definitions
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/break-down-vector-reduction.mlir b/mlir/test/Dialect/Vector/break-down-vector-reduction.mlir
new file mode 100644
index 00000000000000..b7bc19594491bd
--- /dev/null
+++ b/mlir/test/Dialect/Vector/break-down-vector-reduction.mlir
@@ -0,0 +1,104 @@
+// RUN: mlir-opt %s --test-vector-break-down-reduction-patterns --cse | FileCheck %s
+
+// NOTE: This test pass is set break down vector reductions of size 2 or fewer.
+
+// CHECK-LABEL:   func.func @reduce_2x_fp32(
+// CHECK-SAME:     %[[ARG0:.+]]: vector<2xf32>) -> (f32, f32, f32, f32, f32, f32) {
+// CHECK-DAG:      %[[E0:.+]] = vector.extract %[[ARG0]][0] : f32 from vector<2xf32>
+// CHECK-DAG:      %[[E1:.+]] = vector.extract %[[ARG0]][1] : f32 from vector<2xf32>
+// CHECK-DAG:      %[[R0:.+]] = arith.addf %[[E0]], %[[E1]] : f32
+// CHECK-DAG:      %[[R1:.+]] = arith.mulf %[[E0]], %[[E1]] : f32
+// CHECK-DAG:      %[[R2:.+]] = arith.minnumf %[[E0]], %[[E1]] : f32
+// CHECK-DAG:      %[[R3:.+]] = arith.maxnumf %[[E0]], %[[E1]] : f32
+// CHECK-DAG:      %[[R4:.+]] = arith.minimumf %[[E0]], %[[E1]] : f32
+// CHECK-DAG:      %[[R5:.+]] = arith.maximumf %[[E0]], %[[E1]] : f32
+// CHECK:          return %[[R0]], %[[R1]], %[[R2]], %[[R3]], %[[R4]], %[[R5]]
+func.func @reduce_2x_fp32(%arg0: vector<2xf32>) -> (f32, f32, f32, f32, f32, f32) {
+  %0 = vector.reduction <add>, %arg0 : vector<2xf32> into f32
+  %1 = vector.reduction <mul>, %arg0 : vector<2xf32> into f32
+  %2 = vector.reduction <minf>, %arg0 : vector<2xf32> into f32
+  %3 = vector.reduction <maxf>, %arg0 : vector<2xf32> into f32
+  %4 = vector.reduction <minimumf>, %arg0 : vector<2xf32> into f32
+  %5 = vector.reduction <maximumf>, %arg0 : vector<2xf32> into f32
+  return %0, %1, %2, %3, %4, %5 : f32, f32, f32, f32, f32, f32
+}
+
+// CHECK-LABEL:   func.func @reduce_2x_i32(
+// CHECK-SAME:     %[[ARG0:.+]]: vector<2xi32>) -> (i32, i32, i32, i32, i32, i32, i32, i32, i32) {
+// CHECK-DAG:      %[[E0:.+]] = vector.extract %[[ARG0]][0] : i32 from vector<2xi32>
+// CHECK-DAG:      %[[E1:.+]] = vector.extract %[[ARG0]][1] : i32 from vector<2xi32>
+// CHECK-DAG:      %[[R0:.+]] = arith.addi %[[E0]], %[[E1]] : i32
+// CHECK-DAG:      %[[R1:.+]] = arith.muli %[[E0]], %[[E1]] : i32
+// CHECK-DAG:      %[[R2:.+]] = arith.minsi %[[E0]], %[[E1]] : i32
+// CHECK-DAG:      %[[R3:.+]] = arith.maxsi %[[E0]], %[[E1]] : i32
+// CHECK-DAG:      %[[R4:.+]] = arith.minui %[[E0]], %[[E1]] : i32
+// CHECK-DAG:      %[[R5:.+]] = arith.maxui %[[E0]], %[[E1]] : i32
+// CHECK-DAG:      %[[R6:.+]] = arith.andi %[[E0]], %[[E1]] : i32
+// CHECK-DAG:      %[[R7:.+]] = arith.ori %[[E0]], %[[E1]] : i32
+// CHECK-DAG:      %[[R8:.+]] = arith.xori %[[E0]], %[[E1]] : i32
+// CHECK:          return %[[R0]], %[[R1]], %[[R2]], %[[R3]], %[[R4]], %[[R5]], %[[R6]], %[[R7]], %[[R8]]
+func.func @reduce_2x_i32(%arg0: vector<2xi32>) -> (i32, i32, i32, i32, i32, i32, i32, i32, i32) {
+  %0 = vector.reduction <add>, %arg0 : vector<2xi32> into i32
+  %1 = vector.reduction <mul>, %arg0 : vector<2xi32> into i32
+  %2 = vector.reduction <minsi>, %arg0 : vector<2xi32> into i32
+  %3 = vector.reduction <maxsi>, %arg0 : vector<2xi32> into i32
+  %4 = vector.reduction <minui>, %arg0 : vector<2xi32> into i32
+  %5 = vector.reduction <maxui>, %arg0 : vector<2xi32> into i32
+  %6 = vector.reduction <and>, %arg0 : vector<2xi32> into i32
+  %7 = vector.reduction <or>, %arg0 : vector<2xi32> into i32
+  %8 = vector.reduction <xor>, %arg0 : vector<2xi32> into i32
+  return %0, %1, %2, %3, %4, %5, %6, %7, %8 : i32, i32, i32, i32, i32, i32, i32, i32, i32
+}
+
+// CHECK-LABEL:   func.func @reduce_1x_fp32(
+// CHECK-SAME:     %[[ARG0:.+]]: vector<1xf32>) -> f32 {
+// CHECK-NEXT:     %[[E0:.+]] = vector.extract %[[ARG0]][0] : f32 from vector<1xf32>
+// CHECK-NEXT:     return %[[E0]] : f32
+func.func @reduce_1x_fp32(%arg0: vector<1xf32>) -> f32 {
+  %0 = vector.reduction <add>, %arg0 : vector<1xf32> into f32
+  return %0 : f32
+}
+
+// CHECK-LABEL:   func.func @reduce_1x_acc_fp32(
+// CHECK-SAME:     %[[ARG0:.+]]: vector<1xf32>, %[[ARG1:.+]]: f32) -> f32 {
+// CHECK-NEXT:     %[[E0:.+]] = vector.extract %[[ARG0]][0] : f32 from vector<1xf32>
+// CHECK-NEXT:     %[[R0:.+]] = arith.addf %[[E0]], %[[ARG1]] : f32
+// CHECK-NEXT:     return %[[R0]] : f32
+func.func @reduce_1x_acc_fp32(%arg0: vector<1xf32>, %arg1: f32) -> f32 {
+  %0 = vector.reduction <add>, %arg0, %arg1 : vector<1xf32> into f32
+  return %0 : f32
+}
+
+// CHECK-LABEL:   func.func @reduce_1x_acc_i32(
+// CHECK-SAME:     %[[ARG0:.+]]: vector<1xi32>, %[[ARG1:.+]]: i32) -> i32 {
+// CHECK-NEXT:     %[[E0:.+]] = vector.extract %[[ARG0]][0] : i32 from vector<1xi32>
+// CHECK-NEXT:     %[[R0:.+]] = arith.addi %[[E0]], %[[ARG1]] : i32
+// CHECK-NEXT:     return %[[R0]] : i32
+func.func @reduce_1x_acc_i32(%arg0: vector<1xi32>, %arg1: i32) -> i32 {
+  %0 = vector.reduction <add>, %arg0, %arg1 : vector<1xi32> into i32
+  return %0 : i32
+}
+
+// CHECK-LABEL:   func.func @reduce_2x_acc_fp32(
+// CHECK-SAME:     %[[ARG0:.+]]: vector<2xf32>, %[[ARG1:.+]]: f32) -> (f32, f32) {
+// CHECK-DAG:      %[[E0:.+]] = vector.extract %[[ARG0]][0] : f32 from vector<2xf32>
+// CHECK-DAG:      %[[E1:.+]] = vector.extract %[[ARG0]][1] : f32 from vector<2xf32>
+// CHECK:          %[[A0:.+]] = arith.addf %[[E0]], %[[E1]] : f32
+// CHECK:          %[[R0:.+]] = arith.addf %[[A0]], %[[ARG1]] : f32
+// CHECK:          %[[M0:.+]] = arith.mulf %[[E0]], %[[E1]] fastmath<nnan> : f32
+// CHECK:          %[[R1:.+]] = arith.mulf %[[M0]], %[[ARG1]] fastmath<nnan> : f32
+// CHECK-NEXT:     return %[[R0]], %[[R1]] : f32, f32
+func.func @reduce_2x_acc_fp32(%arg0: vector<2xf32>, %arg1: f32) -> (f32, f32) {
+  %0 = vector.reduction <add>, %arg0, %arg1 : vector<2xf32> into f32
+  %1 = vector.reduction <mul>, %arg0, %arg1 fastmath<nnan> : vector<2xf32> into f32
+  return %0, %1 : f32, f32
+}
+
+// CHECK-LABEL:   func.func @reduce_3x_fp32(
+// CHECK-SAME:     %[[ARG0:.+]]: vector<3xf32>) -> f32 {
+// CHECK-NEXT:     %[[R0:.+]] = vector.reduction <add>, %[[ARG0]] : vector<3xf32> into f32
+// CHECK-NEXT:     return %[[R0]] : f32
+func.func @reduce_3x_fp32(%arg0: vector<3xf32>) -> f32 {
+  %0 = vector.reduction <add>, %arg0 : vector<3xf32> into f32
+  return %0 : f32
+}
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 1021c73cc57d34..b5164b66817352 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -2172,6 +2172,18 @@ func.func @reduce_one_element_vector_addf(%a : vector<1xf32>, %b: f32) -> f32 {
 
 // -----
 
+// CHECK-LABEL: func @reduce_one_element_vector_addf_fastmath
+//  CHECK-SAME: (%[[V:.+]]: vector<1xf32>, %[[B:.+]]: f32)
+//       CHECK:   %[[A:.+]] = vector.extract %[[V]][0] : f32 from vector<1xf32>
+//       CHECK:   %[[S:.+]] = arith.addf %[[A]], %arg1 fastmath<nnan,ninf> : f32
+//       CHECK:   return %[[S]]
+func.func @reduce_one_element_vector_addf_fastmath(%a : vector<1xf32>, %b: f32) -> f32 {
+  %s = vector.reduction <add>, %a, %b fastmath<nnan,ninf> : vector<1xf32> into f32
+  return %s : f32
+}
+
+// -----
+
 // CHECK-LABEL: func @masked_reduce_one_element_vector_addf
 //  CHECK-SAME: %[[VAL_0:.*]]: vector<1xf32>, %[[VAL_1:.*]]: f32,
 //  CHECK-SAME: %[[VAL_2:.*]]: vector<1xi1>)
diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir
index 12ea87ffb1413f..614a97fe4d6777 100644
--- a/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir
@@ -27,13 +27,13 @@ func.func @vector_multi_reduction_min(%arg0: vector<2x4xf32>, %acc: vector<2xf32
 //  CHECK-SAME:   %[[INPUT:.+]]: vector<2x4xf32>, %[[ACC:.*]]: vector<2xf32>
 //       CHECK:   %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xf32> to vector<4x2xf32>
 //       CHECK:   %[[V0:.+]] = vector.extract %[[TRANSPOSED]][0] : vector<2xf32> from vector<4x2xf32>
-//       CHECK:   %[[RV0:.+]] = arith.minimumf %[[V0]], %[[ACC]] : vector<2xf32>
+//       CHECK:   %[[RV0:.+]] = arith.minnumf %[[V0]], %[[ACC]] : vector<2xf32>
 //       CHECK:   %[[V1:.+]] = vector.extract %[[TRANSPOSED]][1] : vector<2xf32> from vector<4x2xf32>
-//       CHECK:   %[[RV01:.+]] = arith.minimumf %[[V1]], %[[RV0]] : vector<2xf32>
+//       CHECK:   %[[RV01:.+]] = arith.minnumf %[[V1]], %[[RV0]] : vector<2xf32>
 //       CHECK:   %[[V2:.+]] = vector.extract %[[TRANSPOSED]][2] : vector<2xf32> from vector<4x2xf32>
-//       CHECK:   %[[RV012:.+]] = arith.minimumf %[[V2]], %[[RV01]] : vector<2xf32>
+//       CHECK:   %[[RV012:.+]] = arith.minnumf %[[V2]], %[[RV01]] : vector<2xf32>
 //       CHECK:   %[[V3:.+]] = vector.extract %[[TRANSPOSED]][3] : vector<2xf32> from vector<4x2xf32>
-//       CHECK:   %[[RESULT_VEC:.+]] = arith.minimumf %[[V3]], %[[RV012]] : vector<2xf32>
+//       CHECK:   %[[RESULT_VEC:.+]] = arith.minnumf %[[V3]], %[[RV012]] : vector<2xf32>
 //       CHECK:   return %[[RESULT_VEC]] : vector<2xf32>
 
 func.func @vector_multi_reduction_max(%arg0: vector<2x4xf32>, %acc: vector<2xf32>) -> vector<2xf32> {
@@ -45,13 +45,13 @@ func.func @vector_multi_reduction_max(%arg0: vector<2x4xf32>, %acc: vector<2xf32
 //  CHECK-SAME:   %[[INPUT:.+]]: vector<2x4xf32>, %[[ACC:.*]]: vector<2xf32>
 //       CHECK:   %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xf32> to vector<4x2xf32>
 //       CHECK:   %[[V0:.+]] = vector.extract %[[TRANSPOSED]][0] : vector<2xf32> from vector<4x2xf32>
-//       CHECK:   %[[RV0:.+]] = arith.maximumf %[[V0]], %[[ACC]] : vector<2xf32>
+//       CHECK:   %[[RV0:.+]] = arith.maxnumf %[[V0]], %[[ACC]] : vector<2xf32>
 //       CHECK:   %[[V1:.+]] = vector.extract %[[TRANSPOSED]][1] : vector<2xf32> from vector<4x2xf32>
-//       CHECK:   %[[RV01:.+]] = arith.maximumf %[[V1]], %[[RV0]] : vector<2xf32>
+//       CHECK:   %[[RV01:.+]] = arith.maxnumf %[[V1]], %[[RV0]] : vector<2xf32>
 //       CHECK:   %[[V2:.+]] = vector.extract %[[TRANSPOSED]][2] : vector<2xf32> from vector<4x2xf32>
-//       CHECK:   %[[RV012:.+]] = arith.maximumf %[[V2]], %[[RV01]] : vector<2xf32>
+//       CHECK:   %[[RV012:.+]] = arith.maxnumf %[[V2]], %[[RV01]] : vector<2xf32>
 //       CHECK:   %[[V3:.+]] = vector.extract %[[TRANSPOSED]][3] : vector<2xf32> from vector<4x2xf32>
-//       CHECK:   %[[RESULT_VEC:.+]] = arith.maximumf %[[V3]], %[[RV012]] : vector<2xf32>
+//       CHECK:   %[[RESULT_VEC:.+]] = arith.maxnumf %[[V3]], %[[RV012]] : vector<2xf32>
 //       CHECK:   return %[[RESULT_VEC]] : vector<2xf32>
 
 func.func @vector_multi_reduction_and(%arg0: vector<2x4xi32>, %acc: vector<2xi32>) -> vector<2xi32> {
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 03ddebe82344d8..126d65b1b8487f 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -439,6 +439,27 @@ struct TestVectorChainedReductionFoldingPatterns
   }
 };
 
+struct TestVectorBreakDownReductionPatterns
+    : public PassWrapper<TestVectorBreakDownReductionPatterns,
+                         OperationPass<func::FuncOp>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
+      TestVectorBreakDownReductionPatterns)
+
+  StringRef getArgument() const final {
+    return "test-vector-break-down-reduction-patterns";
+  }
+  StringRef getDescription() const final {
+    return "Test patterns to break down vector reductions into arith "
+           "reductions";
+  }
+  void runOnOperation() override {
+    RewritePatternSet patterns(&getContext());
+    populateBreakDownVectorReductionPatterns(patterns,
+                                             /*maxNumElementsToExtract=*/2);
+    (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+  }
+};
+
 struct TestFlattenVectorTransferPatterns
     : public PassWrapper<TestFlattenVectorTransferPatterns,
                          OperationPass<func::FuncOp>> {
@@ -827,6 +848,8 @@ void registerTestVectorLowerings() {
 
   PassRegistration<TestVectorChainedReductionFoldingPatterns>();
 
+  PassRegistration<TestVectorBreakDownReductionPatterns>();
+
   PassRegistration<TestFlattenVectorTransferPatterns>();
 
   PassRegistration<TestVectorScanLowering>();



More information about the Mlir-commits mailing list