[Mlir-commits] [mlir] 0767711 - [mlir][vector] Add pattern to break down reductions into arith ops (#75727)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Dec 18 14:54:58 PST 2023
Author: Jakub Kuderski
Date: 2023-12-18T17:54:54-05:00
New Revision: 07677113ffeb3744df350ef7c4ece1a93f7a5e1f
URL: https://github.com/llvm/llvm-project/commit/07677113ffeb3744df350ef7c4ece1a93f7a5e1f
DIFF: https://github.com/llvm/llvm-project/commit/07677113ffeb3744df350ef7c4ece1a93f7a5e1f.diff
LOG: [mlir][vector] Add pattern to break down reductions into arith ops (#75727)
The number of vector elements considered 'small' enough to extract is
parameterized.
This is to avoid going into specialized reduction lowering when a
single/couple of arith ops can do. Targets without dedicated reduction
intrinsics can use that as an emulation path too.
Depends on https://github.com/llvm/llvm-project/pull/75846.
Added:
mlir/test/Dialect/Vector/break-down-vector-reduction.mlir
Modified:
mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index 17173c01ab762a..49b74c0c466d2f 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -166,6 +166,25 @@ void populateSinkVectorBroadcastPatterns(RewritePatternSet &patterns,
void populateChainedVectorReductionFoldingPatterns(RewritePatternSet &patterns,
PatternBenefit benefit = 1);
+/// Patterns to break down vector reductions into a series of arith reductions
+/// over vector elements. This is intended to be simplify code with reductions
+/// over small vector types and avoid more specialized reduction lowering when
+/// possible.
+///
+/// Example:
+/// ```
+/// %a = vector.reduction <add> %x : vector<2xf32> into f32
+/// ```
+/// is transformed into:
+/// ```
+/// %y = vector.extract %x[0] : f32 from vector<2xf32>
+/// %z = vector.extract %x[1] : f32 from vector<2xf32>
+/// %a = arith.addf %y, %z : f32
+/// ```
+void populateBreakDownVectorReductionPatterns(
+ RewritePatternSet &patterns, unsigned maxNumElementsToExtract = 2,
+ PatternBenefit benefit = 1);
+
/// Populate `patterns` with the following patterns.
///
/// [DecomposeDifferentRankInsertStridedSlice]
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 5936b0b54af4e3..661674dd74c0cd 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -12,6 +12,7 @@
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
+#include <cassert>
#include <cstdint>
#include <functional>
#include <optional>
@@ -44,6 +45,7 @@
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
+#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/raw_ostream.h"
#define DEBUG_TYPE "vector-to-vector"
@@ -1578,6 +1580,60 @@ struct ReduceRedundantZero final : OpRewritePattern<vector::ReductionOp> {
}
};
+/// Example:
+/// ```
+/// %a = vector.reduction <add> %x : vector<2xf32> into f32
+/// ```
+/// is transformed into:
+/// ```
+/// %y = vector.extract %x[0] : f32 from vector<2xf32>
+/// %z = vector.extract %x[1] : f32 from vector<2xf32>
+/// %a = arith.addf %y, %z : f32
+/// ```
+struct BreakDownVectorReduction final : OpRewritePattern<vector::ReductionOp> {
+ BreakDownVectorReduction(MLIRContext *context,
+ unsigned maxNumElementsToExtract,
+ PatternBenefit benefit)
+ : OpRewritePattern(context, benefit),
+ maxNumElementsToExtract(maxNumElementsToExtract) {}
+
+ LogicalResult matchAndRewrite(vector::ReductionOp op,
+ PatternRewriter &rewriter) const override {
+ VectorType type = op.getSourceVectorType();
+ if (type.isScalable() || op.isMasked())
+ return failure();
+ assert(type.getRank() == 1 && "Expected a 1-d vector");
+
+ int64_t numElems = type.getNumElements();
+ if (numElems > maxNumElementsToExtract) {
+ return rewriter.notifyMatchFailure(
+ op, llvm::formatv("has too many vector elements ({0}) to break down "
+ "(max allowed: {1})",
+ numElems, maxNumElementsToExtract));
+ }
+
+ Location loc = op.getLoc();
+ SmallVector<Value> extracted(numElems, nullptr);
+ for (auto [idx, extractedElem] : llvm::enumerate(extracted))
+ extractedElem = rewriter.create<vector::ExtractOp>(
+ loc, op.getVector(), static_cast<int64_t>(idx));
+
+ Value res = extracted.front();
+ for (auto extractedElem : llvm::drop_begin(extracted))
+ res = vector::makeArithReduction(rewriter, loc, op.getKind(), res,
+ extractedElem, op.getFastmathAttr());
+ if (Value acc = op.getAcc())
+ res = vector::makeArithReduction(rewriter, loc, op.getKind(), res, acc,
+ op.getFastmathAttr());
+
+ rewriter.replaceOp(op, res);
+ return success();
+ }
+
+private:
+ unsigned maxNumElementsToExtract = 0;
+};
+
} // namespace
void mlir::vector::populateFoldArithExtensionPatterns(
@@ -1656,6 +1712,13 @@ void mlir::vector::populateChainedVectorReductionFoldingPatterns(
PatternBenefit(benefit.getBenefit() + 1));
}
+void mlir::vector::populateBreakDownVectorReductionPatterns(
+ RewritePatternSet &patterns, unsigned maxNumElementsToExtract,
+ PatternBenefit benefit) {
+ patterns.add<BreakDownVectorReduction>(patterns.getContext(),
+ maxNumElementsToExtract, benefit);
+}
+
//===----------------------------------------------------------------------===//
// TableGen'd enum attribute definitions
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/break-down-vector-reduction.mlir b/mlir/test/Dialect/Vector/break-down-vector-reduction.mlir
new file mode 100644
index 00000000000000..34234591b79cab
--- /dev/null
+++ b/mlir/test/Dialect/Vector/break-down-vector-reduction.mlir
@@ -0,0 +1,126 @@
+// RUN: mlir-opt %s --test-vector-break-down-reduction-patterns --cse | FileCheck %s
+
+// NOTE: This test pass is set break down vector reductions of size 2 or fewer.
+
+// CHECK-LABEL: func.func @reduce_2x_f32(
+// CHECK-SAME: %[[ARG0:.+]]: vector<2xf32>) -> (f32, f32, f32, f32, f32, f32) {
+// CHECK-DAG: %[[E0:.+]] = vector.extract %[[ARG0]][0] : f32 from vector<2xf32>
+// CHECK-DAG: %[[E1:.+]] = vector.extract %[[ARG0]][1] : f32 from vector<2xf32>
+// CHECK-DAG: %[[R0:.+]] = arith.addf %[[E0]], %[[E1]] : f32
+// CHECK-DAG: %[[R1:.+]] = arith.mulf %[[E0]], %[[E1]] : f32
+// CHECK-DAG: %[[R2:.+]] = arith.minnumf %[[E0]], %[[E1]] : f32
+// CHECK-DAG: %[[R3:.+]] = arith.maxnumf %[[E0]], %[[E1]] : f32
+// CHECK-DAG: %[[R4:.+]] = arith.minimumf %[[E0]], %[[E1]] : f32
+// CHECK-DAG: %[[R5:.+]] = arith.maximumf %[[E0]], %[[E1]] : f32
+// CHECK: return %[[R0]], %[[R1]], %[[R2]], %[[R3]], %[[R4]], %[[R5]]
+func.func @reduce_2x_f32(%arg0: vector<2xf32>) -> (f32, f32, f32, f32, f32, f32) {
+ %0 = vector.reduction <add>, %arg0 : vector<2xf32> into f32
+ %1 = vector.reduction <mul>, %arg0 : vector<2xf32> into f32
+ %2 = vector.reduction <minf>, %arg0 : vector<2xf32> into f32
+ %3 = vector.reduction <maxf>, %arg0 : vector<2xf32> into f32
+ %4 = vector.reduction <minimumf>, %arg0 : vector<2xf32> into f32
+ %5 = vector.reduction <maximumf>, %arg0 : vector<2xf32> into f32
+ return %0, %1, %2, %3, %4, %5 : f32, f32, f32, f32, f32, f32
+}
+
+// CHECK-LABEL: func.func @reduce_2x_i32(
+// CHECK-SAME: %[[ARG0:.+]]: vector<2xi32>) -> (i32, i32, i32, i32, i32, i32, i32, i32, i32) {
+// CHECK-DAG: %[[E0:.+]] = vector.extract %[[ARG0]][0] : i32 from vector<2xi32>
+// CHECK-DAG: %[[E1:.+]] = vector.extract %[[ARG0]][1] : i32 from vector<2xi32>
+// CHECK-DAG: %[[R0:.+]] = arith.addi %[[E0]], %[[E1]] : i32
+// CHECK-DAG: %[[R1:.+]] = arith.muli %[[E0]], %[[E1]] : i32
+// CHECK-DAG: %[[R2:.+]] = arith.minsi %[[E0]], %[[E1]] : i32
+// CHECK-DAG: %[[R3:.+]] = arith.maxsi %[[E0]], %[[E1]] : i32
+// CHECK-DAG: %[[R4:.+]] = arith.minui %[[E0]], %[[E1]] : i32
+// CHECK-DAG: %[[R5:.+]] = arith.maxui %[[E0]], %[[E1]] : i32
+// CHECK-DAG: %[[R6:.+]] = arith.andi %[[E0]], %[[E1]] : i32
+// CHECK-DAG: %[[R7:.+]] = arith.ori %[[E0]], %[[E1]] : i32
+// CHECK-DAG: %[[R8:.+]] = arith.xori %[[E0]], %[[E1]] : i32
+// CHECK: return %[[R0]], %[[R1]], %[[R2]], %[[R3]], %[[R4]], %[[R5]], %[[R6]], %[[R7]], %[[R8]]
+func.func @reduce_2x_i32(%arg0: vector<2xi32>) -> (i32, i32, i32, i32, i32, i32, i32, i32, i32) {
+ %0 = vector.reduction <add>, %arg0 : vector<2xi32> into i32
+ %1 = vector.reduction <mul>, %arg0 : vector<2xi32> into i32
+ %2 = vector.reduction <minsi>, %arg0 : vector<2xi32> into i32
+ %3 = vector.reduction <maxsi>, %arg0 : vector<2xi32> into i32
+ %4 = vector.reduction <minui>, %arg0 : vector<2xi32> into i32
+ %5 = vector.reduction <maxui>, %arg0 : vector<2xi32> into i32
+ %6 = vector.reduction <and>, %arg0 : vector<2xi32> into i32
+ %7 = vector.reduction <or>, %arg0 : vector<2xi32> into i32
+ %8 = vector.reduction <xor>, %arg0 : vector<2xi32> into i32
+ return %0, %1, %2, %3, %4, %5, %6, %7, %8 : i32, i32, i32, i32, i32, i32, i32, i32, i32
+}
+
+// CHECK-LABEL: func.func @reduce_1x_f32(
+// CHECK-SAME: %[[ARG0:.+]]: vector<1xf32>) -> f32 {
+// CHECK-NEXT: %[[E0:.+]] = vector.extract %[[ARG0]][0] : f32 from vector<1xf32>
+// CHECK-NEXT: return %[[E0]] : f32
+func.func @reduce_1x_f32(%arg0: vector<1xf32>) -> f32 {
+ %0 = vector.reduction <add>, %arg0 : vector<1xf32> into f32
+ return %0 : f32
+}
+
+// CHECK-LABEL: func.func @reduce_1x_acc_f32(
+// CHECK-SAME: %[[ARG0:.+]]: vector<1xf32>, %[[ARG1:.+]]: f32) -> f32 {
+// CHECK-NEXT: %[[E0:.+]] = vector.extract %[[ARG0]][0] : f32 from vector<1xf32>
+// CHECK-NEXT: %[[R0:.+]] = arith.addf %[[E0]], %[[ARG1]] : f32
+// CHECK-NEXT: return %[[R0]] : f32
+func.func @reduce_1x_acc_f32(%arg0: vector<1xf32>, %arg1: f32) -> f32 {
+ %0 = vector.reduction <add>, %arg0, %arg1 : vector<1xf32> into f32
+ return %0 : f32
+}
+
+// CHECK-LABEL: func.func @reduce_1x_acc_i32(
+// CHECK-SAME: %[[ARG0:.+]]: vector<1xi32>, %[[ARG1:.+]]: i32) -> i32 {
+// CHECK-NEXT: %[[E0:.+]] = vector.extract %[[ARG0]][0] : i32 from vector<1xi32>
+// CHECK-NEXT: %[[R0:.+]] = arith.addi %[[E0]], %[[ARG1]] : i32
+// CHECK-NEXT: return %[[R0]] : i32
+func.func @reduce_1x_acc_i32(%arg0: vector<1xi32>, %arg1: i32) -> i32 {
+ %0 = vector.reduction <add>, %arg0, %arg1 : vector<1xi32> into i32
+ return %0 : i32
+}
+
+// CHECK-LABEL: func.func @reduce_2x_acc_f32(
+// CHECK-SAME: %[[ARG0:.+]]: vector<2xf32>, %[[ARG1:.+]]: f32) -> (f32, f32) {
+// CHECK-DAG: %[[E0:.+]] = vector.extract %[[ARG0]][0] : f32 from vector<2xf32>
+// CHECK-DAG: %[[E1:.+]] = vector.extract %[[ARG0]][1] : f32 from vector<2xf32>
+// CHECK: %[[A0:.+]] = arith.addf %[[E0]], %[[E1]] : f32
+// CHECK: %[[R0:.+]] = arith.addf %[[A0]], %[[ARG1]] : f32
+// CHECK: %[[M0:.+]] = arith.mulf %[[E0]], %[[E1]] fastmath<nnan> : f32
+// CHECK: %[[R1:.+]] = arith.mulf %[[M0]], %[[ARG1]] fastmath<nnan> : f32
+// CHECK-NEXT: return %[[R0]], %[[R1]] : f32, f32
+func.func @reduce_2x_acc_f32(%arg0: vector<2xf32>, %arg1: f32) -> (f32, f32) {
+ %0 = vector.reduction <add>, %arg0, %arg1 : vector<2xf32> into f32
+ %1 = vector.reduction <mul>, %arg0, %arg1 fastmath<nnan> : vector<2xf32> into f32
+ return %0, %1 : f32, f32
+}
+
+// CHECK-LABEL: func.func @reduce_3x_f32(
+// CHECK-SAME: %[[ARG0:.+]]: vector<3xf32>) -> f32 {
+// CHECK-NEXT: %[[R0:.+]] = vector.reduction <add>, %[[ARG0]] : vector<3xf32> into f32
+// CHECK-NEXT: return %[[R0]] : f32
+func.func @reduce_3x_f32(%arg0: vector<3xf32>) -> f32 {
+ %0 = vector.reduction <add>, %arg0 : vector<3xf32> into f32
+ return %0 : f32
+}
+
+// Masking is not handled yet.
+// CHECK-LABEL: func.func @reduce_mask_3x_f32
+// CHECK-NEXT: %[[M:.+]] = vector.create_mask
+// CHECK-NEXT: %[[R:.+]] = vector.mask %[[M]]
+// CHECK-SAME: vector.reduction <add>
+// CHECK-NEXT: return %[[R]] : f32
+func.func @reduce_mask_3x_f32(%arg0: vector<3xf32>, %arg1: index) -> f32 {
+ %mask = vector.create_mask %arg1 : vector<3xi1>
+ %0 = vector.mask %mask { vector.reduction <add>, %arg0 : vector<3xf32> into f32 } : vector<3xi1> -> f32
+ return %0 : f32
+}
+
+// Scalable vectors are not supported.
+// CHECK-LABEL: func.func @reduce_scalable_f32(
+// CHECK-SAME: %[[ARG0:.+]]: vector<[1]xf32>) -> f32 {
+// CHECK-NEXT: %[[R0:.+]] = vector.reduction <add>, %[[ARG0]] : vector<[1]xf32> into f32
+// CHECK-NEXT: return %[[R0]] : f32
+func.func @reduce_scalable_f32(%arg0: vector<[1]xf32>) -> f32 {
+ %0 = vector.reduction <add>, %arg0 : vector<[1]xf32> into f32
+ return %0 : f32
+}
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 03ddebe82344d8..126d65b1b8487f 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -439,6 +439,27 @@ struct TestVectorChainedReductionFoldingPatterns
}
};
+struct TestVectorBreakDownReductionPatterns
+ : public PassWrapper<TestVectorBreakDownReductionPatterns,
+ OperationPass<func::FuncOp>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
+ TestVectorBreakDownReductionPatterns)
+
+ StringRef getArgument() const final {
+ return "test-vector-break-down-reduction-patterns";
+ }
+ StringRef getDescription() const final {
+ return "Test patterns to break down vector reductions into arith "
+ "reductions";
+ }
+ void runOnOperation() override {
+ RewritePatternSet patterns(&getContext());
+ populateBreakDownVectorReductionPatterns(patterns,
+ /*maxNumElementsToExtract=*/2);
+ (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+ }
+};
+
struct TestFlattenVectorTransferPatterns
: public PassWrapper<TestFlattenVectorTransferPatterns,
OperationPass<func::FuncOp>> {
@@ -827,6 +848,8 @@ void registerTestVectorLowerings() {
PassRegistration<TestVectorChainedReductionFoldingPatterns>();
+ PassRegistration<TestVectorBreakDownReductionPatterns>();
+
PassRegistration<TestFlattenVectorTransferPatterns>();
PassRegistration<TestVectorScanLowering>();
More information about the Mlir-commits
mailing list