[Mlir-commits] [mlir] [mlir][vector] Add pattern to break down reductions into arith ops (PR #75727)

Jakub Kuderski llvmlistbot at llvm.org
Mon Dec 18 14:49:18 PST 2023


https://github.com/kuhar updated https://github.com/llvm/llvm-project/pull/75727

>From 111ead93642eeb10feb7454129f1754651d44e17 Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Sat, 16 Dec 2023 21:05:48 -0500
Subject: [PATCH] [mlir][vector] Add pattern to break down small reductions
 into arith ops

The number of vector elements considered 'small' enough to extract is
parameterized.

This is to avoid going into specialized reduction lowering when a
single/couple of arith ops can do. Targets without dedicated reduction
intrinsics can use that as an emulation path too.

Depends on https://github.com/llvm/llvm-project/pull/75846.

 Please enter the commit message for your changes. Lines starting
---
 .../Vector/Transforms/VectorRewritePatterns.h |  19 +++
 .../Vector/Transforms/VectorTransforms.cpp    |  63 +++++++++
 .../Vector/break-down-vector-reduction.mlir   | 126 ++++++++++++++++++
 .../Dialect/Vector/TestVectorTransforms.cpp   |  23 ++++
 4 files changed, 231 insertions(+)
 create mode 100644 mlir/test/Dialect/Vector/break-down-vector-reduction.mlir

diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index 17173c01ab762a..49b74c0c466d2f 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -166,6 +166,25 @@ void populateSinkVectorBroadcastPatterns(RewritePatternSet &patterns,
 void populateChainedVectorReductionFoldingPatterns(RewritePatternSet &patterns,
                                                    PatternBenefit benefit = 1);
 
+/// Patterns to break down vector reductions into a series of arith reductions
+/// over vector elements. This is intended to be simplify code with reductions
+/// over small vector types and avoid more specialized reduction lowering when
+/// possible.
+///
+/// Example:
+/// ```
+/// %a = vector.reduction <add> %x : vector<2xf32> into f32
+/// ```
+/// is transformed into:
+/// ```
+/// %y = vector.extract %x[0] : f32 from vector<2xf32>
+/// %z = vector.extract %x[1] : f32 from vector<2xf32>
+/// %a = arith.addf %y, %z : f32
+/// ```
+void populateBreakDownVectorReductionPatterns(
+    RewritePatternSet &patterns, unsigned maxNumElementsToExtract = 2,
+    PatternBenefit benefit = 1);
+
 /// Populate `patterns` with the following patterns.
 ///
 /// [DecomposeDifferentRankInsertStridedSlice]
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 5936b0b54af4e3..661674dd74c0cd 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -12,6 +12,7 @@
 
 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
 
+#include <cassert>
 #include <cstdint>
 #include <functional>
 #include <optional>
@@ -44,6 +45,7 @@
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Debug.h"
+#include "llvm/Support/FormatVariadic.h"
 #include "llvm/Support/raw_ostream.h"
 
 #define DEBUG_TYPE "vector-to-vector"
@@ -1578,6 +1580,60 @@ struct ReduceRedundantZero final : OpRewritePattern<vector::ReductionOp> {
   }
 };
 
+/// Example:
+/// ```
+/// %a = vector.reduction <add> %x : vector<2xf32> into f32
+/// ```
+/// is transformed into:
+/// ```
+/// %y = vector.extract %x[0] : f32 from vector<2xf32>
+/// %z = vector.extract %x[1] : f32 from vector<2xf32>
+/// %a = arith.addf %y, %z : f32
+/// ```
+struct BreakDownVectorReduction final : OpRewritePattern<vector::ReductionOp> {
+  BreakDownVectorReduction(MLIRContext *context,
+                           unsigned maxNumElementsToExtract,
+                           PatternBenefit benefit)
+      : OpRewritePattern(context, benefit),
+        maxNumElementsToExtract(maxNumElementsToExtract) {}
+
+  LogicalResult matchAndRewrite(vector::ReductionOp op,
+                                PatternRewriter &rewriter) const override {
+    VectorType type = op.getSourceVectorType();
+    if (type.isScalable() || op.isMasked())
+      return failure();
+    assert(type.getRank() == 1 && "Expected a 1-d vector");
+
+    int64_t numElems = type.getNumElements();
+    if (numElems > maxNumElementsToExtract) {
+      return rewriter.notifyMatchFailure(
+          op, llvm::formatv("has too many vector elements ({0}) to break down "
+                            "(max allowed: {1})",
+                            numElems, maxNumElementsToExtract));
+    }
+
+    Location loc = op.getLoc();
+    SmallVector<Value> extracted(numElems, nullptr);
+    for (auto [idx, extractedElem] : llvm::enumerate(extracted))
+      extractedElem = rewriter.create<vector::ExtractOp>(
+          loc, op.getVector(), static_cast<int64_t>(idx));
+
+    Value res = extracted.front();
+    for (auto extractedElem : llvm::drop_begin(extracted))
+      res = vector::makeArithReduction(rewriter, loc, op.getKind(), res,
+                                       extractedElem, op.getFastmathAttr());
+    if (Value acc = op.getAcc())
+      res = vector::makeArithReduction(rewriter, loc, op.getKind(), res, acc,
+                                       op.getFastmathAttr());
+
+    rewriter.replaceOp(op, res);
+    return success();
+  }
+
+private:
+  unsigned maxNumElementsToExtract = 0;
+};
+
 } // namespace
 
 void mlir::vector::populateFoldArithExtensionPatterns(
@@ -1656,6 +1712,13 @@ void mlir::vector::populateChainedVectorReductionFoldingPatterns(
                                     PatternBenefit(benefit.getBenefit() + 1));
 }
 
+void mlir::vector::populateBreakDownVectorReductionPatterns(
+    RewritePatternSet &patterns, unsigned maxNumElementsToExtract,
+    PatternBenefit benefit) {
+  patterns.add<BreakDownVectorReduction>(patterns.getContext(),
+                                         maxNumElementsToExtract, benefit);
+}
+
 //===----------------------------------------------------------------------===//
 // TableGen'd enum attribute definitions
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/break-down-vector-reduction.mlir b/mlir/test/Dialect/Vector/break-down-vector-reduction.mlir
new file mode 100644
index 00000000000000..34234591b79cab
--- /dev/null
+++ b/mlir/test/Dialect/Vector/break-down-vector-reduction.mlir
@@ -0,0 +1,126 @@
+// RUN: mlir-opt %s --test-vector-break-down-reduction-patterns --cse | FileCheck %s
+
+// NOTE: This test pass is set break down vector reductions of size 2 or fewer.
+
+// CHECK-LABEL:   func.func @reduce_2x_f32(
+// CHECK-SAME:     %[[ARG0:.+]]: vector<2xf32>) -> (f32, f32, f32, f32, f32, f32) {
+// CHECK-DAG:      %[[E0:.+]] = vector.extract %[[ARG0]][0] : f32 from vector<2xf32>
+// CHECK-DAG:      %[[E1:.+]] = vector.extract %[[ARG0]][1] : f32 from vector<2xf32>
+// CHECK-DAG:      %[[R0:.+]] = arith.addf %[[E0]], %[[E1]] : f32
+// CHECK-DAG:      %[[R1:.+]] = arith.mulf %[[E0]], %[[E1]] : f32
+// CHECK-DAG:      %[[R2:.+]] = arith.minnumf %[[E0]], %[[E1]] : f32
+// CHECK-DAG:      %[[R3:.+]] = arith.maxnumf %[[E0]], %[[E1]] : f32
+// CHECK-DAG:      %[[R4:.+]] = arith.minimumf %[[E0]], %[[E1]] : f32
+// CHECK-DAG:      %[[R5:.+]] = arith.maximumf %[[E0]], %[[E1]] : f32
+// CHECK:          return %[[R0]], %[[R1]], %[[R2]], %[[R3]], %[[R4]], %[[R5]]
+func.func @reduce_2x_f32(%arg0: vector<2xf32>) -> (f32, f32, f32, f32, f32, f32) {
+  %0 = vector.reduction <add>, %arg0 : vector<2xf32> into f32
+  %1 = vector.reduction <mul>, %arg0 : vector<2xf32> into f32
+  %2 = vector.reduction <minf>, %arg0 : vector<2xf32> into f32
+  %3 = vector.reduction <maxf>, %arg0 : vector<2xf32> into f32
+  %4 = vector.reduction <minimumf>, %arg0 : vector<2xf32> into f32
+  %5 = vector.reduction <maximumf>, %arg0 : vector<2xf32> into f32
+  return %0, %1, %2, %3, %4, %5 : f32, f32, f32, f32, f32, f32
+}
+
+// CHECK-LABEL:   func.func @reduce_2x_i32(
+// CHECK-SAME:     %[[ARG0:.+]]: vector<2xi32>) -> (i32, i32, i32, i32, i32, i32, i32, i32, i32) {
+// CHECK-DAG:      %[[E0:.+]] = vector.extract %[[ARG0]][0] : i32 from vector<2xi32>
+// CHECK-DAG:      %[[E1:.+]] = vector.extract %[[ARG0]][1] : i32 from vector<2xi32>
+// CHECK-DAG:      %[[R0:.+]] = arith.addi %[[E0]], %[[E1]] : i32
+// CHECK-DAG:      %[[R1:.+]] = arith.muli %[[E0]], %[[E1]] : i32
+// CHECK-DAG:      %[[R2:.+]] = arith.minsi %[[E0]], %[[E1]] : i32
+// CHECK-DAG:      %[[R3:.+]] = arith.maxsi %[[E0]], %[[E1]] : i32
+// CHECK-DAG:      %[[R4:.+]] = arith.minui %[[E0]], %[[E1]] : i32
+// CHECK-DAG:      %[[R5:.+]] = arith.maxui %[[E0]], %[[E1]] : i32
+// CHECK-DAG:      %[[R6:.+]] = arith.andi %[[E0]], %[[E1]] : i32
+// CHECK-DAG:      %[[R7:.+]] = arith.ori %[[E0]], %[[E1]] : i32
+// CHECK-DAG:      %[[R8:.+]] = arith.xori %[[E0]], %[[E1]] : i32
+// CHECK:          return %[[R0]], %[[R1]], %[[R2]], %[[R3]], %[[R4]], %[[R5]], %[[R6]], %[[R7]], %[[R8]]
+func.func @reduce_2x_i32(%arg0: vector<2xi32>) -> (i32, i32, i32, i32, i32, i32, i32, i32, i32) {
+  %0 = vector.reduction <add>, %arg0 : vector<2xi32> into i32
+  %1 = vector.reduction <mul>, %arg0 : vector<2xi32> into i32
+  %2 = vector.reduction <minsi>, %arg0 : vector<2xi32> into i32
+  %3 = vector.reduction <maxsi>, %arg0 : vector<2xi32> into i32
+  %4 = vector.reduction <minui>, %arg0 : vector<2xi32> into i32
+  %5 = vector.reduction <maxui>, %arg0 : vector<2xi32> into i32
+  %6 = vector.reduction <and>, %arg0 : vector<2xi32> into i32
+  %7 = vector.reduction <or>, %arg0 : vector<2xi32> into i32
+  %8 = vector.reduction <xor>, %arg0 : vector<2xi32> into i32
+  return %0, %1, %2, %3, %4, %5, %6, %7, %8 : i32, i32, i32, i32, i32, i32, i32, i32, i32
+}
+
+// CHECK-LABEL:   func.func @reduce_1x_f32(
+// CHECK-SAME:     %[[ARG0:.+]]: vector<1xf32>) -> f32 {
+// CHECK-NEXT:     %[[E0:.+]] = vector.extract %[[ARG0]][0] : f32 from vector<1xf32>
+// CHECK-NEXT:     return %[[E0]] : f32
+func.func @reduce_1x_f32(%arg0: vector<1xf32>) -> f32 {
+  %0 = vector.reduction <add>, %arg0 : vector<1xf32> into f32
+  return %0 : f32
+}
+
+// CHECK-LABEL:   func.func @reduce_1x_acc_f32(
+// CHECK-SAME:     %[[ARG0:.+]]: vector<1xf32>, %[[ARG1:.+]]: f32) -> f32 {
+// CHECK-NEXT:     %[[E0:.+]] = vector.extract %[[ARG0]][0] : f32 from vector<1xf32>
+// CHECK-NEXT:     %[[R0:.+]] = arith.addf %[[E0]], %[[ARG1]] : f32
+// CHECK-NEXT:     return %[[R0]] : f32
+func.func @reduce_1x_acc_f32(%arg0: vector<1xf32>, %arg1: f32) -> f32 {
+  %0 = vector.reduction <add>, %arg0, %arg1 : vector<1xf32> into f32
+  return %0 : f32
+}
+
+// CHECK-LABEL:   func.func @reduce_1x_acc_i32(
+// CHECK-SAME:     %[[ARG0:.+]]: vector<1xi32>, %[[ARG1:.+]]: i32) -> i32 {
+// CHECK-NEXT:     %[[E0:.+]] = vector.extract %[[ARG0]][0] : i32 from vector<1xi32>
+// CHECK-NEXT:     %[[R0:.+]] = arith.addi %[[E0]], %[[ARG1]] : i32
+// CHECK-NEXT:     return %[[R0]] : i32
+func.func @reduce_1x_acc_i32(%arg0: vector<1xi32>, %arg1: i32) -> i32 {
+  %0 = vector.reduction <add>, %arg0, %arg1 : vector<1xi32> into i32
+  return %0 : i32
+}
+
+// CHECK-LABEL:   func.func @reduce_2x_acc_f32(
+// CHECK-SAME:     %[[ARG0:.+]]: vector<2xf32>, %[[ARG1:.+]]: f32) -> (f32, f32) {
+// CHECK-DAG:      %[[E0:.+]] = vector.extract %[[ARG0]][0] : f32 from vector<2xf32>
+// CHECK-DAG:      %[[E1:.+]] = vector.extract %[[ARG0]][1] : f32 from vector<2xf32>
+// CHECK:          %[[A0:.+]] = arith.addf %[[E0]], %[[E1]] : f32
+// CHECK:          %[[R0:.+]] = arith.addf %[[A0]], %[[ARG1]] : f32
+// CHECK:          %[[M0:.+]] = arith.mulf %[[E0]], %[[E1]] fastmath<nnan> : f32
+// CHECK:          %[[R1:.+]] = arith.mulf %[[M0]], %[[ARG1]] fastmath<nnan> : f32
+// CHECK-NEXT:     return %[[R0]], %[[R1]] : f32, f32
+func.func @reduce_2x_acc_f32(%arg0: vector<2xf32>, %arg1: f32) -> (f32, f32) {
+  %0 = vector.reduction <add>, %arg0, %arg1 : vector<2xf32> into f32
+  %1 = vector.reduction <mul>, %arg0, %arg1 fastmath<nnan> : vector<2xf32> into f32
+  return %0, %1 : f32, f32
+}
+
+// CHECK-LABEL:   func.func @reduce_3x_f32(
+// CHECK-SAME:     %[[ARG0:.+]]: vector<3xf32>) -> f32 {
+// CHECK-NEXT:     %[[R0:.+]] = vector.reduction <add>, %[[ARG0]] : vector<3xf32> into f32
+// CHECK-NEXT:     return %[[R0]] : f32
+func.func @reduce_3x_f32(%arg0: vector<3xf32>) -> f32 {
+  %0 = vector.reduction <add>, %arg0 : vector<3xf32> into f32
+  return %0 : f32
+}
+
+// Masking is not handled yet.
+// CHECK-LABEL:   func.func @reduce_mask_3x_f32
+// CHECK-NEXT:     %[[M:.+]] = vector.create_mask
+// CHECK-NEXT:     %[[R:.+]] = vector.mask %[[M]]
+// CHECK-SAME:       vector.reduction <add>
+// CHECK-NEXT:     return %[[R]] : f32
+func.func @reduce_mask_3x_f32(%arg0: vector<3xf32>, %arg1: index) -> f32 {
+  %mask = vector.create_mask %arg1 : vector<3xi1>
+  %0 = vector.mask %mask { vector.reduction <add>, %arg0 : vector<3xf32> into f32 } : vector<3xi1> -> f32
+  return %0 : f32
+}
+
+// Scalable vectors are not supported.
+// CHECK-LABEL:   func.func @reduce_scalable_f32(
+// CHECK-SAME:     %[[ARG0:.+]]: vector<[1]xf32>) -> f32 {
+// CHECK-NEXT:     %[[R0:.+]] = vector.reduction <add>, %[[ARG0]] : vector<[1]xf32> into f32
+// CHECK-NEXT:     return %[[R0]] : f32
+func.func @reduce_scalable_f32(%arg0: vector<[1]xf32>) -> f32 {
+  %0 = vector.reduction <add>, %arg0 : vector<[1]xf32> into f32
+  return %0 : f32
+}
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 03ddebe82344d8..126d65b1b8487f 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -439,6 +439,27 @@ struct TestVectorChainedReductionFoldingPatterns
   }
 };
 
+struct TestVectorBreakDownReductionPatterns
+    : public PassWrapper<TestVectorBreakDownReductionPatterns,
+                         OperationPass<func::FuncOp>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
+      TestVectorBreakDownReductionPatterns)
+
+  StringRef getArgument() const final {
+    return "test-vector-break-down-reduction-patterns";
+  }
+  StringRef getDescription() const final {
+    return "Test patterns to break down vector reductions into arith "
+           "reductions";
+  }
+  void runOnOperation() override {
+    RewritePatternSet patterns(&getContext());
+    populateBreakDownVectorReductionPatterns(patterns,
+                                             /*maxNumElementsToExtract=*/2);
+    (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+  }
+};
+
 struct TestFlattenVectorTransferPatterns
     : public PassWrapper<TestFlattenVectorTransferPatterns,
                          OperationPass<func::FuncOp>> {
@@ -827,6 +848,8 @@ void registerTestVectorLowerings() {
 
   PassRegistration<TestVectorChainedReductionFoldingPatterns>();
 
+  PassRegistration<TestVectorBreakDownReductionPatterns>();
+
   PassRegistration<TestFlattenVectorTransferPatterns>();
 
   PassRegistration<TestVectorScanLowering>();



More information about the Mlir-commits mailing list