[Mlir-commits] [mlir] [mlir][vector] Add patterns to simplify chained reductions (PR #73048)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Nov 21 14:41:08 PST 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-vector
Author: Jakub Kuderski (kuhar)
<details>
<summary>Changes</summary>
Chained reductions get created during vector unrolling. These patterns simplify them into a series of adds followed by a final reductions.
This is preferred on GPU targets like SPIR-V/Vulkan where vector reduction gets lowered into subgroup operations that are generally more expensive than simple vector additions.
For now, only the `add` combining kind is handled.
---
Full diff: https://github.com/llvm/llvm-project/pull/73048.diff
4 Files Affected:
- (modified) mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h (+17)
- (modified) mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp (+99)
- (added) mlir/test/Dialect/Vector/chained-vector-reduction-folding.mlir (+116)
- (modified) mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp (+21)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index f8221ba0ff283ce..e373e99fb35ec7d 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -147,6 +147,23 @@ void populateVectorTransferCollapseInnerMostContiguousDimsPatterns(
void populateSinkVectorBroadcastPatterns(RewritePatternSet &patterns,
PatternBenefit benefit = 1);
+/// Patterns that fold chained vector reductions. These patterns assume that
+/// vector addition (e.g., `arith.addf` with vector operands) is cheaper than
+/// vector reduction.
+///
+/// Example:
+/// ```
+/// %a = vector.reduction <add> %a, %acc
+/// %b = vector.reduction <add> %b, %a
+/// ```
+/// is transformed into:
+/// ```
+/// %a = arith.addf %a, %b
+/// %b = vector.reduction <add> %a, %acc
+/// ```
+void populateChainedVectorReductionFoldingPatterns(RewritePatternSet &patterns,
+ PatternBenefit benefit = 1);
+
/// Populate `patterns` with the following patterns.
///
/// [DecomposeDifferentRankInsertStridedSlice]
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index a20c8aeeb6f7108..bdb1e3815372318 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1402,6 +1402,98 @@ struct FoldArithExtIntoContractionOp
}
};
+/// Pattern to fold chained to reduction to a series of vector additions and a
+/// final reduction. This form should require fewer subgroup operations.
+///
+/// ```mlir
+/// %a = vector.reduction <add> %x, %acc
+/// %b = vector.reduction <add> %y, %a
+/// ==>
+/// %a = arith.addf %x, %y
+/// %b = vector.reduction <add> %a, %acc
+/// ```
+struct ChainedReduction final : OpRewritePattern<vector::ReductionOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::ReductionOp op,
+ PatternRewriter &rewriter) const override {
+ // TODO: Handle other combining kinds.
+ if (op.getKind() != vector::CombiningKind::ADD)
+ return failure();
+
+ // Accumulator is optional.
+ Value acc = op.getAcc();
+ if (!acc)
+ return failure();
+
+ if (!acc.getType().isIntOrFloat())
+ return failure();
+
+ auto parentReduction = acc.getDefiningOp<vector::ReductionOp>();
+ if (!parentReduction)
+ return failure();
+
+ Location loc = op.getLoc();
+ Value vAdd;
+ if (isa<IntegerType>(acc.getType())) {
+ vAdd = rewriter.createOrFold<arith::AddIOp>(
+ loc, parentReduction.getVector(), op.getVector());
+ } else {
+ vAdd = rewriter.create<arith::AddFOp>(loc, parentReduction.getVector(),
+ op.getVector());
+ }
+ rewriter.replaceOpWithNewOp<vector::ReductionOp>(op, op.getKind(), vAdd,
+ parentReduction.getAcc());
+ return success();
+ }
+};
+
+/// Pattern to eliminate redundant zero-constants added to reduction operands.
+/// It's enough for there to be one initial zero value, so we can eliminate the
+/// extra ones that feed into `vector.reduction <add>`. These get created by the
+/// `ChainedReduction` pattern.
+///
+/// ```mlir
+/// %a = arith.addf %x, %zero
+/// %b = arith.addf %a, %y
+/// %c = vector.reduction <add> %b, %acc
+/// ==>
+/// %b = arith.addf %a, %y
+/// %c = vector.reduction <add> %b, %acc
+/// ```
+struct ReduceRedundantZero final : OpRewritePattern<vector::ReductionOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::ReductionOp op,
+ PatternRewriter &rewriter) const override {
+ // TODO: Handle other reduction kinds and their identity values.
+ if (op.getKind() != vector::CombiningKind::ADD)
+ return failure();
+
+ Type elemType = op.getSourceVectorType().getElementType();
+ // The integer case should be handled by `arith.addi` folders, only check
+ // for floats here.
+ if (!isa<FloatType>(elemType))
+ return failure();
+
+ auto vAdd = op.getVector().getDefiningOp<arith::AddFOp>();
+ if (!vAdd)
+ return failure();
+ auto addLhs = vAdd.getLhs().getDefiningOp<arith::AddFOp>();
+ if (!addLhs)
+ return failure();
+
+ if (!matchPattern(addLhs.getRhs(), m_AnyZeroFloat()))
+ return failure();
+
+ auto newAdd = rewriter.create<arith::AddFOp>(vAdd.getLoc(), addLhs.getLhs(),
+ vAdd.getRhs());
+ rewriter.replaceOpWithNewOp<vector::ReductionOp>(op, op.getKind(), newAdd,
+ op.getAcc());
+ return success();
+ }
+};
+
} // namespace
void mlir::vector::populateFoldArithExtensionPatterns(
@@ -1467,6 +1559,13 @@ void mlir::vector::populateSinkVectorBroadcastPatterns(
patterns.getContext(), benefit);
}
+void mlir::vector::populateChainedVectorReductionFoldingPatterns(
+ RewritePatternSet &patterns, PatternBenefit benefit) {
+ patterns.add<ChainedReduction>(patterns.getContext(), benefit);
+ patterns.add<ReduceRedundantZero>(patterns.getContext(),
+ PatternBenefit(benefit.getBenefit() + 1));
+}
+
//===----------------------------------------------------------------------===//
// TableGen'd enum attribute definitions
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/chained-vector-reduction-folding.mlir b/mlir/test/Dialect/Vector/chained-vector-reduction-folding.mlir
new file mode 100644
index 000000000000000..699a8fefd68ca38
--- /dev/null
+++ b/mlir/test/Dialect/Vector/chained-vector-reduction-folding.mlir
@@ -0,0 +1,116 @@
+// RUN: mlir-opt %s --test-vector-chained-reduction-folding-patterns | FileCheck %s
+
+// CHECK-LABEL: func.func @reduce_1x_fp32(
+// CHECK-SAME: %[[ARG0:.+]]: vector<8xf32>) -> f32 {
+// CHECK-DAG: %[[CST:.+]] = arith.constant 0.0
+// CHECK-NEXT %[[RES:.+]] = vector.reduction <add>, %[[ARG0]], %[[CST]] : vector<8xf32> into f32
+// CHECK-NE return %[[RES]] : f32
+func.func @reduce_1x_fp32(%arg0: vector<8xf32>) -> f32 {
+ %cst0 = arith.constant 0.0 : f32
+ %0 = vector.reduction <add>, %arg0, %cst0 : vector<8xf32> into f32
+ return %0 : f32
+}
+
+// CHECK-LABEL: func.func @reduce_2x_fp32(
+// CHECK-SAME: %[[ARG0:.+]]: vector<8xf32>, %[[ARG1:.+]]: vector<8xf32>) -> f32 {
+// CHECK-DAG: %[[CST:.+]] = arith.constant 0.0
+// CHECK-DAG: %[[ADD:.+]] = arith.addf %[[ARG0]], %[[ARG1]] : vector<8xf32>
+// CHECK-NEXT %[[RES:.+]] = vector.reduction <add>, %[[ADD]], %[[CST]] : vector<8xf32> into f32
+// CHECK-NE return %[[RES]] : f32
+func.func @reduce_2x_fp32(%arg0: vector<8xf32>, %arg1: vector<8xf32>) -> f32 {
+ %cst0 = arith.constant 0.0 : f32
+ %0 = vector.reduction <add>, %arg0, %cst0 : vector<8xf32> into f32
+ %1 = vector.reduction <add>, %arg1, %0 : vector<8xf32> into f32
+ return %1 : f32
+}
+
+// CHECK-LABEL: func.func @reduce_2x_no_acc_fp32(
+// CHECK-SAME: %[[ARG0:.+]]: vector<8xf32>, %[[ARG1:.+]]: vector<8xf32>) -> f32 {
+// CHECK: %[[ADD:.+]] = arith.addf %[[ARG0]], %[[ARG1]] : vector<8xf32>
+// CHECK-NEXT %[[RES:.+]] = vector.reduction <add>, %[[ADD]] : vector<8xf32> into f32
+// CHECK-NE return %[[RES]] : f32
+func.func @reduce_2x_no_acc_fp32(%arg0: vector<8xf32>, %arg1: vector<8xf32>) -> f32 {
+ %0 = vector.reduction <add>, %arg0 : vector<8xf32> into f32
+ %1 = vector.reduction <add>, %arg1, %0 : vector<8xf32> into f32
+ return %1 : f32
+}
+
+// CHECK-LABEL: func.func @reduce_2x_zero_add_fp32(
+// CHECK-SAME: %[[ARG0:.+]]: vector<8xf32>, %[[ARG1:.+]]: vector<8xf32>) -> f32 {
+// CHECK: %[[ADD:.+]] = arith.addf %[[ARG0]], %[[ARG1]] : vector<8xf32>
+// CHECK-NEXT %[[RES:.+]] = vector.reduction <add>, %[[ADD]] : vector<8xf32> into f32
+// CHECK-NE return %[[RES]] : f32
+func.func @reduce_2x_zero_add_fp32(%arg0: vector<8xf32>, %arg1: vector<8xf32>) -> f32 {
+ %cst0 = arith.constant dense<0.0> : vector<8xf32>
+ %x = arith.addf %arg0, %cst0 : vector<8xf32>
+ %0 = vector.reduction <add>, %x : vector<8xf32> into f32
+ %1 = vector.reduction <add>, %arg1, %0 : vector<8xf32> into f32
+ return %1 : f32
+}
+
+// CHECK-LABEL: func.func @reduce_3x_fp32(
+// CHECK-SAME: %[[ARG0:.+]]: vector<8xf32>, %[[ARG1:.+]]: vector<8xf32>,
+// CHECK-SAME: %[[ARG2:.+]]: vector<8xf32>) -> f32 {
+// CHECK-DAG: %[[CST:.+]] = arith.constant 0.0
+// CHECK-DAG: %[[ADD0:.+]] = arith.addf %[[ARG1]], %[[ARG2]] : vector<8xf32>
+// CHECK-DAG: %[[ADD1:.+]] = arith.addf %[[ARG0]], %[[ADD0]] : vector<8xf32>
+// CHECK-NEXT %[[RES:.+]] = vector.reduction <add>, %[[ADD1]], %[[CST]] : vector<8xf32> into f32
+// CHECK-NE return %[[RES]] : f32
+func.func @reduce_3x_fp32(%arg0: vector<8xf32>, %arg1: vector<8xf32>,
+ %arg2: vector<8xf32>) -> f32 {
+ %cst0 = arith.constant 0.0 : f32
+ %0 = vector.reduction <add>, %arg0, %cst0 : vector<8xf32> into f32
+ %1 = vector.reduction <add>, %arg1, %0 : vector<8xf32> into f32
+ %2 = vector.reduction <add>, %arg2, %1 : vector<8xf32> into f32
+ return %2 : f32
+}
+
+// CHECK-LABEL: func.func @reduce_1x_i32(
+// CHECK-SAME: %[[ARG0:.+]]: vector<8xi32>) -> i32 {
+// CHECK-DAG: %[[CST:.+]] = arith.constant 0
+// CHECK-NEXT %[[RES:.+]] = vector.reduction <add>, %[[ARG0]], %[[CST]] : vector<8xi32> into i32
+// CHECK-NE return %[[RES]] : i32
+func.func @reduce_1x_i32(%arg0: vector<8xi32>) -> i32 {
+ %cst0 = arith.constant 0 : i32
+ %0 = vector.reduction <add>, %arg0, %cst0 : vector<8xi32> into i32
+ return %0 : i32
+}
+
+// CHECK-LABEL: func.func @reduce_2x_i32(
+// CHECK-SAME: %[[ARG0:.+]]: vector<8xi32>, %[[ARG1:.+]]: vector<8xi32>) -> i32 {
+// CHECK-DAG: %[[CST:.+]] = arith.constant 0
+// CHECK-DAG: %[[ADD:.+]] = arith.addi %[[ARG0]], %[[ARG1]] : vector<8xi32>
+// CHECK-NEXT %[[RES:.+]] = vector.reduction <add>, %[[ADD]], %[[CST]] : vector<8xi32> into i32
+// CHECK-NE return %[[RES]] : i32
+func.func @reduce_2x_i32(%arg0: vector<8xi32>, %arg1: vector<8xi32>) -> i32 {
+ %cst0 = arith.constant 0 : i32
+ %0 = vector.reduction <add>, %arg0, %cst0 : vector<8xi32> into i32
+ %1 = vector.reduction <add>, %arg1, %0 : vector<8xi32> into i32
+ return %1 : i32
+}
+
+// CHECK-LABEL: func.func @reduce_2x_no_acc_i32(
+// CHECK-SAME: %[[ARG0:.+]]: vector<8xi32>, %[[ARG1:.+]]: vector<8xi32>) -> i32 {
+// CHECK: %[[ADD:.+]] = arith.addi %[[ARG0]], %[[ARG1]] : vector<8xi32>
+// CHECK-NEXT %[[RES:.+]] = vector.reduction <add>, %[[ADD]] : vector<8xi32> into i32
+// CHECK-NE return %[[RES]] : i32
+func.func @reduce_2x_no_acc_i32(%arg0: vector<8xi32>, %arg1: vector<8xi32>) -> i32 {
+ %0 = vector.reduction <add>, %arg0 : vector<8xi32> into i32
+ %1 = vector.reduction <add>, %arg1, %0 : vector<8xi32> into i32
+ return %1 : i32
+}
+
+// CHECK-LABEL: func.func @reduce_2x_zero_add_i32(
+// CHECK-SAME: %[[ARG0:.+]]: vector<8xi32>, %[[ARG1:.+]]: vector<8xi32>) -> i32 {
+// CHECK-DAG: %[[CST:.+]] = arith.constant 0
+// CHECK-DAG: %[[ADD:.+]] = arith.addi %[[ARG0]], %[[ARG1]] : vector<8xi32>
+// CHECK-NEXT %[[RES:.+]] = vector.reduction <add>, %[[ADD]], %[[CST]] : vector<8xi32> into i32
+// CHECK-NE return %[[RES]] : i32
+func.func @reduce_2x_zero_add_i32(%arg0: vector<8xi32>, %arg1: vector<8xi32>) -> i32 {
+ %cst0 = arith.constant 0 : i32
+ %cstV = arith.constant dense<0> : vector<8xi32>
+ %x = arith.addi %arg0, %cstV : vector<8xi32>
+ %0 = vector.reduction <add>, %x, %cst0 : vector<8xi32> into i32
+ %1 = vector.reduction <add>, %arg1, %0 : vector<8xi32> into i32
+ return %1 : i32
+}
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 1a177fa31de37ce..feb716cdbf404eb 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -420,6 +420,25 @@ struct TestVectorReduceToContractPatternsPatterns
}
};
+struct TestVectorChainedReductionFoldingPatterns
+ : public PassWrapper<TestVectorChainedReductionFoldingPatterns,
+ OperationPass<func::FuncOp>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
+ TestVectorChainedReductionFoldingPatterns)
+
+ StringRef getArgument() const final {
+ return "test-vector-chained-reduction-folding-patterns";
+ }
+ StringRef getDescription() const final {
+ return "Test patterns to fold chained vector reductions";
+ }
+ void runOnOperation() override {
+ RewritePatternSet patterns(&getContext());
+ populateChainedVectorReductionFoldingPatterns(patterns);
+ (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+ }
+};
+
struct TestFlattenVectorTransferPatterns
: public PassWrapper<TestFlattenVectorTransferPatterns,
OperationPass<func::FuncOp>> {
@@ -773,6 +792,8 @@ void registerTestVectorLowerings() {
PassRegistration<TestVectorReduceToContractPatternsPatterns>();
+ PassRegistration<TestVectorChainedReductionFoldingPatterns>();
+
PassRegistration<TestFlattenVectorTransferPatterns>();
PassRegistration<TestVectorScanLowering>();
``````````
</details>
https://github.com/llvm/llvm-project/pull/73048
More information about the Mlir-commits
mailing list