[Mlir-commits] [mlir] d054b80 - [mlir][Vector] NFC - Add option to hook vector.transpose lowering to strategies.

Nicolas Vasilache llvmlistbot at llvm.org
Mon Oct 25 05:27:42 PDT 2021


Author: Nicolas Vasilache
Date: 2021-10-25T12:26:33Z
New Revision: d054b80bd3ab1a78d1a870f941024429273d2a83

URL: https://github.com/llvm/llvm-project/commit/d054b80bd3ab1a78d1a870f941024429273d2a83
DIFF: https://github.com/llvm/llvm-project/commit/d054b80bd3ab1a78d1a870f941024429273d2a83.diff

LOG: [mlir][Vector] NFC - Add option to hook vector.transpose lowering to strategies.

This revision also moves some code around to improve overall structure.

Differential Revision: https://reviews.llvm.org/D112437

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/include/mlir/Dialect/Vector/VectorOps.h
    mlir/include/mlir/Dialect/Vector/VectorRewritePatterns.h
    mlir/include/mlir/Dialect/Vector/VectorTransforms.h
    mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
    mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
    mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp
    mlir/lib/Dialect/Vector/VectorMultiDimReductionTransforms.cpp
    mlir/lib/Dialect/Vector/VectorTransforms.cpp
    mlir/test/lib/Dialect/Linalg/TestConvVectorization.cpp
    mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
    mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 640e1221aeb53..cfa38d71c2ba3 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -15,7 +15,7 @@
 #include "mlir/Dialect/SCF/Utils.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
-#include "mlir/Dialect/Vector/VectorOps.h"
+#include "mlir/Dialect/Vector/VectorTransforms.h"
 #include "mlir/IR/Identifier.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Transforms/Bufferize.h"
@@ -846,6 +846,9 @@ struct LinalgVectorizationPattern : public LinalgBaseVectorizationPattern {
       : LinalgBaseVectorizationPattern(opName, context, filter, benefit) {}
 };
 
+//===----------------------------------------------------------------------===//
+// Transformation and lowering options exposed as auxiliary structs.
+//===----------------------------------------------------------------------===//
 /// Options to control the application of enabling transformations.
 /// Hoisting transformations are always deemed beneficial and must be disabled
 /// explicitly.
@@ -887,10 +890,16 @@ struct LinalgVectorLoweringOptions {
     transferLowering = val;
     return *this;
   }
-  /// Trigger full / partial vector.transfer splits.
-  bool transferPartialRewrite = false;
-  LinalgVectorLoweringOptions &enableTransferPartialRewrite(bool val = true) {
-    transferPartialRewrite = val;
+  /// Enable lowering of vector.transpose.
+  bool transposeLowering = false;
+  LinalgVectorLoweringOptions &enableVectorTransposeLowering(bool val = true) {
+    transposeLowering = val;
+    return *this;
+  }
+  /// Enable lowering of vector.multi_reduce.
+  bool multiReductionLowering = false;
+  LinalgVectorLoweringOptions &enableMultiReductionLowering(bool val = true) {
+    multiReductionLowering = val;
     return *this;
   }
   /// Enable lowering of vector.contract.
@@ -899,10 +908,10 @@ struct LinalgVectorLoweringOptions {
     contractionLowering = val;
     return *this;
   }
-  /// Enable lowering of vector.multi_reduce.
-  bool multiReductionLowering = false;
-  LinalgVectorLoweringOptions &enableMultiReductionLowering(bool val = true) {
-    multiReductionLowering = val;
+  /// Trigger full / partial vector.transfer splits.
+  bool transferPartialRewrite = false;
+  LinalgVectorLoweringOptions &enableTransferPartialRewrite(bool val = true) {
+    transferPartialRewrite = val;
     return *this;
   }
   /// Enable lowering of vector.transfer to scf.
@@ -911,13 +920,6 @@ struct LinalgVectorLoweringOptions {
     transferToSCFConversion = val;
     return *this;
   }
-  /// Configure late vector transformations.
-  vector::VectorTransformsOptions vectorTransformOptions;
-  LinalgVectorLoweringOptions &
-  setVectorTransformsOptions(vector::VectorTransformsOptions options) {
-    vectorTransformOptions = options;
-    return *this;
-  }
   /// Configure the post staged-patterns late vector.transfer to scf
   /// conversion.
   VectorTransferToSCFOptions vectorTransferToSCFOptions;
@@ -926,8 +928,18 @@ struct LinalgVectorLoweringOptions {
     vectorTransferToSCFOptions = options;
     return *this;
   }
+  /// Configure late vector transformations.
+  vector::VectorTransformsOptions vectorTransformOptions;
+  LinalgVectorLoweringOptions &
+  setVectorTransformsOptions(vector::VectorTransformsOptions options) {
+    vectorTransformOptions = options;
+    return *this;
+  }
 };
 
+//===----------------------------------------------------------------------===//
+// Transformations exposed as rewrite patterns.
+//===----------------------------------------------------------------------===//
 /// Trait to check if T provides a `getOperationName` method.
 template <typename T, typename... Args>
 using has_get_operation_name = decltype(T::getOperationName());

diff  --git a/mlir/include/mlir/Dialect/Vector/VectorOps.h b/mlir/include/mlir/Dialect/Vector/VectorOps.h
index dd56cd1ea1926..c6f4ba4bc0e59 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.h
@@ -40,76 +40,6 @@ namespace detail {
 struct BitmaskEnumStorage;
 } // namespace detail
 
-/// Enum to control the lowering of `vector.contract` operations.
-enum class VectorContractLowering {
-  /// Progressively lower to finer grained `vector.contract` and dot-products.
-  Dot = 0,
-  /// Lower to `vector.matrix_multiply`, maps 1-1 to LLVM matrix intrinsics.
-  Matmul = 1,
-  /// Lower to `vector.outerproduct`.
-  OuterProduct = 2,
-};
-/// Enum to control the lowering of `vector.multi_reduction` operations.
-enum class VectorMultiReductionLowering {
-  /// Lower multi_reduction into outer-reduction and inner-parallel ops.
-  InnerParallel = 0,
-  /// Lower multi_reduction into outer-parallel and inner-reduction ops.
-  InnerReduction = 1,
-};
-/// Enum to control the lowering of `vector.transpose` operations.
-enum class VectorTransposeLowering {
-  /// Lower transpose into element-wise extract and inserts.
-  EltWise = 0,
-  /// Lower 2-D transpose to `vector.flat_transpose`, maps 1-1 to LLVM matrix
-  /// intrinsics.
-  Flat = 1,
-};
-/// Enum to control the splitting of `vector.transfer` operations into
-/// in-bounds and out-of-bounds variants.
-enum class VectorTransferSplit {
-  /// Do not split vector transfer operations.
-  None = 0,
-  /// Split using in-bounds + out-of-bounds vector.transfer operations.
-  VectorTransfer = 1,
-  /// Split using an in-bounds vector.transfer + linalg.fill + linalg.copy
-  /// operations.
-  LinalgCopy = 2,
-  /// Do not split vector transfer operation but instead mark it as "in-bounds".
-  ForceInBounds = 3
-};
-/// Structure to control the behavior of vector transform patterns.
-struct VectorTransformsOptions {
-  /// Option to control the lowering of vector.contract.
-  VectorContractLowering vectorContractLowering = VectorContractLowering::Dot;
-  VectorTransformsOptions &
-  setVectorTransformsOptions(VectorContractLowering opt) {
-    vectorContractLowering = opt;
-    return *this;
-  }
-  /// Option to control the lowering of vector.multi_reduction.
-  VectorMultiReductionLowering vectorMultiReductionLowering =
-      VectorMultiReductionLowering::InnerParallel;
-  VectorTransformsOptions &
-  setVectorMultiReductionLowering(VectorMultiReductionLowering opt) {
-    vectorMultiReductionLowering = opt;
-    return *this;
-  }
-  /// Option to control the lowering of vector.transpose.
-  VectorTransposeLowering vectorTransposeLowering =
-      VectorTransposeLowering::EltWise;
-  VectorTransformsOptions &
-  setVectorTransposeLowering(VectorTransposeLowering opt) {
-    vectorTransposeLowering = opt;
-    return *this;
-  }
-  /// Option to control the splitting of vector transfers.
-  VectorTransferSplit vectorTransferSplit = VectorTransferSplit::None;
-  VectorTransformsOptions &setVectorTransferSplit(VectorTransferSplit opt) {
-    vectorTransferSplit = opt;
-    return *this;
-  }
-};
-
 /// Return whether `srcType` can be broadcast to `dstVectorType` under the
 /// semantics of the `vector.broadcast` op.
 enum class BroadcastableToResult {
@@ -161,33 +91,6 @@ void populateVectorTransferPermutationMapLoweringPatterns(
 void populateVectorMaskMaterializationPatterns(RewritePatternSet &patterns,
                                                bool enableIndexOptimizations);
 
-/// Collect a set of patterns to convert vector.multi_reduction op into
-/// a sequence of vector.reduction ops. The patterns comprise:
-/// - InnerOuterDimReductionConversion: rewrites vector.multi_reduction such
-/// that all reduction dimensions are either innermost or outermost, by adding
-/// the proper vector.transpose operations.
-/// - ReduceMultiDimReductionRank: once in innermost or outermost reduction
-/// form, rewrites n-D vector.multi_reduction into 2-D vector.multi_reduction,
-/// by introducing vector.shape_cast ops to collapse + multi-reduce + expand
-/// back.
-/// - TwoDimMultiReductionToElementWise: once in 2-D vector.multi_reduction
-/// form, with an **outermost** reduction dimension, unroll the outer dimension
-/// to obtain a sequence of 1-D vector ops. This also has an opportunity for
-/// tree-reduction (in the future).
-/// - TwoDimMultiReductionToReduction: once in 2-D vector.multi_reduction form,
-/// with an **innermost** reduction dimension, unroll the outer dimension to
-/// obtain a sequence of extract + vector.reduction + insert. This can further
-/// lower to horizontal reduction ops.
-/// - OneDimMultiReductionToTwoDim: for cases that reduce to 1-D vector<k>
-/// reduction (and are thus missing either a parallel or a reduction), we lift
-/// them back up to 2-D with a simple vector.shape_cast to vector<1xk> so that
-/// the other patterns can kick in, thus fully exiting out of the
-/// vector.multi_reduction abstraction.
-void populateVectorMultiReductionLoweringPatterns(
-    RewritePatternSet &patterns,
-    VectorMultiReductionLowering options =
-        vector::VectorMultiReductionLowering::InnerParallel);
-
 /// Collect a set of patterns to propagate insert_map/extract_map in the ssa
 /// chain.
 void populatePropagateVectorDistributionPatterns(RewritePatternSet &patterns);
@@ -212,12 +115,6 @@ class CombiningKindAttr
 /// vectors to low-D vector ops.
 void populateVectorBroadcastLoweringPatterns(RewritePatternSet &patterns);
 
-/// Collects patterns to progressively lower vector contraction ops on high-D
-/// into low-D reduction and product ops.
-void populateVectorContractLoweringPatterns(
-    RewritePatternSet &patterns,
-    VectorTransformsOptions options = VectorTransformsOptions());
-
 /// Collects patterns to progressively lower vector mask ops into elementary
 /// selection and insertion ops.
 void populateVectorMaskOpLoweringPatterns(RewritePatternSet &patterns);
@@ -227,15 +124,6 @@ void populateVectorMaskOpLoweringPatterns(RewritePatternSet &patterns);
 /// ops.
 void populateVectorShapeCastLoweringPatterns(RewritePatternSet &patterns);
 
-/// Insert TransposeLowering patterns into extraction/insertion.
-void populateVectorTransposeLoweringPatterns(
-    RewritePatternSet &patterns,
-    VectorTransformsOptions options = VectorTransformsOptions());
-
-/// Collect patterns to convert reduction op to vector.contract and fold
-/// transpose/broadcast ops into the contract.
-void populateVetorReductionToContractPatterns(RewritePatternSet &patterns);
-
 /// Returns the integer type required for subscripts in the vector dialect.
 IntegerType getVectorSubscriptType(Builder &builder);
 

diff  --git a/mlir/include/mlir/Dialect/Vector/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/VectorRewritePatterns.h
index 13b310713f7b5..47375c56673f8 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/VectorRewritePatterns.h
@@ -9,11 +9,173 @@
 #ifndef DIALECT_VECTOR_VECTORREWRITEPATTERNS_H_
 #define DIALECT_VECTOR_VECTORREWRITEPATTERNS_H_
 
+#include "mlir/Dialect/Vector/VectorOps.h"
+#include "mlir/Dialect/Vector/VectorUtils.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/PatternMatch.h"
+
 namespace mlir {
 class RewritePatternSet;
 
 namespace vector {
 
+//===----------------------------------------------------------------------===//
+// Vector transformation options exposed as auxiliary structs.
+//===----------------------------------------------------------------------===//
+/// Enum to control the lowering of `vector.transpose` operations.
+enum class VectorTransposeLowering {
+  /// Lower transpose into element-wise extract and inserts.
+  EltWise = 0,
+  /// Lower 2-D transpose to `vector.flat_transpose`, maps 1-1 to LLVM matrix
+  /// intrinsics.
+  Flat = 1,
+};
+/// Enum to control the lowering of `vector.multi_reduction` operations.
+enum class VectorMultiReductionLowering {
+  /// Lower multi_reduction into outer-reduction and inner-parallel ops.
+  InnerParallel = 0,
+  /// Lower multi_reduction into outer-parallel and inner-reduction ops.
+  InnerReduction = 1,
+};
+/// Enum to control the lowering of `vector.contract` operations.
+enum class VectorContractLowering {
+  /// Progressively lower to finer grained `vector.contract` and dot-products.
+  Dot = 0,
+  /// Lower to `vector.matrix_multiply`, maps 1-1 to LLVM matrix intrinsics.
+  Matmul = 1,
+  /// Lower to `vector.outerproduct`.
+  OuterProduct = 2,
+};
+/// Enum to control the splitting of `vector.transfer` operations into
+/// in-bounds and out-of-bounds variants.
+enum class VectorTransferSplit {
+  /// Do not split vector transfer operations.
+  None = 0,
+  /// Split using in-bounds + out-of-bounds vector.transfer operations.
+  VectorTransfer = 1,
+  /// Split using an in-bounds vector.transfer + linalg.fill + linalg.copy
+  /// operations.
+  LinalgCopy = 2,
+  /// Do not split vector transfer operation but instead mark it as "in-bounds".
+  ForceInBounds = 3
+};
+/// Structure to control the behavior of vector transform patterns.
+struct VectorTransformsOptions {
+  /// Option to control the lowering of vector.contract.
+  VectorContractLowering vectorContractLowering = VectorContractLowering::Dot;
+  VectorTransformsOptions &
+  setVectorTransformsOptions(VectorContractLowering opt) {
+    vectorContractLowering = opt;
+    return *this;
+  }
+  /// Option to control the lowering of vector.multi_reduction.
+  VectorMultiReductionLowering vectorMultiReductionLowering =
+      VectorMultiReductionLowering::InnerParallel;
+  VectorTransformsOptions &
+  setVectorMultiReductionLowering(VectorMultiReductionLowering opt) {
+    vectorMultiReductionLowering = opt;
+    return *this;
+  }
+  /// Option to control the lowering of vector.transpose.
+  VectorTransposeLowering vectorTransposeLowering =
+      VectorTransposeLowering::EltWise;
+  VectorTransformsOptions &
+  setVectorTransposeLowering(VectorTransposeLowering opt) {
+    vectorTransposeLowering = opt;
+    return *this;
+  }
+  /// Option to control the splitting of vector transfers.
+  VectorTransferSplit vectorTransferSplit = VectorTransferSplit::None;
+  VectorTransformsOptions &setVectorTransferSplit(VectorTransferSplit opt) {
+    vectorTransferSplit = opt;
+    return *this;
+  }
+};
+
+/// Options that control the vector unrolling.
+struct UnrollVectorOptions {
+  using FilterConstraintFnType = std::function<LogicalResult(Operation *op)>;
+  /// Callback function that indicates whether vector unrolling should be
+  /// attempted on the operation.
+  FilterConstraintFnType filterConstraint = nullptr;
+  UnrollVectorOptions &setFilterConstraint(FilterConstraintFnType constraint) {
+    filterConstraint = constraint;
+    return *this;
+  }
+
+  using NativeShapeFnType =
+      std::function<Optional<SmallVector<int64_t, 4>>(Operation *op)>;
+  /// Function that returns the shape of the vector to unroll to for a given
+  /// operation. The unrolling is aborted if the function returns `llvm::None`.
+  NativeShapeFnType nativeShape = nullptr;
+  UnrollVectorOptions &setNativeShapeFn(NativeShapeFnType fn) {
+    nativeShape = fn;
+    return *this;
+  }
+
+  /// Set the native shape to use for unrolling.
+  UnrollVectorOptions &setNativeShape(ArrayRef<int64_t> shape) {
+    SmallVector<int64_t, 4> tsShape(shape.begin(), shape.end());
+    nativeShape = [=](Operation *) -> Optional<SmallVector<int64_t, 4>> {
+      return tsShape;
+    };
+    return *this;
+  }
+};
+
+//===----------------------------------------------------------------------===//
+// Vector transformation exposed as populate functions over rewrite patterns.
+//===----------------------------------------------------------------------===//
+
+/// Insert TransposeLowering patterns into extraction/insertion.
+void populateVectorTransposeLoweringPatterns(
+    RewritePatternSet &patterns,
+    VectorTransformsOptions options = VectorTransformsOptions());
+
+/// Collect a set of patterns to convert vector.multi_reduction op into
+/// a sequence of vector.reduction ops. The patterns comprise:
+/// - InnerOuterDimReductionConversion: rewrites vector.multi_reduction such
+/// that all reduction dimensions are either innermost or outermost, by adding
+/// the proper vector.transpose operations.
+/// - ReduceMultiDimReductionRank: once in innermost or outermost reduction
+/// form, rewrites n-D vector.multi_reduction into 2-D vector.multi_reduction,
+/// by introducing vector.shape_cast ops to collapse + multi-reduce + expand
+/// back.
+/// - TwoDimMultiReductionToElementWise: once in 2-D vector.multi_reduction
+/// form, with an **outermost** reduction dimension, unroll the outer dimension
+/// to obtain a sequence of 1-D vector ops. This also has an opportunity for
+/// tree-reduction (in the future).
+/// - TwoDimMultiReductionToReduction: once in 2-D vector.multi_reduction form,
+/// with an **innermost** reduction dimension, unroll the outer dimension to
+/// obtain a sequence of extract + vector.reduction + insert. This can further
+/// lower to horizontal reduction ops.
+/// - OneDimMultiReductionToTwoDim: for cases that reduce to 1-D vector<k>
+/// reduction (and are thus missing either a parallel or a reduction), we lift
+/// them back up to 2-D with a simple vector.shape_cast to vector<1xk> so that
+/// the other patterns can kick in, thus fully exiting out of the
+/// vector.multi_reduction abstraction.
+void populateVectorMultiReductionLoweringPatterns(
+    RewritePatternSet &patterns,
+    VectorMultiReductionLowering options =
+        VectorMultiReductionLowering::InnerParallel);
+
+/// Collects patterns to progressively lower vector contraction ops on high-D
+/// into low-D reduction and product ops.
+void populateVectorContractLoweringPatterns(
+    RewritePatternSet &patterns,
+    VectorTransformsOptions options = VectorTransformsOptions());
+
+/// Collect patterns to convert reduction op to vector.contract and fold
+/// transpose/broadcast ops into the contract.
+void populateVectorReductionToContractPatterns(RewritePatternSet &patterns);
+
+/// Collect a set of patterns to reduce the rank of the operands of vector
+/// transfer ops to operate on the largest contigious vector.
+/// These patterns are useful when lowering to dialects with 1d vector type
+/// such as llvm and it will result fewer memory reads.
+void populateVectorTransferCollapseInnerMostContiguousDimsPatterns(
+    RewritePatternSet &patterns);
+
 /// Populate `patterns` with the following patterns.
 ///
 /// [VectorInsertStridedSliceOpDifferentRankRewritePattern]
@@ -52,6 +214,235 @@ namespace vector {
 void populateVectorInsertExtractStridedSliceTransforms(
     RewritePatternSet &patterns);
 
+/// Collect a set of pattern to unroll vector operations to a smaller shapes.
+/// `options` structure controls which operations are unrolled and the target
+/// shape.
+/// `op` is unrolled to the `targetShape` as follows, for each of its operands:
+///   1. the unrolled type `unrolledVectorType` and number of unrolled instances
+///   `numUnrolledInstances` are computed from the `targetShape`. For now it is
+///   assumed the unrolling factors divide the vector sizes.
+///   2. ExtractStridedSlice are created to break-up the vector operands.
+///   3. the original op is cloned `numUnrolledInstances` times, once for each
+///   result.
+///   4. InsertStridedSlice are inserted to re-assemble the slices into the
+///   original vectore shape.
+///
+/// Example:
+///
+///    opA(operand0, operand1)  // numUnrolledInstances = 3
+///
+///            operand0                   operand1
+///               |                          |
+///             fork                       fork
+///        <----------gather all fork ops --------->
+///              /|\                        /|\
+///          f00 f01 f02                f10 f11 f12
+///        <---------- clone op 3 times --------->
+///          opA0(f00, f10), opA1(f01, f11), opA2(f02, f12)
+///                 \            |            /
+///      <-------------------- join ------------------------->
+///
+/// Other local patterns then kick in iteratively (including DCE) and compose
+/// to combine the ExtractStridedSlice/InsertStridedSlice.
+void populateVectorUnrollPatterns(RewritePatternSet &patterns,
+                                  const UnrollVectorOptions &options);
+
+//===----------------------------------------------------------------------===//
+// Finer-grained patterns exposed for more control over individual lowerings.
+//===----------------------------------------------------------------------===//
+/// Apply `splitFullAndPartialTransfer` selectively via a pattern. This pattern
+/// may take an extra filter to perform selection at a finer granularity.
+struct VectorTransferFullPartialRewriter : public RewritePattern {
+  using FilterConstraintType =
+      std::function<LogicalResult(VectorTransferOpInterface op)>;
+
+  explicit VectorTransferFullPartialRewriter(
+      MLIRContext *context,
+      VectorTransformsOptions options = VectorTransformsOptions(),
+      FilterConstraintType filter =
+          [](VectorTransferOpInterface op) { return success(); },
+      PatternBenefit benefit = 1)
+      : RewritePattern(MatchAnyOpTypeTag(), benefit, context), options(options),
+        filter(filter) {}
+
+  /// Performs the rewrite.
+  LogicalResult matchAndRewrite(Operation *op,
+                                PatternRewriter &rewriter) const override;
+
+private:
+  VectorTransformsOptions options;
+  FilterConstraintType filter;
+};
+
+/// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
+/// semantics to:
+/// ```
+///    %flattened_a = vector.shape_cast %a
+///    %flattened_b = vector.shape_cast %b
+///    %flattened_d = vector.matmul %flattened_a, %flattened_b
+///    %d = vector.shape_cast %%flattened_d
+///    %e = add %c, %d
+/// ```
+/// `vector.matmul` later lowers to `llvm.matrix.multiply`.
+//
+/// This only kicks in when VectorTransformsOptions is set to OuterProduct and
+/// the vector.contract op is a row-major matrix multiply.
+class ContractionOpToMatmulOpLowering
+    : public OpRewritePattern<vector::ContractionOp> {
+public:
+  using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
+  using FilterConstraintType =
+      std::function<LogicalResult(vector::ContractionOp op)>;
+
+  static LogicalResult defaultFilter(vector::ContractionOp op) {
+    return success();
+  }
+
+  ContractionOpToMatmulOpLowering(
+      vector::VectorTransformsOptions vectorTransformOptions,
+      MLIRContext *context, FilterConstraintType constraint = defaultFilter)
+      : OpRewritePattern<vector::ContractionOp>(context),
+        vectorTransformOptions(vectorTransformOptions), filter(constraint) {}
+
+  LogicalResult matchAndRewrite(vector::ContractionOp op,
+                                PatternRewriter &rewriter) const override;
+
+private:
+  /// Options to control the vector patterns.
+  vector::VectorTransformsOptions vectorTransformOptions;
+  FilterConstraintType filter;
+};
+
+/// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
+/// semantics to a reduction_size-unrolled sequence:
+/// ```
+///    %at = vector.transpose %a, [1, 0]
+///    %bRow0 = vector.extract %b[0]
+///    %atRow0 = vector.extract %at[0]
+///    %c0 = vector.outerproduct %atRow0, %bRow0, %c
+///    ...
+///    %bRowK = vector.extract %b[K]
+///    %atRowK = vector.extract %at[K]
+///    %cK = vector.outerproduct %atRowK, %bRowK, %cK-1
+/// ```
+///
+/// This only kicks in when VectorTransformsOptions is set to OuterProduct and
+/// the vector.contract op is a row-major matrix multiply.
+class ContractionOpToOuterProductOpLowering
+    : public OpRewritePattern<vector::ContractionOp> {
+public:
+  using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
+  using FilterConstraintType =
+      std::function<LogicalResult(vector::ContractionOp op)>;
+
+  static LogicalResult defaultFilter(vector::ContractionOp op) {
+    return success();
+  }
+
+  ContractionOpToOuterProductOpLowering(
+      vector::VectorTransformsOptions vectorTransformOptions,
+      MLIRContext *context, FilterConstraintType constraint = defaultFilter)
+      : OpRewritePattern<vector::ContractionOp>(context),
+        vectorTransformOptions(vectorTransformOptions), filter(constraint) {}
+
+  LogicalResult matchAndRewrite(vector::ContractionOp op,
+                                PatternRewriter &rewriter) const override;
+
+private:
+  /// Options to control the vector patterns.
+  vector::VectorTransformsOptions vectorTransformOptions;
+  FilterConstraintType filter;
+};
+
+/// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
+/// semantics to an output-size-unrolled sequence:
+/// ```
+///    %out = arith.constant ... : vector<MxNxelt_type>
+///    %bt = vector.transpose %b, [1, 0]
+///    %aRow0 = vector.extract %a[0]
+///    %btRow0 = vector.extract %bt[0]
+///    %c00 = vector.reduce %atRow0, %bRow0
+///    %out00 = vector.insert %c00, %out[0, 0]
+///    ...
+///    %aRowLast = vector.extract %at[M-1]
+///    %btRowLast = vector.extract %b[N-1]
+///    %cLastLast = vector.reduce %atRowLast, %bRowLast
+///    %outcLastLast = vector.insert %cLastLast, %out[M-1, N-1]
+/// ```
+///
+/// This only kicks in when VectorTransformsOptions is set to Dot and
+/// the vector.contract op is a row-major matmul or matvec.
+class ContractionOpToDotLowering
+    : public OpRewritePattern<vector::ContractionOp> {
+public:
+  using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
+  using FilterConstraintType =
+      std::function<LogicalResult(vector::ContractionOp op)>;
+
+  static LogicalResult defaultFilter(vector::ContractionOp op) {
+    return success();
+  }
+
+  ContractionOpToDotLowering(
+      vector::VectorTransformsOptions vectorTransformOptions,
+      MLIRContext *context, FilterConstraintType constraint = defaultFilter)
+      : OpRewritePattern<vector::ContractionOp>(context),
+        vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {}
+
+  LogicalResult matchAndRewrite(vector::ContractionOp op,
+                                PatternRewriter &rewriter) const override;
+
+private:
+  /// Options to control the vector patterns.
+  vector::VectorTransformsOptions vectorTransformOptions;
+  FilterConstraintType filter;
+};
+
+/// Progressive lowering of ContractionOp.
+///
+/// One:
+///   %x = vector.contract with at least one free/batch dimension
+/// is replaced by:
+///   %a = vector.contract with one less free/batch dimension
+///   %b = vector.contract with one less free/batch dimension
+///   ..
+///   %x = combine %a %b ..
+/// until a pure contraction is reached (no free/batch dimensions),
+/// which is replaced by a dot-product.
+///
+/// This only kicks in when either VectorTransformsOptions is set
+/// to Dot or when other contraction patterns fail.
+class ContractionOpLowering : public OpRewritePattern<vector::ContractionOp> {
+public:
+  using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
+  using FilterConstraintType =
+      std::function<LogicalResult(vector::ContractionOp op)>;
+
+  static LogicalResult defaultFilter(vector::ContractionOp op) {
+    return success();
+  }
+
+  ContractionOpLowering(vector::VectorTransformsOptions vectorTransformOptions,
+                        MLIRContext *context,
+                        FilterConstraintType constraint = defaultFilter)
+      : OpRewritePattern<vector::ContractionOp>(context),
+        vectorTransformOptions(vectorTransformOptions), filter(constraint) {}
+
+  LogicalResult matchAndRewrite(vector::ContractionOp op,
+                                PatternRewriter &rewriter) const override;
+
+private:
+  /// Options to control the vector patterns.
+  vector::VectorTransformsOptions vectorTransformOptions;
+  FilterConstraintType filter;
+  // Lower one parallel dimension.
+  Value lowerParallel(vector::ContractionOp op, int64_t lhsIndex,
+                      int64_t rhsIndex, PatternRewriter &rewriter) const;
+  // Lower one reduction dimension.
+  Value lowerReduction(vector::ContractionOp op,
+                       PatternRewriter &rewriter) const;
+};
+
 } // namespace vector
 } // namespace mlir
 

diff  --git a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h
index d26636c132ac4..811d72192910e 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h
+++ b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h
@@ -9,10 +9,8 @@
 #ifndef DIALECT_VECTOR_VECTORTRANSFORMS_H_
 #define DIALECT_VECTOR_VECTORTRANSFORMS_H_
 
-#include "mlir/Dialect/Vector/VectorOps.h"
+#include "mlir/Dialect/Vector/VectorRewritePatterns.h"
 #include "mlir/Dialect/Vector/VectorUtils.h"
-#include "mlir/IR/BuiltinOps.h"
-#include "mlir/IR/PatternMatch.h"
 
 namespace mlir {
 class MLIRContext;
@@ -26,77 +24,9 @@ class IfOp;
 
 namespace vector {
 
-/// Options that control the vector unrolling.
-struct UnrollVectorOptions {
-  using FilterConstraintFnType = std::function<LogicalResult(Operation *op)>;
-  /// Callback function that indicates whether vector unrolling should be
-  /// attempted on the operation.
-  FilterConstraintFnType filterConstraint = nullptr;
-  UnrollVectorOptions &setFilterConstraint(FilterConstraintFnType constraint) {
-    filterConstraint = constraint;
-    return *this;
-  }
-
-  using NativeShapeFnType =
-      std::function<Optional<SmallVector<int64_t, 4>>(Operation *op)>;
-  /// Function that returns the shape of the vector to unroll to for a given
-  /// operation. The unrolling is aborted if the function returns `llvm::None`.
-  NativeShapeFnType nativeShape = nullptr;
-  UnrollVectorOptions &setNativeShapeFn(NativeShapeFnType fn) {
-    nativeShape = fn;
-    return *this;
-  }
-
-  /// Set the native shape to use for unrolling.
-  UnrollVectorOptions &setNativeShape(ArrayRef<int64_t> shape) {
-    SmallVector<int64_t, 4> tsShape(shape.begin(), shape.end());
-    nativeShape = [=](Operation *) -> Optional<SmallVector<int64_t, 4>> {
-      return tsShape;
-    };
-    return *this;
-  }
-};
-
-/// Collect a set of pattern to unroll vector operations to a smaller shapes.
-/// `options` structure controls which operations are unrolled and the target
-/// shape.
-/// `op` is unrolled to the `targetShape` as follows, for each of its operands:
-///   1. the unrolled type `unrolledVectorType` and number of unrolled instances
-///   `numUnrolledInstances` are computed from the `targetShape`. For now it is
-///   assumed the unrolling factors divide the vector sizes.
-///   2. ExtractStridedSlice are created to break-up the vector operands.
-///   3. the original op is cloned `numUnrolledInstances` times, once for each
-///   result.
-///   4. InsertStridedSlice are inserted to re-assemble the slices into the
-///   original vectore shape.
-///
-/// Example:
-///
-///    opA(operand0, operand1)  // numUnrolledInstances = 3
-///
-///            operand0                   operand1
-///               |                          |
-///             fork                       fork
-///        <----------gather all fork ops --------->
-///              /|\                        /|\
-///          f00 f01 f02                f10 f11 f12
-///        <---------- clone op 3 times --------->
-///          opA0(f00, f10), opA1(f01, f11), opA2(f02, f12)
-///                 \            |            /
-///      <-------------------- join ------------------------->
-///
-/// Other local patterns then kick in iteratively (including DCE) and compose
-/// to combine the ExtractStridedSlice/InsertStridedSlice.
-void populateVectorUnrollPatterns(RewritePatternSet &patterns,
-                                  const UnrollVectorOptions &options);
-
-/// Collect a set of patterns to reduce the rank of the operands of vector
-/// transfer ops to operate on the largest contigious vector.
-/// These patterns are useful when lowering to dialects with 1d vector type
-/// such as llvm and it will result fewer memory reads.
-void populateVectorTransferCollapseInnerMostContiguousDimsPatterns(
-    RewritePatternSet &patterns);
-
+//===----------------------------------------------------------------------===//
+// Standalone transformations and helpers.
+//===----------------------------------------------------------------------===//
 /// Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds
 /// masking) fastpath and a slowpath.
 /// If `ifOp` is not null and the result is `success, the `ifOp` points to the
@@ -130,37 +60,11 @@ void populateVectorTransferCollapseInnerMostContiguousDimsPatterns(
 ///  2. the rank of the `xferOp.memref()` and the rank of the `xferOp.vector()`
 ///  must be equal. This will be relaxed in the future but requires
 ///  rank-reducing subviews.
-LogicalResult
-splitFullAndPartialTransferPrecondition(VectorTransferOpInterface xferOp);
 LogicalResult splitFullAndPartialTransfer(
     OpBuilder &b, VectorTransferOpInterface xferOp,
     VectorTransformsOptions options = VectorTransformsOptions(),
     scf::IfOp *ifOp = nullptr);
 
-/// Apply `splitFullAndPartialTransfer` selectively via a pattern. This pattern
-/// may take an extra filter to perform selection at a finer granularity.
-struct VectorTransferFullPartialRewriter : public RewritePattern {
-  using FilterConstraintType =
-      std::function<LogicalResult(VectorTransferOpInterface op)>;
-
-  explicit VectorTransferFullPartialRewriter(
-      MLIRContext *context,
-      VectorTransformsOptions options = VectorTransformsOptions(),
-      FilterConstraintType filter =
-          [](VectorTransferOpInterface op) { return success(); },
-      PatternBenefit benefit = 1)
-      : RewritePattern(MatchAnyOpTypeTag(), benefit, context), options(options),
-        filter(filter) {}
-
-  /// Performs the rewrite.
-  LogicalResult matchAndRewrite(Operation *op,
-                                PatternRewriter &rewriter) const override;
-
-private:
-  VectorTransformsOptions options;
-  FilterConstraintType filter;
-};
-
 struct DistributeOps {
   ExtractMapOp extract;
   InsertMapOp insert;
@@ -188,180 +92,6 @@ distributPointwiseVectorOp(OpBuilder &builder, Operation *op,
 void transferOpflowOpt(FuncOp func);
 
 } // namespace vector
-
-//===----------------------------------------------------------------------===//
-// Finer-grained patterns exposed for more control over individual lowerings.
-//===----------------------------------------------------------------------===//
-
-/// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
-/// semantics to:
-/// ```
-///    %flattened_a = vector.shape_cast %a
-///    %flattened_b = vector.shape_cast %b
-///    %flattened_d = vector.matmul %flattened_a, %flattened_b
-///    %d = vector.shape_cast %%flattened_d
-///    %e = add %c, %d
-/// ```
-/// `vector.matmul` later lowers to `llvm.matrix.multiply`.
-//
-/// This only kicks in when VectorTransformsOptions is set to OuterProduct and
-/// the vector.contract op is a row-major matrix multiply.
-class ContractionOpToMatmulOpLowering
-    : public OpRewritePattern<vector::ContractionOp> {
-public:
-  using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
-  using FilterConstraintType =
-      std::function<LogicalResult(vector::ContractionOp op)>;
-
-  static LogicalResult defaultFilter(vector::ContractionOp op) {
-    return success();
-  }
-
-  ContractionOpToMatmulOpLowering(
-      vector::VectorTransformsOptions vectorTransformOptions,
-      MLIRContext *context, FilterConstraintType constraint = defaultFilter)
-      : OpRewritePattern<vector::ContractionOp>(context),
-        vectorTransformOptions(vectorTransformOptions), filter(constraint) {}
-
-  LogicalResult matchAndRewrite(vector::ContractionOp op,
-                                PatternRewriter &rewriter) const override;
-
-private:
-  /// Options to control the vector patterns.
-  vector::VectorTransformsOptions vectorTransformOptions;
-  FilterConstraintType filter;
-};
-
-/// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
-/// semantics to a reduction_size-unrolled sequence:
-/// ```
-///    %at = vector.transpose %a, [1, 0]
-///    %bRow0 = vector.extract %b[0]
-///    %atRow0 = vector.extract %at[0]
-///    %c0 = vector.outerproduct %atRow0, %bRow0, %c
-///    ...
-///    %bRowK = vector.extract %b[K]
-///    %atRowK = vector.extract %at[K]
-///    %cK = vector.outerproduct %atRowK, %bRowK, %cK-1
-/// ```
-///
-/// This only kicks in when VectorTransformsOptions is set to OuterProduct and
-/// the vector.contract op is a row-major matrix multiply.
-class ContractionOpToOuterProductOpLowering
-    : public OpRewritePattern<vector::ContractionOp> {
-public:
-  using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
-  using FilterConstraintType =
-      std::function<LogicalResult(vector::ContractionOp op)>;
-
-  static LogicalResult defaultFilter(vector::ContractionOp op) {
-    return success();
-  }
-
-  ContractionOpToOuterProductOpLowering(
-      vector::VectorTransformsOptions vectorTransformOptions,
-      MLIRContext *context, FilterConstraintType constraint = defaultFilter)
-      : OpRewritePattern<vector::ContractionOp>(context),
-        vectorTransformOptions(vectorTransformOptions), filter(constraint) {}
-
-  LogicalResult matchAndRewrite(vector::ContractionOp op,
-                                PatternRewriter &rewriter) const override;
-
-private:
-  /// Options to control the vector patterns.
-  vector::VectorTransformsOptions vectorTransformOptions;
-  FilterConstraintType filter;
-};
-
-/// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
-/// semantics to an output-size-unrolled sequence:
-/// ```
-///    %out = arith.constant ... : vector<MxNxelt_type>
-///    %bt = vector.transpose %b, [1, 0]
-///    %aRow0 = vector.extract %a[0]
-///    %btRow0 = vector.extract %bt[0]
-///    %c00 = vector.reduce %atRow0, %bRow0
-///    %out00 = vector.insert %c00, %out[0, 0]
-///    ...
-///    %aRowLast = vector.extract %at[M-1]
-///    %btRowLast = vector.extract %b[N-1]
-///    %cLastLast = vector.reduce %atRowLast, %bRowLast
-///    %outcLastLast = vector.insert %cLastLast, %out[M-1, N-1]
-/// ```
-///
-/// This only kicks in when VectorTransformsOptions is set to Dot and
-/// the vector.contract op is a row-major matmul or matvec.
-class ContractionOpToDotLowering
-    : public OpRewritePattern<vector::ContractionOp> {
-public:
-  using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
-  using FilterConstraintType =
-      std::function<LogicalResult(vector::ContractionOp op)>;
-
-  static LogicalResult defaultFilter(vector::ContractionOp op) {
-    return success();
-  }
-
-  ContractionOpToDotLowering(
-      vector::VectorTransformsOptions vectorTransformOptions,
-      MLIRContext *context, FilterConstraintType constraint = defaultFilter)
-      : OpRewritePattern<vector::ContractionOp>(context),
-        vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {}
-
-  LogicalResult matchAndRewrite(vector::ContractionOp op,
-                                PatternRewriter &rewriter) const override;
-
-private:
-  /// Options to control the vector patterns.
-  vector::VectorTransformsOptions vectorTransformOptions;
-  FilterConstraintType filter;
-};
-
-/// Progressive lowering of ContractionOp.
-///
-/// One:
-///   %x = vector.contract with at least one free/batch dimension
-/// is replaced by:
-///   %a = vector.contract with one less free/batch dimension
-///   %b = vector.contract with one less free/batch dimension
-///   ..
-///   %x = combine %a %b ..
-/// until a pure contraction is reached (no free/batch dimensions),
-/// which is replaced by a dot-product.
-///
-/// This only kicks in when either VectorTransformsOptions is set
-/// to Dot or when other contraction patterns fail.
-class ContractionOpLowering : public OpRewritePattern<vector::ContractionOp> {
-public:
-  using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
-  using FilterConstraintType =
-      std::function<LogicalResult(vector::ContractionOp op)>;
-
-  static LogicalResult defaultFilter(vector::ContractionOp op) {
-    return success();
-  }
-
-  ContractionOpLowering(vector::VectorTransformsOptions vectorTransformOptions,
-                        MLIRContext *context,
-                        FilterConstraintType constraint = defaultFilter)
-      : OpRewritePattern<vector::ContractionOp>(context),
-        vectorTransformOptions(vectorTransformOptions), filter(constraint) {}
-
-  LogicalResult matchAndRewrite(vector::ContractionOp op,
-                                PatternRewriter &rewriter) const override;
-
-private:
-  /// Options to control the vector patterns.
-  vector::VectorTransformsOptions vectorTransformOptions;
-  FilterConstraintType filter;
-  // Lower one parallel dimension.
-  Value lowerParallel(vector::ContractionOp op, int64_t lhsIndex,
-                      int64_t rhsIndex, PatternRewriter &rewriter) const;
-  // Lower one reduction dimension.
-  Value lowerReduction(vector::ContractionOp op,
-                       PatternRewriter &rewriter) const;
-};
-
 } // namespace mlir
 
 #endif // DIALECT_VECTOR_VECTORTRANSFORMS_H_

diff  --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 77d2a46977172..1046c37588f69 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -14,8 +14,7 @@
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
-#include "mlir/Dialect/Vector/VectorOps.h"
-#include "mlir/Dialect/Vector/VectorRewritePatterns.h"
+#include "mlir/Dialect/Vector/VectorTransforms.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/Support/MathExtras.h"
 #include "mlir/Target/LLVMIR/TypeToLLVM.h"

diff  --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index 583ba4a13eb08..7d75f11d0e3d1 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -21,7 +21,7 @@
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
-#include "mlir/Dialect/Vector/VectorOps.h"
+#include "mlir/Dialect/Vector/VectorRewritePatterns.h"
 #include "mlir/Dialect/X86Vector/Transforms.h"
 #include "mlir/Dialect/X86Vector/X86VectorDialect.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp
index b7506eb91aa9a..97831866ffd08 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp
@@ -22,7 +22,6 @@
 #include "mlir/Dialect/Linalg/Utils/Utils.h"
 #include "mlir/Dialect/SCF/Transforms.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
-#include "mlir/Dialect/Vector/VectorOps.h"
 #include "mlir/Dialect/Vector/VectorTransforms.h"
 #include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/AffineMap.h"
@@ -32,6 +31,7 @@
 #include "mlir/Transforms/Utils.h"
 
 using namespace mlir;
+using namespace mlir::vector;
 using namespace linalg;
 
 namespace {
@@ -191,7 +191,7 @@ struct LinalgStrategyVectorizePass
     }
     vector::populateVectorTransferPermutationMapLoweringPatterns(
         vectorizationPatterns);
-    vector::populateVetorReductionToContractPatterns(vectorizationPatterns);
+    vector::populateVectorReductionToContractPatterns(vectorizationPatterns);
     vectorizationPatterns.add<linalg::LinalgCopyVTRForwardingPattern,
                               linalg::LinalgCopyVTWForwardingPattern>(
         funcOp.getContext(), /*benefit=*/2);
@@ -268,9 +268,14 @@ struct LinalgStrategyLowerVectorsPass
       vector::populateVectorTransferLoweringPatterns(patterns,
                                                      options.maxTransferRank);
     }
-    if (options.transferPartialRewrite) {
-      patterns.add<vector::VectorTransferFullPartialRewriter>(
-          context, options.vectorTransformOptions);
+    if (options.transposeLowering) {
+      vector::populateVectorTransposeLoweringPatterns(
+          patterns, options.vectorTransformOptions);
+    }
+    if (options.multiReductionLowering) {
+      vector::populateVectorMultiReductionLoweringPatterns(
+          patterns,
+          options.vectorTransformOptions.vectorMultiReductionLowering);
     }
     if (options.contractionLowering) {
       patterns.add<ContractionOpToOuterProductOpLowering,
@@ -278,15 +283,15 @@ struct LinalgStrategyLowerVectorsPass
           options.vectorTransformOptions, context);
       vector::populateVectorTransferPermutationMapLoweringPatterns(patterns);
     }
-    if (options.multiReductionLowering) {
-      vector::populateVectorMultiReductionLoweringPatterns(
-          patterns,
-          options.vectorTransformOptions.vectorMultiReductionLowering);
+    if (options.transferPartialRewrite) {
+      patterns.add<vector::VectorTransferFullPartialRewriter>(
+          context, options.vectorTransformOptions);
     }
     if (options.transferToSCFConversion) {
       populateVectorToSCFConversionPatterns(patterns,
                                             options.vectorTransferToSCFOptions);
     }
+    vector::populateVectorShapeCastLoweringPatterns(patterns);
     (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
   }
 

diff  --git a/mlir/lib/Dialect/Vector/VectorMultiDimReductionTransforms.cpp b/mlir/lib/Dialect/Vector/VectorMultiDimReductionTransforms.cpp
index 67d0db4d2cd45..637c8729f06f6 100644
--- a/mlir/lib/Dialect/Vector/VectorMultiDimReductionTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorMultiDimReductionTransforms.cpp
@@ -10,14 +10,9 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "mlir/Dialect/Vector/VectorOps.h"
-#include "mlir/Dialect/Vector/VectorTransforms.h"
+#include "mlir/Dialect/Vector/VectorRewritePatterns.h"
 #include "mlir/Dialect/Vector/VectorUtils.h"
-#include "mlir/IR/AffineExpr.h"
-#include "mlir/IR/AffineMap.h"
-#include "mlir/IR/Attributes.h"
 #include "mlir/IR/Builders.h"
-#include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/ImplicitLocOpBuilder.h"
 #include "mlir/IR/TypeUtilities.h"
 

diff  --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index d98fa705dbf62..4c7ef516fd927 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -21,21 +21,10 @@
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
 
-#include "mlir/Dialect/Vector/VectorOps.h"
 #include "mlir/Dialect/Vector/VectorTransforms.h"
-#include "mlir/Dialect/Vector/VectorUtils.h"
-#include "mlir/IR/AffineExpr.h"
-#include "mlir/IR/AffineMap.h"
-#include "mlir/IR/Attributes.h"
-#include "mlir/IR/Builders.h"
-#include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/ImplicitLocOpBuilder.h"
-#include "mlir/IR/Location.h"
 #include "mlir/IR/Matchers.h"
-#include "mlir/IR/OperationSupport.h"
 #include "mlir/IR/PatternMatch.h"
-#include "mlir/IR/TypeUtilities.h"
-#include "mlir/IR/Types.h"
 #include "mlir/Interfaces/VectorInterfaces.h"
 
 #include "llvm/ADT/DenseSet.h"
@@ -48,6 +37,7 @@
 #define DEBUG_TYPE "vector-to-vector"
 
 using namespace mlir;
+using namespace mlir::vector;
 
 // Helper to find an index in an affine map.
 static Optional<int64_t> getResultIndex(AffineMap map, int64_t index) {
@@ -1978,9 +1968,41 @@ static Value createInBoundsCond(OpBuilder &b,
   });
   return inBoundsCond;
 }
-
-LogicalResult mlir::vector::splitFullAndPartialTransferPrecondition(
-    VectorTransferOpInterface xferOp) {
+/// Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds
+/// masking) fastpath and a slowpath.
+/// If `ifOp` is not null and the result is `success, the `ifOp` points to the
+/// newly created conditional upon function return.
+/// To accomodate for the fact that the original vector.transfer indexing may be
+/// arbitrary and the slow path indexes @[0...0] in the temporary buffer, the
+/// scf.if op returns a view and values of type index.
+/// At this time, only vector.transfer_read case is implemented.
+///
+/// Example (a 2-D vector.transfer_read):
+/// ```
+///    %1 = vector.transfer_read %0[...], %pad : memref<A...>, vector<...>
+/// ```
+/// is transformed into:
+/// ```
+///    %1:3 = scf.if (%inBounds) {
+///      // fastpath, direct cast
+///      memref.cast %A: memref<A...> to compatibleMemRefType
+///      scf.yield %view : compatibleMemRefType, index, index
+///    } else {
+///      // slowpath, not in-bounds vector.transfer or linalg.copy.
+///      memref.cast %alloc: memref<B...> to compatibleMemRefType
+///      scf.yield %4 : compatibleMemRefType, index, index
+//     }
+///    %0 = vector.transfer_read %1#0[%1#1, %1#2] {in_bounds = [true ... true]}
+/// ```
+/// where `alloc` is a top of the function alloca'ed buffer of one vector.
+///
+/// Preconditions:
+///  1. `xferOp.permutation_map()` must be a minor identity map
+///  2. the rank of the `xferOp.memref()` and the rank of the `xferOp.vector()`
+///  must be equal. This will be relaxed in the future but requires
+///  rank-reducing subviews.
+static LogicalResult
+splitFullAndPartialTransferPrecondition(VectorTransferOpInterface xferOp) {
   // TODO: expand support to these 2 cases.
   if (!xferOp.permutation_map().isMinorIdentity())
     return failure();
@@ -3863,7 +3885,7 @@ void mlir::vector::populateVectorTransposeLoweringPatterns(
   patterns.add<TransposeOpLowering>(options, patterns.getContext());
 }
 
-void mlir::vector::populateVetorReductionToContractPatterns(
+void mlir::vector::populateVectorReductionToContractPatterns(
     RewritePatternSet &patterns) {
   patterns.add<MultiReduceToContract, CombineContractBroadcast,
                CombineContractTranspose>(patterns.getContext());

diff  --git a/mlir/test/lib/Dialect/Linalg/TestConvVectorization.cpp b/mlir/test/lib/Dialect/Linalg/TestConvVectorization.cpp
index f73cd6c0a1eb0..9c8f138743dec 100644
--- a/mlir/test/lib/Dialect/Linalg/TestConvVectorization.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestConvVectorization.cpp
@@ -98,9 +98,8 @@ void TestConvVectorization::runOnOperation() {
       VectorTransposeLowering::EltWise};
 
   RewritePatternSet vectorTransferPatterns(context);
-  // Pattern is not applied because rank-reducing vector transfer is not yet
-  // supported as can be seen in splitFullAndPartialTransferPrecondition,
-  // VectorTransforms.cpp
+  // Pattern is not applied: rank-reducing vector transfer is not yet supported
+  // (see: splitFullAndPartialTransferPrecondition in VectorTransforms.cpp).
   vectorTransferPatterns.add<VectorTransferFullPartialRewriter>(
       context, vectorTransformOptions);
   (void)applyPatternsAndFoldGreedily(module, std::move(vectorTransferPatterns));

diff  --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
index b1468083f52df..1e8620aee75ec 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
@@ -536,7 +536,7 @@ applyMatmulToVectorPatterns(FuncOp funcOp,
     RewritePatternSet canonicalizationPatterns(funcOp.getContext());
     vector::populateVectorTransferPermutationMapLoweringPatterns(
         canonicalizationPatterns);
-    vector::populateVetorReductionToContractPatterns(canonicalizationPatterns);
+    vector::populateVectorReductionToContractPatterns(canonicalizationPatterns);
     stage1Patterns.push_back(std::move(canonicalizationPatterns));
   }
   SmallVector<FrozenRewritePatternSet, 4> frozenStage1Patterns;

diff  --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index b95c45d21633a..e7d520bcdb173 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -14,13 +14,13 @@
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/SCF/SCF.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
-#include "mlir/Dialect/Vector/VectorOps.h"
 #include "mlir/Dialect/Vector/VectorTransforms.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
 using namespace mlir;
 using namespace mlir::vector;
+
 namespace {
 
 struct TestVectorToVectorConversion
@@ -511,7 +511,7 @@ struct TestVectorReduceToContractPatternsPatterns
   }
   void runOnFunction() override {
     RewritePatternSet patterns(&getContext());
-    populateVetorReductionToContractPatterns(patterns);
+    populateVectorReductionToContractPatterns(patterns);
     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
   }
 };


        


More information about the Mlir-commits mailing list