[Mlir-commits] [mlir] Refactor ConvertVectorToLLVMPass options (PR #128219)

Artemiy Bulavin llvmlistbot at llvm.org
Fri Feb 21 11:17:20 PST 2025


https://github.com/abulavin created https://github.com/llvm/llvm-project/pull/128219

The `VectorTransformsOptions` on the `ConvertVectorToLLVMPass` is currently represented as a struct, which makes it not serialisable. This means a pass pipeline that contains this pass cannot be represented as textual form, which breaks reproducer generation and options such as `--dump-pass-pipeline`.

This PR expands the `VectorTransformsOptions` struct into the two options that are actually used by the Pass' patterns: `vector-contract-lowering` and `vector-transpose-lowering` . The other options present in VectorTransformOptions are not used by any patterns in this pass.

Additionally, I have changed some interfaces to only take these specific options over the full options struct as, again, many of the conversion patterns only need one of the options.

Finally, I have added a simple lit test that just prints the pass pipeline using `--dump-pass-pipeline` to ensure the options on this pass remain serialisable.

>From 69100604c8883af26d64eb47495aa4401fdef86f Mon Sep 17 00:00:00 2001
From: Artemiy Bulavin <artemiyb at graphcore.ai>
Date: Fri, 14 Feb 2025 15:02:11 +0000
Subject: [PATCH] Explicitly speciy all vector transform options on
 ConvertVectorToLLVMPass

Refactor ConvertVectorToLLVMPass options
---
 mlir/include/mlir/Conversion/Passes.td        | 32 ++++++++--
 .../Vector/Transforms/LoweringPatterns.h      | 11 ++--
 .../VectorToLLVM/ConvertVectorToLLVMPass.cpp  |  4 +-
 .../SPIRV/Transforms/SPIRVConversion.cpp      |  5 +-
 .../TransformOps/VectorTransformOps.cpp       |  9 +--
 .../Vector/Transforms/LowerVectorContract.cpp | 64 +++++++++----------
 .../Transforms/LowerVectorTranspose.cpp       | 28 ++++----
 .../VectorToLLVM/test-serialisable.mlir       | 16 +++++
 8 files changed, 101 insertions(+), 68 deletions(-)
 create mode 100644 mlir/test/Conversion/VectorToLLVM/test-serialisable.mlir

diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index cccdf0a8518bf..606a38f7d98eb 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -10,7 +10,7 @@
 #define MLIR_CONVERSION_PASSES
 
 include "mlir/Pass/PassBase.td"
-
+include "mlir/Dialect/Vector/Transforms/VectorTransformsBase.td"
 
 //===----------------------------------------------------------------------===//
 // ToLLVM
@@ -1410,10 +1410,32 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> {
            "bool", /*default=*/"false",
            "Enables the use of X86Vector dialect while lowering the vector "
 	   "dialect.">,
-    Option<"vectorTransformsOptions", "vector-transform-options",
-           "vector::VectorTransformsOptions",
-           /*default=*/"vector::VectorTransformsOptions()",
-           "Options to lower some operations like contractions and transposes.">,
+    Option<"vectorContractLowering", "vector-contract-lowering",
+           "vector::VectorContractLowering",
+           /*default=*/"vector::VectorContractLowering::Dot",
+           VectorContractLoweringAttr.summary, [{::llvm::cl::values(
+           clEnumValN(::mlir::vector::VectorContractLowering::Dot, "dot", 
+            "Progressively lower to finer grained `vector.contract` and dot-products. (default)"),
+           clEnumValN(::mlir::vector::VectorContractLowering::Matmul, "matmul",
+            "Lower to `vector.matrix_multiply`, maps 1-1 to LLVM matrix intrinsics."),
+           clEnumValN(::mlir::vector::VectorContractLowering::OuterProduct, "outerproduct",
+            "Lower to `vector.outerproduct`."),
+           clEnumValN(::mlir::vector::VectorContractLowering::ParallelArith, "parallelarith",
+            "Lower contract with all reduction dimensions unrolled to 1 to a vector elementwise operations.")
+	        )}]>,
+    Option<"vectorTransposeLowering", "vector-transpose-lowering",
+           "vector::VectorTransposeLowering",
+           /*default=*/"vector::VectorTransposeLowering::EltWise",
+           VectorTransposeLoweringAttr.summary, [{::llvm::cl::values(
+           clEnumValN(::mlir::vector::VectorTransposeLowering::EltWise, "eltwise",
+            "Lower transpose into element-wise extract and inserts (default)"),
+           clEnumValN(::mlir::vector::VectorTransposeLowering::Flat, "flat",
+            "Lower 2-D transpose to `vector.flat_transpose`, maps 1-1 to LLVM matrix intrinsics"),
+           clEnumValN(::mlir::vector::VectorTransposeLowering::Shuffle1D, "shuffle1d",
+            "Lower 2-D transpose to `vector.shuffle` on 1-D vector."),
+           clEnumValN(::mlir::vector::VectorTransposeLowering::Shuffle16x16, "shuffle16x16",
+            "Lower 2-D transpose to `vector.shuffle` on 16x16 vector.")
+          )}]>,
   ];
 }
 
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index 6aeae30a0a6c0..601a65333d026 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -9,6 +9,7 @@
 #ifndef MLIR_DIALECT_VECTOR_TRANSFORMS_LOWERINGPATTERNS_H
 #define MLIR_DIALECT_VECTOR_TRANSFORMS_LOWERINGPATTERNS_H
 
+#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
 
 namespace mlir {
@@ -47,7 +48,8 @@ namespace vector {
 /// Progressively lower a `vector.contract` with row-major matmul semantics to
 /// linearized `vector.extract` + `vector.outerproduct` + `vector.insert`.
 void populateVectorContractLoweringPatterns(
-    RewritePatternSet &patterns, VectorTransformsOptions options,
+    RewritePatternSet &patterns,
+    VectorContractLowering vectorContractLoweringOption,
     PatternBenefit benefit = 1, bool disableOuterProductLowering = false);
 
 /// Populate the pattern set with the following patterns:
@@ -142,9 +144,10 @@ void populateVectorShapeCastLoweringPatterns(RewritePatternSet &patterns,
 ///
 /// [TransposeOp2DToShuffleLowering]
 ///
-void populateVectorTransposeLoweringPatterns(RewritePatternSet &patterns,
-                                             VectorTransformsOptions options,
-                                             PatternBenefit benefit = 1);
+void populateVectorTransposeLoweringPatterns(
+    RewritePatternSet &patterns,
+    VectorTransposeLowering vectorTransposeLowering,
+    PatternBenefit benefit = 1);
 
 /// Populate the pattern set with the following patterns:
 ///
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index e3a81bd20212d..eb1555df5d574 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -69,11 +69,11 @@ void ConvertVectorToLLVMPass::runOnOperation() {
     populateVectorToVectorCanonicalizationPatterns(patterns);
     populateVectorBitCastLoweringPatterns(patterns);
     populateVectorBroadcastLoweringPatterns(patterns);
-    populateVectorContractLoweringPatterns(patterns, vectorTransformsOptions);
+    populateVectorContractLoweringPatterns(patterns, vectorContractLowering);
     populateVectorMaskOpLoweringPatterns(patterns);
     populateVectorShapeCastLoweringPatterns(patterns);
     populateVectorInterleaveLoweringPatterns(patterns);
-    populateVectorTransposeLoweringPatterns(patterns, vectorTransformsOptions);
+    populateVectorTransposeLoweringPatterns(patterns, vectorTransposeLowering);
     // Vector transfer ops with rank > 1 should be lowered with VectorToSCF.
     populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1);
     populateVectorMaskMaterializationPatterns(patterns,
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index c56dbcca2175d..a60410d01ac57 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -1374,9 +1374,8 @@ LogicalResult mlir::spirv::unrollVectorsInFuncBodies(Operation *op) {
   // further transformations to canonicalize/cancel.
   {
     RewritePatternSet patterns(context);
-    auto options = vector::VectorTransformsOptions().setVectorTransposeLowering(
-        vector::VectorTransposeLowering::EltWise);
-    vector::populateVectorTransposeLoweringPatterns(patterns, options);
+    vector::populateVectorTransposeLoweringPatterns(
+        patterns, vector::VectorTransposeLowering::EltWise);
     vector::populateVectorShapeCastLoweringPatterns(patterns);
     if (failed(applyPatternsGreedily(op, std::move(patterns))))
       return failure();
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index 241e83e234d62..20c577273d786 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -102,9 +102,7 @@ void transform::ApplyLowerBroadcastPatternsOp::populatePatterns(
 
 void transform::ApplyLowerContractionPatternsOp::populatePatterns(
     RewritePatternSet &patterns) {
-  vector::VectorTransformsOptions vectorTransformOptions;
-  vectorTransformOptions.setVectorTransformsOptions(getLoweringStrategy());
-  populateVectorContractLoweringPatterns(patterns, vectorTransformOptions,
+  populateVectorContractLoweringPatterns(patterns, getLoweringStrategy(),
                                          /*benefit=*/1,
                                          /*disableOuterProductLowering=*/true);
 }
@@ -161,9 +159,8 @@ void transform::ApplyLowerTransferPatternsOp::populatePatterns(
 
 void transform::ApplyLowerTransposePatternsOp::populatePatterns(
     RewritePatternSet &patterns) {
-  vector::populateVectorTransposeLoweringPatterns(
-      patterns, vector::VectorTransformsOptions().setVectorTransposeLowering(
-                    getLoweringStrategy()));
+  vector::populateVectorTransposeLoweringPatterns(patterns,
+                                                  getLoweringStrategy());
   if (getAvx2LoweringStrategy()) {
     auto avx2LoweringOptions =
         x86vector::avx2::LoweringOptions().setTransposeOptions(
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
index 21261478f0648..d2f60a55fb4a6 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
@@ -221,7 +221,7 @@ namespace {
 /// ```
 /// `vector.matmul` later lowers to `llvm.matrix.multiply`.
 //
-/// This only kicks in when VectorTransformsOptions is set to OuterProduct and
+/// This only kicks in when VectorTransformsOptions is set to Matmul and
 /// the vector.contract op is a row-major matrix multiply.
 class ContractionOpToMatmulOpLowering
     : public vector::MaskableOpRewritePattern<vector::ContractionOp> {
@@ -236,11 +236,11 @@ class ContractionOpToMatmulOpLowering
   }
 
   ContractionOpToMatmulOpLowering(
-      vector::VectorTransformsOptions vectorTransformOptions,
+      vector::VectorContractLowering vectorContractLowering,
       MLIRContext *context, PatternBenefit benefit = 1,
       FilterConstraintType constraint = defaultFilter)
       : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
-        vectorTransformOptions(vectorTransformOptions),
+        vectorContractLowering(vectorContractLowering),
         filter(std::move(constraint)) {}
 
   FailureOr<Value>
@@ -249,7 +249,7 @@ class ContractionOpToMatmulOpLowering
 
 private:
   /// Options to control the vector patterns.
-  vector::VectorTransformsOptions vectorTransformOptions;
+  vector::VectorContractLowering vectorContractLowering;
   FilterConstraintType filter;
 };
 
@@ -281,11 +281,11 @@ class ContractionOpToOuterProductOpLowering
   }
 
   ContractionOpToOuterProductOpLowering(
-      vector::VectorTransformsOptions vectorTransformOptions,
+      vector::VectorContractLowering vectorContractLowering,
       MLIRContext *context, PatternBenefit benefit = 1,
       FilterConstraintType constraint = defaultFilter)
       : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
-        vectorTransformOptions(vectorTransformOptions),
+        vectorContractLowering(vectorContractLowering),
         filter(std::move(constraint)) {}
 
   FailureOr<Value>
@@ -294,7 +294,7 @@ class ContractionOpToOuterProductOpLowering
 
 private:
   /// Options to control the vector patterns.
-  vector::VectorTransformsOptions vectorTransformOptions;
+  vector::VectorContractLowering vectorContractLowering;
   FilterConstraintType filter;
 };
 
@@ -329,11 +329,11 @@ class ContractionOpToDotLowering
   }
 
   ContractionOpToDotLowering(
-      vector::VectorTransformsOptions vectorTransformOptions,
+      vector::VectorContractLowering vectorContractLowering,
       MLIRContext *context, PatternBenefit benefit = 1,
       const FilterConstraintType &constraint = defaultFilter)
       : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
-        vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {}
+        vectorContractLowering(vectorContractLowering), filter(defaultFilter) {}
 
   FailureOr<Value>
   matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp,
@@ -341,7 +341,7 @@ class ContractionOpToDotLowering
 
 private:
   /// Options to control the vector patterns.
-  vector::VectorTransformsOptions vectorTransformOptions;
+  vector::VectorContractLowering vectorContractLowering;
   FilterConstraintType filter;
 };
 
@@ -370,11 +370,12 @@ class ContractionOpLowering
     return success();
   }
 
-  ContractionOpLowering(vector::VectorTransformsOptions vectorTransformOptions,
-                        MLIRContext *context, PatternBenefit benefit = 1,
-                        FilterConstraintType constraint = defaultFilter)
+  ContractionOpLowering(
+      vector::VectorContractLowering vectorContractLoweringOption,
+      MLIRContext *context, PatternBenefit benefit = 1,
+      FilterConstraintType constraint = defaultFilter)
       : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
-        vectorTransformOptions(vectorTransformOptions),
+        vectorContractLoweringOption(vectorContractLoweringOption),
         filter(std::move(constraint)) {}
 
   FailureOr<Value>
@@ -383,7 +384,7 @@ class ContractionOpLowering
 
 private:
   /// Options to control the vector patterns.
-  vector::VectorTransformsOptions vectorTransformOptions;
+  vector::VectorContractLowering vectorContractLoweringOption;
   FilterConstraintType filter;
   // Lower one parallel dimension.
   FailureOr<Value> lowerParallel(PatternRewriter &rewriter,
@@ -641,8 +642,7 @@ FailureOr<Value>
 ContractionOpToOuterProductOpLowering::matchAndRewriteMaskableOp(
     vector::ContractionOp op, MaskingOpInterface maskOp,
     PatternRewriter &rewriter) const {
-  if (vectorTransformOptions.vectorContractLowering !=
-      vector::VectorContractLowering::OuterProduct)
+  if (vectorContractLowering != vector::VectorContractLowering::OuterProduct)
     return failure();
 
   if (failed(filter(op)))
@@ -672,8 +672,7 @@ FailureOr<Value> ContractionOpToDotLowering::matchAndRewriteMaskableOp(
   if (failed(filter(op)))
     return failure();
 
-  if (vectorTransformOptions.vectorContractLowering !=
-      vector::VectorContractLowering::Dot)
+  if (vectorContractLowering != vector::VectorContractLowering::Dot)
     return failure();
 
   auto iteratorTypes = op.getIteratorTypes().getValue();
@@ -789,11 +788,11 @@ struct ContractOpToElementwise
     return success();
   }
   ContractOpToElementwise(
-      vector::VectorTransformsOptions vectorTransformOptions,
+      vector::VectorContractLowering vectorContractLowering,
       MLIRContext *context, PatternBenefit benefit = 1,
       const FilterConstraintType &constraint = defaultFilter)
       : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
-        vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {}
+        vectorContractLowering(vectorContractLowering), filter(defaultFilter) {}
 
   FailureOr<Value>
   matchAndRewriteMaskableOp(vector::ContractionOp contractOp,
@@ -806,8 +805,7 @@ struct ContractOpToElementwise
     if (failed(filter(contractOp)))
       return failure();
 
-    if (vectorTransformOptions.vectorContractLowering !=
-        vector::VectorContractLowering::ParallelArith)
+    if (vectorContractLowering != vector::VectorContractLowering::ParallelArith)
       return failure();
 
     ArrayRef<int64_t> lhsShape = contractOp.getLhsType().getShape();
@@ -898,7 +896,7 @@ struct ContractOpToElementwise
 
 private:
   /// Options to control the vector patterns.
-  vector::VectorTransformsOptions vectorTransformOptions;
+  vector::VectorContractLowering vectorContractLowering;
   FilterConstraintType filter;
 };
 
@@ -941,25 +939,25 @@ FailureOr<Value> ContractionOpLowering::matchAndRewriteMaskableOp(
   // TODO: implement benefits, cost models.
   MLIRContext *ctx = op.getContext();
 
-  ContractionOpToMatmulOpLowering pat1(vectorTransformOptions, ctx);
+  ContractionOpToMatmulOpLowering pat1(vectorContractLoweringOption, ctx);
   FailureOr<Value> newVal1 =
       pat1.matchAndRewriteMaskableOp(op, maskOp, rewriter);
   if (!failed(newVal1))
     return newVal1;
 
-  ContractionOpToOuterProductOpLowering pat2(vectorTransformOptions, ctx);
+  ContractionOpToOuterProductOpLowering pat2(vectorContractLoweringOption, ctx);
   FailureOr<Value> newVal2 =
       pat2.matchAndRewriteMaskableOp(op, maskOp, rewriter);
   if (!failed(newVal2))
     return newVal2;
 
-  ContractionOpToDotLowering pat3(vectorTransformOptions, ctx);
+  ContractionOpToDotLowering pat3(vectorContractLoweringOption, ctx);
   FailureOr<Value> newVal3 =
       pat3.matchAndRewriteMaskableOp(op, maskOp, rewriter);
   if (!failed(newVal3))
     return newVal3;
 
-  ContractOpToElementwise pat4(vectorTransformOptions, ctx);
+  ContractOpToElementwise pat4(vectorContractLoweringOption, ctx);
   FailureOr<Value> newVal4 =
       pat4.matchAndRewriteMaskableOp(op, maskOp, rewriter);
   if (!failed(newVal4))
@@ -1292,8 +1290,7 @@ FailureOr<Value> ContractionOpToMatmulOpLowering::matchAndRewriteMaskableOp(
   if (maskOp)
     return failure();
 
-  if (vectorTransformOptions.vectorContractLowering !=
-      vector::VectorContractLowering::Matmul)
+  if (vectorContractLowering != vector::VectorContractLowering::Matmul)
     return failure();
   if (failed(filter(op)))
     return failure();
@@ -1382,13 +1379,14 @@ FailureOr<Value> ContractionOpToMatmulOpLowering::matchAndRewriteMaskableOp(
 } // namespace
 
 void mlir::vector::populateVectorContractLoweringPatterns(
-    RewritePatternSet &patterns, VectorTransformsOptions options,
-    PatternBenefit benefit, bool disableOuterProductLowering) {
+    RewritePatternSet &patterns,
+    VectorContractLowering vectorContractLoweringOption, PatternBenefit benefit,
+    bool disableOuterProductLowering) {
   if (!disableOuterProductLowering)
     patterns.add<OuterProductOpLowering>(patterns.getContext(), benefit);
   patterns.add<ContractionOpLowering, ContractionOpToMatmulOpLowering,
                ContractionOpToOuterProductOpLowering>(
-      options, patterns.getContext(), benefit);
+      vectorContractLoweringOption, patterns.getContext(), benefit);
 }
 
 void mlir::vector::populateVectorOuterProductLoweringPatterns(
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
index fb4dee33bc5f5..732e316c93381 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
@@ -304,10 +304,10 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
 public:
   using OpRewritePattern::OpRewritePattern;
 
-  TransposeOpLowering(vector::VectorTransformsOptions vectorTransformOptions,
+  TransposeOpLowering(vector::VectorTransposeLowering vectorTransposeLowering,
                       MLIRContext *context, PatternBenefit benefit = 1)
       : OpRewritePattern<vector::TransposeOp>(context, benefit),
-        vectorTransformOptions(vectorTransformOptions) {}
+        vectorTransposeLowering(vectorTransposeLowering) {}
 
   LogicalResult matchAndRewrite(vector::TransposeOp op,
                                 PatternRewriter &rewriter) const override {
@@ -324,14 +324,13 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
     // Set up convenience transposition table.
     ArrayRef<int64_t> transp = op.getPermutation();
 
-    if (isShuffleLike(vectorTransformOptions.vectorTransposeLowering) &&
+    if (isShuffleLike(vectorTransposeLowering) &&
         succeeded(isTranspose2DSlice(op)))
       return rewriter.notifyMatchFailure(
           op, "Options specifies lowering to shuffle");
 
     // Handle a true 2-D matrix transpose differently when requested.
-    if (vectorTransformOptions.vectorTransposeLowering ==
-            vector::VectorTransposeLowering::Flat &&
+    if (vectorTransposeLowering == vector::VectorTransposeLowering::Flat &&
         resType.getRank() == 2 && transp[0] == 1 && transp[1] == 0) {
       Type flattenedType =
           VectorType::get(resType.getNumElements(), resType.getElementType());
@@ -380,7 +379,7 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
 
 private:
   /// Options to control the vector patterns.
-  vector::VectorTransformsOptions vectorTransformOptions;
+  vector::VectorTransposeLowering vectorTransposeLowering;
 };
 
 /// Rewrites vector.transpose as vector.shape_cast. This pattern is only applied
@@ -454,14 +453,14 @@ class TransposeOp2DToShuffleLowering
   using OpRewritePattern::OpRewritePattern;
 
   TransposeOp2DToShuffleLowering(
-      vector::VectorTransformsOptions vectorTransformOptions,
+      vector::VectorTransposeLowering vectorTransposeLowering,
       MLIRContext *context, PatternBenefit benefit = 1)
       : OpRewritePattern<vector::TransposeOp>(context, benefit),
-        vectorTransformOptions(vectorTransformOptions) {}
+        vectorTransposeLowering(vectorTransposeLowering) {}
 
   LogicalResult matchAndRewrite(vector::TransposeOp op,
                                 PatternRewriter &rewriter) const override {
-    if (!isShuffleLike(vectorTransformOptions.vectorTransposeLowering))
+    if (!isShuffleLike(vectorTransposeLowering))
       return rewriter.notifyMatchFailure(
           op, "not using vector shuffle based lowering");
 
@@ -487,8 +486,7 @@ class TransposeOp2DToShuffleLowering
                                                           op.getVector());
 
     Value res;
-    if (vectorTransformOptions.vectorTransposeLowering ==
-            VectorTransposeLowering::Shuffle16x16 &&
+    if (vectorTransposeLowering == VectorTransposeLowering::Shuffle16x16 &&
         m == 16 && n == 16) {
       reshInput =
           rewriter.create<vector::ShapeCastOp>(loc, reshInputType, reshInput);
@@ -506,15 +504,15 @@ class TransposeOp2DToShuffleLowering
 
 private:
   /// Options to control the vector patterns.
-  vector::VectorTransformsOptions vectorTransformOptions;
+  vector::VectorTransposeLowering vectorTransposeLowering;
 };
 } // namespace
 
 void mlir::vector::populateVectorTransposeLoweringPatterns(
-    RewritePatternSet &patterns, VectorTransformsOptions options,
-    PatternBenefit benefit) {
+    RewritePatternSet &patterns,
+    VectorTransposeLowering vectorTransposeLowering, PatternBenefit benefit) {
   patterns.add<Transpose2DWithUnitDimToShapeCast>(patterns.getContext(),
                                                   benefit);
   patterns.add<TransposeOpLowering, TransposeOp2DToShuffleLowering>(
-      options, patterns.getContext(), benefit);
+      vectorTransposeLowering, patterns.getContext(), benefit);
 }
diff --git a/mlir/test/Conversion/VectorToLLVM/test-serialisable.mlir b/mlir/test/Conversion/VectorToLLVM/test-serialisable.mlir
new file mode 100644
index 0000000000000..d641c715ad74e
--- /dev/null
+++ b/mlir/test/Conversion/VectorToLLVM/test-serialisable.mlir
@@ -0,0 +1,16 @@
+// RUN: mlir-opt --convert-vector-to-llvm --dump-pass-pipeline %s 2>&1 | FileCheck %s
+
+// Simple regression test that ensures ConvertVectorToLLVMPass options remain
+// serialisable. We don't need to actually parse any IR to print the pass
+// options. We just need to provide --dump-pass-pipeline
+
+// CHECK: builtin.module(
+// CHECK-SAME: convert-vector-to-llvm{
+// CHECK-SAME: enable-amx={{[aA-zZ0-9]+}}
+// CHECK-SAME: enable-arm-neon={{[aA-zZ0-9]+}}
+// CHECK-SAME: enable-arm-sve={{[aA-zZ0-9]+}}
+// CHECK-SAME: enable-x86vector={{[aA-zZ0-9]+}}
+// CHECK-SAME: force-32bit-vector-indices={{[aA-zZ0-9]+}}
+// CHECK-SAME: reassociate-fp-reductions={{[aA-zZ0-9]+}}
+// CHECK-SAME: vector-contract-lowering={{[aA-zZ0-9]+}}
+// CHECK-SAME: vector-transpose-lowering={{[aA-zZ0-9]+}}})



More information about the Mlir-commits mailing list