[Mlir-commits] [mlir] Refactor ConvertVectorToLLVMPass options (PR #128219)
Artemiy Bulavin
llvmlistbot at llvm.org
Fri Feb 21 11:17:20 PST 2025
https://github.com/abulavin created https://github.com/llvm/llvm-project/pull/128219
The `VectorTransformsOptions` on the `ConvertVectorToLLVMPass` is currently represented as a struct, which makes it not serialisable. This means a pass pipeline that contains this pass cannot be represented as textual form, which breaks reproducer generation and options such as `--dump-pass-pipeline`.
This PR expands the `VectorTransformsOptions` struct into the two options that are actually used by the Pass' patterns: `vector-contract-lowering` and `vector-transpose-lowering` . The other options present in VectorTransformOptions are not used by any patterns in this pass.
Additionally, I have changed some interfaces to only take these specific options over the full options struct as, again, many of the conversion patterns only need one of the options.
Finally, I have added a simple lit test that just prints the pass pipeline using `--dump-pass-pipeline` to ensure the options on this pass remain serialisable.
>From 69100604c8883af26d64eb47495aa4401fdef86f Mon Sep 17 00:00:00 2001
From: Artemiy Bulavin <artemiyb at graphcore.ai>
Date: Fri, 14 Feb 2025 15:02:11 +0000
Subject: [PATCH] Explicitly speciy all vector transform options on
ConvertVectorToLLVMPass
Refactor ConvertVectorToLLVMPass options
---
mlir/include/mlir/Conversion/Passes.td | 32 ++++++++--
.../Vector/Transforms/LoweringPatterns.h | 11 ++--
.../VectorToLLVM/ConvertVectorToLLVMPass.cpp | 4 +-
.../SPIRV/Transforms/SPIRVConversion.cpp | 5 +-
.../TransformOps/VectorTransformOps.cpp | 9 +--
.../Vector/Transforms/LowerVectorContract.cpp | 64 +++++++++----------
.../Transforms/LowerVectorTranspose.cpp | 28 ++++----
.../VectorToLLVM/test-serialisable.mlir | 16 +++++
8 files changed, 101 insertions(+), 68 deletions(-)
create mode 100644 mlir/test/Conversion/VectorToLLVM/test-serialisable.mlir
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index cccdf0a8518bf..606a38f7d98eb 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -10,7 +10,7 @@
#define MLIR_CONVERSION_PASSES
include "mlir/Pass/PassBase.td"
-
+include "mlir/Dialect/Vector/Transforms/VectorTransformsBase.td"
//===----------------------------------------------------------------------===//
// ToLLVM
@@ -1410,10 +1410,32 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> {
"bool", /*default=*/"false",
"Enables the use of X86Vector dialect while lowering the vector "
"dialect.">,
- Option<"vectorTransformsOptions", "vector-transform-options",
- "vector::VectorTransformsOptions",
- /*default=*/"vector::VectorTransformsOptions()",
- "Options to lower some operations like contractions and transposes.">,
+ Option<"vectorContractLowering", "vector-contract-lowering",
+ "vector::VectorContractLowering",
+ /*default=*/"vector::VectorContractLowering::Dot",
+ VectorContractLoweringAttr.summary, [{::llvm::cl::values(
+ clEnumValN(::mlir::vector::VectorContractLowering::Dot, "dot",
+ "Progressively lower to finer grained `vector.contract` and dot-products. (default)"),
+ clEnumValN(::mlir::vector::VectorContractLowering::Matmul, "matmul",
+ "Lower to `vector.matrix_multiply`, maps 1-1 to LLVM matrix intrinsics."),
+ clEnumValN(::mlir::vector::VectorContractLowering::OuterProduct, "outerproduct",
+ "Lower to `vector.outerproduct`."),
+ clEnumValN(::mlir::vector::VectorContractLowering::ParallelArith, "parallelarith",
+ "Lower contract with all reduction dimensions unrolled to 1 to a vector elementwise operations.")
+ )}]>,
+ Option<"vectorTransposeLowering", "vector-transpose-lowering",
+ "vector::VectorTransposeLowering",
+ /*default=*/"vector::VectorTransposeLowering::EltWise",
+ VectorTransposeLoweringAttr.summary, [{::llvm::cl::values(
+ clEnumValN(::mlir::vector::VectorTransposeLowering::EltWise, "eltwise",
+ "Lower transpose into element-wise extract and inserts (default)"),
+ clEnumValN(::mlir::vector::VectorTransposeLowering::Flat, "flat",
+ "Lower 2-D transpose to `vector.flat_transpose`, maps 1-1 to LLVM matrix intrinsics"),
+ clEnumValN(::mlir::vector::VectorTransposeLowering::Shuffle1D, "shuffle1d",
+ "Lower 2-D transpose to `vector.shuffle` on 1-D vector."),
+ clEnumValN(::mlir::vector::VectorTransposeLowering::Shuffle16x16, "shuffle16x16",
+ "Lower 2-D transpose to `vector.shuffle` on 16x16 vector.")
+ )}]>,
];
}
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index 6aeae30a0a6c0..601a65333d026 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -9,6 +9,7 @@
#ifndef MLIR_DIALECT_VECTOR_TRANSFORMS_LOWERINGPATTERNS_H
#define MLIR_DIALECT_VECTOR_TRANSFORMS_LOWERINGPATTERNS_H
+#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
namespace mlir {
@@ -47,7 +48,8 @@ namespace vector {
/// Progressively lower a `vector.contract` with row-major matmul semantics to
/// linearized `vector.extract` + `vector.outerproduct` + `vector.insert`.
void populateVectorContractLoweringPatterns(
- RewritePatternSet &patterns, VectorTransformsOptions options,
+ RewritePatternSet &patterns,
+ VectorContractLowering vectorContractLoweringOption,
PatternBenefit benefit = 1, bool disableOuterProductLowering = false);
/// Populate the pattern set with the following patterns:
@@ -142,9 +144,10 @@ void populateVectorShapeCastLoweringPatterns(RewritePatternSet &patterns,
///
/// [TransposeOp2DToShuffleLowering]
///
-void populateVectorTransposeLoweringPatterns(RewritePatternSet &patterns,
- VectorTransformsOptions options,
- PatternBenefit benefit = 1);
+void populateVectorTransposeLoweringPatterns(
+ RewritePatternSet &patterns,
+ VectorTransposeLowering vectorTransposeLowering,
+ PatternBenefit benefit = 1);
/// Populate the pattern set with the following patterns:
///
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index e3a81bd20212d..eb1555df5d574 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -69,11 +69,11 @@ void ConvertVectorToLLVMPass::runOnOperation() {
populateVectorToVectorCanonicalizationPatterns(patterns);
populateVectorBitCastLoweringPatterns(patterns);
populateVectorBroadcastLoweringPatterns(patterns);
- populateVectorContractLoweringPatterns(patterns, vectorTransformsOptions);
+ populateVectorContractLoweringPatterns(patterns, vectorContractLowering);
populateVectorMaskOpLoweringPatterns(patterns);
populateVectorShapeCastLoweringPatterns(patterns);
populateVectorInterleaveLoweringPatterns(patterns);
- populateVectorTransposeLoweringPatterns(patterns, vectorTransformsOptions);
+ populateVectorTransposeLoweringPatterns(patterns, vectorTransposeLowering);
// Vector transfer ops with rank > 1 should be lowered with VectorToSCF.
populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1);
populateVectorMaskMaterializationPatterns(patterns,
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index c56dbcca2175d..a60410d01ac57 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -1374,9 +1374,8 @@ LogicalResult mlir::spirv::unrollVectorsInFuncBodies(Operation *op) {
// further transformations to canonicalize/cancel.
{
RewritePatternSet patterns(context);
- auto options = vector::VectorTransformsOptions().setVectorTransposeLowering(
- vector::VectorTransposeLowering::EltWise);
- vector::populateVectorTransposeLoweringPatterns(patterns, options);
+ vector::populateVectorTransposeLoweringPatterns(
+ patterns, vector::VectorTransposeLowering::EltWise);
vector::populateVectorShapeCastLoweringPatterns(patterns);
if (failed(applyPatternsGreedily(op, std::move(patterns))))
return failure();
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index 241e83e234d62..20c577273d786 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -102,9 +102,7 @@ void transform::ApplyLowerBroadcastPatternsOp::populatePatterns(
void transform::ApplyLowerContractionPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
- vector::VectorTransformsOptions vectorTransformOptions;
- vectorTransformOptions.setVectorTransformsOptions(getLoweringStrategy());
- populateVectorContractLoweringPatterns(patterns, vectorTransformOptions,
+ populateVectorContractLoweringPatterns(patterns, getLoweringStrategy(),
/*benefit=*/1,
/*disableOuterProductLowering=*/true);
}
@@ -161,9 +159,8 @@ void transform::ApplyLowerTransferPatternsOp::populatePatterns(
void transform::ApplyLowerTransposePatternsOp::populatePatterns(
RewritePatternSet &patterns) {
- vector::populateVectorTransposeLoweringPatterns(
- patterns, vector::VectorTransformsOptions().setVectorTransposeLowering(
- getLoweringStrategy()));
+ vector::populateVectorTransposeLoweringPatterns(patterns,
+ getLoweringStrategy());
if (getAvx2LoweringStrategy()) {
auto avx2LoweringOptions =
x86vector::avx2::LoweringOptions().setTransposeOptions(
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
index 21261478f0648..d2f60a55fb4a6 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
@@ -221,7 +221,7 @@ namespace {
/// ```
/// `vector.matmul` later lowers to `llvm.matrix.multiply`.
//
-/// This only kicks in when VectorTransformsOptions is set to OuterProduct and
+/// This only kicks in when VectorTransformsOptions is set to Matmul and
/// the vector.contract op is a row-major matrix multiply.
class ContractionOpToMatmulOpLowering
: public vector::MaskableOpRewritePattern<vector::ContractionOp> {
@@ -236,11 +236,11 @@ class ContractionOpToMatmulOpLowering
}
ContractionOpToMatmulOpLowering(
- vector::VectorTransformsOptions vectorTransformOptions,
+ vector::VectorContractLowering vectorContractLowering,
MLIRContext *context, PatternBenefit benefit = 1,
FilterConstraintType constraint = defaultFilter)
: MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
- vectorTransformOptions(vectorTransformOptions),
+ vectorContractLowering(vectorContractLowering),
filter(std::move(constraint)) {}
FailureOr<Value>
@@ -249,7 +249,7 @@ class ContractionOpToMatmulOpLowering
private:
/// Options to control the vector patterns.
- vector::VectorTransformsOptions vectorTransformOptions;
+ vector::VectorContractLowering vectorContractLowering;
FilterConstraintType filter;
};
@@ -281,11 +281,11 @@ class ContractionOpToOuterProductOpLowering
}
ContractionOpToOuterProductOpLowering(
- vector::VectorTransformsOptions vectorTransformOptions,
+ vector::VectorContractLowering vectorContractLowering,
MLIRContext *context, PatternBenefit benefit = 1,
FilterConstraintType constraint = defaultFilter)
: MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
- vectorTransformOptions(vectorTransformOptions),
+ vectorContractLowering(vectorContractLowering),
filter(std::move(constraint)) {}
FailureOr<Value>
@@ -294,7 +294,7 @@ class ContractionOpToOuterProductOpLowering
private:
/// Options to control the vector patterns.
- vector::VectorTransformsOptions vectorTransformOptions;
+ vector::VectorContractLowering vectorContractLowering;
FilterConstraintType filter;
};
@@ -329,11 +329,11 @@ class ContractionOpToDotLowering
}
ContractionOpToDotLowering(
- vector::VectorTransformsOptions vectorTransformOptions,
+ vector::VectorContractLowering vectorContractLowering,
MLIRContext *context, PatternBenefit benefit = 1,
const FilterConstraintType &constraint = defaultFilter)
: MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
- vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {}
+ vectorContractLowering(vectorContractLowering), filter(defaultFilter) {}
FailureOr<Value>
matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp,
@@ -341,7 +341,7 @@ class ContractionOpToDotLowering
private:
/// Options to control the vector patterns.
- vector::VectorTransformsOptions vectorTransformOptions;
+ vector::VectorContractLowering vectorContractLowering;
FilterConstraintType filter;
};
@@ -370,11 +370,12 @@ class ContractionOpLowering
return success();
}
- ContractionOpLowering(vector::VectorTransformsOptions vectorTransformOptions,
- MLIRContext *context, PatternBenefit benefit = 1,
- FilterConstraintType constraint = defaultFilter)
+ ContractionOpLowering(
+ vector::VectorContractLowering vectorContractLoweringOption,
+ MLIRContext *context, PatternBenefit benefit = 1,
+ FilterConstraintType constraint = defaultFilter)
: MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
- vectorTransformOptions(vectorTransformOptions),
+ vectorContractLoweringOption(vectorContractLoweringOption),
filter(std::move(constraint)) {}
FailureOr<Value>
@@ -383,7 +384,7 @@ class ContractionOpLowering
private:
/// Options to control the vector patterns.
- vector::VectorTransformsOptions vectorTransformOptions;
+ vector::VectorContractLowering vectorContractLoweringOption;
FilterConstraintType filter;
// Lower one parallel dimension.
FailureOr<Value> lowerParallel(PatternRewriter &rewriter,
@@ -641,8 +642,7 @@ FailureOr<Value>
ContractionOpToOuterProductOpLowering::matchAndRewriteMaskableOp(
vector::ContractionOp op, MaskingOpInterface maskOp,
PatternRewriter &rewriter) const {
- if (vectorTransformOptions.vectorContractLowering !=
- vector::VectorContractLowering::OuterProduct)
+ if (vectorContractLowering != vector::VectorContractLowering::OuterProduct)
return failure();
if (failed(filter(op)))
@@ -672,8 +672,7 @@ FailureOr<Value> ContractionOpToDotLowering::matchAndRewriteMaskableOp(
if (failed(filter(op)))
return failure();
- if (vectorTransformOptions.vectorContractLowering !=
- vector::VectorContractLowering::Dot)
+ if (vectorContractLowering != vector::VectorContractLowering::Dot)
return failure();
auto iteratorTypes = op.getIteratorTypes().getValue();
@@ -789,11 +788,11 @@ struct ContractOpToElementwise
return success();
}
ContractOpToElementwise(
- vector::VectorTransformsOptions vectorTransformOptions,
+ vector::VectorContractLowering vectorContractLowering,
MLIRContext *context, PatternBenefit benefit = 1,
const FilterConstraintType &constraint = defaultFilter)
: MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
- vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {}
+ vectorContractLowering(vectorContractLowering), filter(defaultFilter) {}
FailureOr<Value>
matchAndRewriteMaskableOp(vector::ContractionOp contractOp,
@@ -806,8 +805,7 @@ struct ContractOpToElementwise
if (failed(filter(contractOp)))
return failure();
- if (vectorTransformOptions.vectorContractLowering !=
- vector::VectorContractLowering::ParallelArith)
+ if (vectorContractLowering != vector::VectorContractLowering::ParallelArith)
return failure();
ArrayRef<int64_t> lhsShape = contractOp.getLhsType().getShape();
@@ -898,7 +896,7 @@ struct ContractOpToElementwise
private:
/// Options to control the vector patterns.
- vector::VectorTransformsOptions vectorTransformOptions;
+ vector::VectorContractLowering vectorContractLowering;
FilterConstraintType filter;
};
@@ -941,25 +939,25 @@ FailureOr<Value> ContractionOpLowering::matchAndRewriteMaskableOp(
// TODO: implement benefits, cost models.
MLIRContext *ctx = op.getContext();
- ContractionOpToMatmulOpLowering pat1(vectorTransformOptions, ctx);
+ ContractionOpToMatmulOpLowering pat1(vectorContractLoweringOption, ctx);
FailureOr<Value> newVal1 =
pat1.matchAndRewriteMaskableOp(op, maskOp, rewriter);
if (!failed(newVal1))
return newVal1;
- ContractionOpToOuterProductOpLowering pat2(vectorTransformOptions, ctx);
+ ContractionOpToOuterProductOpLowering pat2(vectorContractLoweringOption, ctx);
FailureOr<Value> newVal2 =
pat2.matchAndRewriteMaskableOp(op, maskOp, rewriter);
if (!failed(newVal2))
return newVal2;
- ContractionOpToDotLowering pat3(vectorTransformOptions, ctx);
+ ContractionOpToDotLowering pat3(vectorContractLoweringOption, ctx);
FailureOr<Value> newVal3 =
pat3.matchAndRewriteMaskableOp(op, maskOp, rewriter);
if (!failed(newVal3))
return newVal3;
- ContractOpToElementwise pat4(vectorTransformOptions, ctx);
+ ContractOpToElementwise pat4(vectorContractLoweringOption, ctx);
FailureOr<Value> newVal4 =
pat4.matchAndRewriteMaskableOp(op, maskOp, rewriter);
if (!failed(newVal4))
@@ -1292,8 +1290,7 @@ FailureOr<Value> ContractionOpToMatmulOpLowering::matchAndRewriteMaskableOp(
if (maskOp)
return failure();
- if (vectorTransformOptions.vectorContractLowering !=
- vector::VectorContractLowering::Matmul)
+ if (vectorContractLowering != vector::VectorContractLowering::Matmul)
return failure();
if (failed(filter(op)))
return failure();
@@ -1382,13 +1379,14 @@ FailureOr<Value> ContractionOpToMatmulOpLowering::matchAndRewriteMaskableOp(
} // namespace
void mlir::vector::populateVectorContractLoweringPatterns(
- RewritePatternSet &patterns, VectorTransformsOptions options,
- PatternBenefit benefit, bool disableOuterProductLowering) {
+ RewritePatternSet &patterns,
+ VectorContractLowering vectorContractLoweringOption, PatternBenefit benefit,
+ bool disableOuterProductLowering) {
if (!disableOuterProductLowering)
patterns.add<OuterProductOpLowering>(patterns.getContext(), benefit);
patterns.add<ContractionOpLowering, ContractionOpToMatmulOpLowering,
ContractionOpToOuterProductOpLowering>(
- options, patterns.getContext(), benefit);
+ vectorContractLoweringOption, patterns.getContext(), benefit);
}
void mlir::vector::populateVectorOuterProductLoweringPatterns(
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
index fb4dee33bc5f5..732e316c93381 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
@@ -304,10 +304,10 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
public:
using OpRewritePattern::OpRewritePattern;
- TransposeOpLowering(vector::VectorTransformsOptions vectorTransformOptions,
+ TransposeOpLowering(vector::VectorTransposeLowering vectorTransposeLowering,
MLIRContext *context, PatternBenefit benefit = 1)
: OpRewritePattern<vector::TransposeOp>(context, benefit),
- vectorTransformOptions(vectorTransformOptions) {}
+ vectorTransposeLowering(vectorTransposeLowering) {}
LogicalResult matchAndRewrite(vector::TransposeOp op,
PatternRewriter &rewriter) const override {
@@ -324,14 +324,13 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
// Set up convenience transposition table.
ArrayRef<int64_t> transp = op.getPermutation();
- if (isShuffleLike(vectorTransformOptions.vectorTransposeLowering) &&
+ if (isShuffleLike(vectorTransposeLowering) &&
succeeded(isTranspose2DSlice(op)))
return rewriter.notifyMatchFailure(
op, "Options specifies lowering to shuffle");
// Handle a true 2-D matrix transpose differently when requested.
- if (vectorTransformOptions.vectorTransposeLowering ==
- vector::VectorTransposeLowering::Flat &&
+ if (vectorTransposeLowering == vector::VectorTransposeLowering::Flat &&
resType.getRank() == 2 && transp[0] == 1 && transp[1] == 0) {
Type flattenedType =
VectorType::get(resType.getNumElements(), resType.getElementType());
@@ -380,7 +379,7 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
private:
/// Options to control the vector patterns.
- vector::VectorTransformsOptions vectorTransformOptions;
+ vector::VectorTransposeLowering vectorTransposeLowering;
};
/// Rewrites vector.transpose as vector.shape_cast. This pattern is only applied
@@ -454,14 +453,14 @@ class TransposeOp2DToShuffleLowering
using OpRewritePattern::OpRewritePattern;
TransposeOp2DToShuffleLowering(
- vector::VectorTransformsOptions vectorTransformOptions,
+ vector::VectorTransposeLowering vectorTransposeLowering,
MLIRContext *context, PatternBenefit benefit = 1)
: OpRewritePattern<vector::TransposeOp>(context, benefit),
- vectorTransformOptions(vectorTransformOptions) {}
+ vectorTransposeLowering(vectorTransposeLowering) {}
LogicalResult matchAndRewrite(vector::TransposeOp op,
PatternRewriter &rewriter) const override {
- if (!isShuffleLike(vectorTransformOptions.vectorTransposeLowering))
+ if (!isShuffleLike(vectorTransposeLowering))
return rewriter.notifyMatchFailure(
op, "not using vector shuffle based lowering");
@@ -487,8 +486,7 @@ class TransposeOp2DToShuffleLowering
op.getVector());
Value res;
- if (vectorTransformOptions.vectorTransposeLowering ==
- VectorTransposeLowering::Shuffle16x16 &&
+ if (vectorTransposeLowering == VectorTransposeLowering::Shuffle16x16 &&
m == 16 && n == 16) {
reshInput =
rewriter.create<vector::ShapeCastOp>(loc, reshInputType, reshInput);
@@ -506,15 +504,15 @@ class TransposeOp2DToShuffleLowering
private:
/// Options to control the vector patterns.
- vector::VectorTransformsOptions vectorTransformOptions;
+ vector::VectorTransposeLowering vectorTransposeLowering;
};
} // namespace
void mlir::vector::populateVectorTransposeLoweringPatterns(
- RewritePatternSet &patterns, VectorTransformsOptions options,
- PatternBenefit benefit) {
+ RewritePatternSet &patterns,
+ VectorTransposeLowering vectorTransposeLowering, PatternBenefit benefit) {
patterns.add<Transpose2DWithUnitDimToShapeCast>(patterns.getContext(),
benefit);
patterns.add<TransposeOpLowering, TransposeOp2DToShuffleLowering>(
- options, patterns.getContext(), benefit);
+ vectorTransposeLowering, patterns.getContext(), benefit);
}
diff --git a/mlir/test/Conversion/VectorToLLVM/test-serialisable.mlir b/mlir/test/Conversion/VectorToLLVM/test-serialisable.mlir
new file mode 100644
index 0000000000000..d641c715ad74e
--- /dev/null
+++ b/mlir/test/Conversion/VectorToLLVM/test-serialisable.mlir
@@ -0,0 +1,16 @@
+// RUN: mlir-opt --convert-vector-to-llvm --dump-pass-pipeline %s 2>&1 | FileCheck %s
+
+// Simple regression test that ensures ConvertVectorToLLVMPass options remain
+// serialisable. We don't need to actually parse any IR to print the pass
+// options. We just need to provide --dump-pass-pipeline
+
+// CHECK: builtin.module(
+// CHECK-SAME: convert-vector-to-llvm{
+// CHECK-SAME: enable-amx={{[aA-zZ0-9]+}}
+// CHECK-SAME: enable-arm-neon={{[aA-zZ0-9]+}}
+// CHECK-SAME: enable-arm-sve={{[aA-zZ0-9]+}}
+// CHECK-SAME: enable-x86vector={{[aA-zZ0-9]+}}
+// CHECK-SAME: force-32bit-vector-indices={{[aA-zZ0-9]+}}
+// CHECK-SAME: reassociate-fp-reductions={{[aA-zZ0-9]+}}
+// CHECK-SAME: vector-contract-lowering={{[aA-zZ0-9]+}}
+// CHECK-SAME: vector-transpose-lowering={{[aA-zZ0-9]+}}})
More information about the Mlir-commits
mailing list