[Mlir-commits] [mlir] [mlir][vector] Add patterns to simplify chained reductions (PR #73048)

Jakub Kuderski llvmlistbot at llvm.org
Tue Nov 21 18:46:40 PST 2023


https://github.com/kuhar updated https://github.com/llvm/llvm-project/pull/73048

>From 24b3833415132b4a1989fa04c5c367c35c6bad2a Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Tue, 21 Nov 2023 17:36:49 -0500
Subject: [PATCH 1/4] [mlir][vector] Add patterns to simplify chained
 reductions

Chained reductions get created during vector unrolling. These patterns
simplify them into a series of adds followed by a final reductions.

This is preferred on GPU targets like SPIR-V/Vulkan where vector
reduction gets lowered into subgroup operations that are generally more
expensive than simple vector additions.

For now, only the `add` combining kind is hanlded.
---
 .../Vector/Transforms/VectorRewritePatterns.h |  17 +++
 .../Vector/Transforms/VectorTransforms.cpp    |  99 +++++++++++++++
 .../chained-vector-reduction-folding.mlir     | 116 ++++++++++++++++++
 .../Dialect/Vector/TestVectorTransforms.cpp   |  21 ++++
 4 files changed, 253 insertions(+)
 create mode 100644 mlir/test/Dialect/Vector/chained-vector-reduction-folding.mlir

diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index f8221ba0ff283ce..e373e99fb35ec7d 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -147,6 +147,23 @@ void populateVectorTransferCollapseInnerMostContiguousDimsPatterns(
 void populateSinkVectorBroadcastPatterns(RewritePatternSet &patterns,
                                          PatternBenefit benefit = 1);
 
+/// Patterns that fold chained vector reductions. These patterns assume that
+/// vector addition (e.g., `arith.addf` with vector operands) is cheaper than
+/// vector reduction.
+///
+/// Example:
+/// ```
+/// %a = vector.reduction <add> %a, %acc
+/// %b = vector.reduction <add> %b, %a
+/// ```
+/// is transformed into:
+/// ```
+/// %a = arith.addf %a, %b
+/// %b = vector.reduction <add> %a, %acc
+/// ```
+void populateChainedVectorReductionFoldingPatterns(RewritePatternSet &patterns,
+                                                   PatternBenefit benefit = 1);
+
 /// Populate `patterns` with the following patterns.
 ///
 /// [DecomposeDifferentRankInsertStridedSlice]
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index a20c8aeeb6f7108..bdb1e3815372318 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1402,6 +1402,98 @@ struct FoldArithExtIntoContractionOp
   }
 };
 
+/// Pattern to fold chained to reduction to a series of vector additions and a
+/// final reduction. This form should require fewer subgroup operations.
+///
+/// ```mlir
+/// %a = vector.reduction <add> %x, %acc
+/// %b = vector.reduction <add> %y, %a
+///  ==>
+/// %a = arith.addf %x, %y
+/// %b = vector.reduction <add> %a, %acc
+/// ```
+struct ChainedReduction final : OpRewritePattern<vector::ReductionOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::ReductionOp op,
+                                PatternRewriter &rewriter) const override {
+    // TODO: Handle other combining kinds.
+    if (op.getKind() != vector::CombiningKind::ADD)
+      return failure();
+
+    // Accumulator is optional.
+    Value acc = op.getAcc();
+    if (!acc)
+      return failure();
+
+    if (!acc.getType().isIntOrFloat())
+      return failure();
+
+    auto parentReduction = acc.getDefiningOp<vector::ReductionOp>();
+    if (!parentReduction)
+      return failure();
+
+    Location loc = op.getLoc();
+    Value vAdd;
+    if (isa<IntegerType>(acc.getType())) {
+      vAdd = rewriter.createOrFold<arith::AddIOp>(
+          loc, parentReduction.getVector(), op.getVector());
+    } else {
+      vAdd = rewriter.create<arith::AddFOp>(loc, parentReduction.getVector(),
+                                            op.getVector());
+    }
+    rewriter.replaceOpWithNewOp<vector::ReductionOp>(op, op.getKind(), vAdd,
+                                                     parentReduction.getAcc());
+    return success();
+  }
+};
+
+/// Pattern to eliminate redundant zero-constants added to reduction operands.
+/// It's enough for there to be one initial zero value, so we can eliminate the
+/// extra ones that feed into `vector.reduction <add>`. These get created by the
+/// `ChainedReduction` pattern.
+///
+/// ```mlir
+/// %a = arith.addf %x, %zero
+/// %b = arith.addf %a, %y
+/// %c = vector.reduction <add> %b, %acc
+///  ==>
+/// %b = arith.addf %a, %y
+/// %c = vector.reduction <add> %b, %acc
+/// ```
+struct ReduceRedundantZero final : OpRewritePattern<vector::ReductionOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::ReductionOp op,
+                                PatternRewriter &rewriter) const override {
+    // TODO: Handle other reduction kinds and their identity values.
+    if (op.getKind() != vector::CombiningKind::ADD)
+      return failure();
+
+    Type elemType = op.getSourceVectorType().getElementType();
+    // The integer case should be handled by `arith.addi` folders, only check
+    // for floats here.
+    if (!isa<FloatType>(elemType))
+      return failure();
+
+    auto vAdd = op.getVector().getDefiningOp<arith::AddFOp>();
+    if (!vAdd)
+      return failure();
+    auto addLhs = vAdd.getLhs().getDefiningOp<arith::AddFOp>();
+    if (!addLhs)
+      return failure();
+
+    if (!matchPattern(addLhs.getRhs(), m_AnyZeroFloat()))
+      return failure();
+
+    auto newAdd = rewriter.create<arith::AddFOp>(vAdd.getLoc(), addLhs.getLhs(),
+                                                 vAdd.getRhs());
+    rewriter.replaceOpWithNewOp<vector::ReductionOp>(op, op.getKind(), newAdd,
+                                                     op.getAcc());
+    return success();
+  }
+};
+
 } // namespace
 
 void mlir::vector::populateFoldArithExtensionPatterns(
@@ -1467,6 +1559,13 @@ void mlir::vector::populateSinkVectorBroadcastPatterns(
       patterns.getContext(), benefit);
 }
 
+void mlir::vector::populateChainedVectorReductionFoldingPatterns(
+    RewritePatternSet &patterns, PatternBenefit benefit) {
+  patterns.add<ChainedReduction>(patterns.getContext(), benefit);
+  patterns.add<ReduceRedundantZero>(patterns.getContext(),
+                                    PatternBenefit(benefit.getBenefit() + 1));
+}
+
 //===----------------------------------------------------------------------===//
 // TableGen'd enum attribute definitions
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/chained-vector-reduction-folding.mlir b/mlir/test/Dialect/Vector/chained-vector-reduction-folding.mlir
new file mode 100644
index 000000000000000..699a8fefd68ca38
--- /dev/null
+++ b/mlir/test/Dialect/Vector/chained-vector-reduction-folding.mlir
@@ -0,0 +1,116 @@
+// RUN: mlir-opt %s --test-vector-chained-reduction-folding-patterns | FileCheck %s
+
+// CHECK-LABEL:   func.func @reduce_1x_fp32(
+// CHECK-SAME:     %[[ARG0:.+]]: vector<8xf32>) -> f32 {
+// CHECK-DAG:        %[[CST:.+]] = arith.constant 0.0
+// CHECK-NEXT        %[[RES:.+]] = vector.reduction <add>, %[[ARG0]], %[[CST]] : vector<8xf32> into f32
+// CHECK-NE          return %[[RES]] : f32
+func.func @reduce_1x_fp32(%arg0: vector<8xf32>) -> f32 {
+  %cst0 = arith.constant 0.0 : f32
+  %0 = vector.reduction <add>, %arg0, %cst0 : vector<8xf32> into f32
+  return %0 : f32
+}
+
+// CHECK-LABEL:   func.func @reduce_2x_fp32(
+// CHECK-SAME:     %[[ARG0:.+]]: vector<8xf32>, %[[ARG1:.+]]: vector<8xf32>) -> f32 {
+// CHECK-DAG:        %[[CST:.+]] = arith.constant 0.0
+// CHECK-DAG:        %[[ADD:.+]] = arith.addf %[[ARG0]], %[[ARG1]] : vector<8xf32>
+// CHECK-NEXT        %[[RES:.+]] = vector.reduction <add>, %[[ADD]], %[[CST]] : vector<8xf32> into f32
+// CHECK-NE          return %[[RES]] : f32
+func.func @reduce_2x_fp32(%arg0: vector<8xf32>, %arg1: vector<8xf32>) -> f32 {
+  %cst0 = arith.constant 0.0 : f32
+  %0 = vector.reduction <add>, %arg0, %cst0 : vector<8xf32> into f32
+  %1 = vector.reduction <add>, %arg1, %0 : vector<8xf32> into f32
+  return %1 : f32
+}
+
+// CHECK-LABEL:   func.func @reduce_2x_no_acc_fp32(
+// CHECK-SAME:     %[[ARG0:.+]]: vector<8xf32>, %[[ARG1:.+]]: vector<8xf32>) -> f32 {
+// CHECK:            %[[ADD:.+]] = arith.addf %[[ARG0]], %[[ARG1]] : vector<8xf32>
+// CHECK-NEXT        %[[RES:.+]] = vector.reduction <add>, %[[ADD]] : vector<8xf32> into f32
+// CHECK-NE          return %[[RES]] : f32
+func.func @reduce_2x_no_acc_fp32(%arg0: vector<8xf32>, %arg1: vector<8xf32>) -> f32 {
+  %0 = vector.reduction <add>, %arg0 : vector<8xf32> into f32
+  %1 = vector.reduction <add>, %arg1, %0 : vector<8xf32> into f32
+  return %1 : f32
+}
+
+// CHECK-LABEL:   func.func @reduce_2x_zero_add_fp32(
+// CHECK-SAME:     %[[ARG0:.+]]: vector<8xf32>, %[[ARG1:.+]]: vector<8xf32>) -> f32 {
+// CHECK:            %[[ADD:.+]] = arith.addf %[[ARG0]], %[[ARG1]] : vector<8xf32>
+// CHECK-NEXT        %[[RES:.+]] = vector.reduction <add>, %[[ADD]] : vector<8xf32> into f32
+// CHECK-NE          return %[[RES]] : f32
+func.func @reduce_2x_zero_add_fp32(%arg0: vector<8xf32>, %arg1: vector<8xf32>) -> f32 {
+  %cst0 = arith.constant dense<0.0> : vector<8xf32>
+  %x = arith.addf %arg0, %cst0 : vector<8xf32>
+  %0 = vector.reduction <add>, %x : vector<8xf32> into f32
+  %1 = vector.reduction <add>, %arg1, %0 : vector<8xf32> into f32
+  return %1 : f32
+}
+
+// CHECK-LABEL:   func.func @reduce_3x_fp32(
+// CHECK-SAME:     %[[ARG0:.+]]: vector<8xf32>, %[[ARG1:.+]]: vector<8xf32>,
+// CHECK-SAME:     %[[ARG2:.+]]: vector<8xf32>) -> f32 {
+// CHECK-DAG:        %[[CST:.+]] = arith.constant 0.0
+// CHECK-DAG:        %[[ADD0:.+]] = arith.addf %[[ARG1]], %[[ARG2]] : vector<8xf32>
+// CHECK-DAG:        %[[ADD1:.+]] = arith.addf %[[ARG0]], %[[ADD0]] : vector<8xf32>
+// CHECK-NEXT        %[[RES:.+]] = vector.reduction <add>, %[[ADD1]], %[[CST]] : vector<8xf32> into f32
+// CHECK-NE          return %[[RES]] : f32
+func.func @reduce_3x_fp32(%arg0: vector<8xf32>, %arg1: vector<8xf32>,
+                          %arg2: vector<8xf32>) -> f32 {
+  %cst0 = arith.constant 0.0 : f32
+  %0 = vector.reduction <add>, %arg0, %cst0 : vector<8xf32> into f32
+  %1 = vector.reduction <add>, %arg1, %0 : vector<8xf32> into f32
+  %2 = vector.reduction <add>, %arg2, %1 : vector<8xf32> into f32
+  return %2 : f32
+}
+
+// CHECK-LABEL:   func.func @reduce_1x_i32(
+// CHECK-SAME:     %[[ARG0:.+]]: vector<8xi32>) -> i32 {
+// CHECK-DAG:        %[[CST:.+]] = arith.constant 0
+// CHECK-NEXT        %[[RES:.+]] = vector.reduction <add>, %[[ARG0]], %[[CST]] : vector<8xi32> into i32
+// CHECK-NE          return %[[RES]] : i32
+func.func @reduce_1x_i32(%arg0: vector<8xi32>) -> i32 {
+  %cst0 = arith.constant 0 : i32
+  %0 = vector.reduction <add>, %arg0, %cst0 : vector<8xi32> into i32
+  return %0 : i32
+}
+
+// CHECK-LABEL:   func.func @reduce_2x_i32(
+// CHECK-SAME:     %[[ARG0:.+]]: vector<8xi32>, %[[ARG1:.+]]: vector<8xi32>) -> i32 {
+// CHECK-DAG:        %[[CST:.+]] = arith.constant 0
+// CHECK-DAG:        %[[ADD:.+]] = arith.addi %[[ARG0]], %[[ARG1]] : vector<8xi32>
+// CHECK-NEXT        %[[RES:.+]] = vector.reduction <add>, %[[ADD]], %[[CST]] : vector<8xi32> into i32
+// CHECK-NE          return %[[RES]] : i32
+func.func @reduce_2x_i32(%arg0: vector<8xi32>, %arg1: vector<8xi32>) -> i32 {
+  %cst0 = arith.constant 0 : i32
+  %0 = vector.reduction <add>, %arg0, %cst0 : vector<8xi32> into i32
+  %1 = vector.reduction <add>, %arg1, %0 : vector<8xi32> into i32
+  return %1 : i32
+}
+
+// CHECK-LABEL:   func.func @reduce_2x_no_acc_i32(
+// CHECK-SAME:     %[[ARG0:.+]]: vector<8xi32>, %[[ARG1:.+]]: vector<8xi32>) -> i32 {
+// CHECK:            %[[ADD:.+]] = arith.addi %[[ARG0]], %[[ARG1]] : vector<8xi32>
+// CHECK-NEXT        %[[RES:.+]] = vector.reduction <add>, %[[ADD]] : vector<8xi32> into i32
+// CHECK-NE          return %[[RES]] : i32
+func.func @reduce_2x_no_acc_i32(%arg0: vector<8xi32>, %arg1: vector<8xi32>) -> i32 {
+  %0 = vector.reduction <add>, %arg0 : vector<8xi32> into i32
+  %1 = vector.reduction <add>, %arg1, %0 : vector<8xi32> into i32
+  return %1 : i32
+}
+
+// CHECK-LABEL:   func.func @reduce_2x_zero_add_i32(
+// CHECK-SAME:     %[[ARG0:.+]]: vector<8xi32>, %[[ARG1:.+]]: vector<8xi32>) -> i32 {
+// CHECK-DAG:        %[[CST:.+]] = arith.constant 0
+// CHECK-DAG:        %[[ADD:.+]] = arith.addi %[[ARG0]], %[[ARG1]] : vector<8xi32>
+// CHECK-NEXT        %[[RES:.+]] = vector.reduction <add>, %[[ADD]], %[[CST]] : vector<8xi32> into i32
+// CHECK-NE          return %[[RES]] : i32
+func.func @reduce_2x_zero_add_i32(%arg0: vector<8xi32>, %arg1: vector<8xi32>) -> i32 {
+  %cst0 = arith.constant 0 : i32
+  %cstV = arith.constant dense<0> : vector<8xi32>
+  %x = arith.addi %arg0, %cstV : vector<8xi32>
+  %0 = vector.reduction <add>, %x, %cst0 : vector<8xi32> into i32
+  %1 = vector.reduction <add>, %arg1, %0 : vector<8xi32> into i32
+  return %1 : i32
+}
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 1a177fa31de37ce..feb716cdbf404eb 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -420,6 +420,25 @@ struct TestVectorReduceToContractPatternsPatterns
   }
 };
 
+struct TestVectorChainedReductionFoldingPatterns
+    : public PassWrapper<TestVectorChainedReductionFoldingPatterns,
+                         OperationPass<func::FuncOp>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
+      TestVectorChainedReductionFoldingPatterns)
+
+  StringRef getArgument() const final {
+    return "test-vector-chained-reduction-folding-patterns";
+  }
+  StringRef getDescription() const final {
+    return "Test patterns to fold chained vector reductions";
+  }
+  void runOnOperation() override {
+    RewritePatternSet patterns(&getContext());
+    populateChainedVectorReductionFoldingPatterns(patterns);
+    (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+  }
+};
+
 struct TestFlattenVectorTransferPatterns
     : public PassWrapper<TestFlattenVectorTransferPatterns,
                          OperationPass<func::FuncOp>> {
@@ -773,6 +792,8 @@ void registerTestVectorLowerings() {
 
   PassRegistration<TestVectorReduceToContractPatternsPatterns>();
 
+  PassRegistration<TestVectorChainedReductionFoldingPatterns>();
+
   PassRegistration<TestFlattenVectorTransferPatterns>();
 
   PassRegistration<TestVectorScanLowering>();

>From 55567b5e3c8e9ae63beb1746b5a6348ecfb2b7ba Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Tue, 21 Nov 2023 17:43:11 -0500
Subject: [PATCH 2/4] Fix typo

---
 .../chained-vector-reduction-folding.mlir     | 36 +++++++++----------
 1 file changed, 18 insertions(+), 18 deletions(-)

diff --git a/mlir/test/Dialect/Vector/chained-vector-reduction-folding.mlir b/mlir/test/Dialect/Vector/chained-vector-reduction-folding.mlir
index 699a8fefd68ca38..3048f53345e5cd2 100644
--- a/mlir/test/Dialect/Vector/chained-vector-reduction-folding.mlir
+++ b/mlir/test/Dialect/Vector/chained-vector-reduction-folding.mlir
@@ -3,8 +3,8 @@
 // CHECK-LABEL:   func.func @reduce_1x_fp32(
 // CHECK-SAME:     %[[ARG0:.+]]: vector<8xf32>) -> f32 {
 // CHECK-DAG:        %[[CST:.+]] = arith.constant 0.0
-// CHECK-NEXT        %[[RES:.+]] = vector.reduction <add>, %[[ARG0]], %[[CST]] : vector<8xf32> into f32
-// CHECK-NE          return %[[RES]] : f32
+// CHECK-NEXT:       %[[RES:.+]] = vector.reduction <add>, %[[ARG0]], %[[CST]] : vector<8xf32> into f32
+// CHECK-NEXT:       return %[[RES]] : f32
 func.func @reduce_1x_fp32(%arg0: vector<8xf32>) -> f32 {
   %cst0 = arith.constant 0.0 : f32
   %0 = vector.reduction <add>, %arg0, %cst0 : vector<8xf32> into f32
@@ -15,8 +15,8 @@ func.func @reduce_1x_fp32(%arg0: vector<8xf32>) -> f32 {
 // CHECK-SAME:     %[[ARG0:.+]]: vector<8xf32>, %[[ARG1:.+]]: vector<8xf32>) -> f32 {
 // CHECK-DAG:        %[[CST:.+]] = arith.constant 0.0
 // CHECK-DAG:        %[[ADD:.+]] = arith.addf %[[ARG0]], %[[ARG1]] : vector<8xf32>
-// CHECK-NEXT        %[[RES:.+]] = vector.reduction <add>, %[[ADD]], %[[CST]] : vector<8xf32> into f32
-// CHECK-NE          return %[[RES]] : f32
+// CHECK-NEXT:       %[[RES:.+]] = vector.reduction <add>, %[[ADD]], %[[CST]] : vector<8xf32> into f32
+// CHECK-NEXT:       return %[[RES]] : f32
 func.func @reduce_2x_fp32(%arg0: vector<8xf32>, %arg1: vector<8xf32>) -> f32 {
   %cst0 = arith.constant 0.0 : f32
   %0 = vector.reduction <add>, %arg0, %cst0 : vector<8xf32> into f32
@@ -27,8 +27,8 @@ func.func @reduce_2x_fp32(%arg0: vector<8xf32>, %arg1: vector<8xf32>) -> f32 {
 // CHECK-LABEL:   func.func @reduce_2x_no_acc_fp32(
 // CHECK-SAME:     %[[ARG0:.+]]: vector<8xf32>, %[[ARG1:.+]]: vector<8xf32>) -> f32 {
 // CHECK:            %[[ADD:.+]] = arith.addf %[[ARG0]], %[[ARG1]] : vector<8xf32>
-// CHECK-NEXT        %[[RES:.+]] = vector.reduction <add>, %[[ADD]] : vector<8xf32> into f32
-// CHECK-NE          return %[[RES]] : f32
+// CHECK-NEXT:       %[[RES:.+]] = vector.reduction <add>, %[[ADD]] : vector<8xf32> into f32
+// CHECK-NEXT:       return %[[RES]] : f32
 func.func @reduce_2x_no_acc_fp32(%arg0: vector<8xf32>, %arg1: vector<8xf32>) -> f32 {
   %0 = vector.reduction <add>, %arg0 : vector<8xf32> into f32
   %1 = vector.reduction <add>, %arg1, %0 : vector<8xf32> into f32
@@ -38,8 +38,8 @@ func.func @reduce_2x_no_acc_fp32(%arg0: vector<8xf32>, %arg1: vector<8xf32>) ->
 // CHECK-LABEL:   func.func @reduce_2x_zero_add_fp32(
 // CHECK-SAME:     %[[ARG0:.+]]: vector<8xf32>, %[[ARG1:.+]]: vector<8xf32>) -> f32 {
 // CHECK:            %[[ADD:.+]] = arith.addf %[[ARG0]], %[[ARG1]] : vector<8xf32>
-// CHECK-NEXT        %[[RES:.+]] = vector.reduction <add>, %[[ADD]] : vector<8xf32> into f32
-// CHECK-NE          return %[[RES]] : f32
+// CHECK-NEXT:       %[[RES:.+]] = vector.reduction <add>, %[[ADD]] : vector<8xf32> into f32
+// CHECK-NEXT:       return %[[RES]] : f32
 func.func @reduce_2x_zero_add_fp32(%arg0: vector<8xf32>, %arg1: vector<8xf32>) -> f32 {
   %cst0 = arith.constant dense<0.0> : vector<8xf32>
   %x = arith.addf %arg0, %cst0 : vector<8xf32>
@@ -54,8 +54,8 @@ func.func @reduce_2x_zero_add_fp32(%arg0: vector<8xf32>, %arg1: vector<8xf32>) -
 // CHECK-DAG:        %[[CST:.+]] = arith.constant 0.0
 // CHECK-DAG:        %[[ADD0:.+]] = arith.addf %[[ARG1]], %[[ARG2]] : vector<8xf32>
 // CHECK-DAG:        %[[ADD1:.+]] = arith.addf %[[ARG0]], %[[ADD0]] : vector<8xf32>
-// CHECK-NEXT        %[[RES:.+]] = vector.reduction <add>, %[[ADD1]], %[[CST]] : vector<8xf32> into f32
-// CHECK-NE          return %[[RES]] : f32
+// CHECK-NEXT:       %[[RES:.+]] = vector.reduction <add>, %[[ADD1]], %[[CST]] : vector<8xf32> into f32
+// CHECK-NEXT:       return %[[RES]] : f32
 func.func @reduce_3x_fp32(%arg0: vector<8xf32>, %arg1: vector<8xf32>,
                           %arg2: vector<8xf32>) -> f32 {
   %cst0 = arith.constant 0.0 : f32
@@ -68,8 +68,8 @@ func.func @reduce_3x_fp32(%arg0: vector<8xf32>, %arg1: vector<8xf32>,
 // CHECK-LABEL:   func.func @reduce_1x_i32(
 // CHECK-SAME:     %[[ARG0:.+]]: vector<8xi32>) -> i32 {
 // CHECK-DAG:        %[[CST:.+]] = arith.constant 0
-// CHECK-NEXT        %[[RES:.+]] = vector.reduction <add>, %[[ARG0]], %[[CST]] : vector<8xi32> into i32
-// CHECK-NE          return %[[RES]] : i32
+// CHECK-NEXT:       %[[RES:.+]] = vector.reduction <add>, %[[ARG0]], %[[CST]] : vector<8xi32> into i32
+// CHECK-NEXT:       return %[[RES]] : i32
 func.func @reduce_1x_i32(%arg0: vector<8xi32>) -> i32 {
   %cst0 = arith.constant 0 : i32
   %0 = vector.reduction <add>, %arg0, %cst0 : vector<8xi32> into i32
@@ -80,8 +80,8 @@ func.func @reduce_1x_i32(%arg0: vector<8xi32>) -> i32 {
 // CHECK-SAME:     %[[ARG0:.+]]: vector<8xi32>, %[[ARG1:.+]]: vector<8xi32>) -> i32 {
 // CHECK-DAG:        %[[CST:.+]] = arith.constant 0
 // CHECK-DAG:        %[[ADD:.+]] = arith.addi %[[ARG0]], %[[ARG1]] : vector<8xi32>
-// CHECK-NEXT        %[[RES:.+]] = vector.reduction <add>, %[[ADD]], %[[CST]] : vector<8xi32> into i32
-// CHECK-NE          return %[[RES]] : i32
+// CHECK-NEXT:       %[[RES:.+]] = vector.reduction <add>, %[[ADD]], %[[CST]] : vector<8xi32> into i32
+// CHECK-NEXT:       return %[[RES]] : i32
 func.func @reduce_2x_i32(%arg0: vector<8xi32>, %arg1: vector<8xi32>) -> i32 {
   %cst0 = arith.constant 0 : i32
   %0 = vector.reduction <add>, %arg0, %cst0 : vector<8xi32> into i32
@@ -92,8 +92,8 @@ func.func @reduce_2x_i32(%arg0: vector<8xi32>, %arg1: vector<8xi32>) -> i32 {
 // CHECK-LABEL:   func.func @reduce_2x_no_acc_i32(
 // CHECK-SAME:     %[[ARG0:.+]]: vector<8xi32>, %[[ARG1:.+]]: vector<8xi32>) -> i32 {
 // CHECK:            %[[ADD:.+]] = arith.addi %[[ARG0]], %[[ARG1]] : vector<8xi32>
-// CHECK-NEXT        %[[RES:.+]] = vector.reduction <add>, %[[ADD]] : vector<8xi32> into i32
-// CHECK-NE          return %[[RES]] : i32
+// CHECK-NEXT:       %[[RES:.+]] = vector.reduction <add>, %[[ADD]] : vector<8xi32> into i32
+// CHECK-NEXT:       return %[[RES]] : i32
 func.func @reduce_2x_no_acc_i32(%arg0: vector<8xi32>, %arg1: vector<8xi32>) -> i32 {
   %0 = vector.reduction <add>, %arg0 : vector<8xi32> into i32
   %1 = vector.reduction <add>, %arg1, %0 : vector<8xi32> into i32
@@ -104,8 +104,8 @@ func.func @reduce_2x_no_acc_i32(%arg0: vector<8xi32>, %arg1: vector<8xi32>) -> i
 // CHECK-SAME:     %[[ARG0:.+]]: vector<8xi32>, %[[ARG1:.+]]: vector<8xi32>) -> i32 {
 // CHECK-DAG:        %[[CST:.+]] = arith.constant 0
 // CHECK-DAG:        %[[ADD:.+]] = arith.addi %[[ARG0]], %[[ARG1]] : vector<8xi32>
-// CHECK-NEXT        %[[RES:.+]] = vector.reduction <add>, %[[ADD]], %[[CST]] : vector<8xi32> into i32
-// CHECK-NE          return %[[RES]] : i32
+// CHECK-NEXT:       %[[RES:.+]] = vector.reduction <add>, %[[ADD]], %[[CST]] : vector<8xi32> into i32
+// CHECK-NEXT:       return %[[RES]] : i32
 func.func @reduce_2x_zero_add_i32(%arg0: vector<8xi32>, %arg1: vector<8xi32>) -> i32 {
   %cst0 = arith.constant 0 : i32
   %cstV = arith.constant dense<0> : vector<8xi32>

>From 83c2bd708aa5d33f3c98231f2e81013f4509a5eb Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Tue, 21 Nov 2023 17:45:38 -0500
Subject: [PATCH 3/4] Fix typo

---
 .../mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h  | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index e373e99fb35ec7d..5abb48f9be876e2 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -153,12 +153,12 @@ void populateSinkVectorBroadcastPatterns(RewritePatternSet &patterns,
 ///
 /// Example:
 /// ```
-/// %a = vector.reduction <add> %a, %acc
-/// %b = vector.reduction <add> %b, %a
+/// %a = vector.reduction <add> %x, %acc
+/// %b = vector.reduction <add> %y, %a
 /// ```
 /// is transformed into:
 /// ```
-/// %a = arith.addf %a, %b
+/// %a = arith.addf %x, %y
 /// %b = vector.reduction <add> %a, %acc
 /// ```
 void populateChainedVectorReductionFoldingPatterns(RewritePatternSet &patterns,

>From 26c782553c63b828696f3bb0096ed241bece07b8 Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Tue, 21 Nov 2023 21:46:28 -0500
Subject: [PATCH 4/4] Improve comments

---
 .../mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h  | 6 ++++--
 mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp     | 2 +-
 2 files changed, 5 insertions(+), 3 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index 5abb48f9be876e2..08c08172d0531e4 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -148,8 +148,10 @@ void populateSinkVectorBroadcastPatterns(RewritePatternSet &patterns,
                                          PatternBenefit benefit = 1);
 
 /// Patterns that fold chained vector reductions. These patterns assume that
-/// vector addition (e.g., `arith.addf` with vector operands) is cheaper than
-/// vector reduction.
+/// elementwise operations (e.g., `arith.addf` with vector operands) are
+/// cheaper than vector reduction.
+/// Note that these patterns change the order of reduction which may not always
+/// produce bit-identical results on some floating point inputs.
 ///
 /// Example:
 /// ```
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index bdb1e3815372318..582d627d1ce4ac0 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1402,7 +1402,7 @@ struct FoldArithExtIntoContractionOp
   }
 };
 
-/// Pattern to fold chained to reduction to a series of vector additions and a
+/// Pattern to fold chained reduction to a series of vector additions and a
 /// final reduction. This form should require fewer subgroup operations.
 ///
 /// ```mlir



More information about the Mlir-commits mailing list