[Mlir-commits] [mlir] Add lower-vector-multi-reduction pass (PR #87333)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Apr 8 18:25:10 PDT 2024
https://github.com/xiaoleis-nv updated https://github.com/llvm/llvm-project/pull/87333
>From b9e5018943971d16e2880fb63d6a31f24da353dc Mon Sep 17 00:00:00 2001
From: Xiaolei Shi <xiaoleis at nvidia.com>
Date: Tue, 2 Apr 2024 18:53:37 +0800
Subject: [PATCH 1/4] add lower-vector-multi-reduction pass
---
.../mlir/Dialect/Vector/Transforms/Passes.h | 6 +++
.../mlir/Dialect/Vector/Transforms/Passes.td | 18 +++++++++
.../Transforms/LowerVectorMultiReduction.cpp | 40 +++++++++++++++++++
...eduction-inner-parallel-pass-lowering.mlir | 39 ++++++++++++++++++
...duction-inner-reduction-pass-lowering.mlir | 34 ++++++++++++++++
5 files changed, 137 insertions(+)
create mode 100644 mlir/test/Dialect/Vector/vector-multi-reduction-inner-parallel-pass-lowering.mlir
create mode 100644 mlir/test/Dialect/Vector/vector-multi-reduction-inner-reduction-pass-lowering.mlir
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h
index bf89b01e2b60c5..911402551e14d4 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h
@@ -9,6 +9,7 @@
#ifndef MLIR_DIALECT_VECTOR_TRANSFORMS_PASSES_H_
#define MLIR_DIALECT_VECTOR_TRANSFORMS_PASSES_H_
+#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
#include "mlir/Pass/Pass.h"
namespace mlir {
@@ -22,6 +23,11 @@ std::unique_ptr<Pass> createVectorBufferizePass();
/// Creates an instance of the `vector.mask` lowering pass.
std::unique_ptr<Pass> createLowerVectorMaskPass();
+/// Creates an instance of the `vector.multi_reduction` lowering pass.
+std::unique_ptr<Pass> createLowerVectorMultiReductionPass(
+ VectorMultiReductionLowering option =
+ VectorMultiReductionLowering::InnerParallel);
+
//===----------------------------------------------------------------------===//
// Registration
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td
index 4911a61ab3c25d..31a0b3b2f0c53d 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td
@@ -21,4 +21,22 @@ def LowerVectorMaskPass : Pass<"lower-vector-mask", "func::FuncOp"> {
let constructor = "mlir::vector::createLowerVectorMaskPass()";
}
+def LowerVectorMultiReduction : Pass<"lower-vector-multi-reduction", "func::FuncOp"> {
+ let summary = "Lower 'vector.multi_reduction' operations";
+ let constructor = "mlir::vector::createLowerVectorMultiReductionPass()";
+ let options = [
+ Option<"loweringStrategy", "lowering-strategy", "mlir::vector::VectorMultiReductionLowering",
+ /*default=*/"mlir::vector::VectorMultiReductionLowering::InnerParallel",
+ "Select the strategy to control how multi_reduction is lowered.",
+ [{::llvm::cl::values(
+ clEnumValN(mlir::vector::VectorMultiReductionLowering::InnerParallel,
+ "inner-parallel",
+ "Lower multi_reduction into outer-reduction and inner-parallel ops."),
+ clEnumValN(mlir::vector::VectorMultiReductionLowering::InnerReduction,
+ "inner-reduction",
+ "Lower multi_reduction into outer-parallel and inner-reduction ops.")
+ )}]>
+ ];
+}
+
#endif // MLIR_DIALECT_VECTOR_TRANSFORMS_PASSES
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index bed2c2496719dd..2f21c50c63473b 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -12,9 +12,19 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
+#include "mlir/Dialect/Vector/Transforms/Passes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/TypeUtilities.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir {
+namespace vector {
+#define GEN_PASS_DEF_LOWERVECTORMULTIREDUCTION
+#include "mlir/Dialect/Vector/Transforms/Passes.h.inc"
+} // namespace vector
+} // namespace mlir
#define DEBUG_TYPE "vector-multi-reduction"
@@ -461,6 +471,31 @@ struct OneDimMultiReductionToTwoDim
return success();
}
};
+
+struct LowerVectorMultiReductionPass
+ : public vector::impl::LowerVectorMultiReductionBase<
+ LowerVectorMultiReductionPass> {
+ LowerVectorMultiReductionPass(vector::VectorMultiReductionLowering option) {
+ this->loweringStrategy = option;
+ }
+
+ void runOnOperation() override {
+ Operation *op = getOperation();
+ MLIRContext *context = op->getContext();
+
+ RewritePatternSet loweringPatterns(context);
+ populateVectorMultiReductionLoweringPatterns(loweringPatterns,
+ this->loweringStrategy);
+
+ if (failed(applyPatternsAndFoldGreedily(op, std::move(loweringPatterns))))
+ signalPassFailure();
+ }
+
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<vector::VectorDialect>();
+ }
+};
+
} // namespace
void mlir::vector::populateVectorMultiReductionLoweringPatterns(
@@ -476,3 +511,8 @@ void mlir::vector::populateVectorMultiReductionLoweringPatterns(
patterns.add<TwoDimMultiReductionToElementWise>(patterns.getContext(),
benefit);
}
+
+std::unique_ptr<Pass> vector::createLowerVectorMultiReductionPass(
+ vector::VectorMultiReductionLowering option) {
+ return std::make_unique<LowerVectorMultiReductionPass>(option);
+}
diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-inner-parallel-pass-lowering.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-inner-parallel-pass-lowering.mlir
new file mode 100644
index 00000000000000..502cd7a1cbbcbc
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-multi-reduction-inner-parallel-pass-lowering.mlir
@@ -0,0 +1,39 @@
+// RUN: mlir-opt -lower-vector-multi-reduction="lowering-strategy=inner-parallel" -split-input-file %s | FileCheck %s
+
+// -----
+func.func @vector_multi_reduction(%arg0: vector<2x4xf32>, %acc: vector<2xf32>) -> vector<2xf32> {
+ %0 = vector.multi_reduction <mul>, %arg0, %acc [1] : vector<2x4xf32> to vector<2xf32>
+ return %0 : vector<2xf32>
+}
+
+// CHECK-LABEL: func @vector_multi_reduction
+// CHECK-SAME: %[[INPUT:.+]]: vector<2x4xf32>, %[[ACC:.*]]: vector<2xf32>
+// CHECK: %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xf32> to vector<4x2xf32>
+// CHECK: %[[V0:.+]] = vector.extract %[[TRANSPOSED]][0] : vector<2xf32> from vector<4x2xf32>
+// CHECK: %[[RV0:.+]] = arith.mulf %[[V0]], %[[ACC]] : vector<2xf32>
+// CHECK: %[[V1:.+]] = vector.extract %[[TRANSPOSED]][1] : vector<2xf32> from vector<4x2xf32>
+// CHECK: %[[RV01:.+]] = arith.mulf %[[V1]], %[[RV0]] : vector<2xf32>
+// CHECK: %[[V2:.+]] = vector.extract %[[TRANSPOSED]][2] : vector<2xf32> from vector<4x2xf32>
+// CHECK: %[[RV012:.+]] = arith.mulf %[[V2]], %[[RV01]] : vector<2xf32>
+// CHECK: %[[V3:.+]] = vector.extract %[[TRANSPOSED]][3] : vector<2xf32> from vector<4x2xf32>
+// CHECK: %[[RESULT_VEC:.+]] = arith.mulf %[[V3]], %[[RV012]] : vector<2xf32>
+// CHECK: return %[[RESULT_VEC]] : vector<2xf32>
+
+// -----
+func.func @vector_multi_reduction_min(%arg0: vector<2x4xf32>, %acc: vector<2xf32>) -> vector<2xf32> {
+ %0 = vector.multi_reduction <minnumf>, %arg0, %acc [1] : vector<2x4xf32> to vector<2xf32>
+ return %0 : vector<2xf32>
+}
+
+// CHECK-LABEL: func @vector_multi_reduction_min
+// CHECK-SAME: %[[INPUT:.+]]: vector<2x4xf32>, %[[ACC:.*]]: vector<2xf32>
+// CHECK: %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xf32> to vector<4x2xf32>
+// CHECK: %[[V0:.+]] = vector.extract %[[TRANSPOSED]][0] : vector<2xf32> from vector<4x2xf32>
+// CHECK: %[[RV0:.+]] = arith.minnumf %[[V0]], %[[ACC]] : vector<2xf32>
+// CHECK: %[[V1:.+]] = vector.extract %[[TRANSPOSED]][1] : vector<2xf32> from vector<4x2xf32>
+// CHECK: %[[RV01:.+]] = arith.minnumf %[[V1]], %[[RV0]] : vector<2xf32>
+// CHECK: %[[V2:.+]] = vector.extract %[[TRANSPOSED]][2] : vector<2xf32> from vector<4x2xf32>
+// CHECK: %[[RV012:.+]] = arith.minnumf %[[V2]], %[[RV01]] : vector<2xf32>
+// CHECK: %[[V3:.+]] = vector.extract %[[TRANSPOSED]][3] : vector<2xf32> from vector<4x2xf32>
+// CHECK: %[[RESULT_VEC:.+]] = arith.minnumf %[[V3]], %[[RV012]] : vector<2xf32>
+// CHECK: return %[[RESULT_VEC]] : vector<2xf32>
\ No newline at end of file
diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-inner-reduction-pass-lowering.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-inner-reduction-pass-lowering.mlir
new file mode 100644
index 00000000000000..f051ce73fc49b4
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-multi-reduction-inner-reduction-pass-lowering.mlir
@@ -0,0 +1,34 @@
+// RUN: mlir-opt -lower-vector-multi-reduction="lowering-strategy=inner-reduction" -split-input-file %s | FileCheck %s
+
+// -----
+func.func @vector_multi_reduction(%arg0: vector<2x4xf32>, %acc: vector<2xf32>) -> vector<2xf32> {
+ %0 = vector.multi_reduction <mul>, %arg0, %acc [1] : vector<2x4xf32> to vector<2xf32>
+ return %0 : vector<2xf32>
+}
+// CHECK-LABEL: func @vector_multi_reduction
+// CHECK-SAME: %[[INPUT:.+]]: vector<2x4xf32>, %[[ACC:.*]]: vector<2xf32>)
+// CHECK-DAG: %[[RESULT_VEC_0:.+]] = arith.constant dense<{{.*}}> : vector<2xf32>
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK: %[[V0:.+]] = vector.extract %[[INPUT]][0]
+// CHECK: %[[ACC0:.+]] = vector.extract %[[ACC]][0]
+// CHECK: %[[RV0:.+]] = vector.reduction <mul>, %[[V0]], %[[ACC0]] : vector<4xf32> into f32
+// CHECK: %[[RESULT_VEC_1:.+]] = vector.insertelement %[[RV0:.+]], %[[RESULT_VEC_0]][%[[C0]] : index] : vector<2xf32>
+// CHECK: %[[V1:.+]] = vector.extract %[[INPUT]][1]
+// CHECK: %[[ACC1:.+]] = vector.extract %[[ACC]][1]
+// CHECK: %[[RV1:.+]] = vector.reduction <mul>, %[[V1]], %[[ACC1]] : vector<4xf32> into f32
+// CHECK: %[[RESULT_VEC:.+]] = vector.insertelement %[[RV1:.+]], %[[RESULT_VEC_1]][%[[C1]] : index] : vector<2xf32>
+// CHECK: return %[[RESULT_VEC]]
+
+// -----
+func.func @vector_multi_reduction_to_scalar(%arg0: vector<2x4xf32>, %acc: f32) -> f32 {
+ %0 = vector.multi_reduction <mul>, %arg0, %acc [0, 1] : vector<2x4xf32> to f32
+ return %0 : f32
+}
+// CHECK-LABEL: func @vector_multi_reduction_to_scalar
+// CHECK-SAME: %[[INPUT:.+]]: vector<2x4xf32>, %[[ACC:.*]]: f32)
+// CHECK: %[[CASTED:.*]] = vector.shape_cast %[[INPUT]] : vector<2x4xf32> to vector<8xf32>
+// CHECK: %[[REDUCED:.*]] = vector.reduction <mul>, %[[CASTED]], %[[ACC]] : vector<8xf32> into f32
+// CHECK: %[[INSERTED:.*]] = vector.insertelement %[[REDUCED]], {{.*}} : vector<1xf32>
+// CHECK: %[[RES:.*]] = vector.extract %[[INSERTED]][0] : f32 from vector<1xf32>
+// CHECK: return %[[RES]]
>From 7a26bc351ddcfa1fd9d3874ef3c23fab2b36fb4e Mon Sep 17 00:00:00 2001
From: Xiaolei Shi <xiaoleis at nvidia.com>
Date: Wed, 3 Apr 2024 09:43:56 +0800
Subject: [PATCH 2/4] merge test files
---
...eduction-inner-parallel-pass-lowering.mlir | 39 ---------------
...duction-inner-reduction-pass-lowering.mlir | 34 -------------
.../vector-multi-reduction-pass-lowering.mlir | 48 +++++++++++++++++++
3 files changed, 48 insertions(+), 73 deletions(-)
delete mode 100644 mlir/test/Dialect/Vector/vector-multi-reduction-inner-parallel-pass-lowering.mlir
delete mode 100644 mlir/test/Dialect/Vector/vector-multi-reduction-inner-reduction-pass-lowering.mlir
create mode 100644 mlir/test/Dialect/Vector/vector-multi-reduction-pass-lowering.mlir
diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-inner-parallel-pass-lowering.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-inner-parallel-pass-lowering.mlir
deleted file mode 100644
index 502cd7a1cbbcbc..00000000000000
--- a/mlir/test/Dialect/Vector/vector-multi-reduction-inner-parallel-pass-lowering.mlir
+++ /dev/null
@@ -1,39 +0,0 @@
-// RUN: mlir-opt -lower-vector-multi-reduction="lowering-strategy=inner-parallel" -split-input-file %s | FileCheck %s
-
-// -----
-func.func @vector_multi_reduction(%arg0: vector<2x4xf32>, %acc: vector<2xf32>) -> vector<2xf32> {
- %0 = vector.multi_reduction <mul>, %arg0, %acc [1] : vector<2x4xf32> to vector<2xf32>
- return %0 : vector<2xf32>
-}
-
-// CHECK-LABEL: func @vector_multi_reduction
-// CHECK-SAME: %[[INPUT:.+]]: vector<2x4xf32>, %[[ACC:.*]]: vector<2xf32>
-// CHECK: %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xf32> to vector<4x2xf32>
-// CHECK: %[[V0:.+]] = vector.extract %[[TRANSPOSED]][0] : vector<2xf32> from vector<4x2xf32>
-// CHECK: %[[RV0:.+]] = arith.mulf %[[V0]], %[[ACC]] : vector<2xf32>
-// CHECK: %[[V1:.+]] = vector.extract %[[TRANSPOSED]][1] : vector<2xf32> from vector<4x2xf32>
-// CHECK: %[[RV01:.+]] = arith.mulf %[[V1]], %[[RV0]] : vector<2xf32>
-// CHECK: %[[V2:.+]] = vector.extract %[[TRANSPOSED]][2] : vector<2xf32> from vector<4x2xf32>
-// CHECK: %[[RV012:.+]] = arith.mulf %[[V2]], %[[RV01]] : vector<2xf32>
-// CHECK: %[[V3:.+]] = vector.extract %[[TRANSPOSED]][3] : vector<2xf32> from vector<4x2xf32>
-// CHECK: %[[RESULT_VEC:.+]] = arith.mulf %[[V3]], %[[RV012]] : vector<2xf32>
-// CHECK: return %[[RESULT_VEC]] : vector<2xf32>
-
-// -----
-func.func @vector_multi_reduction_min(%arg0: vector<2x4xf32>, %acc: vector<2xf32>) -> vector<2xf32> {
- %0 = vector.multi_reduction <minnumf>, %arg0, %acc [1] : vector<2x4xf32> to vector<2xf32>
- return %0 : vector<2xf32>
-}
-
-// CHECK-LABEL: func @vector_multi_reduction_min
-// CHECK-SAME: %[[INPUT:.+]]: vector<2x4xf32>, %[[ACC:.*]]: vector<2xf32>
-// CHECK: %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xf32> to vector<4x2xf32>
-// CHECK: %[[V0:.+]] = vector.extract %[[TRANSPOSED]][0] : vector<2xf32> from vector<4x2xf32>
-// CHECK: %[[RV0:.+]] = arith.minnumf %[[V0]], %[[ACC]] : vector<2xf32>
-// CHECK: %[[V1:.+]] = vector.extract %[[TRANSPOSED]][1] : vector<2xf32> from vector<4x2xf32>
-// CHECK: %[[RV01:.+]] = arith.minnumf %[[V1]], %[[RV0]] : vector<2xf32>
-// CHECK: %[[V2:.+]] = vector.extract %[[TRANSPOSED]][2] : vector<2xf32> from vector<4x2xf32>
-// CHECK: %[[RV012:.+]] = arith.minnumf %[[V2]], %[[RV01]] : vector<2xf32>
-// CHECK: %[[V3:.+]] = vector.extract %[[TRANSPOSED]][3] : vector<2xf32> from vector<4x2xf32>
-// CHECK: %[[RESULT_VEC:.+]] = arith.minnumf %[[V3]], %[[RV012]] : vector<2xf32>
-// CHECK: return %[[RESULT_VEC]] : vector<2xf32>
\ No newline at end of file
diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-inner-reduction-pass-lowering.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-inner-reduction-pass-lowering.mlir
deleted file mode 100644
index f051ce73fc49b4..00000000000000
--- a/mlir/test/Dialect/Vector/vector-multi-reduction-inner-reduction-pass-lowering.mlir
+++ /dev/null
@@ -1,34 +0,0 @@
-// RUN: mlir-opt -lower-vector-multi-reduction="lowering-strategy=inner-reduction" -split-input-file %s | FileCheck %s
-
-// -----
-func.func @vector_multi_reduction(%arg0: vector<2x4xf32>, %acc: vector<2xf32>) -> vector<2xf32> {
- %0 = vector.multi_reduction <mul>, %arg0, %acc [1] : vector<2x4xf32> to vector<2xf32>
- return %0 : vector<2xf32>
-}
-// CHECK-LABEL: func @vector_multi_reduction
-// CHECK-SAME: %[[INPUT:.+]]: vector<2x4xf32>, %[[ACC:.*]]: vector<2xf32>)
-// CHECK-DAG: %[[RESULT_VEC_0:.+]] = arith.constant dense<{{.*}}> : vector<2xf32>
-// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
-// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
-// CHECK: %[[V0:.+]] = vector.extract %[[INPUT]][0]
-// CHECK: %[[ACC0:.+]] = vector.extract %[[ACC]][0]
-// CHECK: %[[RV0:.+]] = vector.reduction <mul>, %[[V0]], %[[ACC0]] : vector<4xf32> into f32
-// CHECK: %[[RESULT_VEC_1:.+]] = vector.insertelement %[[RV0:.+]], %[[RESULT_VEC_0]][%[[C0]] : index] : vector<2xf32>
-// CHECK: %[[V1:.+]] = vector.extract %[[INPUT]][1]
-// CHECK: %[[ACC1:.+]] = vector.extract %[[ACC]][1]
-// CHECK: %[[RV1:.+]] = vector.reduction <mul>, %[[V1]], %[[ACC1]] : vector<4xf32> into f32
-// CHECK: %[[RESULT_VEC:.+]] = vector.insertelement %[[RV1:.+]], %[[RESULT_VEC_1]][%[[C1]] : index] : vector<2xf32>
-// CHECK: return %[[RESULT_VEC]]
-
-// -----
-func.func @vector_multi_reduction_to_scalar(%arg0: vector<2x4xf32>, %acc: f32) -> f32 {
- %0 = vector.multi_reduction <mul>, %arg0, %acc [0, 1] : vector<2x4xf32> to f32
- return %0 : f32
-}
-// CHECK-LABEL: func @vector_multi_reduction_to_scalar
-// CHECK-SAME: %[[INPUT:.+]]: vector<2x4xf32>, %[[ACC:.*]]: f32)
-// CHECK: %[[CASTED:.*]] = vector.shape_cast %[[INPUT]] : vector<2x4xf32> to vector<8xf32>
-// CHECK: %[[REDUCED:.*]] = vector.reduction <mul>, %[[CASTED]], %[[ACC]] : vector<8xf32> into f32
-// CHECK: %[[INSERTED:.*]] = vector.insertelement %[[REDUCED]], {{.*}} : vector<1xf32>
-// CHECK: %[[RES:.*]] = vector.extract %[[INSERTED]][0] : f32 from vector<1xf32>
-// CHECK: return %[[RES]]
diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-pass-lowering.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-pass-lowering.mlir
new file mode 100644
index 00000000000000..7317c5958f73be
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-multi-reduction-pass-lowering.mlir
@@ -0,0 +1,48 @@
+// RUN: mlir-opt -lower-vector-multi-reduction="lowering-strategy=inner-reduction" -split-input-file %s | FileCheck %s --check-prefix=CHECK-RED
+// RUN: mlir-opt -lower-vector-multi-reduction="lowering-strategy=inner-parallel" -split-input-file %s | FileCheck %s --check-prefix=CHECK-PAR
+// RUN: mlir-opt -lower-vector-multi-reduction -split-input-file %s | FileCheck %s --check-prefix=CHECK-PAR
+
+func.func @vector_multi_reduction(%arg0: vector<2x4xf32>, %acc: vector<2xf32>) -> vector<2xf32> {
+ %0 = vector.multi_reduction <mul>, %arg0, %acc [1] : vector<2x4xf32> to vector<2xf32>
+ return %0 : vector<2xf32>
+}
+// CHECK-RED-LABEL: func @vector_multi_reduction
+// CHECK-RED-SAME: %[[INPUT:.+]]: vector<2x4xf32>, %[[ACC:.*]]: vector<2xf32>)
+// CHECK-RED-DAG: %[[RESULT_VEC_0:.+]] = arith.constant dense<{{.*}}> : vector<2xf32>
+// CHECK-RED-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-RED-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-RED: %[[V0:.+]] = vector.extract %[[INPUT]][0]
+// CHECK-RED: %[[ACC0:.+]] = vector.extract %[[ACC]][0]
+// CHECK-RED: %[[RV0:.+]] = vector.reduction <mul>, %[[V0]], %[[ACC0]] : vector<4xf32> into f32
+// CHECK-RED: %[[RESULT_VEC_1:.+]] = vector.insertelement %[[RV0:.+]], %[[RESULT_VEC_0]][%[[C0]] : index] : vector<2xf32>
+// CHECK-RED: %[[V1:.+]] = vector.extract %[[INPUT]][1]
+// CHECK-RED: %[[ACC1:.+]] = vector.extract %[[ACC]][1]
+// CHECK-RED: %[[RV1:.+]] = vector.reduction <mul>, %[[V1]], %[[ACC1]] : vector<4xf32> into f32
+// CHECK-RED: %[[RESULT_VEC:.+]] = vector.insertelement %[[RV1:.+]], %[[RESULT_VEC_1]][%[[C1]] : index] : vector<2xf32>
+// CHECK-RED: return %[[RESULT_VEC]]
+
+// CHECK-PAR-LABEL: func @vector_multi_reduction
+// CHECK-PAR-SAME: %[[INPUT:.+]]: vector<2x4xf32>, %[[ACC:.*]]: vector<2xf32>)
+// CHECK-PAR: %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xf32> to vector<4x2xf32>
+// CHECK-PAR: %[[V0:.+]] = vector.extract %[[TRANSPOSED]][0] : vector<2xf32> from vector<4x2xf32>
+// CHECK-PAR: %[[RV0:.+]] = arith.mulf %[[V0]], %[[ACC]] : vector<2xf32>
+// CHECK-PAR: %[[V1:.+]] = vector.extract %[[TRANSPOSED]][1] : vector<2xf32> from vector<4x2xf32>
+// CHECK-PAR: %[[RV01:.+]] = arith.mulf %[[V1]], %[[RV0]] : vector<2xf32>
+// CHECK-PAR: %[[V2:.+]] = vector.extract %[[TRANSPOSED]][2] : vector<2xf32> from vector<4x2xf32>
+// CHECK-PAR: %[[RV012:.+]] = arith.mulf %[[V2]], %[[RV01]] : vector<2xf32>
+// CHECK-PAR: %[[V3:.+]] = vector.extract %[[TRANSPOSED]][3] : vector<2xf32> from vector<4x2xf32>
+// CHECK-PAR: %[[RESULT_VEC:.+]] = arith.mulf %[[V3]], %[[RV012]] : vector<2xf32>
+// CHECK-PAR: return %[[RESULT_VEC]] : vector<2xf32>
+
+func.func @vector_multi_reduction_parallel_middle(%arg0: vector<3x4x5xf32>, %acc: vector<4xf32>) -> vector<4xf32> {
+ %0 = vector.multi_reduction <add>, %arg0, %acc [0, 2] : vector<3x4x5xf32> to vector<4xf32>
+ return %0 : vector<4xf32>
+}
+
+// CHECK-RED-LABEL: func @vector_multi_reduction_parallel_middle
+// CHECK-RED-SAME: %[[INPUT:.+]]: vector<3x4x5xf32>, %[[ACC:.+]]: vector<4xf32>
+// CHECK-RED: vector.transpose %[[INPUT]], [1, 0, 2] : vector<3x4x5xf32> to vector<4x3x5xf32>
+
+// CHECK-PAR-LABEL: func @vector_multi_reduction_parallel_middle
+// CHECK-PAR-SAME: %[[INPUT:.+]]: vector<3x4x5xf32>, %[[ACC:.+]]: vector<4xf32>
+// CHECK-PAR: vector.transpose %[[INPUT]], [0, 2, 1] : vector<3x4x5xf32> to vector<3x5x4xf32>
>From c0310e54115dc21b8dfb6cd6f94b06612166956a Mon Sep 17 00:00:00 2001
From: Xiaolei Shi <xiaoleis at nvidia.com>
Date: Tue, 9 Apr 2024 09:18:02 +0800
Subject: [PATCH 3/4] add test separator | remove duplicated check-label |
rename check prefixes
---
.../vector-multi-reduction-pass-lowering.mlir | 69 +++++++++----------
1 file changed, 33 insertions(+), 36 deletions(-)
diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-pass-lowering.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-pass-lowering.mlir
index 7317c5958f73be..e6f5eda28b1a27 100644
--- a/mlir/test/Dialect/Vector/vector-multi-reduction-pass-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-multi-reduction-pass-lowering.mlir
@@ -1,48 +1,45 @@
-// RUN: mlir-opt -lower-vector-multi-reduction="lowering-strategy=inner-reduction" -split-input-file %s | FileCheck %s --check-prefix=CHECK-RED
-// RUN: mlir-opt -lower-vector-multi-reduction="lowering-strategy=inner-parallel" -split-input-file %s | FileCheck %s --check-prefix=CHECK-PAR
-// RUN: mlir-opt -lower-vector-multi-reduction -split-input-file %s | FileCheck %s --check-prefix=CHECK-PAR
+// RUN: mlir-opt -lower-vector-multi-reduction="lowering-strategy=inner-reduction" -split-input-file %s | FileCheck %s --check-prefixes=ALL,INNER-REDUCTION
+// RUN: mlir-opt -lower-vector-multi-reduction="lowering-strategy=inner-parallel" -split-input-file %s | FileCheck %s --check-prefixes=ALL,INNER-PARALLEL
+// RUN: mlir-opt -lower-vector-multi-reduction -split-input-file %s | FileCheck %s --check-prefixes=ALL,INNER-PARALLEL
func.func @vector_multi_reduction(%arg0: vector<2x4xf32>, %acc: vector<2xf32>) -> vector<2xf32> {
%0 = vector.multi_reduction <mul>, %arg0, %acc [1] : vector<2x4xf32> to vector<2xf32>
return %0 : vector<2xf32>
}
-// CHECK-RED-LABEL: func @vector_multi_reduction
-// CHECK-RED-SAME: %[[INPUT:.+]]: vector<2x4xf32>, %[[ACC:.*]]: vector<2xf32>)
-// CHECK-RED-DAG: %[[RESULT_VEC_0:.+]] = arith.constant dense<{{.*}}> : vector<2xf32>
-// CHECK-RED-DAG: %[[C0:.+]] = arith.constant 0 : index
-// CHECK-RED-DAG: %[[C1:.+]] = arith.constant 1 : index
-// CHECK-RED: %[[V0:.+]] = vector.extract %[[INPUT]][0]
-// CHECK-RED: %[[ACC0:.+]] = vector.extract %[[ACC]][0]
-// CHECK-RED: %[[RV0:.+]] = vector.reduction <mul>, %[[V0]], %[[ACC0]] : vector<4xf32> into f32
-// CHECK-RED: %[[RESULT_VEC_1:.+]] = vector.insertelement %[[RV0:.+]], %[[RESULT_VEC_0]][%[[C0]] : index] : vector<2xf32>
-// CHECK-RED: %[[V1:.+]] = vector.extract %[[INPUT]][1]
-// CHECK-RED: %[[ACC1:.+]] = vector.extract %[[ACC]][1]
-// CHECK-RED: %[[RV1:.+]] = vector.reduction <mul>, %[[V1]], %[[ACC1]] : vector<4xf32> into f32
-// CHECK-RED: %[[RESULT_VEC:.+]] = vector.insertelement %[[RV1:.+]], %[[RESULT_VEC_1]][%[[C1]] : index] : vector<2xf32>
-// CHECK-RED: return %[[RESULT_VEC]]
+// ALL-LABEL: func @vector_multi_reduction
+// ALL-SAME: %[[INPUT:.+]]: vector<2x4xf32>, %[[ACC:.*]]: vector<2xf32>)
+// INNER-REDUCTION-DAG: %[[RESULT_VEC_0:.+]] = arith.constant dense<{{.*}}> : vector<2xf32>
+// INNER-REDUCTION-DAG: %[[C0:.+]] = arith.constant 0 : index
+// INNER-REDUCTION-DAG: %[[C1:.+]] = arith.constant 1 : index
+// INNER-REDUCTION: %[[V0:.+]] = vector.extract %[[INPUT]][0]
+// INNER-REDUCTION: %[[ACC0:.+]] = vector.extract %[[ACC]][0]
+// INNER-REDUCTION: %[[RV0:.+]] = vector.reduction <mul>, %[[V0]], %[[ACC0]] : vector<4xf32> into f32
+// INNER-REDUCTION: %[[RESULT_VEC_1:.+]] = vector.insertelement %[[RV0:.+]], %[[RESULT_VEC_0]][%[[C0]] : index] : vector<2xf32>
+// INNER-REDUCTION: %[[V1:.+]] = vector.extract %[[INPUT]][1]
+// INNER-REDUCTION: %[[ACC1:.+]] = vector.extract %[[ACC]][1]
+// INNER-REDUCTION: %[[RV1:.+]] = vector.reduction <mul>, %[[V1]], %[[ACC1]] : vector<4xf32> into f32
+// INNER-REDUCTION: %[[RESULT_VEC:.+]] = vector.insertelement %[[RV1:.+]], %[[RESULT_VEC_1]][%[[C1]] : index] : vector<2xf32>
+// INNER-REDUCTION: return %[[RESULT_VEC]]
-// CHECK-PAR-LABEL: func @vector_multi_reduction
-// CHECK-PAR-SAME: %[[INPUT:.+]]: vector<2x4xf32>, %[[ACC:.*]]: vector<2xf32>)
-// CHECK-PAR: %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xf32> to vector<4x2xf32>
-// CHECK-PAR: %[[V0:.+]] = vector.extract %[[TRANSPOSED]][0] : vector<2xf32> from vector<4x2xf32>
-// CHECK-PAR: %[[RV0:.+]] = arith.mulf %[[V0]], %[[ACC]] : vector<2xf32>
-// CHECK-PAR: %[[V1:.+]] = vector.extract %[[TRANSPOSED]][1] : vector<2xf32> from vector<4x2xf32>
-// CHECK-PAR: %[[RV01:.+]] = arith.mulf %[[V1]], %[[RV0]] : vector<2xf32>
-// CHECK-PAR: %[[V2:.+]] = vector.extract %[[TRANSPOSED]][2] : vector<2xf32> from vector<4x2xf32>
-// CHECK-PAR: %[[RV012:.+]] = arith.mulf %[[V2]], %[[RV01]] : vector<2xf32>
-// CHECK-PAR: %[[V3:.+]] = vector.extract %[[TRANSPOSED]][3] : vector<2xf32> from vector<4x2xf32>
-// CHECK-PAR: %[[RESULT_VEC:.+]] = arith.mulf %[[V3]], %[[RV012]] : vector<2xf32>
-// CHECK-PAR: return %[[RESULT_VEC]] : vector<2xf32>
+// INNER-PARALLEL: %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xf32> to vector<4x2xf32>
+// INNER-PARALLEL: %[[V0:.+]] = vector.extract %[[TRANSPOSED]][0] : vector<2xf32> from vector<4x2xf32>
+// INNER-PARALLEL: %[[RV0:.+]] = arith.mulf %[[V0]], %[[ACC]] : vector<2xf32>
+// INNER-PARALLEL: %[[V1:.+]] = vector.extract %[[TRANSPOSED]][1] : vector<2xf32> from vector<4x2xf32>
+// INNER-PARALLEL: %[[RV01:.+]] = arith.mulf %[[V1]], %[[RV0]] : vector<2xf32>
+// INNER-PARALLEL: %[[V2:.+]] = vector.extract %[[TRANSPOSED]][2] : vector<2xf32> from vector<4x2xf32>
+// INNER-PARALLEL: %[[RV012:.+]] = arith.mulf %[[V2]], %[[RV01]] : vector<2xf32>
+// INNER-PARALLEL: %[[V3:.+]] = vector.extract %[[TRANSPOSED]][3] : vector<2xf32> from vector<4x2xf32>
+// INNER-PARALLEL: %[[RESULT_VEC:.+]] = arith.mulf %[[V3]], %[[RV012]] : vector<2xf32>
+// INNER-PARALLEL: return %[[RESULT_VEC]] : vector<2xf32>
+
+// -----
func.func @vector_multi_reduction_parallel_middle(%arg0: vector<3x4x5xf32>, %acc: vector<4xf32>) -> vector<4xf32> {
%0 = vector.multi_reduction <add>, %arg0, %acc [0, 2] : vector<3x4x5xf32> to vector<4xf32>
return %0 : vector<4xf32>
}
-// CHECK-RED-LABEL: func @vector_multi_reduction_parallel_middle
-// CHECK-RED-SAME: %[[INPUT:.+]]: vector<3x4x5xf32>, %[[ACC:.+]]: vector<4xf32>
-// CHECK-RED: vector.transpose %[[INPUT]], [1, 0, 2] : vector<3x4x5xf32> to vector<4x3x5xf32>
-
-// CHECK-PAR-LABEL: func @vector_multi_reduction_parallel_middle
-// CHECK-PAR-SAME: %[[INPUT:.+]]: vector<3x4x5xf32>, %[[ACC:.+]]: vector<4xf32>
-// CHECK-PAR: vector.transpose %[[INPUT]], [0, 2, 1] : vector<3x4x5xf32> to vector<3x5x4xf32>
+// ALL-LABEL: func @vector_multi_reduction_parallel_middle
+// ALL-SAME: %[[INPUT:.+]]: vector<3x4x5xf32>, %[[ACC:.+]]: vector<4xf32>
+// INNER-REDUCTION: vector.transpose %[[INPUT]], [1, 0, 2] : vector<3x4x5xf32> to vector<4x3x5xf32>
+// INNER-PARALLEL: vector.transpose %[[INPUT]], [0, 2, 1] : vector<3x4x5xf32> to vector<3x5x4xf32>
>From f6bde73e2b9a0ad809e5269b216f4e5525b45d53 Mon Sep 17 00:00:00 2001
From: Xiaolei Shi <xiaoleis at nvidia.com>
Date: Tue, 9 Apr 2024 09:24:58 +0800
Subject: [PATCH 4/4] format the test
---
.../vector-multi-reduction-pass-lowering.mlir | 20 +++++++++----------
1 file changed, 10 insertions(+), 10 deletions(-)
diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-pass-lowering.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-pass-lowering.mlir
index e6f5eda28b1a27..4cb6fba9b691a6 100644
--- a/mlir/test/Dialect/Vector/vector-multi-reduction-pass-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-multi-reduction-pass-lowering.mlir
@@ -21,16 +21,16 @@ func.func @vector_multi_reduction(%arg0: vector<2x4xf32>, %acc: vector<2xf32>) -
// INNER-REDUCTION: %[[RESULT_VEC:.+]] = vector.insertelement %[[RV1:.+]], %[[RESULT_VEC_1]][%[[C1]] : index] : vector<2xf32>
// INNER-REDUCTION: return %[[RESULT_VEC]]
-// INNER-PARALLEL: %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xf32> to vector<4x2xf32>
-// INNER-PARALLEL: %[[V0:.+]] = vector.extract %[[TRANSPOSED]][0] : vector<2xf32> from vector<4x2xf32>
-// INNER-PARALLEL: %[[RV0:.+]] = arith.mulf %[[V0]], %[[ACC]] : vector<2xf32>
-// INNER-PARALLEL: %[[V1:.+]] = vector.extract %[[TRANSPOSED]][1] : vector<2xf32> from vector<4x2xf32>
-// INNER-PARALLEL: %[[RV01:.+]] = arith.mulf %[[V1]], %[[RV0]] : vector<2xf32>
-// INNER-PARALLEL: %[[V2:.+]] = vector.extract %[[TRANSPOSED]][2] : vector<2xf32> from vector<4x2xf32>
-// INNER-PARALLEL: %[[RV012:.+]] = arith.mulf %[[V2]], %[[RV01]] : vector<2xf32>
-// INNER-PARALLEL: %[[V3:.+]] = vector.extract %[[TRANSPOSED]][3] : vector<2xf32> from vector<4x2xf32>
-// INNER-PARALLEL: %[[RESULT_VEC:.+]] = arith.mulf %[[V3]], %[[RV012]] : vector<2xf32>
-// INNER-PARALLEL: return %[[RESULT_VEC]] : vector<2xf32>
+// INNER-PARALLEL: %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xf32> to vector<4x2xf32>
+// INNER-PARALLEL: %[[V0:.+]] = vector.extract %[[TRANSPOSED]][0] : vector<2xf32> from vector<4x2xf32>
+// INNER-PARALLEL: %[[RV0:.+]] = arith.mulf %[[V0]], %[[ACC]] : vector<2xf32>
+// INNER-PARALLEL: %[[V1:.+]] = vector.extract %[[TRANSPOSED]][1] : vector<2xf32> from vector<4x2xf32>
+// INNER-PARALLEL: %[[RV01:.+]] = arith.mulf %[[V1]], %[[RV0]] : vector<2xf32>
+// INNER-PARALLEL: %[[V2:.+]] = vector.extract %[[TRANSPOSED]][2] : vector<2xf32> from vector<4x2xf32>
+// INNER-PARALLEL: %[[RV012:.+]] = arith.mulf %[[V2]], %[[RV01]] : vector<2xf32>
+// INNER-PARALLEL: %[[V3:.+]] = vector.extract %[[TRANSPOSED]][3] : vector<2xf32> from vector<4x2xf32>
+// INNER-PARALLEL: %[[RESULT_VEC:.+]] = arith.mulf %[[V3]], %[[RV012]] : vector<2xf32>
+// INNER-PARALLEL: return %[[RESULT_VEC]] : vector<2xf32>
// -----
More information about the Mlir-commits
mailing list