[Mlir-commits] [mlir] d33bad6 - [mlir][vector] Add patterns to simplify chained reductions (#73048)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Nov 22 07:30:08 PST 2023


Author: Jakub Kuderski
Date: 2023-11-22T10:30:04-05:00
New Revision: d33bad66d86a6fdb443c59561f9524f451a82db0

URL: https://github.com/llvm/llvm-project/commit/d33bad66d86a6fdb443c59561f9524f451a82db0
DIFF: https://github.com/llvm/llvm-project/commit/d33bad66d86a6fdb443c59561f9524f451a82db0.diff

LOG: [mlir][vector] Add patterns to simplify chained reductions (#73048)

Chained reductions get created during vector unrolling. These patterns
simplify them into a series of adds followed by a final reductions.

This is preferred on GPU targets like SPIR-V/Vulkan where vector
reduction gets lowered into subgroup operations that are generally more
expensive than simple vector additions.

For now, only the `add` combining kind is handled.

Added: 
    mlir/test/Dialect/Vector/chained-vector-reduction-folding.mlir

Modified: 
    mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
    mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
    mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index f8221ba0ff283ce..08c08172d0531e4 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -147,6 +147,25 @@ void populateVectorTransferCollapseInnerMostContiguousDimsPatterns(
 void populateSinkVectorBroadcastPatterns(RewritePatternSet &patterns,
                                          PatternBenefit benefit = 1);
 
+/// Patterns that fold chained vector reductions. These patterns assume that
+/// elementwise operations (e.g., `arith.addf` with vector operands) are
+/// cheaper than vector reduction.
+/// Note that these patterns change the order of reduction which may not always
+/// produce bit-identical results on some floating point inputs.
+///
+/// Example:
+/// ```
+/// %a = vector.reduction <add> %x, %acc
+/// %b = vector.reduction <add> %y, %a
+/// ```
+/// is transformed into:
+/// ```
+/// %a = arith.addf %x, %y
+/// %b = vector.reduction <add> %a, %acc
+/// ```
+void populateChainedVectorReductionFoldingPatterns(RewritePatternSet &patterns,
+                                                   PatternBenefit benefit = 1);
+
 /// Populate `patterns` with the following patterns.
 ///
 /// [DecomposeDifferentRankInsertStridedSlice]

diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index a20c8aeeb6f7108..582d627d1ce4ac0 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1402,6 +1402,98 @@ struct FoldArithExtIntoContractionOp
   }
 };
 
+/// Pattern to fold chained reduction to a series of vector additions and a
+/// final reduction. This form should require fewer subgroup operations.
+///
+/// ```mlir
+/// %a = vector.reduction <add> %x, %acc
+/// %b = vector.reduction <add> %y, %a
+///  ==>
+/// %a = arith.addf %x, %y
+/// %b = vector.reduction <add> %a, %acc
+/// ```
+struct ChainedReduction final : OpRewritePattern<vector::ReductionOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::ReductionOp op,
+                                PatternRewriter &rewriter) const override {
+    // TODO: Handle other combining kinds.
+    if (op.getKind() != vector::CombiningKind::ADD)
+      return failure();
+
+    // Accumulator is optional.
+    Value acc = op.getAcc();
+    if (!acc)
+      return failure();
+
+    if (!acc.getType().isIntOrFloat())
+      return failure();
+
+    auto parentReduction = acc.getDefiningOp<vector::ReductionOp>();
+    if (!parentReduction)
+      return failure();
+
+    Location loc = op.getLoc();
+    Value vAdd;
+    if (isa<IntegerType>(acc.getType())) {
+      vAdd = rewriter.createOrFold<arith::AddIOp>(
+          loc, parentReduction.getVector(), op.getVector());
+    } else {
+      vAdd = rewriter.create<arith::AddFOp>(loc, parentReduction.getVector(),
+                                            op.getVector());
+    }
+    rewriter.replaceOpWithNewOp<vector::ReductionOp>(op, op.getKind(), vAdd,
+                                                     parentReduction.getAcc());
+    return success();
+  }
+};
+
+/// Pattern to eliminate redundant zero-constants added to reduction operands.
+/// It's enough for there to be one initial zero value, so we can eliminate the
+/// extra ones that feed into `vector.reduction <add>`. These get created by the
+/// `ChainedReduction` pattern.
+///
+/// ```mlir
+/// %a = arith.addf %x, %zero
+/// %b = arith.addf %a, %y
+/// %c = vector.reduction <add> %b, %acc
+///  ==>
+/// %b = arith.addf %a, %y
+/// %c = vector.reduction <add> %b, %acc
+/// ```
+struct ReduceRedundantZero final : OpRewritePattern<vector::ReductionOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::ReductionOp op,
+                                PatternRewriter &rewriter) const override {
+    // TODO: Handle other reduction kinds and their identity values.
+    if (op.getKind() != vector::CombiningKind::ADD)
+      return failure();
+
+    Type elemType = op.getSourceVectorType().getElementType();
+    // The integer case should be handled by `arith.addi` folders, only check
+    // for floats here.
+    if (!isa<FloatType>(elemType))
+      return failure();
+
+    auto vAdd = op.getVector().getDefiningOp<arith::AddFOp>();
+    if (!vAdd)
+      return failure();
+    auto addLhs = vAdd.getLhs().getDefiningOp<arith::AddFOp>();
+    if (!addLhs)
+      return failure();
+
+    if (!matchPattern(addLhs.getRhs(), m_AnyZeroFloat()))
+      return failure();
+
+    auto newAdd = rewriter.create<arith::AddFOp>(vAdd.getLoc(), addLhs.getLhs(),
+                                                 vAdd.getRhs());
+    rewriter.replaceOpWithNewOp<vector::ReductionOp>(op, op.getKind(), newAdd,
+                                                     op.getAcc());
+    return success();
+  }
+};
+
 } // namespace
 
 void mlir::vector::populateFoldArithExtensionPatterns(
@@ -1467,6 +1559,13 @@ void mlir::vector::populateSinkVectorBroadcastPatterns(
       patterns.getContext(), benefit);
 }
 
+void mlir::vector::populateChainedVectorReductionFoldingPatterns(
+    RewritePatternSet &patterns, PatternBenefit benefit) {
+  patterns.add<ChainedReduction>(patterns.getContext(), benefit);
+  patterns.add<ReduceRedundantZero>(patterns.getContext(),
+                                    PatternBenefit(benefit.getBenefit() + 1));
+}
+
 //===----------------------------------------------------------------------===//
 // TableGen'd enum attribute definitions
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Vector/chained-vector-reduction-folding.mlir b/mlir/test/Dialect/Vector/chained-vector-reduction-folding.mlir
new file mode 100644
index 000000000000000..3048f53345e5cd2
--- /dev/null
+++ b/mlir/test/Dialect/Vector/chained-vector-reduction-folding.mlir
@@ -0,0 +1,116 @@
+// RUN: mlir-opt %s --test-vector-chained-reduction-folding-patterns | FileCheck %s
+
+// CHECK-LABEL:   func.func @reduce_1x_fp32(
+// CHECK-SAME:     %[[ARG0:.+]]: vector<8xf32>) -> f32 {
+// CHECK-DAG:        %[[CST:.+]] = arith.constant 0.0
+// CHECK-NEXT:       %[[RES:.+]] = vector.reduction <add>, %[[ARG0]], %[[CST]] : vector<8xf32> into f32
+// CHECK-NEXT:       return %[[RES]] : f32
+func.func @reduce_1x_fp32(%arg0: vector<8xf32>) -> f32 {
+  %cst0 = arith.constant 0.0 : f32
+  %0 = vector.reduction <add>, %arg0, %cst0 : vector<8xf32> into f32
+  return %0 : f32
+}
+
+// CHECK-LABEL:   func.func @reduce_2x_fp32(
+// CHECK-SAME:     %[[ARG0:.+]]: vector<8xf32>, %[[ARG1:.+]]: vector<8xf32>) -> f32 {
+// CHECK-DAG:        %[[CST:.+]] = arith.constant 0.0
+// CHECK-DAG:        %[[ADD:.+]] = arith.addf %[[ARG0]], %[[ARG1]] : vector<8xf32>
+// CHECK-NEXT:       %[[RES:.+]] = vector.reduction <add>, %[[ADD]], %[[CST]] : vector<8xf32> into f32
+// CHECK-NEXT:       return %[[RES]] : f32
+func.func @reduce_2x_fp32(%arg0: vector<8xf32>, %arg1: vector<8xf32>) -> f32 {
+  %cst0 = arith.constant 0.0 : f32
+  %0 = vector.reduction <add>, %arg0, %cst0 : vector<8xf32> into f32
+  %1 = vector.reduction <add>, %arg1, %0 : vector<8xf32> into f32
+  return %1 : f32
+}
+
+// CHECK-LABEL:   func.func @reduce_2x_no_acc_fp32(
+// CHECK-SAME:     %[[ARG0:.+]]: vector<8xf32>, %[[ARG1:.+]]: vector<8xf32>) -> f32 {
+// CHECK:            %[[ADD:.+]] = arith.addf %[[ARG0]], %[[ARG1]] : vector<8xf32>
+// CHECK-NEXT:       %[[RES:.+]] = vector.reduction <add>, %[[ADD]] : vector<8xf32> into f32
+// CHECK-NEXT:       return %[[RES]] : f32
+func.func @reduce_2x_no_acc_fp32(%arg0: vector<8xf32>, %arg1: vector<8xf32>) -> f32 {
+  %0 = vector.reduction <add>, %arg0 : vector<8xf32> into f32
+  %1 = vector.reduction <add>, %arg1, %0 : vector<8xf32> into f32
+  return %1 : f32
+}
+
+// CHECK-LABEL:   func.func @reduce_2x_zero_add_fp32(
+// CHECK-SAME:     %[[ARG0:.+]]: vector<8xf32>, %[[ARG1:.+]]: vector<8xf32>) -> f32 {
+// CHECK:            %[[ADD:.+]] = arith.addf %[[ARG0]], %[[ARG1]] : vector<8xf32>
+// CHECK-NEXT:       %[[RES:.+]] = vector.reduction <add>, %[[ADD]] : vector<8xf32> into f32
+// CHECK-NEXT:       return %[[RES]] : f32
+func.func @reduce_2x_zero_add_fp32(%arg0: vector<8xf32>, %arg1: vector<8xf32>) -> f32 {
+  %cst0 = arith.constant dense<0.0> : vector<8xf32>
+  %x = arith.addf %arg0, %cst0 : vector<8xf32>
+  %0 = vector.reduction <add>, %x : vector<8xf32> into f32
+  %1 = vector.reduction <add>, %arg1, %0 : vector<8xf32> into f32
+  return %1 : f32
+}
+
+// CHECK-LABEL:   func.func @reduce_3x_fp32(
+// CHECK-SAME:     %[[ARG0:.+]]: vector<8xf32>, %[[ARG1:.+]]: vector<8xf32>,
+// CHECK-SAME:     %[[ARG2:.+]]: vector<8xf32>) -> f32 {
+// CHECK-DAG:        %[[CST:.+]] = arith.constant 0.0
+// CHECK-DAG:        %[[ADD0:.+]] = arith.addf %[[ARG1]], %[[ARG2]] : vector<8xf32>
+// CHECK-DAG:        %[[ADD1:.+]] = arith.addf %[[ARG0]], %[[ADD0]] : vector<8xf32>
+// CHECK-NEXT:       %[[RES:.+]] = vector.reduction <add>, %[[ADD1]], %[[CST]] : vector<8xf32> into f32
+// CHECK-NEXT:       return %[[RES]] : f32
+func.func @reduce_3x_fp32(%arg0: vector<8xf32>, %arg1: vector<8xf32>,
+                          %arg2: vector<8xf32>) -> f32 {
+  %cst0 = arith.constant 0.0 : f32
+  %0 = vector.reduction <add>, %arg0, %cst0 : vector<8xf32> into f32
+  %1 = vector.reduction <add>, %arg1, %0 : vector<8xf32> into f32
+  %2 = vector.reduction <add>, %arg2, %1 : vector<8xf32> into f32
+  return %2 : f32
+}
+
+// CHECK-LABEL:   func.func @reduce_1x_i32(
+// CHECK-SAME:     %[[ARG0:.+]]: vector<8xi32>) -> i32 {
+// CHECK-DAG:        %[[CST:.+]] = arith.constant 0
+// CHECK-NEXT:       %[[RES:.+]] = vector.reduction <add>, %[[ARG0]], %[[CST]] : vector<8xi32> into i32
+// CHECK-NEXT:       return %[[RES]] : i32
+func.func @reduce_1x_i32(%arg0: vector<8xi32>) -> i32 {
+  %cst0 = arith.constant 0 : i32
+  %0 = vector.reduction <add>, %arg0, %cst0 : vector<8xi32> into i32
+  return %0 : i32
+}
+
+// CHECK-LABEL:   func.func @reduce_2x_i32(
+// CHECK-SAME:     %[[ARG0:.+]]: vector<8xi32>, %[[ARG1:.+]]: vector<8xi32>) -> i32 {
+// CHECK-DAG:        %[[CST:.+]] = arith.constant 0
+// CHECK-DAG:        %[[ADD:.+]] = arith.addi %[[ARG0]], %[[ARG1]] : vector<8xi32>
+// CHECK-NEXT:       %[[RES:.+]] = vector.reduction <add>, %[[ADD]], %[[CST]] : vector<8xi32> into i32
+// CHECK-NEXT:       return %[[RES]] : i32
+func.func @reduce_2x_i32(%arg0: vector<8xi32>, %arg1: vector<8xi32>) -> i32 {
+  %cst0 = arith.constant 0 : i32
+  %0 = vector.reduction <add>, %arg0, %cst0 : vector<8xi32> into i32
+  %1 = vector.reduction <add>, %arg1, %0 : vector<8xi32> into i32
+  return %1 : i32
+}
+
+// CHECK-LABEL:   func.func @reduce_2x_no_acc_i32(
+// CHECK-SAME:     %[[ARG0:.+]]: vector<8xi32>, %[[ARG1:.+]]: vector<8xi32>) -> i32 {
+// CHECK:            %[[ADD:.+]] = arith.addi %[[ARG0]], %[[ARG1]] : vector<8xi32>
+// CHECK-NEXT:       %[[RES:.+]] = vector.reduction <add>, %[[ADD]] : vector<8xi32> into i32
+// CHECK-NEXT:       return %[[RES]] : i32
+func.func @reduce_2x_no_acc_i32(%arg0: vector<8xi32>, %arg1: vector<8xi32>) -> i32 {
+  %0 = vector.reduction <add>, %arg0 : vector<8xi32> into i32
+  %1 = vector.reduction <add>, %arg1, %0 : vector<8xi32> into i32
+  return %1 : i32
+}
+
+// CHECK-LABEL:   func.func @reduce_2x_zero_add_i32(
+// CHECK-SAME:     %[[ARG0:.+]]: vector<8xi32>, %[[ARG1:.+]]: vector<8xi32>) -> i32 {
+// CHECK-DAG:        %[[CST:.+]] = arith.constant 0
+// CHECK-DAG:        %[[ADD:.+]] = arith.addi %[[ARG0]], %[[ARG1]] : vector<8xi32>
+// CHECK-NEXT:       %[[RES:.+]] = vector.reduction <add>, %[[ADD]], %[[CST]] : vector<8xi32> into i32
+// CHECK-NEXT:       return %[[RES]] : i32
+func.func @reduce_2x_zero_add_i32(%arg0: vector<8xi32>, %arg1: vector<8xi32>) -> i32 {
+  %cst0 = arith.constant 0 : i32
+  %cstV = arith.constant dense<0> : vector<8xi32>
+  %x = arith.addi %arg0, %cstV : vector<8xi32>
+  %0 = vector.reduction <add>, %x, %cst0 : vector<8xi32> into i32
+  %1 = vector.reduction <add>, %arg1, %0 : vector<8xi32> into i32
+  return %1 : i32
+}

diff  --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 1a177fa31de37ce..feb716cdbf404eb 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -420,6 +420,25 @@ struct TestVectorReduceToContractPatternsPatterns
   }
 };
 
+struct TestVectorChainedReductionFoldingPatterns
+    : public PassWrapper<TestVectorChainedReductionFoldingPatterns,
+                         OperationPass<func::FuncOp>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
+      TestVectorChainedReductionFoldingPatterns)
+
+  StringRef getArgument() const final {
+    return "test-vector-chained-reduction-folding-patterns";
+  }
+  StringRef getDescription() const final {
+    return "Test patterns to fold chained vector reductions";
+  }
+  void runOnOperation() override {
+    RewritePatternSet patterns(&getContext());
+    populateChainedVectorReductionFoldingPatterns(patterns);
+    (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+  }
+};
+
 struct TestFlattenVectorTransferPatterns
     : public PassWrapper<TestFlattenVectorTransferPatterns,
                          OperationPass<func::FuncOp>> {
@@ -773,6 +792,8 @@ void registerTestVectorLowerings() {
 
   PassRegistration<TestVectorReduceToContractPatternsPatterns>();
 
+  PassRegistration<TestVectorChainedReductionFoldingPatterns>();
+
   PassRegistration<TestFlattenVectorTransferPatterns>();
 
   PassRegistration<TestVectorScanLowering>();


        


More information about the Mlir-commits mailing list