[Mlir-commits] [mlir] Add lower-vector-multi-reduction pass (PR #87333)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Apr 2 03:56:50 PDT 2024


https://github.com/xiaoleis-nv created https://github.com/llvm/llvm-project/pull/87333

This MR adds the `lower-vector-multi-reduction` pass to lower the vector.multi_reduction operation. Two test files have been added to ensure that different lowering strategies work as expected.

>From b9e5018943971d16e2880fb63d6a31f24da353dc Mon Sep 17 00:00:00 2001
From: Xiaolei Shi <xiaoleis at nvidia.com>
Date: Tue, 2 Apr 2024 18:53:37 +0800
Subject: [PATCH] add lower-vector-multi-reduction pass

---
 .../mlir/Dialect/Vector/Transforms/Passes.h   |  6 +++
 .../mlir/Dialect/Vector/Transforms/Passes.td  | 18 +++++++++
 .../Transforms/LowerVectorMultiReduction.cpp  | 40 +++++++++++++++++++
 ...eduction-inner-parallel-pass-lowering.mlir | 39 ++++++++++++++++++
 ...duction-inner-reduction-pass-lowering.mlir | 34 ++++++++++++++++
 5 files changed, 137 insertions(+)
 create mode 100644 mlir/test/Dialect/Vector/vector-multi-reduction-inner-parallel-pass-lowering.mlir
 create mode 100644 mlir/test/Dialect/Vector/vector-multi-reduction-inner-reduction-pass-lowering.mlir

diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h
index bf89b01e2b60c5..911402551e14d4 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h
@@ -9,6 +9,7 @@
 #ifndef MLIR_DIALECT_VECTOR_TRANSFORMS_PASSES_H_
 #define MLIR_DIALECT_VECTOR_TRANSFORMS_PASSES_H_
 
+#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
 #include "mlir/Pass/Pass.h"
 
 namespace mlir {
@@ -22,6 +23,11 @@ std::unique_ptr<Pass> createVectorBufferizePass();
 /// Creates an instance of the `vector.mask` lowering pass.
 std::unique_ptr<Pass> createLowerVectorMaskPass();
 
+/// Creates an instance of the `vector.multi_reduction` lowering pass.
+std::unique_ptr<Pass> createLowerVectorMultiReductionPass(
+    VectorMultiReductionLowering option =
+        VectorMultiReductionLowering::InnerParallel);
+
 //===----------------------------------------------------------------------===//
 // Registration
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td
index 4911a61ab3c25d..31a0b3b2f0c53d 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td
@@ -21,4 +21,22 @@ def LowerVectorMaskPass : Pass<"lower-vector-mask", "func::FuncOp"> {
   let constructor = "mlir::vector::createLowerVectorMaskPass()";
 }
 
+def LowerVectorMultiReduction : Pass<"lower-vector-multi-reduction", "func::FuncOp"> {
+  let summary = "Lower 'vector.multi_reduction' operations";
+  let constructor = "mlir::vector::createLowerVectorMultiReductionPass()";
+  let options = [
+    Option<"loweringStrategy", "lowering-strategy", "mlir::vector::VectorMultiReductionLowering",
+           /*default=*/"mlir::vector::VectorMultiReductionLowering::InnerParallel",
+           "Select the strategy to control how multi_reduction is lowered.",
+           [{::llvm::cl::values(
+            clEnumValN(mlir::vector::VectorMultiReductionLowering::InnerParallel,
+                       "inner-parallel",
+                       "Lower multi_reduction into outer-reduction and inner-parallel ops."),
+            clEnumValN(mlir::vector::VectorMultiReductionLowering::InnerReduction,
+                       "inner-reduction",
+                       "Lower multi_reduction into outer-parallel and inner-reduction ops.")
+        )}]>
+  ];
+}
+
 #endif // MLIR_DIALECT_VECTOR_TRANSFORMS_PASSES
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index bed2c2496719dd..2f21c50c63473b 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -12,9 +12,19 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
+#include "mlir/Dialect/Vector/Transforms/Passes.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/TypeUtilities.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir {
+namespace vector {
+#define GEN_PASS_DEF_LOWERVECTORMULTIREDUCTION
+#include "mlir/Dialect/Vector/Transforms/Passes.h.inc"
+} // namespace vector
+} // namespace mlir
 
 #define DEBUG_TYPE "vector-multi-reduction"
 
@@ -461,6 +471,31 @@ struct OneDimMultiReductionToTwoDim
     return success();
   }
 };
+
+struct LowerVectorMultiReductionPass
+    : public vector::impl::LowerVectorMultiReductionBase<
+          LowerVectorMultiReductionPass> {
+  LowerVectorMultiReductionPass(vector::VectorMultiReductionLowering option) {
+    this->loweringStrategy = option;
+  }
+
+  void runOnOperation() override {
+    Operation *op = getOperation();
+    MLIRContext *context = op->getContext();
+
+    RewritePatternSet loweringPatterns(context);
+    populateVectorMultiReductionLoweringPatterns(loweringPatterns,
+                                                 this->loweringStrategy);
+
+    if (failed(applyPatternsAndFoldGreedily(op, std::move(loweringPatterns))))
+      signalPassFailure();
+  }
+
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<vector::VectorDialect>();
+  }
+};
+
 } // namespace
 
 void mlir::vector::populateVectorMultiReductionLoweringPatterns(
@@ -476,3 +511,8 @@ void mlir::vector::populateVectorMultiReductionLoweringPatterns(
     patterns.add<TwoDimMultiReductionToElementWise>(patterns.getContext(),
                                                     benefit);
 }
+
+std::unique_ptr<Pass> vector::createLowerVectorMultiReductionPass(
+    vector::VectorMultiReductionLowering option) {
+  return std::make_unique<LowerVectorMultiReductionPass>(option);
+}
diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-inner-parallel-pass-lowering.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-inner-parallel-pass-lowering.mlir
new file mode 100644
index 00000000000000..502cd7a1cbbcbc
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-multi-reduction-inner-parallel-pass-lowering.mlir
@@ -0,0 +1,39 @@
+// RUN: mlir-opt -lower-vector-multi-reduction="lowering-strategy=inner-parallel" -split-input-file %s | FileCheck %s
+
+// -----
+func.func @vector_multi_reduction(%arg0: vector<2x4xf32>, %acc: vector<2xf32>) -> vector<2xf32> {
+    %0 = vector.multi_reduction <mul>, %arg0, %acc [1] : vector<2x4xf32> to vector<2xf32>
+    return %0 : vector<2xf32>
+}
+
+// CHECK-LABEL: func @vector_multi_reduction
+//  CHECK-SAME:   %[[INPUT:.+]]: vector<2x4xf32>, %[[ACC:.*]]: vector<2xf32>
+//       CHECK:   %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xf32> to vector<4x2xf32>
+//       CHECK:   %[[V0:.+]] = vector.extract %[[TRANSPOSED]][0] : vector<2xf32> from vector<4x2xf32>
+//       CHECK:   %[[RV0:.+]] = arith.mulf %[[V0]], %[[ACC]] : vector<2xf32>
+//       CHECK:   %[[V1:.+]] = vector.extract %[[TRANSPOSED]][1] : vector<2xf32> from vector<4x2xf32>
+//       CHECK:   %[[RV01:.+]] = arith.mulf %[[V1]], %[[RV0]] : vector<2xf32>
+//       CHECK:   %[[V2:.+]] = vector.extract %[[TRANSPOSED]][2] : vector<2xf32> from vector<4x2xf32>
+//       CHECK:   %[[RV012:.+]] = arith.mulf %[[V2]], %[[RV01]] : vector<2xf32>
+//       CHECK:   %[[V3:.+]] = vector.extract %[[TRANSPOSED]][3] : vector<2xf32> from vector<4x2xf32>
+//       CHECK:   %[[RESULT_VEC:.+]] = arith.mulf %[[V3]], %[[RV012]] : vector<2xf32>
+//       CHECK:   return %[[RESULT_VEC]] : vector<2xf32>
+
+// -----
+func.func @vector_multi_reduction_min(%arg0: vector<2x4xf32>, %acc: vector<2xf32>) -> vector<2xf32> {
+    %0 = vector.multi_reduction <minnumf>, %arg0, %acc [1] : vector<2x4xf32> to vector<2xf32>
+    return %0 : vector<2xf32>
+}
+
+// CHECK-LABEL: func @vector_multi_reduction_min
+//  CHECK-SAME:   %[[INPUT:.+]]: vector<2x4xf32>, %[[ACC:.*]]: vector<2xf32>
+//       CHECK:   %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xf32> to vector<4x2xf32>
+//       CHECK:   %[[V0:.+]] = vector.extract %[[TRANSPOSED]][0] : vector<2xf32> from vector<4x2xf32>
+//       CHECK:   %[[RV0:.+]] = arith.minnumf %[[V0]], %[[ACC]] : vector<2xf32>
+//       CHECK:   %[[V1:.+]] = vector.extract %[[TRANSPOSED]][1] : vector<2xf32> from vector<4x2xf32>
+//       CHECK:   %[[RV01:.+]] = arith.minnumf %[[V1]], %[[RV0]] : vector<2xf32>
+//       CHECK:   %[[V2:.+]] = vector.extract %[[TRANSPOSED]][2] : vector<2xf32> from vector<4x2xf32>
+//       CHECK:   %[[RV012:.+]] = arith.minnumf %[[V2]], %[[RV01]] : vector<2xf32>
+//       CHECK:   %[[V3:.+]] = vector.extract %[[TRANSPOSED]][3] : vector<2xf32> from vector<4x2xf32>
+//       CHECK:   %[[RESULT_VEC:.+]] = arith.minnumf %[[V3]], %[[RV012]] : vector<2xf32>
+//       CHECK:   return %[[RESULT_VEC]] : vector<2xf32>
\ No newline at end of file
diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-inner-reduction-pass-lowering.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-inner-reduction-pass-lowering.mlir
new file mode 100644
index 00000000000000..f051ce73fc49b4
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-multi-reduction-inner-reduction-pass-lowering.mlir
@@ -0,0 +1,34 @@
+// RUN: mlir-opt -lower-vector-multi-reduction="lowering-strategy=inner-reduction" -split-input-file %s | FileCheck %s
+
+// -----
+func.func @vector_multi_reduction(%arg0: vector<2x4xf32>, %acc: vector<2xf32>) -> vector<2xf32> {
+    %0 = vector.multi_reduction <mul>, %arg0, %acc [1] : vector<2x4xf32> to vector<2xf32>
+    return %0 : vector<2xf32>
+}
+// CHECK-LABEL: func @vector_multi_reduction
+//  CHECK-SAME:   %[[INPUT:.+]]: vector<2x4xf32>, %[[ACC:.*]]: vector<2xf32>)
+//   CHECK-DAG:       %[[RESULT_VEC_0:.+]] = arith.constant dense<{{.*}}> : vector<2xf32>
+//   CHECK-DAG:       %[[C0:.+]] = arith.constant 0 : index
+//   CHECK-DAG:       %[[C1:.+]] = arith.constant 1 : index
+//       CHECK:       %[[V0:.+]] = vector.extract %[[INPUT]][0]
+//       CHECK:       %[[ACC0:.+]] = vector.extract %[[ACC]][0]
+//       CHECK:       %[[RV0:.+]] = vector.reduction <mul>, %[[V0]], %[[ACC0]] : vector<4xf32> into f32
+//       CHECK:       %[[RESULT_VEC_1:.+]] = vector.insertelement %[[RV0:.+]], %[[RESULT_VEC_0]][%[[C0]] : index] : vector<2xf32>
+//       CHECK:       %[[V1:.+]] = vector.extract %[[INPUT]][1]
+//       CHECK:       %[[ACC1:.+]] = vector.extract %[[ACC]][1]
+//       CHECK:       %[[RV1:.+]] = vector.reduction <mul>, %[[V1]], %[[ACC1]] : vector<4xf32> into f32
+//       CHECK:       %[[RESULT_VEC:.+]] = vector.insertelement %[[RV1:.+]], %[[RESULT_VEC_1]][%[[C1]] : index] : vector<2xf32>
+//       CHECK:       return %[[RESULT_VEC]]
+
+// -----
+func.func @vector_multi_reduction_to_scalar(%arg0: vector<2x4xf32>, %acc: f32) -> f32 {
+    %0 = vector.multi_reduction <mul>, %arg0, %acc [0, 1] : vector<2x4xf32> to f32
+    return %0 : f32
+}
+// CHECK-LABEL: func @vector_multi_reduction_to_scalar
+//  CHECK-SAME:   %[[INPUT:.+]]: vector<2x4xf32>, %[[ACC:.*]]: f32)
+//       CHECK:   %[[CASTED:.*]] = vector.shape_cast %[[INPUT]] : vector<2x4xf32> to vector<8xf32>
+//       CHECK:   %[[REDUCED:.*]] = vector.reduction <mul>, %[[CASTED]], %[[ACC]] : vector<8xf32> into f32
+//       CHECK:   %[[INSERTED:.*]] = vector.insertelement %[[REDUCED]], {{.*}} : vector<1xf32>
+//       CHECK:   %[[RES:.*]] = vector.extract %[[INSERTED]][0] : f32 from vector<1xf32>
+//       CHECK:   return %[[RES]]



More information about the Mlir-commits mailing list