[Mlir-commits] [mlir] f3dcc0f - [mlir] Refactor ConvertVectorToLLVMPass options (#128219)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Mar 10 03:32:06 PDT 2025


Author: Artemiy Bulavin
Date: 2025-03-10T10:32:03Z
New Revision: f3dcc0fe228f6a1a69147ead0a76ce0fe02d316d

URL: https://github.com/llvm/llvm-project/commit/f3dcc0fe228f6a1a69147ead0a76ce0fe02d316d
DIFF: https://github.com/llvm/llvm-project/commit/f3dcc0fe228f6a1a69147ead0a76ce0fe02d316d.diff

LOG: [mlir] Refactor ConvertVectorToLLVMPass options (#128219)

The `VectorTransformsOptions` on the `ConvertVectorToLLVMPass` is
currently represented as a struct, which makes it not serialisable. This
means a pass pipeline that contains this pass cannot be represented as
textual form, which breaks reproducer generation and options such as
`--dump-pass-pipeline`.

This PR expands the `VectorTransformsOptions` struct into the two
options that are actually used by the Pass' patterns:
`vector-contract-lowering` and `vector-transpose-lowering` . The other
options present in VectorTransformOptions are not used by any patterns
in this pass.

Additionally, I have changed some interfaces to only take these specific
options over the full options struct as, again, the vector contract and
transpose lowering patterns only need one of their respective options.

Finally, I have added a simple lit test that just prints the pass
pipeline using `--dump-pass-pipeline` to ensure the options on this pass
remain serialisable.

Fixes #129046

Added: 
    mlir/test/Conversion/VectorToLLVM/pass-option-serialization.mlir

Modified: 
    mlir/include/mlir/Conversion/Passes.td
    mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
    mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
    mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
    mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
    mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
    mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 6074e0e8d822c..bbba495e613b2 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -10,7 +10,7 @@
 #define MLIR_CONVERSION_PASSES
 
 include "mlir/Pass/PassBase.td"
-
+include "mlir/Dialect/Vector/Transforms/VectorTransformsBase.td"
 
 //===----------------------------------------------------------------------===//
 // ToLLVM
@@ -1410,10 +1410,32 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> {
            "bool", /*default=*/"false",
            "Enables the use of X86Vector dialect while lowering the vector "
 	   "dialect.">,
-    Option<"vectorTransformsOptions", "vector-transform-options",
-           "vector::VectorTransformsOptions",
-           /*default=*/"vector::VectorTransformsOptions()",
-           "Options to lower some operations like contractions and transposes.">,
+    Option<"vectorContractLowering", "vector-contract-lowering",
+           "vector::VectorContractLowering",
+           /*default=*/"vector::VectorContractLowering::Dot",
+           VectorContractLoweringAttr.summary, [{::llvm::cl::values(
+           clEnumValN(::mlir::vector::VectorContractLowering::Dot, "dot",
+            "Progressively lower to finer grained `vector.contract` and dot-products. (default)"),
+           clEnumValN(::mlir::vector::VectorContractLowering::Matmul, "matmul",
+            "Lower to `vector.matrix_multiply`, maps 1-1 to LLVM matrix intrinsics."),
+           clEnumValN(::mlir::vector::VectorContractLowering::OuterProduct, "outerproduct",
+            "Lower to `vector.outerproduct`."),
+           clEnumValN(::mlir::vector::VectorContractLowering::ParallelArith, "parallelarith",
+            "Lower contract with all reduction dimensions unrolled to 1 to a vector elementwise operations.")
+	        )}]>,
+    Option<"vectorTransposeLowering", "vector-transpose-lowering",
+           "vector::VectorTransposeLowering",
+           /*default=*/"vector::VectorTransposeLowering::EltWise",
+           VectorTransposeLoweringAttr.summary, [{::llvm::cl::values(
+           clEnumValN(::mlir::vector::VectorTransposeLowering::EltWise, "eltwise",
+            "Lower transpose into element-wise extract and inserts (default)"),
+           clEnumValN(::mlir::vector::VectorTransposeLowering::Flat, "flat",
+            "Lower 2-D transpose to `vector.flat_transpose`, maps 1-1 to LLVM matrix intrinsics"),
+           clEnumValN(::mlir::vector::VectorTransposeLowering::Shuffle1D, "shuffle1d",
+            "Lower 2-D transpose to `vector.shuffle` on 1-D vector."),
+           clEnumValN(::mlir::vector::VectorTransposeLowering::Shuffle16x16, "shuffle16x16",
+            "Lower 2-D transpose to `vector.shuffle` on 16x16 vector.")
+          )}]>,
   ];
 }
 

diff  --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index 6aeae30a0a6c0..601a65333d026 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -9,6 +9,7 @@
 #ifndef MLIR_DIALECT_VECTOR_TRANSFORMS_LOWERINGPATTERNS_H
 #define MLIR_DIALECT_VECTOR_TRANSFORMS_LOWERINGPATTERNS_H
 
+#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
 
 namespace mlir {
@@ -47,7 +48,8 @@ namespace vector {
 /// Progressively lower a `vector.contract` with row-major matmul semantics to
 /// linearized `vector.extract` + `vector.outerproduct` + `vector.insert`.
 void populateVectorContractLoweringPatterns(
-    RewritePatternSet &patterns, VectorTransformsOptions options,
+    RewritePatternSet &patterns,
+    VectorContractLowering vectorContractLoweringOption,
     PatternBenefit benefit = 1, bool disableOuterProductLowering = false);
 
 /// Populate the pattern set with the following patterns:
@@ -142,9 +144,10 @@ void populateVectorShapeCastLoweringPatterns(RewritePatternSet &patterns,
 ///
 /// [TransposeOp2DToShuffleLowering]
 ///
-void populateVectorTransposeLoweringPatterns(RewritePatternSet &patterns,
-                                             VectorTransformsOptions options,
-                                             PatternBenefit benefit = 1);
+void populateVectorTransposeLoweringPatterns(
+    RewritePatternSet &patterns,
+    VectorTransposeLowering vectorTransposeLowering,
+    PatternBenefit benefit = 1);
 
 /// Populate the pattern set with the following patterns:
 ///

diff  --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index e3a81bd20212d..eb1555df5d574 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -69,11 +69,11 @@ void ConvertVectorToLLVMPass::runOnOperation() {
     populateVectorToVectorCanonicalizationPatterns(patterns);
     populateVectorBitCastLoweringPatterns(patterns);
     populateVectorBroadcastLoweringPatterns(patterns);
-    populateVectorContractLoweringPatterns(patterns, vectorTransformsOptions);
+    populateVectorContractLoweringPatterns(patterns, vectorContractLowering);
     populateVectorMaskOpLoweringPatterns(patterns);
     populateVectorShapeCastLoweringPatterns(patterns);
     populateVectorInterleaveLoweringPatterns(patterns);
-    populateVectorTransposeLoweringPatterns(patterns, vectorTransformsOptions);
+    populateVectorTransposeLoweringPatterns(patterns, vectorTransposeLowering);
     // Vector transfer ops with rank > 1 should be lowered with VectorToSCF.
     populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1);
     populateVectorMaskMaterializationPatterns(patterns,

diff  --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index c56dbcca2175d..a60410d01ac57 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -1374,9 +1374,8 @@ LogicalResult mlir::spirv::unrollVectorsInFuncBodies(Operation *op) {
   // further transformations to canonicalize/cancel.
   {
     RewritePatternSet patterns(context);
-    auto options = vector::VectorTransformsOptions().setVectorTransposeLowering(
-        vector::VectorTransposeLowering::EltWise);
-    vector::populateVectorTransposeLoweringPatterns(patterns, options);
+    vector::populateVectorTransposeLoweringPatterns(
+        patterns, vector::VectorTransposeLowering::EltWise);
     vector::populateVectorShapeCastLoweringPatterns(patterns);
     if (failed(applyPatternsGreedily(op, std::move(patterns))))
       return failure();

diff  --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index 241e83e234d62..20c577273d786 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -102,9 +102,7 @@ void transform::ApplyLowerBroadcastPatternsOp::populatePatterns(
 
 void transform::ApplyLowerContractionPatternsOp::populatePatterns(
     RewritePatternSet &patterns) {
-  vector::VectorTransformsOptions vectorTransformOptions;
-  vectorTransformOptions.setVectorTransformsOptions(getLoweringStrategy());
-  populateVectorContractLoweringPatterns(patterns, vectorTransformOptions,
+  populateVectorContractLoweringPatterns(patterns, getLoweringStrategy(),
                                          /*benefit=*/1,
                                          /*disableOuterProductLowering=*/true);
 }
@@ -161,9 +159,8 @@ void transform::ApplyLowerTransferPatternsOp::populatePatterns(
 
 void transform::ApplyLowerTransposePatternsOp::populatePatterns(
     RewritePatternSet &patterns) {
-  vector::populateVectorTransposeLoweringPatterns(
-      patterns, vector::VectorTransformsOptions().setVectorTransposeLowering(
-                    getLoweringStrategy()));
+  vector::populateVectorTransposeLoweringPatterns(patterns,
+                                                  getLoweringStrategy());
   if (getAvx2LoweringStrategy()) {
     auto avx2LoweringOptions =
         x86vector::avx2::LoweringOptions().setTransposeOptions(

diff  --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
index 21261478f0648..c74d0622b3828 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
@@ -215,13 +215,13 @@ namespace {
 /// ```
 ///    %flattened_a = vector.shape_cast %a
 ///    %flattened_b = vector.shape_cast %b
-///    %flattened_d = vector.matmul %flattened_a, %flattened_b
+///    %flattened_d = vector.matrix_multiply %flattened_a, %flattened_b
 ///    %d = vector.shape_cast %%flattened_d
 ///    %e = add %c, %d
 /// ```
-/// `vector.matmul` later lowers to `llvm.matrix.multiply`.
+/// `vector.matrix_multiply` later lowers to `llvm.matrix.multiply`.
 //
-/// This only kicks in when VectorTransformsOptions is set to OuterProduct and
+/// This only kicks in when vectorContractLowering is set to Matmul and
 /// the vector.contract op is a row-major matrix multiply.
 class ContractionOpToMatmulOpLowering
     : public vector::MaskableOpRewritePattern<vector::ContractionOp> {
@@ -236,11 +236,11 @@ class ContractionOpToMatmulOpLowering
   }
 
   ContractionOpToMatmulOpLowering(
-      vector::VectorTransformsOptions vectorTransformOptions,
+      vector::VectorContractLowering vectorContractLowering,
       MLIRContext *context, PatternBenefit benefit = 1,
       FilterConstraintType constraint = defaultFilter)
       : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
-        vectorTransformOptions(vectorTransformOptions),
+        vectorContractLowering(vectorContractLowering),
         filter(std::move(constraint)) {}
 
   FailureOr<Value>
@@ -249,7 +249,7 @@ class ContractionOpToMatmulOpLowering
 
 private:
   /// Options to control the vector patterns.
-  vector::VectorTransformsOptions vectorTransformOptions;
+  vector::VectorContractLowering vectorContractLowering;
   FilterConstraintType filter;
 };
 
@@ -266,7 +266,7 @@ class ContractionOpToMatmulOpLowering
 ///    %cK = vector.outerproduct %atRowK, %bRowK, %cK-1
 /// ```
 ///
-/// This only kicks in when VectorTransformsOptions is set to OuterProduct and
+/// This only kicks in when vectorContractLowering is set to OuterProduct and
 /// the vector.contract op is a row-major matrix multiply.
 class ContractionOpToOuterProductOpLowering
     : public MaskableOpRewritePattern<vector::ContractionOp> {
@@ -281,11 +281,11 @@ class ContractionOpToOuterProductOpLowering
   }
 
   ContractionOpToOuterProductOpLowering(
-      vector::VectorTransformsOptions vectorTransformOptions,
+      vector::VectorContractLowering vectorContractLowering,
       MLIRContext *context, PatternBenefit benefit = 1,
       FilterConstraintType constraint = defaultFilter)
       : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
-        vectorTransformOptions(vectorTransformOptions),
+        vectorContractLowering(vectorContractLowering),
         filter(std::move(constraint)) {}
 
   FailureOr<Value>
@@ -294,7 +294,7 @@ class ContractionOpToOuterProductOpLowering
 
 private:
   /// Options to control the vector patterns.
-  vector::VectorTransformsOptions vectorTransformOptions;
+  vector::VectorContractLowering vectorContractLowering;
   FilterConstraintType filter;
 };
 
@@ -329,11 +329,11 @@ class ContractionOpToDotLowering
   }
 
   ContractionOpToDotLowering(
-      vector::VectorTransformsOptions vectorTransformOptions,
+      vector::VectorContractLowering vectorContractLowering,
       MLIRContext *context, PatternBenefit benefit = 1,
       const FilterConstraintType &constraint = defaultFilter)
       : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
-        vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {}
+        vectorContractLowering(vectorContractLowering), filter(defaultFilter) {}
 
   FailureOr<Value>
   matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp,
@@ -341,7 +341,7 @@ class ContractionOpToDotLowering
 
 private:
   /// Options to control the vector patterns.
-  vector::VectorTransformsOptions vectorTransformOptions;
+  vector::VectorContractLowering vectorContractLowering;
   FilterConstraintType filter;
 };
 
@@ -370,11 +370,12 @@ class ContractionOpLowering
     return success();
   }
 
-  ContractionOpLowering(vector::VectorTransformsOptions vectorTransformOptions,
-                        MLIRContext *context, PatternBenefit benefit = 1,
-                        FilterConstraintType constraint = defaultFilter)
+  ContractionOpLowering(
+      vector::VectorContractLowering vectorContractLoweringOption,
+      MLIRContext *context, PatternBenefit benefit = 1,
+      FilterConstraintType constraint = defaultFilter)
       : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
-        vectorTransformOptions(vectorTransformOptions),
+        vectorContractLoweringOption(vectorContractLoweringOption),
         filter(std::move(constraint)) {}
 
   FailureOr<Value>
@@ -383,7 +384,7 @@ class ContractionOpLowering
 
 private:
   /// Options to control the vector patterns.
-  vector::VectorTransformsOptions vectorTransformOptions;
+  vector::VectorContractLowering vectorContractLoweringOption;
   FilterConstraintType filter;
   // Lower one parallel dimension.
   FailureOr<Value> lowerParallel(PatternRewriter &rewriter,
@@ -635,14 +636,13 @@ struct UnrolledOuterProductGenerator
 ///    %cK = vector.outerproduct %atRowK, %bRowK, %cK-1
 /// ```
 ///
-/// This only kicks in when VectorTransformsOptions is set to OuterProduct but
+/// This only kicks in when vectorContractLowering is set to OuterProduct but
 /// otherwise supports any layout permutation of the matrix-multiply.
 FailureOr<Value>
 ContractionOpToOuterProductOpLowering::matchAndRewriteMaskableOp(
     vector::ContractionOp op, MaskingOpInterface maskOp,
     PatternRewriter &rewriter) const {
-  if (vectorTransformOptions.vectorContractLowering !=
-      vector::VectorContractLowering::OuterProduct)
+  if (vectorContractLowering != vector::VectorContractLowering::OuterProduct)
     return failure();
 
   if (failed(filter(op)))
@@ -672,8 +672,7 @@ FailureOr<Value> ContractionOpToDotLowering::matchAndRewriteMaskableOp(
   if (failed(filter(op)))
     return failure();
 
-  if (vectorTransformOptions.vectorContractLowering !=
-      vector::VectorContractLowering::Dot)
+  if (vectorContractLowering != vector::VectorContractLowering::Dot)
     return failure();
 
   auto iteratorTypes = op.getIteratorTypes().getValue();
@@ -789,11 +788,11 @@ struct ContractOpToElementwise
     return success();
   }
   ContractOpToElementwise(
-      vector::VectorTransformsOptions vectorTransformOptions,
+      vector::VectorContractLowering vectorContractLowering,
       MLIRContext *context, PatternBenefit benefit = 1,
       const FilterConstraintType &constraint = defaultFilter)
       : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
-        vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {}
+        vectorContractLowering(vectorContractLowering), filter(defaultFilter) {}
 
   FailureOr<Value>
   matchAndRewriteMaskableOp(vector::ContractionOp contractOp,
@@ -806,8 +805,7 @@ struct ContractOpToElementwise
     if (failed(filter(contractOp)))
       return failure();
 
-    if (vectorTransformOptions.vectorContractLowering !=
-        vector::VectorContractLowering::ParallelArith)
+    if (vectorContractLowering != vector::VectorContractLowering::ParallelArith)
       return failure();
 
     ArrayRef<int64_t> lhsShape = contractOp.getLhsType().getShape();
@@ -898,7 +896,7 @@ struct ContractOpToElementwise
 
 private:
   /// Options to control the vector patterns.
-  vector::VectorTransformsOptions vectorTransformOptions;
+  vector::VectorContractLowering vectorContractLowering;
   FilterConstraintType filter;
 };
 
@@ -913,7 +911,7 @@ struct ContractOpToElementwise
 /// until a pure contraction is reached (no free/batch dimensions),
 /// which is replaced by a dot-product.
 ///
-/// This only kicks in when either VectorTransformsOptions is set
+/// This only kicks in when either vectorContractLoweringOption is set
 /// to DOT or when other contraction patterns fail.
 //
 // TODO: break down into transpose/reshape/cast ops
@@ -941,25 +939,25 @@ FailureOr<Value> ContractionOpLowering::matchAndRewriteMaskableOp(
   // TODO: implement benefits, cost models.
   MLIRContext *ctx = op.getContext();
 
-  ContractionOpToMatmulOpLowering pat1(vectorTransformOptions, ctx);
+  ContractionOpToMatmulOpLowering pat1(vectorContractLoweringOption, ctx);
   FailureOr<Value> newVal1 =
       pat1.matchAndRewriteMaskableOp(op, maskOp, rewriter);
   if (!failed(newVal1))
     return newVal1;
 
-  ContractionOpToOuterProductOpLowering pat2(vectorTransformOptions, ctx);
+  ContractionOpToOuterProductOpLowering pat2(vectorContractLoweringOption, ctx);
   FailureOr<Value> newVal2 =
       pat2.matchAndRewriteMaskableOp(op, maskOp, rewriter);
   if (!failed(newVal2))
     return newVal2;
 
-  ContractionOpToDotLowering pat3(vectorTransformOptions, ctx);
+  ContractionOpToDotLowering pat3(vectorContractLoweringOption, ctx);
   FailureOr<Value> newVal3 =
       pat3.matchAndRewriteMaskableOp(op, maskOp, rewriter);
   if (!failed(newVal3))
     return newVal3;
 
-  ContractOpToElementwise pat4(vectorTransformOptions, ctx);
+  ContractOpToElementwise pat4(vectorContractLoweringOption, ctx);
   FailureOr<Value> newVal4 =
       pat4.matchAndRewriteMaskableOp(op, maskOp, rewriter);
   if (!failed(newVal4))
@@ -1273,14 +1271,14 @@ class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
 ///    %mtb = maybe_transpose
 ///    %flattened_a = vector.shape_cast %mta
 ///    %flattened_b = vector.shape_cast %mtb
-///    %flattened_d = vector.matmul %flattened_a, %flattened_b
+///    %flattened_d = vector.matrix_multiply %flattened_a, %flattened_b
 ///    %mtd = vector.shape_cast %flattened_d
 ///    %d = maybe_untranspose %mtd
 ///    %e = add %c, %d
 /// ```
-/// `vector.matmul` later lowers to `llvm.matrix.multiply`.
+/// `vector.matrix_multiply` later lowers to `llvm.matrix.multiply`.
 //
-/// This only kicks in when VectorTransformsOptions is set to `Matmul`.
+/// This only kicks in when vectorContractLowering is set to `Matmul`.
 /// vector.transpose operations are inserted if the vector.contract op is not a
 /// row-major matrix multiply.
 ///
@@ -1292,8 +1290,7 @@ FailureOr<Value> ContractionOpToMatmulOpLowering::matchAndRewriteMaskableOp(
   if (maskOp)
     return failure();
 
-  if (vectorTransformOptions.vectorContractLowering !=
-      vector::VectorContractLowering::Matmul)
+  if (vectorContractLowering != vector::VectorContractLowering::Matmul)
     return failure();
   if (failed(filter(op)))
     return failure();
@@ -1382,13 +1379,14 @@ FailureOr<Value> ContractionOpToMatmulOpLowering::matchAndRewriteMaskableOp(
 } // namespace
 
 void mlir::vector::populateVectorContractLoweringPatterns(
-    RewritePatternSet &patterns, VectorTransformsOptions options,
-    PatternBenefit benefit, bool disableOuterProductLowering) {
+    RewritePatternSet &patterns,
+    VectorContractLowering vectorContractLoweringOption, PatternBenefit benefit,
+    bool disableOuterProductLowering) {
   if (!disableOuterProductLowering)
     patterns.add<OuterProductOpLowering>(patterns.getContext(), benefit);
   patterns.add<ContractionOpLowering, ContractionOpToMatmulOpLowering,
                ContractionOpToOuterProductOpLowering>(
-      options, patterns.getContext(), benefit);
+      vectorContractLoweringOption, patterns.getContext(), benefit);
 }
 
 void mlir::vector::populateVectorOuterProductLoweringPatterns(

diff  --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
index fb4dee33bc5f5..732e316c93381 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
@@ -304,10 +304,10 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
 public:
   using OpRewritePattern::OpRewritePattern;
 
-  TransposeOpLowering(vector::VectorTransformsOptions vectorTransformOptions,
+  TransposeOpLowering(vector::VectorTransposeLowering vectorTransposeLowering,
                       MLIRContext *context, PatternBenefit benefit = 1)
       : OpRewritePattern<vector::TransposeOp>(context, benefit),
-        vectorTransformOptions(vectorTransformOptions) {}
+        vectorTransposeLowering(vectorTransposeLowering) {}
 
   LogicalResult matchAndRewrite(vector::TransposeOp op,
                                 PatternRewriter &rewriter) const override {
@@ -324,14 +324,13 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
     // Set up convenience transposition table.
     ArrayRef<int64_t> transp = op.getPermutation();
 
-    if (isShuffleLike(vectorTransformOptions.vectorTransposeLowering) &&
+    if (isShuffleLike(vectorTransposeLowering) &&
         succeeded(isTranspose2DSlice(op)))
       return rewriter.notifyMatchFailure(
           op, "Options specifies lowering to shuffle");
 
     // Handle a true 2-D matrix transpose 
diff erently when requested.
-    if (vectorTransformOptions.vectorTransposeLowering ==
-            vector::VectorTransposeLowering::Flat &&
+    if (vectorTransposeLowering == vector::VectorTransposeLowering::Flat &&
         resType.getRank() == 2 && transp[0] == 1 && transp[1] == 0) {
       Type flattenedType =
           VectorType::get(resType.getNumElements(), resType.getElementType());
@@ -380,7 +379,7 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
 
 private:
   /// Options to control the vector patterns.
-  vector::VectorTransformsOptions vectorTransformOptions;
+  vector::VectorTransposeLowering vectorTransposeLowering;
 };
 
 /// Rewrites vector.transpose as vector.shape_cast. This pattern is only applied
@@ -454,14 +453,14 @@ class TransposeOp2DToShuffleLowering
   using OpRewritePattern::OpRewritePattern;
 
   TransposeOp2DToShuffleLowering(
-      vector::VectorTransformsOptions vectorTransformOptions,
+      vector::VectorTransposeLowering vectorTransposeLowering,
       MLIRContext *context, PatternBenefit benefit = 1)
       : OpRewritePattern<vector::TransposeOp>(context, benefit),
-        vectorTransformOptions(vectorTransformOptions) {}
+        vectorTransposeLowering(vectorTransposeLowering) {}
 
   LogicalResult matchAndRewrite(vector::TransposeOp op,
                                 PatternRewriter &rewriter) const override {
-    if (!isShuffleLike(vectorTransformOptions.vectorTransposeLowering))
+    if (!isShuffleLike(vectorTransposeLowering))
       return rewriter.notifyMatchFailure(
           op, "not using vector shuffle based lowering");
 
@@ -487,8 +486,7 @@ class TransposeOp2DToShuffleLowering
                                                           op.getVector());
 
     Value res;
-    if (vectorTransformOptions.vectorTransposeLowering ==
-            VectorTransposeLowering::Shuffle16x16 &&
+    if (vectorTransposeLowering == VectorTransposeLowering::Shuffle16x16 &&
         m == 16 && n == 16) {
       reshInput =
           rewriter.create<vector::ShapeCastOp>(loc, reshInputType, reshInput);
@@ -506,15 +504,15 @@ class TransposeOp2DToShuffleLowering
 
 private:
   /// Options to control the vector patterns.
-  vector::VectorTransformsOptions vectorTransformOptions;
+  vector::VectorTransposeLowering vectorTransposeLowering;
 };
 } // namespace
 
 void mlir::vector::populateVectorTransposeLoweringPatterns(
-    RewritePatternSet &patterns, VectorTransformsOptions options,
-    PatternBenefit benefit) {
+    RewritePatternSet &patterns,
+    VectorTransposeLowering vectorTransposeLowering, PatternBenefit benefit) {
   patterns.add<Transpose2DWithUnitDimToShapeCast>(patterns.getContext(),
                                                   benefit);
   patterns.add<TransposeOpLowering, TransposeOp2DToShuffleLowering>(
-      options, patterns.getContext(), benefit);
+      vectorTransposeLowering, patterns.getContext(), benefit);
 }

diff  --git a/mlir/test/Conversion/VectorToLLVM/pass-option-serialization.mlir b/mlir/test/Conversion/VectorToLLVM/pass-option-serialization.mlir
new file mode 100644
index 0000000000000..ebf06c57a1b3b
--- /dev/null
+++ b/mlir/test/Conversion/VectorToLLVM/pass-option-serialization.mlir
@@ -0,0 +1,30 @@
+// Ensure that ConvertVectorToLLVMPass options remain serialisable.
+
+// This test also allows us to exercise these options (to some extent) even if we
+// don't use them in other Vector to LLVM conversion tests. This is quite relevant
+// for the `Vector` Dialect (and `--convert-vector-to-llvm` pass) as in many cases
+// we use the Transform Dialect (TD) rather than `--convert-vector-to-llvm` for
+// testing. So here we don't check the correctness of the passes, as they're
+// covered by other tests that use TD, but we still provide some test coverage of
+// these pass options.
+
+// We don't need to actually parse any IR to print the pass options. We just need
+// to provide --dump-pass-pipeline
+
+// RUN: mlir-opt --convert-vector-to-llvm --dump-pass-pipeline %s 2>&1 | FileCheck %s --check-prefix=DEFAULT
+
+// RUN: mlir-opt --convert-vector-to-llvm='vector-contract-lowering=matmul vector-transpose-lowering=flat' \
+// RUN:          --dump-pass-pipeline 2>&1 | FileCheck %s --check-prefix=NON-DEFAULT
+
+// CHECK: builtin.module(
+// CHECK-SAME: convert-vector-to-llvm{
+// CHECK-SAME: enable-amx={{[aA-zZ0-9]+}}
+// CHECK-SAME: enable-arm-neon={{[aA-zZ0-9]+}}
+// CHECK-SAME: enable-arm-sve={{[aA-zZ0-9]+}}
+// CHECK-SAME: enable-x86vector={{[aA-zZ0-9]+}}
+// CHECK-SAME: force-32bit-vector-indices={{[aA-zZ0-9]+}}
+// CHECK-SAME: reassociate-fp-reductions={{[aA-zZ0-9]+}}
+// DEFAULT: vector-contract-lowering=dot
+// DEFAULT: vector-transpose-lowering=eltwise
+// NON-DEFAULT: vector-contract-lowering=matmul
+// NON-DEFAULT: vector-transpose-lowering=flat


        


More information about the Mlir-commits mailing list