[Mlir-commits] [mlir] Add lower-vector-multi-reduction pass (PR #87333)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Apr 2 03:57:17 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-vector
Author: None (xiaoleis-nv)
<details>
<summary>Changes</summary>
This MR adds the `lower-vector-multi-reduction` pass to lower the vector.multi_reduction operation. Two test files have been added to ensure that different lowering strategies work as expected.
---
Full diff: https://github.com/llvm/llvm-project/pull/87333.diff
5 Files Affected:
- (modified) mlir/include/mlir/Dialect/Vector/Transforms/Passes.h (+6)
- (modified) mlir/include/mlir/Dialect/Vector/Transforms/Passes.td (+18)
- (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp (+40)
- (added) mlir/test/Dialect/Vector/vector-multi-reduction-inner-parallel-pass-lowering.mlir (+39)
- (added) mlir/test/Dialect/Vector/vector-multi-reduction-inner-reduction-pass-lowering.mlir (+34)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h
index bf89b01e2b60c5..911402551e14d4 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h
@@ -9,6 +9,7 @@
#ifndef MLIR_DIALECT_VECTOR_TRANSFORMS_PASSES_H_
#define MLIR_DIALECT_VECTOR_TRANSFORMS_PASSES_H_
+#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
#include "mlir/Pass/Pass.h"
namespace mlir {
@@ -22,6 +23,11 @@ std::unique_ptr<Pass> createVectorBufferizePass();
/// Creates an instance of the `vector.mask` lowering pass.
std::unique_ptr<Pass> createLowerVectorMaskPass();
+/// Creates an instance of the `vector.multi_reduction` lowering pass.
+std::unique_ptr<Pass> createLowerVectorMultiReductionPass(
+ VectorMultiReductionLowering option =
+ VectorMultiReductionLowering::InnerParallel);
+
//===----------------------------------------------------------------------===//
// Registration
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td
index 4911a61ab3c25d..31a0b3b2f0c53d 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td
@@ -21,4 +21,22 @@ def LowerVectorMaskPass : Pass<"lower-vector-mask", "func::FuncOp"> {
let constructor = "mlir::vector::createLowerVectorMaskPass()";
}
+def LowerVectorMultiReduction : Pass<"lower-vector-multi-reduction", "func::FuncOp"> {
+ let summary = "Lower 'vector.multi_reduction' operations";
+ let constructor = "mlir::vector::createLowerVectorMultiReductionPass()";
+ let options = [
+ Option<"loweringStrategy", "lowering-strategy", "mlir::vector::VectorMultiReductionLowering",
+ /*default=*/"mlir::vector::VectorMultiReductionLowering::InnerParallel",
+ "Select the strategy to control how multi_reduction is lowered.",
+ [{::llvm::cl::values(
+ clEnumValN(mlir::vector::VectorMultiReductionLowering::InnerParallel,
+ "inner-parallel",
+ "Lower multi_reduction into outer-reduction and inner-parallel ops."),
+ clEnumValN(mlir::vector::VectorMultiReductionLowering::InnerReduction,
+ "inner-reduction",
+ "Lower multi_reduction into outer-parallel and inner-reduction ops.")
+ )}]>
+ ];
+}
+
#endif // MLIR_DIALECT_VECTOR_TRANSFORMS_PASSES
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index bed2c2496719dd..2f21c50c63473b 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -12,9 +12,19 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
+#include "mlir/Dialect/Vector/Transforms/Passes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/TypeUtilities.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir {
+namespace vector {
+#define GEN_PASS_DEF_LOWERVECTORMULTIREDUCTION
+#include "mlir/Dialect/Vector/Transforms/Passes.h.inc"
+} // namespace vector
+} // namespace mlir
#define DEBUG_TYPE "vector-multi-reduction"
@@ -461,6 +471,31 @@ struct OneDimMultiReductionToTwoDim
return success();
}
};
+
+struct LowerVectorMultiReductionPass
+ : public vector::impl::LowerVectorMultiReductionBase<
+ LowerVectorMultiReductionPass> {
+ LowerVectorMultiReductionPass(vector::VectorMultiReductionLowering option) {
+ this->loweringStrategy = option;
+ }
+
+ void runOnOperation() override {
+ Operation *op = getOperation();
+ MLIRContext *context = op->getContext();
+
+ RewritePatternSet loweringPatterns(context);
+ populateVectorMultiReductionLoweringPatterns(loweringPatterns,
+ this->loweringStrategy);
+
+ if (failed(applyPatternsAndFoldGreedily(op, std::move(loweringPatterns))))
+ signalPassFailure();
+ }
+
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<vector::VectorDialect>();
+ }
+};
+
} // namespace
void mlir::vector::populateVectorMultiReductionLoweringPatterns(
@@ -476,3 +511,8 @@ void mlir::vector::populateVectorMultiReductionLoweringPatterns(
patterns.add<TwoDimMultiReductionToElementWise>(patterns.getContext(),
benefit);
}
+
+std::unique_ptr<Pass> vector::createLowerVectorMultiReductionPass(
+ vector::VectorMultiReductionLowering option) {
+ return std::make_unique<LowerVectorMultiReductionPass>(option);
+}
diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-inner-parallel-pass-lowering.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-inner-parallel-pass-lowering.mlir
new file mode 100644
index 00000000000000..502cd7a1cbbcbc
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-multi-reduction-inner-parallel-pass-lowering.mlir
@@ -0,0 +1,39 @@
+// RUN: mlir-opt -lower-vector-multi-reduction="lowering-strategy=inner-parallel" -split-input-file %s | FileCheck %s
+
+// -----
+func.func @vector_multi_reduction(%arg0: vector<2x4xf32>, %acc: vector<2xf32>) -> vector<2xf32> {
+ %0 = vector.multi_reduction <mul>, %arg0, %acc [1] : vector<2x4xf32> to vector<2xf32>
+ return %0 : vector<2xf32>
+}
+
+// CHECK-LABEL: func @vector_multi_reduction
+// CHECK-SAME: %[[INPUT:.+]]: vector<2x4xf32>, %[[ACC:.*]]: vector<2xf32>
+// CHECK: %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xf32> to vector<4x2xf32>
+// CHECK: %[[V0:.+]] = vector.extract %[[TRANSPOSED]][0] : vector<2xf32> from vector<4x2xf32>
+// CHECK: %[[RV0:.+]] = arith.mulf %[[V0]], %[[ACC]] : vector<2xf32>
+// CHECK: %[[V1:.+]] = vector.extract %[[TRANSPOSED]][1] : vector<2xf32> from vector<4x2xf32>
+// CHECK: %[[RV01:.+]] = arith.mulf %[[V1]], %[[RV0]] : vector<2xf32>
+// CHECK: %[[V2:.+]] = vector.extract %[[TRANSPOSED]][2] : vector<2xf32> from vector<4x2xf32>
+// CHECK: %[[RV012:.+]] = arith.mulf %[[V2]], %[[RV01]] : vector<2xf32>
+// CHECK: %[[V3:.+]] = vector.extract %[[TRANSPOSED]][3] : vector<2xf32> from vector<4x2xf32>
+// CHECK: %[[RESULT_VEC:.+]] = arith.mulf %[[V3]], %[[RV012]] : vector<2xf32>
+// CHECK: return %[[RESULT_VEC]] : vector<2xf32>
+
+// -----
+func.func @vector_multi_reduction_min(%arg0: vector<2x4xf32>, %acc: vector<2xf32>) -> vector<2xf32> {
+ %0 = vector.multi_reduction <minnumf>, %arg0, %acc [1] : vector<2x4xf32> to vector<2xf32>
+ return %0 : vector<2xf32>
+}
+
+// CHECK-LABEL: func @vector_multi_reduction_min
+// CHECK-SAME: %[[INPUT:.+]]: vector<2x4xf32>, %[[ACC:.*]]: vector<2xf32>
+// CHECK: %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xf32> to vector<4x2xf32>
+// CHECK: %[[V0:.+]] = vector.extract %[[TRANSPOSED]][0] : vector<2xf32> from vector<4x2xf32>
+// CHECK: %[[RV0:.+]] = arith.minnumf %[[V0]], %[[ACC]] : vector<2xf32>
+// CHECK: %[[V1:.+]] = vector.extract %[[TRANSPOSED]][1] : vector<2xf32> from vector<4x2xf32>
+// CHECK: %[[RV01:.+]] = arith.minnumf %[[V1]], %[[RV0]] : vector<2xf32>
+// CHECK: %[[V2:.+]] = vector.extract %[[TRANSPOSED]][2] : vector<2xf32> from vector<4x2xf32>
+// CHECK: %[[RV012:.+]] = arith.minnumf %[[V2]], %[[RV01]] : vector<2xf32>
+// CHECK: %[[V3:.+]] = vector.extract %[[TRANSPOSED]][3] : vector<2xf32> from vector<4x2xf32>
+// CHECK: %[[RESULT_VEC:.+]] = arith.minnumf %[[V3]], %[[RV012]] : vector<2xf32>
+// CHECK: return %[[RESULT_VEC]] : vector<2xf32>
\ No newline at end of file
diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-inner-reduction-pass-lowering.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-inner-reduction-pass-lowering.mlir
new file mode 100644
index 00000000000000..f051ce73fc49b4
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-multi-reduction-inner-reduction-pass-lowering.mlir
@@ -0,0 +1,34 @@
+// RUN: mlir-opt -lower-vector-multi-reduction="lowering-strategy=inner-reduction" -split-input-file %s | FileCheck %s
+
+// -----
+func.func @vector_multi_reduction(%arg0: vector<2x4xf32>, %acc: vector<2xf32>) -> vector<2xf32> {
+ %0 = vector.multi_reduction <mul>, %arg0, %acc [1] : vector<2x4xf32> to vector<2xf32>
+ return %0 : vector<2xf32>
+}
+// CHECK-LABEL: func @vector_multi_reduction
+// CHECK-SAME: %[[INPUT:.+]]: vector<2x4xf32>, %[[ACC:.*]]: vector<2xf32>)
+// CHECK-DAG: %[[RESULT_VEC_0:.+]] = arith.constant dense<{{.*}}> : vector<2xf32>
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK: %[[V0:.+]] = vector.extract %[[INPUT]][0]
+// CHECK: %[[ACC0:.+]] = vector.extract %[[ACC]][0]
+// CHECK: %[[RV0:.+]] = vector.reduction <mul>, %[[V0]], %[[ACC0]] : vector<4xf32> into f32
+// CHECK: %[[RESULT_VEC_1:.+]] = vector.insertelement %[[RV0:.+]], %[[RESULT_VEC_0]][%[[C0]] : index] : vector<2xf32>
+// CHECK: %[[V1:.+]] = vector.extract %[[INPUT]][1]
+// CHECK: %[[ACC1:.+]] = vector.extract %[[ACC]][1]
+// CHECK: %[[RV1:.+]] = vector.reduction <mul>, %[[V1]], %[[ACC1]] : vector<4xf32> into f32
+// CHECK: %[[RESULT_VEC:.+]] = vector.insertelement %[[RV1:.+]], %[[RESULT_VEC_1]][%[[C1]] : index] : vector<2xf32>
+// CHECK: return %[[RESULT_VEC]]
+
+// -----
+func.func @vector_multi_reduction_to_scalar(%arg0: vector<2x4xf32>, %acc: f32) -> f32 {
+ %0 = vector.multi_reduction <mul>, %arg0, %acc [0, 1] : vector<2x4xf32> to f32
+ return %0 : f32
+}
+// CHECK-LABEL: func @vector_multi_reduction_to_scalar
+// CHECK-SAME: %[[INPUT:.+]]: vector<2x4xf32>, %[[ACC:.*]]: f32)
+// CHECK: %[[CASTED:.*]] = vector.shape_cast %[[INPUT]] : vector<2x4xf32> to vector<8xf32>
+// CHECK: %[[REDUCED:.*]] = vector.reduction <mul>, %[[CASTED]], %[[ACC]] : vector<8xf32> into f32
+// CHECK: %[[INSERTED:.*]] = vector.insertelement %[[REDUCED]], {{.*}} : vector<1xf32>
+// CHECK: %[[RES:.*]] = vector.extract %[[INSERTED]][0] : f32 from vector<1xf32>
+// CHECK: return %[[RES]]
``````````
</details>
https://github.com/llvm/llvm-project/pull/87333
More information about the Mlir-commits
mailing list