[Mlir-commits] [mlir] [mlir][vector] Add pattern to break down small reductions into arith ops (PR #75727)
Jakub Kuderski
llvmlistbot at llvm.org
Sat Dec 16 18:21:00 PST 2023
https://github.com/kuhar updated https://github.com/llvm/llvm-project/pull/75727
>From 693ecfbb497e5c6bd0dd6276947e7f078c6bdb94 Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Sat, 16 Dec 2023 21:05:48 -0500
Subject: [PATCH 1/2] [mlir][vector] Add pattern to break down small reductions
into arith ops
The number of vector elements considered 'small' is parameterized.
This is to avoid going into specialized reduction lowering when a
single/couple of arith ops can do.
Also update the `vector::makeArithReduction` helper to distinguish
`minf`/`maxf` and `minimumf`/`maximumf`, and to propagate fast math
flags.
---
.../mlir/Dialect/Vector/IR/VectorOps.h | 6 +-
.../Vector/Transforms/VectorRewritePatterns.h | 19 ++++
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 25 +++--
.../Vector/Transforms/LowerVectorContract.cpp | 3 +-
.../Vector/Transforms/VectorTransforms.cpp | 53 +++++++++
.../Vector/break-down-vector-reduction.mlir | 104 ++++++++++++++++++
mlir/test/Dialect/Vector/canonicalize.mlir | 12 ++
...vector-multi-reduction-outer-lowering.mlir | 16 +--
.../Dialect/Vector/TestVectorTransforms.cpp | 23 ++++
9 files changed, 243 insertions(+), 18 deletions(-)
create mode 100644 mlir/test/Dialect/Vector/break-down-vector-reduction.mlir
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
index 59d585a77b1e29..a28b27e4e15816 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
@@ -123,10 +123,12 @@ bool isDisjointTransferSet(VectorTransferOpInterface transferA,
VectorTransferOpInterface transferB,
bool testDynamicValueUsingBounds = false);
-/// Return the result value of reducing two scalar/vector values with the
+/// Returns the result value of reducing two scalar/vector values with the
/// corresponding arith operation.
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind,
- Value v1, Value acc, Value mask = Value());
+ Value v1, Value acc,
+ arith::FastMathFlagsAttr fastmath = nullptr,
+ Value mask = nullptr);
/// Returns true if `attr` has "parallel" iterator type semantics.
inline bool isParallelIterator(Attribute attr) {
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index 17173c01ab762a..49b74c0c466d2f 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -166,6 +166,25 @@ void populateSinkVectorBroadcastPatterns(RewritePatternSet &patterns,
void populateChainedVectorReductionFoldingPatterns(RewritePatternSet &patterns,
PatternBenefit benefit = 1);
+/// Patterns to break down vector reductions into a series of arith reductions
+/// over vector elements. This is intended to be simplify code with reductions
+/// over small vector types and avoid more specialized reduction lowering when
+/// possible.
+///
+/// Example:
+/// ```
+/// %a = vector.reduction <add> %x : vector<2xf32> into f32
+/// ```
+/// is transformed into:
+/// ```
+/// %y = vector.extract %x[0] : f32 from vector<2xf32>
+/// %z = vector.extract %x[1] : f32 from vector<2xf32>
+/// %a = arith.addf %y, %z : f32
+/// ```
+void populateBreakDownVectorReductionPatterns(
+ RewritePatternSet &patterns, unsigned maxNumElementsToExtract = 2,
+ PatternBenefit benefit = 1);
+
/// Populate `patterns` with the following patterns.
///
/// [DecomposeDifferentRankInsertStridedSlice]
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 540959b486db9c..9f3e13c90a624d 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -507,8 +507,9 @@ struct ElideUnitDimsInMultiDimReduction
zeroIdx);
}
- Value result = vector::makeArithReduction(
- rewriter, loc, reductionOp.getKind(), acc, cast, mask);
+ Value result =
+ vector::makeArithReduction(rewriter, loc, reductionOp.getKind(), acc,
+ cast, /*fastmath=*/nullptr, mask);
rewriter.replaceOp(rootOp, result);
return success();
}
@@ -650,7 +651,8 @@ struct ElideSingleElementReduction : public OpRewritePattern<ReductionOp> {
if (Value acc = reductionOp.getAcc())
result = vector::makeArithReduction(rewriter, loc, reductionOp.getKind(),
- result, acc, mask);
+ result, acc,
+ reductionOp.getFastmathAttr(), mask);
rewriter.replaceOp(rootOp, result);
return success();
@@ -6212,6 +6214,7 @@ bool WarpExecuteOnLane0Op::areTypesCompatible(Type lhs, Type rhs) {
Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc,
CombiningKind kind, Value v1, Value acc,
+ arith::FastMathFlagsAttr fastmath,
Value mask) {
Type t1 = getElementTypeOrSelf(v1.getType());
Type tAcc = getElementTypeOrSelf(acc.getType());
@@ -6222,7 +6225,7 @@ Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc,
if (t1.isIntOrIndex() && tAcc.isIntOrIndex())
result = b.createOrFold<arith::AddIOp>(loc, v1, acc);
else if (llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc))
- result = b.createOrFold<arith::AddFOp>(loc, v1, acc);
+ result = b.createOrFold<arith::AddFOp>(loc, v1, acc, fastmath);
else
llvm_unreachable("invalid value types for ADD reduction");
break;
@@ -6231,16 +6234,24 @@ Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc,
result = b.createOrFold<arith::AndIOp>(loc, v1, acc);
break;
case CombiningKind::MAXF:
+ assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
+ "expected float values");
+ result = b.createOrFold<arith::MaxNumFOp>(loc, v1, acc, fastmath);
+ break;
case CombiningKind::MAXIMUMF:
assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
"expected float values");
- result = b.createOrFold<arith::MaximumFOp>(loc, v1, acc);
+ result = b.createOrFold<arith::MaximumFOp>(loc, v1, acc, fastmath);
break;
case CombiningKind::MINF:
+ assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
+ "expected float values");
+ result = b.createOrFold<arith::MinNumFOp>(loc, v1, acc, fastmath);
+ break;
case CombiningKind::MINIMUMF:
assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
"expected float values");
- result = b.createOrFold<arith::MinimumFOp>(loc, v1, acc);
+ result = b.createOrFold<arith::MinimumFOp>(loc, v1, acc, fastmath);
break;
case CombiningKind::MAXSI:
assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
@@ -6262,7 +6273,7 @@ Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc,
if (t1.isIntOrIndex() && tAcc.isIntOrIndex())
result = b.createOrFold<arith::MulIOp>(loc, v1, acc);
else if (llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc))
- result = b.createOrFold<arith::MulFOp>(loc, v1, acc);
+ result = b.createOrFold<arith::MulFOp>(loc, v1, acc, fastmath);
else
llvm_unreachable("invalid value types for MUL reduction");
break;
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
index 6dbe36e605e9a7..41ff0c18fe6258 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
@@ -167,7 +167,8 @@ createContractArithOp(Location loc, Value x, Value y, Value acc,
if (!acc)
return std::optional<Value>(mul);
- return makeArithReduction(rewriter, loc, kind, mul, acc, mask);
+ return makeArithReduction(rewriter, loc, kind, mul, acc,
+ /*fastmath=*/nullptr, mask);
}
/// Return the positions of the reductions in the given map.
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 5936b0b54af4e3..143360079916a3 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -12,6 +12,7 @@
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
+#include <cassert>
#include <cstdint>
#include <functional>
#include <optional>
@@ -44,6 +45,7 @@
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
+#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/raw_ostream.h"
#define DEBUG_TYPE "vector-to-vector"
@@ -1578,6 +1580,50 @@ struct ReduceRedundantZero final : OpRewritePattern<vector::ReductionOp> {
}
};
+struct BreakDownVectorReduction final : OpRewritePattern<vector::ReductionOp> {
+ BreakDownVectorReduction(MLIRContext *context,
+ unsigned maxNumElementsToExtract,
+ PatternBenefit benefit)
+ : OpRewritePattern(context, benefit),
+ maxNumElementsToExtract(maxNumElementsToExtract) {}
+
+ LogicalResult matchAndRewrite(vector::ReductionOp op,
+ PatternRewriter &rewriter) const override {
+ VectorType type = op.getSourceVectorType();
+ if (type.isScalable() || op.isMasked())
+ return failure();
+ assert(type.getRank() == 1 && "Expected a 1-d vector");
+
+ int64_t numElems = type.getNumElements();
+ if (numElems > maxNumElementsToExtract) {
+ return rewriter.notifyMatchFailure(
+ op, llvm::formatv("has too many vector elements ({0}) to break down "
+ "(max allowed: {1})",
+ numElems, maxNumElementsToExtract));
+ }
+
+ Location loc = op.getLoc();
+ SmallVector<Value> extracted(numElems, nullptr);
+ for (auto [idx, extractedElem] : llvm::enumerate(extracted))
+ extractedElem = rewriter.create<vector::ExtractOp>(
+ loc, op.getVector(), static_cast<int64_t>(idx));
+
+ Value res = extracted.front();
+ for (auto extractedElem : llvm::drop_begin(extracted))
+ res = vector::makeArithReduction(rewriter, loc, op.getKind(), res,
+ extractedElem, op.getFastmathAttr());
+ if (Value acc = op.getAcc())
+ res = vector::makeArithReduction(rewriter, loc, op.getKind(), res, acc,
+ op.getFastmathAttr());
+
+ rewriter.replaceOp(op, res);
+ return success();
+ }
+
+private:
+ unsigned maxNumElementsToExtract = 0;
+};
+
} // namespace
void mlir::vector::populateFoldArithExtensionPatterns(
@@ -1656,6 +1702,13 @@ void mlir::vector::populateChainedVectorReductionFoldingPatterns(
PatternBenefit(benefit.getBenefit() + 1));
}
+void mlir::vector::populateBreakDownVectorReductionPatterns(
+ RewritePatternSet &patterns, unsigned maxNumElementsToExtract,
+ PatternBenefit benefit) {
+ patterns.add<BreakDownVectorReduction>(patterns.getContext(),
+ maxNumElementsToExtract, benefit);
+}
+
//===----------------------------------------------------------------------===//
// TableGen'd enum attribute definitions
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/break-down-vector-reduction.mlir b/mlir/test/Dialect/Vector/break-down-vector-reduction.mlir
new file mode 100644
index 00000000000000..b7bc19594491bd
--- /dev/null
+++ b/mlir/test/Dialect/Vector/break-down-vector-reduction.mlir
@@ -0,0 +1,104 @@
+// RUN: mlir-opt %s --test-vector-break-down-reduction-patterns --cse | FileCheck %s
+
+// NOTE: This test pass is set break down vector reductions of size 2 or fewer.
+
+// CHECK-LABEL: func.func @reduce_2x_fp32(
+// CHECK-SAME: %[[ARG0:.+]]: vector<2xf32>) -> (f32, f32, f32, f32, f32, f32) {
+// CHECK-DAG: %[[E0:.+]] = vector.extract %[[ARG0]][0] : f32 from vector<2xf32>
+// CHECK-DAG: %[[E1:.+]] = vector.extract %[[ARG0]][1] : f32 from vector<2xf32>
+// CHECK-DAG: %[[R0:.+]] = arith.addf %[[E0]], %[[E1]] : f32
+// CHECK-DAG: %[[R1:.+]] = arith.mulf %[[E0]], %[[E1]] : f32
+// CHECK-DAG: %[[R2:.+]] = arith.minnumf %[[E0]], %[[E1]] : f32
+// CHECK-DAG: %[[R3:.+]] = arith.maxnumf %[[E0]], %[[E1]] : f32
+// CHECK-DAG: %[[R4:.+]] = arith.minimumf %[[E0]], %[[E1]] : f32
+// CHECK-DAG: %[[R5:.+]] = arith.maximumf %[[E0]], %[[E1]] : f32
+// CHECK: return %[[R0]], %[[R1]], %[[R2]], %[[R3]], %[[R4]], %[[R5]]
+func.func @reduce_2x_fp32(%arg0: vector<2xf32>) -> (f32, f32, f32, f32, f32, f32) {
+ %0 = vector.reduction <add>, %arg0 : vector<2xf32> into f32
+ %1 = vector.reduction <mul>, %arg0 : vector<2xf32> into f32
+ %2 = vector.reduction <minf>, %arg0 : vector<2xf32> into f32
+ %3 = vector.reduction <maxf>, %arg0 : vector<2xf32> into f32
+ %4 = vector.reduction <minimumf>, %arg0 : vector<2xf32> into f32
+ %5 = vector.reduction <maximumf>, %arg0 : vector<2xf32> into f32
+ return %0, %1, %2, %3, %4, %5 : f32, f32, f32, f32, f32, f32
+}
+
+// CHECK-LABEL: func.func @reduce_2x_i32(
+// CHECK-SAME: %[[ARG0:.+]]: vector<2xi32>) -> (i32, i32, i32, i32, i32, i32, i32, i32, i32) {
+// CHECK-DAG: %[[E0:.+]] = vector.extract %[[ARG0]][0] : i32 from vector<2xi32>
+// CHECK-DAG: %[[E1:.+]] = vector.extract %[[ARG0]][1] : i32 from vector<2xi32>
+// CHECK-DAG: %[[R0:.+]] = arith.addi %[[E0]], %[[E1]] : i32
+// CHECK-DAG: %[[R1:.+]] = arith.muli %[[E0]], %[[E1]] : i32
+// CHECK-DAG: %[[R2:.+]] = arith.minsi %[[E0]], %[[E1]] : i32
+// CHECK-DAG: %[[R3:.+]] = arith.maxsi %[[E0]], %[[E1]] : i32
+// CHECK-DAG: %[[R4:.+]] = arith.minui %[[E0]], %[[E1]] : i32
+// CHECK-DAG: %[[R5:.+]] = arith.maxui %[[E0]], %[[E1]] : i32
+// CHECK-DAG: %[[R6:.+]] = arith.andi %[[E0]], %[[E1]] : i32
+// CHECK-DAG: %[[R7:.+]] = arith.ori %[[E0]], %[[E1]] : i32
+// CHECK-DAG: %[[R8:.+]] = arith.xori %[[E0]], %[[E1]] : i32
+// CHECK: return %[[R0]], %[[R1]], %[[R2]], %[[R3]], %[[R4]], %[[R5]], %[[R6]], %[[R7]], %[[R8]]
+func.func @reduce_2x_i32(%arg0: vector<2xi32>) -> (i32, i32, i32, i32, i32, i32, i32, i32, i32) {
+ %0 = vector.reduction <add>, %arg0 : vector<2xi32> into i32
+ %1 = vector.reduction <mul>, %arg0 : vector<2xi32> into i32
+ %2 = vector.reduction <minsi>, %arg0 : vector<2xi32> into i32
+ %3 = vector.reduction <maxsi>, %arg0 : vector<2xi32> into i32
+ %4 = vector.reduction <minui>, %arg0 : vector<2xi32> into i32
+ %5 = vector.reduction <maxui>, %arg0 : vector<2xi32> into i32
+ %6 = vector.reduction <and>, %arg0 : vector<2xi32> into i32
+ %7 = vector.reduction <or>, %arg0 : vector<2xi32> into i32
+ %8 = vector.reduction <xor>, %arg0 : vector<2xi32> into i32
+ return %0, %1, %2, %3, %4, %5, %6, %7, %8 : i32, i32, i32, i32, i32, i32, i32, i32, i32
+}
+
+// CHECK-LABEL: func.func @reduce_1x_fp32(
+// CHECK-SAME: %[[ARG0:.+]]: vector<1xf32>) -> f32 {
+// CHECK-NEXT: %[[E0:.+]] = vector.extract %[[ARG0]][0] : f32 from vector<1xf32>
+// CHECK-NEXT: return %[[E0]] : f32
+func.func @reduce_1x_fp32(%arg0: vector<1xf32>) -> f32 {
+ %0 = vector.reduction <add>, %arg0 : vector<1xf32> into f32
+ return %0 : f32
+}
+
+// CHECK-LABEL: func.func @reduce_1x_acc_fp32(
+// CHECK-SAME: %[[ARG0:.+]]: vector<1xf32>, %[[ARG1:.+]]: f32) -> f32 {
+// CHECK-NEXT: %[[E0:.+]] = vector.extract %[[ARG0]][0] : f32 from vector<1xf32>
+// CHECK-NEXT: %[[R0:.+]] = arith.addf %[[E0]], %[[ARG1]] : f32
+// CHECK-NEXT: return %[[R0]] : f32
+func.func @reduce_1x_acc_fp32(%arg0: vector<1xf32>, %arg1: f32) -> f32 {
+ %0 = vector.reduction <add>, %arg0, %arg1 : vector<1xf32> into f32
+ return %0 : f32
+}
+
+// CHECK-LABEL: func.func @reduce_1x_acc_i32(
+// CHECK-SAME: %[[ARG0:.+]]: vector<1xi32>, %[[ARG1:.+]]: i32) -> i32 {
+// CHECK-NEXT: %[[E0:.+]] = vector.extract %[[ARG0]][0] : i32 from vector<1xi32>
+// CHECK-NEXT: %[[R0:.+]] = arith.addi %[[E0]], %[[ARG1]] : i32
+// CHECK-NEXT: return %[[R0]] : i32
+func.func @reduce_1x_acc_i32(%arg0: vector<1xi32>, %arg1: i32) -> i32 {
+ %0 = vector.reduction <add>, %arg0, %arg1 : vector<1xi32> into i32
+ return %0 : i32
+}
+
+// CHECK-LABEL: func.func @reduce_2x_acc_fp32(
+// CHECK-SAME: %[[ARG0:.+]]: vector<2xf32>, %[[ARG1:.+]]: f32) -> (f32, f32) {
+// CHECK-DAG: %[[E0:.+]] = vector.extract %[[ARG0]][0] : f32 from vector<2xf32>
+// CHECK-DAG: %[[E1:.+]] = vector.extract %[[ARG0]][1] : f32 from vector<2xf32>
+// CHECK: %[[A0:.+]] = arith.addf %[[E0]], %[[E1]] : f32
+// CHECK: %[[R0:.+]] = arith.addf %[[A0]], %[[ARG1]] : f32
+// CHECK: %[[M0:.+]] = arith.mulf %[[E0]], %[[E1]] fastmath<nnan> : f32
+// CHECK: %[[R1:.+]] = arith.mulf %[[M0]], %[[ARG1]] fastmath<nnan> : f32
+// CHECK-NEXT: return %[[R0]], %[[R1]] : f32, f32
+func.func @reduce_2x_acc_fp32(%arg0: vector<2xf32>, %arg1: f32) -> (f32, f32) {
+ %0 = vector.reduction <add>, %arg0, %arg1 : vector<2xf32> into f32
+ %1 = vector.reduction <mul>, %arg0, %arg1 fastmath<nnan> : vector<2xf32> into f32
+ return %0, %1 : f32, f32
+}
+
+// CHECK-LABEL: func.func @reduce_3x_fp32(
+// CHECK-SAME: %[[ARG0:.+]]: vector<3xf32>) -> f32 {
+// CHECK-NEXT: %[[R0:.+]] = vector.reduction <add>, %[[ARG0]] : vector<3xf32> into f32
+// CHECK-NEXT: return %[[R0]] : f32
+func.func @reduce_3x_fp32(%arg0: vector<3xf32>) -> f32 {
+ %0 = vector.reduction <add>, %arg0 : vector<3xf32> into f32
+ return %0 : f32
+}
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 1021c73cc57d34..b5164b66817352 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -2172,6 +2172,18 @@ func.func @reduce_one_element_vector_addf(%a : vector<1xf32>, %b: f32) -> f32 {
// -----
+// CHECK-LABEL: func @reduce_one_element_vector_addf_fastmath
+// CHECK-SAME: (%[[V:.+]]: vector<1xf32>, %[[B:.+]]: f32)
+// CHECK: %[[A:.+]] = vector.extract %[[V]][0] : f32 from vector<1xf32>
+// CHECK: %[[S:.+]] = arith.addf %[[A]], %arg1 fastmath<nnan,ninf> : f32
+// CHECK: return %[[S]]
+func.func @reduce_one_element_vector_addf_fastmath(%a : vector<1xf32>, %b: f32) -> f32 {
+ %s = vector.reduction <add>, %a, %b fastmath<nnan,ninf> : vector<1xf32> into f32
+ return %s : f32
+}
+
+// -----
+
// CHECK-LABEL: func @masked_reduce_one_element_vector_addf
// CHECK-SAME: %[[VAL_0:.*]]: vector<1xf32>, %[[VAL_1:.*]]: f32,
// CHECK-SAME: %[[VAL_2:.*]]: vector<1xi1>)
diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir
index 12ea87ffb1413f..614a97fe4d6777 100644
--- a/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir
@@ -27,13 +27,13 @@ func.func @vector_multi_reduction_min(%arg0: vector<2x4xf32>, %acc: vector<2xf32
// CHECK-SAME: %[[INPUT:.+]]: vector<2x4xf32>, %[[ACC:.*]]: vector<2xf32>
// CHECK: %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xf32> to vector<4x2xf32>
// CHECK: %[[V0:.+]] = vector.extract %[[TRANSPOSED]][0] : vector<2xf32> from vector<4x2xf32>
-// CHECK: %[[RV0:.+]] = arith.minimumf %[[V0]], %[[ACC]] : vector<2xf32>
+// CHECK: %[[RV0:.+]] = arith.minnumf %[[V0]], %[[ACC]] : vector<2xf32>
// CHECK: %[[V1:.+]] = vector.extract %[[TRANSPOSED]][1] : vector<2xf32> from vector<4x2xf32>
-// CHECK: %[[RV01:.+]] = arith.minimumf %[[V1]], %[[RV0]] : vector<2xf32>
+// CHECK: %[[RV01:.+]] = arith.minnumf %[[V1]], %[[RV0]] : vector<2xf32>
// CHECK: %[[V2:.+]] = vector.extract %[[TRANSPOSED]][2] : vector<2xf32> from vector<4x2xf32>
-// CHECK: %[[RV012:.+]] = arith.minimumf %[[V2]], %[[RV01]] : vector<2xf32>
+// CHECK: %[[RV012:.+]] = arith.minnumf %[[V2]], %[[RV01]] : vector<2xf32>
// CHECK: %[[V3:.+]] = vector.extract %[[TRANSPOSED]][3] : vector<2xf32> from vector<4x2xf32>
-// CHECK: %[[RESULT_VEC:.+]] = arith.minimumf %[[V3]], %[[RV012]] : vector<2xf32>
+// CHECK: %[[RESULT_VEC:.+]] = arith.minnumf %[[V3]], %[[RV012]] : vector<2xf32>
// CHECK: return %[[RESULT_VEC]] : vector<2xf32>
func.func @vector_multi_reduction_max(%arg0: vector<2x4xf32>, %acc: vector<2xf32>) -> vector<2xf32> {
@@ -45,13 +45,13 @@ func.func @vector_multi_reduction_max(%arg0: vector<2x4xf32>, %acc: vector<2xf32
// CHECK-SAME: %[[INPUT:.+]]: vector<2x4xf32>, %[[ACC:.*]]: vector<2xf32>
// CHECK: %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xf32> to vector<4x2xf32>
// CHECK: %[[V0:.+]] = vector.extract %[[TRANSPOSED]][0] : vector<2xf32> from vector<4x2xf32>
-// CHECK: %[[RV0:.+]] = arith.maximumf %[[V0]], %[[ACC]] : vector<2xf32>
+// CHECK: %[[RV0:.+]] = arith.maxnumf %[[V0]], %[[ACC]] : vector<2xf32>
// CHECK: %[[V1:.+]] = vector.extract %[[TRANSPOSED]][1] : vector<2xf32> from vector<4x2xf32>
-// CHECK: %[[RV01:.+]] = arith.maximumf %[[V1]], %[[RV0]] : vector<2xf32>
+// CHECK: %[[RV01:.+]] = arith.maxnumf %[[V1]], %[[RV0]] : vector<2xf32>
// CHECK: %[[V2:.+]] = vector.extract %[[TRANSPOSED]][2] : vector<2xf32> from vector<4x2xf32>
-// CHECK: %[[RV012:.+]] = arith.maximumf %[[V2]], %[[RV01]] : vector<2xf32>
+// CHECK: %[[RV012:.+]] = arith.maxnumf %[[V2]], %[[RV01]] : vector<2xf32>
// CHECK: %[[V3:.+]] = vector.extract %[[TRANSPOSED]][3] : vector<2xf32> from vector<4x2xf32>
-// CHECK: %[[RESULT_VEC:.+]] = arith.maximumf %[[V3]], %[[RV012]] : vector<2xf32>
+// CHECK: %[[RESULT_VEC:.+]] = arith.maxnumf %[[V3]], %[[RV012]] : vector<2xf32>
// CHECK: return %[[RESULT_VEC]] : vector<2xf32>
func.func @vector_multi_reduction_and(%arg0: vector<2x4xi32>, %acc: vector<2xi32>) -> vector<2xi32> {
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 03ddebe82344d8..126d65b1b8487f 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -439,6 +439,27 @@ struct TestVectorChainedReductionFoldingPatterns
}
};
+struct TestVectorBreakDownReductionPatterns
+ : public PassWrapper<TestVectorBreakDownReductionPatterns,
+ OperationPass<func::FuncOp>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
+ TestVectorBreakDownReductionPatterns)
+
+ StringRef getArgument() const final {
+ return "test-vector-break-down-reduction-patterns";
+ }
+ StringRef getDescription() const final {
+ return "Test patterns to break down vector reductions into arith "
+ "reductions";
+ }
+ void runOnOperation() override {
+ RewritePatternSet patterns(&getContext());
+ populateBreakDownVectorReductionPatterns(patterns,
+ /*maxNumElementsToExtract=*/2);
+ (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+ }
+};
+
struct TestFlattenVectorTransferPatterns
: public PassWrapper<TestFlattenVectorTransferPatterns,
OperationPass<func::FuncOp>> {
@@ -827,6 +848,8 @@ void registerTestVectorLowerings() {
PassRegistration<TestVectorChainedReductionFoldingPatterns>();
+ PassRegistration<TestVectorBreakDownReductionPatterns>();
+
PassRegistration<TestFlattenVectorTransferPatterns>();
PassRegistration<TestVectorScanLowering>();
>From 2ca70b4fabb41f1352781de8fad56081cd4eee60 Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Sat, 16 Dec 2023 21:20:49 -0500
Subject: [PATCH 2/2] Update test
---
.../VectorToLLVM/vector-to-llvm.mlir | 4 ++--
.../Vector/break-down-vector-reduction.mlir | 20 +++++++++----------
2 files changed, 12 insertions(+), 12 deletions(-)
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 012d30d96799f2..7353d16d79cea0 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -450,7 +450,7 @@ func.func @masked_float_max_outerprod(%arg0: vector<2xf32>, %arg1: f32, %arg2: v
// CHECK-LABEL: func.func @masked_float_max_outerprod(
// CHECK-SAME: %[[VAL_0:.*]]: vector<2xf32>, %[[VAL_1:.*]]: f32, %[[VAL_2:.*]]: vector<2xf32>, %[[VAL_3:.*]]: vector<2xi1>) -> vector<2xf32> {
// CHECK: %[[VAL_8:.*]] = arith.mulf %[[VAL_0]], %{{.*}} : vector<2xf32>
-// CHECK: %[[VAL_9:.*]] = arith.maximumf %[[VAL_8]], %[[VAL_2]] : vector<2xf32>
+// CHECK: %[[VAL_9:.*]] = arith.maxnumf %[[VAL_8]], %[[VAL_2]] : vector<2xf32>
// CHECK: %[[VAL_10:.*]] = arith.select %[[VAL_3]], %[[VAL_9]], %[[VAL_2]] : vector<2xi1>, vector<2xf32>
// -----
@@ -463,7 +463,7 @@ func.func @masked_float_min_outerprod(%arg0: vector<2xf32>, %arg1: f32, %arg2: v
// CHECK-LABEL: func.func @masked_float_min_outerprod(
// CHECK-SAME: %[[VAL_0:.*]]: vector<2xf32>, %[[VAL_1:.*]]: f32, %[[VAL_2:.*]]: vector<2xf32>, %[[VAL_3:.*]]: vector<2xi1>) -> vector<2xf32> {
// CHECK: %[[VAL_8:.*]] = arith.mulf %[[VAL_0]], %{{.*}} : vector<2xf32>
-// CHECK: %[[VAL_9:.*]] = arith.minimumf %[[VAL_8]], %[[VAL_2]] : vector<2xf32>
+// CHECK: %[[VAL_9:.*]] = arith.minnumf %[[VAL_8]], %[[VAL_2]] : vector<2xf32>
// CHECK: %[[VAL_10:.*]] = arith.select %[[VAL_3]], %[[VAL_9]], %[[VAL_2]] : vector<2xi1>, vector<2xf32>
// -----
diff --git a/mlir/test/Dialect/Vector/break-down-vector-reduction.mlir b/mlir/test/Dialect/Vector/break-down-vector-reduction.mlir
index b7bc19594491bd..f45326c085a34f 100644
--- a/mlir/test/Dialect/Vector/break-down-vector-reduction.mlir
+++ b/mlir/test/Dialect/Vector/break-down-vector-reduction.mlir
@@ -2,7 +2,7 @@
// NOTE: This test pass is set break down vector reductions of size 2 or fewer.
-// CHECK-LABEL: func.func @reduce_2x_fp32(
+// CHECK-LABEL: func.func @reduce_2x_f32(
// CHECK-SAME: %[[ARG0:.+]]: vector<2xf32>) -> (f32, f32, f32, f32, f32, f32) {
// CHECK-DAG: %[[E0:.+]] = vector.extract %[[ARG0]][0] : f32 from vector<2xf32>
// CHECK-DAG: %[[E1:.+]] = vector.extract %[[ARG0]][1] : f32 from vector<2xf32>
@@ -13,7 +13,7 @@
// CHECK-DAG: %[[R4:.+]] = arith.minimumf %[[E0]], %[[E1]] : f32
// CHECK-DAG: %[[R5:.+]] = arith.maximumf %[[E0]], %[[E1]] : f32
// CHECK: return %[[R0]], %[[R1]], %[[R2]], %[[R3]], %[[R4]], %[[R5]]
-func.func @reduce_2x_fp32(%arg0: vector<2xf32>) -> (f32, f32, f32, f32, f32, f32) {
+func.func @reduce_2x_f32(%arg0: vector<2xf32>) -> (f32, f32, f32, f32, f32, f32) {
%0 = vector.reduction <add>, %arg0 : vector<2xf32> into f32
%1 = vector.reduction <mul>, %arg0 : vector<2xf32> into f32
%2 = vector.reduction <minf>, %arg0 : vector<2xf32> into f32
@@ -50,21 +50,21 @@ func.func @reduce_2x_i32(%arg0: vector<2xi32>) -> (i32, i32, i32, i32, i32, i32,
return %0, %1, %2, %3, %4, %5, %6, %7, %8 : i32, i32, i32, i32, i32, i32, i32, i32, i32
}
-// CHECK-LABEL: func.func @reduce_1x_fp32(
+// CHECK-LABEL: func.func @reduce_1x_f32(
// CHECK-SAME: %[[ARG0:.+]]: vector<1xf32>) -> f32 {
// CHECK-NEXT: %[[E0:.+]] = vector.extract %[[ARG0]][0] : f32 from vector<1xf32>
// CHECK-NEXT: return %[[E0]] : f32
-func.func @reduce_1x_fp32(%arg0: vector<1xf32>) -> f32 {
+func.func @reduce_1x_f32(%arg0: vector<1xf32>) -> f32 {
%0 = vector.reduction <add>, %arg0 : vector<1xf32> into f32
return %0 : f32
}
-// CHECK-LABEL: func.func @reduce_1x_acc_fp32(
+// CHECK-LABEL: func.func @reduce_1x_acc_f32(
// CHECK-SAME: %[[ARG0:.+]]: vector<1xf32>, %[[ARG1:.+]]: f32) -> f32 {
// CHECK-NEXT: %[[E0:.+]] = vector.extract %[[ARG0]][0] : f32 from vector<1xf32>
// CHECK-NEXT: %[[R0:.+]] = arith.addf %[[E0]], %[[ARG1]] : f32
// CHECK-NEXT: return %[[R0]] : f32
-func.func @reduce_1x_acc_fp32(%arg0: vector<1xf32>, %arg1: f32) -> f32 {
+func.func @reduce_1x_acc_f32(%arg0: vector<1xf32>, %arg1: f32) -> f32 {
%0 = vector.reduction <add>, %arg0, %arg1 : vector<1xf32> into f32
return %0 : f32
}
@@ -79,7 +79,7 @@ func.func @reduce_1x_acc_i32(%arg0: vector<1xi32>, %arg1: i32) -> i32 {
return %0 : i32
}
-// CHECK-LABEL: func.func @reduce_2x_acc_fp32(
+// CHECK-LABEL: func.func @reduce_2x_acc_f32(
// CHECK-SAME: %[[ARG0:.+]]: vector<2xf32>, %[[ARG1:.+]]: f32) -> (f32, f32) {
// CHECK-DAG: %[[E0:.+]] = vector.extract %[[ARG0]][0] : f32 from vector<2xf32>
// CHECK-DAG: %[[E1:.+]] = vector.extract %[[ARG0]][1] : f32 from vector<2xf32>
@@ -88,17 +88,17 @@ func.func @reduce_1x_acc_i32(%arg0: vector<1xi32>, %arg1: i32) -> i32 {
// CHECK: %[[M0:.+]] = arith.mulf %[[E0]], %[[E1]] fastmath<nnan> : f32
// CHECK: %[[R1:.+]] = arith.mulf %[[M0]], %[[ARG1]] fastmath<nnan> : f32
// CHECK-NEXT: return %[[R0]], %[[R1]] : f32, f32
-func.func @reduce_2x_acc_fp32(%arg0: vector<2xf32>, %arg1: f32) -> (f32, f32) {
+func.func @reduce_2x_acc_f32(%arg0: vector<2xf32>, %arg1: f32) -> (f32, f32) {
%0 = vector.reduction <add>, %arg0, %arg1 : vector<2xf32> into f32
%1 = vector.reduction <mul>, %arg0, %arg1 fastmath<nnan> : vector<2xf32> into f32
return %0, %1 : f32, f32
}
-// CHECK-LABEL: func.func @reduce_3x_fp32(
+// CHECK-LABEL: func.func @reduce_3x_f32(
// CHECK-SAME: %[[ARG0:.+]]: vector<3xf32>) -> f32 {
// CHECK-NEXT: %[[R0:.+]] = vector.reduction <add>, %[[ARG0]] : vector<3xf32> into f32
// CHECK-NEXT: return %[[R0]] : f32
-func.func @reduce_3x_fp32(%arg0: vector<3xf32>) -> f32 {
+func.func @reduce_3x_f32(%arg0: vector<3xf32>) -> f32 {
%0 = vector.reduction <add>, %arg0 : vector<3xf32> into f32
return %0 : f32
}
More information about the Mlir-commits
mailing list