[Mlir-commits] [mlir] Refactor ConvertVectorToLLVMPass options (PR #128219)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Feb 21 11:18:11 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-vector
@llvm/pr-subscribers-mlir-spirv
Author: Artemiy Bulavin (abulavin)
<details>
<summary>Changes</summary>
The `VectorTransformsOptions` on the `ConvertVectorToLLVMPass` is currently represented as a struct, which makes it not serialisable. This means a pass pipeline that contains this pass cannot be represented as textual form, which breaks reproducer generation and options such as `--dump-pass-pipeline`.
This PR expands the `VectorTransformsOptions` struct into the two options that are actually used by the Pass' patterns: `vector-contract-lowering` and `vector-transpose-lowering` . The other options present in VectorTransformOptions are not used by any patterns in this pass.
Additionally, I have changed some interfaces to only take these specific options over the full options struct as, again, many of the conversion patterns only need one of the options.
Finally, I have added a simple lit test that just prints the pass pipeline using `--dump-pass-pipeline` to ensure the options on this pass remain serialisable.
---
Patch is 22.72 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/128219.diff
8 Files Affected:
- (modified) mlir/include/mlir/Conversion/Passes.td (+27-5)
- (modified) mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h (+7-4)
- (modified) mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp (+2-2)
- (modified) mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp (+2-3)
- (modified) mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp (+3-6)
- (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp (+31-33)
- (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp (+13-15)
- (added) mlir/test/Conversion/VectorToLLVM/test-serialisable.mlir (+16)
``````````diff
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index cccdf0a8518bf..606a38f7d98eb 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -10,7 +10,7 @@
#define MLIR_CONVERSION_PASSES
include "mlir/Pass/PassBase.td"
-
+include "mlir/Dialect/Vector/Transforms/VectorTransformsBase.td"
//===----------------------------------------------------------------------===//
// ToLLVM
@@ -1410,10 +1410,32 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> {
"bool", /*default=*/"false",
"Enables the use of X86Vector dialect while lowering the vector "
"dialect.">,
- Option<"vectorTransformsOptions", "vector-transform-options",
- "vector::VectorTransformsOptions",
- /*default=*/"vector::VectorTransformsOptions()",
- "Options to lower some operations like contractions and transposes.">,
+ Option<"vectorContractLowering", "vector-contract-lowering",
+ "vector::VectorContractLowering",
+ /*default=*/"vector::VectorContractLowering::Dot",
+ VectorContractLoweringAttr.summary, [{::llvm::cl::values(
+ clEnumValN(::mlir::vector::VectorContractLowering::Dot, "dot",
+ "Progressively lower to finer grained `vector.contract` and dot-products. (default)"),
+ clEnumValN(::mlir::vector::VectorContractLowering::Matmul, "matmul",
+ "Lower to `vector.matrix_multiply`, maps 1-1 to LLVM matrix intrinsics."),
+ clEnumValN(::mlir::vector::VectorContractLowering::OuterProduct, "outerproduct",
+ "Lower to `vector.outerproduct`."),
+ clEnumValN(::mlir::vector::VectorContractLowering::ParallelArith, "parallelarith",
+ "Lower contract with all reduction dimensions unrolled to 1 to a vector elementwise operations.")
+ )}]>,
+ Option<"vectorTransposeLowering", "vector-transpose-lowering",
+ "vector::VectorTransposeLowering",
+ /*default=*/"vector::VectorTransposeLowering::EltWise",
+ VectorTransposeLoweringAttr.summary, [{::llvm::cl::values(
+ clEnumValN(::mlir::vector::VectorTransposeLowering::EltWise, "eltwise",
+ "Lower transpose into element-wise extract and inserts (default)"),
+ clEnumValN(::mlir::vector::VectorTransposeLowering::Flat, "flat",
+ "Lower 2-D transpose to `vector.flat_transpose`, maps 1-1 to LLVM matrix intrinsics"),
+ clEnumValN(::mlir::vector::VectorTransposeLowering::Shuffle1D, "shuffle1d",
+ "Lower 2-D transpose to `vector.shuffle` on 1-D vector."),
+ clEnumValN(::mlir::vector::VectorTransposeLowering::Shuffle16x16, "shuffle16x16",
+ "Lower 2-D transpose to `vector.shuffle` on 16x16 vector.")
+ )}]>,
];
}
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index 6aeae30a0a6c0..601a65333d026 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -9,6 +9,7 @@
#ifndef MLIR_DIALECT_VECTOR_TRANSFORMS_LOWERINGPATTERNS_H
#define MLIR_DIALECT_VECTOR_TRANSFORMS_LOWERINGPATTERNS_H
+#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
namespace mlir {
@@ -47,7 +48,8 @@ namespace vector {
/// Progressively lower a `vector.contract` with row-major matmul semantics to
/// linearized `vector.extract` + `vector.outerproduct` + `vector.insert`.
void populateVectorContractLoweringPatterns(
- RewritePatternSet &patterns, VectorTransformsOptions options,
+ RewritePatternSet &patterns,
+ VectorContractLowering vectorContractLoweringOption,
PatternBenefit benefit = 1, bool disableOuterProductLowering = false);
/// Populate the pattern set with the following patterns:
@@ -142,9 +144,10 @@ void populateVectorShapeCastLoweringPatterns(RewritePatternSet &patterns,
///
/// [TransposeOp2DToShuffleLowering]
///
-void populateVectorTransposeLoweringPatterns(RewritePatternSet &patterns,
- VectorTransformsOptions options,
- PatternBenefit benefit = 1);
+void populateVectorTransposeLoweringPatterns(
+ RewritePatternSet &patterns,
+ VectorTransposeLowering vectorTransposeLowering,
+ PatternBenefit benefit = 1);
/// Populate the pattern set with the following patterns:
///
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index e3a81bd20212d..eb1555df5d574 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -69,11 +69,11 @@ void ConvertVectorToLLVMPass::runOnOperation() {
populateVectorToVectorCanonicalizationPatterns(patterns);
populateVectorBitCastLoweringPatterns(patterns);
populateVectorBroadcastLoweringPatterns(patterns);
- populateVectorContractLoweringPatterns(patterns, vectorTransformsOptions);
+ populateVectorContractLoweringPatterns(patterns, vectorContractLowering);
populateVectorMaskOpLoweringPatterns(patterns);
populateVectorShapeCastLoweringPatterns(patterns);
populateVectorInterleaveLoweringPatterns(patterns);
- populateVectorTransposeLoweringPatterns(patterns, vectorTransformsOptions);
+ populateVectorTransposeLoweringPatterns(patterns, vectorTransposeLowering);
// Vector transfer ops with rank > 1 should be lowered with VectorToSCF.
populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1);
populateVectorMaskMaterializationPatterns(patterns,
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index c56dbcca2175d..a60410d01ac57 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -1374,9 +1374,8 @@ LogicalResult mlir::spirv::unrollVectorsInFuncBodies(Operation *op) {
// further transformations to canonicalize/cancel.
{
RewritePatternSet patterns(context);
- auto options = vector::VectorTransformsOptions().setVectorTransposeLowering(
- vector::VectorTransposeLowering::EltWise);
- vector::populateVectorTransposeLoweringPatterns(patterns, options);
+ vector::populateVectorTransposeLoweringPatterns(
+ patterns, vector::VectorTransposeLowering::EltWise);
vector::populateVectorShapeCastLoweringPatterns(patterns);
if (failed(applyPatternsGreedily(op, std::move(patterns))))
return failure();
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index 241e83e234d62..20c577273d786 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -102,9 +102,7 @@ void transform::ApplyLowerBroadcastPatternsOp::populatePatterns(
void transform::ApplyLowerContractionPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
- vector::VectorTransformsOptions vectorTransformOptions;
- vectorTransformOptions.setVectorTransformsOptions(getLoweringStrategy());
- populateVectorContractLoweringPatterns(patterns, vectorTransformOptions,
+ populateVectorContractLoweringPatterns(patterns, getLoweringStrategy(),
/*benefit=*/1,
/*disableOuterProductLowering=*/true);
}
@@ -161,9 +159,8 @@ void transform::ApplyLowerTransferPatternsOp::populatePatterns(
void transform::ApplyLowerTransposePatternsOp::populatePatterns(
RewritePatternSet &patterns) {
- vector::populateVectorTransposeLoweringPatterns(
- patterns, vector::VectorTransformsOptions().setVectorTransposeLowering(
- getLoweringStrategy()));
+ vector::populateVectorTransposeLoweringPatterns(patterns,
+ getLoweringStrategy());
if (getAvx2LoweringStrategy()) {
auto avx2LoweringOptions =
x86vector::avx2::LoweringOptions().setTransposeOptions(
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
index 21261478f0648..d2f60a55fb4a6 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
@@ -221,7 +221,7 @@ namespace {
/// ```
/// `vector.matmul` later lowers to `llvm.matrix.multiply`.
//
-/// This only kicks in when VectorTransformsOptions is set to OuterProduct and
+/// This only kicks in when VectorTransformsOptions is set to Matmul and
/// the vector.contract op is a row-major matrix multiply.
class ContractionOpToMatmulOpLowering
: public vector::MaskableOpRewritePattern<vector::ContractionOp> {
@@ -236,11 +236,11 @@ class ContractionOpToMatmulOpLowering
}
ContractionOpToMatmulOpLowering(
- vector::VectorTransformsOptions vectorTransformOptions,
+ vector::VectorContractLowering vectorContractLowering,
MLIRContext *context, PatternBenefit benefit = 1,
FilterConstraintType constraint = defaultFilter)
: MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
- vectorTransformOptions(vectorTransformOptions),
+ vectorContractLowering(vectorContractLowering),
filter(std::move(constraint)) {}
FailureOr<Value>
@@ -249,7 +249,7 @@ class ContractionOpToMatmulOpLowering
private:
/// Options to control the vector patterns.
- vector::VectorTransformsOptions vectorTransformOptions;
+ vector::VectorContractLowering vectorContractLowering;
FilterConstraintType filter;
};
@@ -281,11 +281,11 @@ class ContractionOpToOuterProductOpLowering
}
ContractionOpToOuterProductOpLowering(
- vector::VectorTransformsOptions vectorTransformOptions,
+ vector::VectorContractLowering vectorContractLowering,
MLIRContext *context, PatternBenefit benefit = 1,
FilterConstraintType constraint = defaultFilter)
: MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
- vectorTransformOptions(vectorTransformOptions),
+ vectorContractLowering(vectorContractLowering),
filter(std::move(constraint)) {}
FailureOr<Value>
@@ -294,7 +294,7 @@ class ContractionOpToOuterProductOpLowering
private:
/// Options to control the vector patterns.
- vector::VectorTransformsOptions vectorTransformOptions;
+ vector::VectorContractLowering vectorContractLowering;
FilterConstraintType filter;
};
@@ -329,11 +329,11 @@ class ContractionOpToDotLowering
}
ContractionOpToDotLowering(
- vector::VectorTransformsOptions vectorTransformOptions,
+ vector::VectorContractLowering vectorContractLowering,
MLIRContext *context, PatternBenefit benefit = 1,
const FilterConstraintType &constraint = defaultFilter)
: MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
- vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {}
+ vectorContractLowering(vectorContractLowering), filter(defaultFilter) {}
FailureOr<Value>
matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp,
@@ -341,7 +341,7 @@ class ContractionOpToDotLowering
private:
/// Options to control the vector patterns.
- vector::VectorTransformsOptions vectorTransformOptions;
+ vector::VectorContractLowering vectorContractLowering;
FilterConstraintType filter;
};
@@ -370,11 +370,12 @@ class ContractionOpLowering
return success();
}
- ContractionOpLowering(vector::VectorTransformsOptions vectorTransformOptions,
- MLIRContext *context, PatternBenefit benefit = 1,
- FilterConstraintType constraint = defaultFilter)
+ ContractionOpLowering(
+ vector::VectorContractLowering vectorContractLoweringOption,
+ MLIRContext *context, PatternBenefit benefit = 1,
+ FilterConstraintType constraint = defaultFilter)
: MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
- vectorTransformOptions(vectorTransformOptions),
+ vectorContractLoweringOption(vectorContractLoweringOption),
filter(std::move(constraint)) {}
FailureOr<Value>
@@ -383,7 +384,7 @@ class ContractionOpLowering
private:
/// Options to control the vector patterns.
- vector::VectorTransformsOptions vectorTransformOptions;
+ vector::VectorContractLowering vectorContractLoweringOption;
FilterConstraintType filter;
// Lower one parallel dimension.
FailureOr<Value> lowerParallel(PatternRewriter &rewriter,
@@ -641,8 +642,7 @@ FailureOr<Value>
ContractionOpToOuterProductOpLowering::matchAndRewriteMaskableOp(
vector::ContractionOp op, MaskingOpInterface maskOp,
PatternRewriter &rewriter) const {
- if (vectorTransformOptions.vectorContractLowering !=
- vector::VectorContractLowering::OuterProduct)
+ if (vectorContractLowering != vector::VectorContractLowering::OuterProduct)
return failure();
if (failed(filter(op)))
@@ -672,8 +672,7 @@ FailureOr<Value> ContractionOpToDotLowering::matchAndRewriteMaskableOp(
if (failed(filter(op)))
return failure();
- if (vectorTransformOptions.vectorContractLowering !=
- vector::VectorContractLowering::Dot)
+ if (vectorContractLowering != vector::VectorContractLowering::Dot)
return failure();
auto iteratorTypes = op.getIteratorTypes().getValue();
@@ -789,11 +788,11 @@ struct ContractOpToElementwise
return success();
}
ContractOpToElementwise(
- vector::VectorTransformsOptions vectorTransformOptions,
+ vector::VectorContractLowering vectorContractLowering,
MLIRContext *context, PatternBenefit benefit = 1,
const FilterConstraintType &constraint = defaultFilter)
: MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
- vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {}
+ vectorContractLowering(vectorContractLowering), filter(defaultFilter) {}
FailureOr<Value>
matchAndRewriteMaskableOp(vector::ContractionOp contractOp,
@@ -806,8 +805,7 @@ struct ContractOpToElementwise
if (failed(filter(contractOp)))
return failure();
- if (vectorTransformOptions.vectorContractLowering !=
- vector::VectorContractLowering::ParallelArith)
+ if (vectorContractLowering != vector::VectorContractLowering::ParallelArith)
return failure();
ArrayRef<int64_t> lhsShape = contractOp.getLhsType().getShape();
@@ -898,7 +896,7 @@ struct ContractOpToElementwise
private:
/// Options to control the vector patterns.
- vector::VectorTransformsOptions vectorTransformOptions;
+ vector::VectorContractLowering vectorContractLowering;
FilterConstraintType filter;
};
@@ -941,25 +939,25 @@ FailureOr<Value> ContractionOpLowering::matchAndRewriteMaskableOp(
// TODO: implement benefits, cost models.
MLIRContext *ctx = op.getContext();
- ContractionOpToMatmulOpLowering pat1(vectorTransformOptions, ctx);
+ ContractionOpToMatmulOpLowering pat1(vectorContractLoweringOption, ctx);
FailureOr<Value> newVal1 =
pat1.matchAndRewriteMaskableOp(op, maskOp, rewriter);
if (!failed(newVal1))
return newVal1;
- ContractionOpToOuterProductOpLowering pat2(vectorTransformOptions, ctx);
+ ContractionOpToOuterProductOpLowering pat2(vectorContractLoweringOption, ctx);
FailureOr<Value> newVal2 =
pat2.matchAndRewriteMaskableOp(op, maskOp, rewriter);
if (!failed(newVal2))
return newVal2;
- ContractionOpToDotLowering pat3(vectorTransformOptions, ctx);
+ ContractionOpToDotLowering pat3(vectorContractLoweringOption, ctx);
FailureOr<Value> newVal3 =
pat3.matchAndRewriteMaskableOp(op, maskOp, rewriter);
if (!failed(newVal3))
return newVal3;
- ContractOpToElementwise pat4(vectorTransformOptions, ctx);
+ ContractOpToElementwise pat4(vectorContractLoweringOption, ctx);
FailureOr<Value> newVal4 =
pat4.matchAndRewriteMaskableOp(op, maskOp, rewriter);
if (!failed(newVal4))
@@ -1292,8 +1290,7 @@ FailureOr<Value> ContractionOpToMatmulOpLowering::matchAndRewriteMaskableOp(
if (maskOp)
return failure();
- if (vectorTransformOptions.vectorContractLowering !=
- vector::VectorContractLowering::Matmul)
+ if (vectorContractLowering != vector::VectorContractLowering::Matmul)
return failure();
if (failed(filter(op)))
return failure();
@@ -1382,13 +1379,14 @@ FailureOr<Value> ContractionOpToMatmulOpLowering::matchAndRewriteMaskableOp(
} // namespace
void mlir::vector::populateVectorContractLoweringPatterns(
- RewritePatternSet &patterns, VectorTransformsOptions options,
- PatternBenefit benefit, bool disableOuterProductLowering) {
+ RewritePatternSet &patterns,
+ VectorContractLowering vectorContractLoweringOption, PatternBenefit benefit,
+ bool disableOuterProductLowering) {
if (!disableOuterProductLowering)
patterns.add<OuterProductOpLowering>(patterns.getContext(), benefit);
patterns.add<ContractionOpLowering, ContractionOpToMatmulOpLowering,
ContractionOpToOuterProductOpLowering>(
- options, patterns.getContext(), benefit);
+ vectorContractLoweringOption, patterns.getContext(), benefit);
}
void mlir::vector::populateVectorOuterProductLoweringPatterns(
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
index fb4dee33bc5f5..732e316c93381 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
@@ -304,10 +304,10 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
public:
using OpRewritePattern::OpRewritePattern;
- TransposeOpLowering(vector::VectorTransformsOptions vectorTransformOptions,
+ TransposeOpLowering(vector::VectorTransposeLowering vectorTransposeLowering,
MLIRContext *context, PatternBenefit benefit = 1)
: OpRewritePattern<vector::TransposeOp>(context, benefit),
- vectorTransformOptions(vectorTransformOptions) {}
+ vectorTransposeLowering(vectorTransposeLowering) {}
LogicalResult matchAndRewrite(vector::TransposeOp op,
PatternRewriter &rewriter) const override {
@@ -324,14 +324,13 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
// Set up convenience transposition table.
ArrayRef<int64_t> transp = op.getPermutation();
- if (isShuffleLike(vectorTransformOptions.vectorTransposeLowering) &&
+ if (isShuffleLike(vectorTransposeLowering) &&
succeeded(isTranspose2DSlice(op)))
return rewriter.notifyMatchFailure(
op, "Options specifies lowering to shuffle");
// Handle a true 2-D matrix transpose differently when requested.
- if (vectorTransformOptions.vectorTransposeLowering ==
- vector::VectorTransposeLowering::Flat &&
+ if (vectorTransposeLowering == vector::VectorTransposeLowering::Flat &&
resType.getRank() == 2 && transp[0] == 1 && transp[1] == 0) {
Type flattenedType =
VectorType::get(resType.getNumElements(), resType.getElementType());
@@ -380,7 +379,7 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
private:
/// Options to control the vector patterns.
- vector::VectorTransformsOptions vectorTransformOptions;
+ vector::VectorTransposeLowering vectorTransposeLowering;
};
/// Rewrites vector.transpose as vector.shape_cast. This pattern is only applied
@@ -454,14 +453,14 @@ class TransposeOp2DToShuffleLowering
using OpRewritePattern::OpRewritePattern;
TransposeOp2DToShuffleLowering(
- vector::VectorTransformsOptions vectorTransformOptions,
+ vector::VectorTransposeLowering vectorTransposeLowering,
MLIRContext *context, PatternBenefit benefit = 1)
: OpRewritePattern<vector::TransposeOp>(context, benefit),
- vectorTransformOptions(vectorTransformOptions) {}
+ vectorTransposeLower...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/128219
More information about the Mlir-commits
mailing list