[Mlir-commits] [mlir] 8d6469b - [mlir][vector] Add lower-vector-multi-reduction pass (#87333)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Apr 9 10:04:29 PDT 2024


Author: xiaoleis-nv
Date: 2024-04-09T10:04:25-07:00
New Revision: 8d6469b0e02c4c7bd1d496972c63ed3e2de0e077

URL: https://github.com/llvm/llvm-project/commit/8d6469b0e02c4c7bd1d496972c63ed3e2de0e077
DIFF: https://github.com/llvm/llvm-project/commit/8d6469b0e02c4c7bd1d496972c63ed3e2de0e077.diff

LOG: [mlir][vector] Add lower-vector-multi-reduction pass (#87333)

This MR adds the `lower-vector-multi-reduction` pass to lower the
vector.multi_reduction operation.

While the Transform Dialect includes an operation,
`transform.apply_patterns.vector.lower_multi_reduction`, intended for a
similar purpose, its utility is limited to projects that have adopted
the Transform Dialect. Recognizing that not all projects are equipped to
integrate this dialect, the proposed pass serves as a vital standalone
alternative. It ensures that projects solely dependent on the
traditional pass infrastructure can also benefit from the optimized
lowering of `multi_reduction` operation.

---------

Co-authored-by: Xiaolei Shi <xiaoleis at nvidia.com>

Added: 
    mlir/test/Dialect/Vector/vector-multi-reduction-pass-lowering.mlir

Modified: 
    mlir/include/mlir/Dialect/Vector/Transforms/Passes.h
    mlir/include/mlir/Dialect/Vector/Transforms/Passes.td
    mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h
index bf89b01e2b60c5..911402551e14d4 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h
@@ -9,6 +9,7 @@
 #ifndef MLIR_DIALECT_VECTOR_TRANSFORMS_PASSES_H_
 #define MLIR_DIALECT_VECTOR_TRANSFORMS_PASSES_H_
 
+#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
 #include "mlir/Pass/Pass.h"
 
 namespace mlir {
@@ -22,6 +23,11 @@ std::unique_ptr<Pass> createVectorBufferizePass();
 /// Creates an instance of the `vector.mask` lowering pass.
 std::unique_ptr<Pass> createLowerVectorMaskPass();
 
+/// Creates an instance of the `vector.multi_reduction` lowering pass.
+std::unique_ptr<Pass> createLowerVectorMultiReductionPass(
+    VectorMultiReductionLowering option =
+        VectorMultiReductionLowering::InnerParallel);
+
 //===----------------------------------------------------------------------===//
 // Registration
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td
index 4911a61ab3c25d..31a0b3b2f0c53d 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td
@@ -21,4 +21,22 @@ def LowerVectorMaskPass : Pass<"lower-vector-mask", "func::FuncOp"> {
   let constructor = "mlir::vector::createLowerVectorMaskPass()";
 }
 
+def LowerVectorMultiReduction : Pass<"lower-vector-multi-reduction", "func::FuncOp"> {
+  let summary = "Lower 'vector.multi_reduction' operations";
+  let constructor = "mlir::vector::createLowerVectorMultiReductionPass()";
+  let options = [
+    Option<"loweringStrategy", "lowering-strategy", "mlir::vector::VectorMultiReductionLowering",
+           /*default=*/"mlir::vector::VectorMultiReductionLowering::InnerParallel",
+           "Select the strategy to control how multi_reduction is lowered.",
+           [{::llvm::cl::values(
+            clEnumValN(mlir::vector::VectorMultiReductionLowering::InnerParallel,
+                       "inner-parallel",
+                       "Lower multi_reduction into outer-reduction and inner-parallel ops."),
+            clEnumValN(mlir::vector::VectorMultiReductionLowering::InnerReduction,
+                       "inner-reduction",
+                       "Lower multi_reduction into outer-parallel and inner-reduction ops.")
+        )}]>
+  ];
+}
+
 #endif // MLIR_DIALECT_VECTOR_TRANSFORMS_PASSES

diff  --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index bed2c2496719dd..2f21c50c63473b 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -12,9 +12,19 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
+#include "mlir/Dialect/Vector/Transforms/Passes.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/TypeUtilities.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir {
+namespace vector {
+#define GEN_PASS_DEF_LOWERVECTORMULTIREDUCTION
+#include "mlir/Dialect/Vector/Transforms/Passes.h.inc"
+} // namespace vector
+} // namespace mlir
 
 #define DEBUG_TYPE "vector-multi-reduction"
 
@@ -461,6 +471,31 @@ struct OneDimMultiReductionToTwoDim
     return success();
   }
 };
+
+struct LowerVectorMultiReductionPass
+    : public vector::impl::LowerVectorMultiReductionBase<
+          LowerVectorMultiReductionPass> {
+  LowerVectorMultiReductionPass(vector::VectorMultiReductionLowering option) {
+    this->loweringStrategy = option;
+  }
+
+  void runOnOperation() override {
+    Operation *op = getOperation();
+    MLIRContext *context = op->getContext();
+
+    RewritePatternSet loweringPatterns(context);
+    populateVectorMultiReductionLoweringPatterns(loweringPatterns,
+                                                 this->loweringStrategy);
+
+    if (failed(applyPatternsAndFoldGreedily(op, std::move(loweringPatterns))))
+      signalPassFailure();
+  }
+
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<vector::VectorDialect>();
+  }
+};
+
 } // namespace
 
 void mlir::vector::populateVectorMultiReductionLoweringPatterns(
@@ -476,3 +511,8 @@ void mlir::vector::populateVectorMultiReductionLoweringPatterns(
     patterns.add<TwoDimMultiReductionToElementWise>(patterns.getContext(),
                                                     benefit);
 }
+
+std::unique_ptr<Pass> vector::createLowerVectorMultiReductionPass(
+    vector::VectorMultiReductionLowering option) {
+  return std::make_unique<LowerVectorMultiReductionPass>(option);
+}

diff  --git a/mlir/test/Dialect/Vector/vector-multi-reduction-pass-lowering.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-pass-lowering.mlir
new file mode 100644
index 00000000000000..4cb6fba9b691a6
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-multi-reduction-pass-lowering.mlir
@@ -0,0 +1,45 @@
+// RUN: mlir-opt -lower-vector-multi-reduction="lowering-strategy=inner-reduction" -split-input-file %s | FileCheck %s --check-prefixes=ALL,INNER-REDUCTION
+// RUN: mlir-opt -lower-vector-multi-reduction="lowering-strategy=inner-parallel" -split-input-file %s | FileCheck %s --check-prefixes=ALL,INNER-PARALLEL
+// RUN: mlir-opt -lower-vector-multi-reduction -split-input-file %s | FileCheck %s --check-prefixes=ALL,INNER-PARALLEL
+
+func.func @vector_multi_reduction(%arg0: vector<2x4xf32>, %acc: vector<2xf32>) -> vector<2xf32> {
+    %0 = vector.multi_reduction <mul>, %arg0, %acc [1] : vector<2x4xf32> to vector<2xf32>
+    return %0 : vector<2xf32>
+}
+//           ALL-LABEL: func @vector_multi_reduction
+//            ALL-SAME: %[[INPUT:.+]]: vector<2x4xf32>, %[[ACC:.*]]: vector<2xf32>)
+// INNER-REDUCTION-DAG: %[[RESULT_VEC_0:.+]] = arith.constant dense<{{.*}}> : vector<2xf32>
+// INNER-REDUCTION-DAG: %[[C0:.+]] = arith.constant 0 : index
+// INNER-REDUCTION-DAG: %[[C1:.+]] = arith.constant 1 : index
+//     INNER-REDUCTION: %[[V0:.+]] = vector.extract %[[INPUT]][0]
+//     INNER-REDUCTION: %[[ACC0:.+]] = vector.extract %[[ACC]][0]
+//     INNER-REDUCTION: %[[RV0:.+]] = vector.reduction <mul>, %[[V0]], %[[ACC0]] : vector<4xf32> into f32
+//     INNER-REDUCTION: %[[RESULT_VEC_1:.+]] = vector.insertelement %[[RV0:.+]], %[[RESULT_VEC_0]][%[[C0]] : index] : vector<2xf32>
+//     INNER-REDUCTION: %[[V1:.+]] = vector.extract %[[INPUT]][1]
+//     INNER-REDUCTION: %[[ACC1:.+]] = vector.extract %[[ACC]][1]
+//     INNER-REDUCTION: %[[RV1:.+]] = vector.reduction <mul>, %[[V1]], %[[ACC1]] : vector<4xf32> into f32
+//     INNER-REDUCTION: %[[RESULT_VEC:.+]] = vector.insertelement %[[RV1:.+]], %[[RESULT_VEC_1]][%[[C1]] : index] : vector<2xf32>
+//     INNER-REDUCTION: return %[[RESULT_VEC]]
+
+//      INNER-PARALLEL: %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xf32> to vector<4x2xf32>
+//      INNER-PARALLEL: %[[V0:.+]] = vector.extract %[[TRANSPOSED]][0] : vector<2xf32> from vector<4x2xf32>
+//      INNER-PARALLEL: %[[RV0:.+]] = arith.mulf %[[V0]], %[[ACC]] : vector<2xf32>
+//      INNER-PARALLEL: %[[V1:.+]] = vector.extract %[[TRANSPOSED]][1] : vector<2xf32> from vector<4x2xf32>
+//      INNER-PARALLEL: %[[RV01:.+]] = arith.mulf %[[V1]], %[[RV0]] : vector<2xf32>
+//      INNER-PARALLEL: %[[V2:.+]] = vector.extract %[[TRANSPOSED]][2] : vector<2xf32> from vector<4x2xf32>
+//      INNER-PARALLEL: %[[RV012:.+]] = arith.mulf %[[V2]], %[[RV01]] : vector<2xf32>
+//      INNER-PARALLEL: %[[V3:.+]] = vector.extract %[[TRANSPOSED]][3] : vector<2xf32> from vector<4x2xf32>
+//      INNER-PARALLEL: %[[RESULT_VEC:.+]] = arith.mulf %[[V3]], %[[RV012]] : vector<2xf32>
+//      INNER-PARALLEL: return %[[RESULT_VEC]] : vector<2xf32>
+
+// -----
+
+func.func @vector_multi_reduction_parallel_middle(%arg0: vector<3x4x5xf32>, %acc: vector<4xf32>) -> vector<4xf32> {
+    %0 = vector.multi_reduction <add>, %arg0, %acc [0, 2] : vector<3x4x5xf32> to vector<4xf32>
+    return %0 : vector<4xf32>
+}
+
+//       ALL-LABEL: func @vector_multi_reduction_parallel_middle
+//        ALL-SAME: %[[INPUT:.+]]: vector<3x4x5xf32>, %[[ACC:.+]]: vector<4xf32>
+// INNER-REDUCTION: vector.transpose %[[INPUT]], [1, 0, 2] : vector<3x4x5xf32> to vector<4x3x5xf32>
+//  INNER-PARALLEL: vector.transpose %[[INPUT]], [0, 2, 1] : vector<3x4x5xf32> to vector<3x5x4xf32>


        


More information about the Mlir-commits mailing list