[Mlir-commits] [mlir] d054b80 - [mlir][Vector] NFC - Add option to hook vector.transpose lowering to strategies.
Nicolas Vasilache
llvmlistbot at llvm.org
Mon Oct 25 05:27:42 PDT 2021
Author: Nicolas Vasilache
Date: 2021-10-25T12:26:33Z
New Revision: d054b80bd3ab1a78d1a870f941024429273d2a83
URL: https://github.com/llvm/llvm-project/commit/d054b80bd3ab1a78d1a870f941024429273d2a83
DIFF: https://github.com/llvm/llvm-project/commit/d054b80bd3ab1a78d1a870f941024429273d2a83.diff
LOG: [mlir][Vector] NFC - Add option to hook vector.transpose lowering to strategies.
This revision also moves some code around to improve overall structure.
Differential Revision: https://reviews.llvm.org/D112437
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/include/mlir/Dialect/Vector/VectorOps.h
mlir/include/mlir/Dialect/Vector/VectorRewritePatterns.h
mlir/include/mlir/Dialect/Vector/VectorTransforms.h
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp
mlir/lib/Dialect/Vector/VectorMultiDimReductionTransforms.cpp
mlir/lib/Dialect/Vector/VectorTransforms.cpp
mlir/test/lib/Dialect/Linalg/TestConvVectorization.cpp
mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 640e1221aeb53..cfa38d71c2ba3 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -15,7 +15,7 @@
#include "mlir/Dialect/SCF/Utils.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
-#include "mlir/Dialect/Vector/VectorOps.h"
+#include "mlir/Dialect/Vector/VectorTransforms.h"
#include "mlir/IR/Identifier.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/Bufferize.h"
@@ -846,6 +846,9 @@ struct LinalgVectorizationPattern : public LinalgBaseVectorizationPattern {
: LinalgBaseVectorizationPattern(opName, context, filter, benefit) {}
};
+//===----------------------------------------------------------------------===//
+// Transformation and lowering options exposed as auxiliary structs.
+//===----------------------------------------------------------------------===//
/// Options to control the application of enabling transformations.
/// Hoisting transformations are always deemed beneficial and must be disabled
/// explicitly.
@@ -887,10 +890,16 @@ struct LinalgVectorLoweringOptions {
transferLowering = val;
return *this;
}
- /// Trigger full / partial vector.transfer splits.
- bool transferPartialRewrite = false;
- LinalgVectorLoweringOptions &enableTransferPartialRewrite(bool val = true) {
- transferPartialRewrite = val;
+ /// Enable lowering of vector.transpose.
+ bool transposeLowering = false;
+ LinalgVectorLoweringOptions &enableVectorTransposeLowering(bool val = true) {
+ transposeLowering = val;
+ return *this;
+ }
+ /// Enable lowering of vector.multi_reduce.
+ bool multiReductionLowering = false;
+ LinalgVectorLoweringOptions &enableMultiReductionLowering(bool val = true) {
+ multiReductionLowering = val;
return *this;
}
/// Enable lowering of vector.contract.
@@ -899,10 +908,10 @@ struct LinalgVectorLoweringOptions {
contractionLowering = val;
return *this;
}
- /// Enable lowering of vector.multi_reduce.
- bool multiReductionLowering = false;
- LinalgVectorLoweringOptions &enableMultiReductionLowering(bool val = true) {
- multiReductionLowering = val;
+ /// Trigger full / partial vector.transfer splits.
+ bool transferPartialRewrite = false;
+ LinalgVectorLoweringOptions &enableTransferPartialRewrite(bool val = true) {
+ transferPartialRewrite = val;
return *this;
}
/// Enable lowering of vector.transfer to scf.
@@ -911,13 +920,6 @@ struct LinalgVectorLoweringOptions {
transferToSCFConversion = val;
return *this;
}
- /// Configure late vector transformations.
- vector::VectorTransformsOptions vectorTransformOptions;
- LinalgVectorLoweringOptions &
- setVectorTransformsOptions(vector::VectorTransformsOptions options) {
- vectorTransformOptions = options;
- return *this;
- }
/// Configure the post staged-patterns late vector.transfer to scf
/// conversion.
VectorTransferToSCFOptions vectorTransferToSCFOptions;
@@ -926,8 +928,18 @@ struct LinalgVectorLoweringOptions {
vectorTransferToSCFOptions = options;
return *this;
}
+ /// Configure late vector transformations.
+ vector::VectorTransformsOptions vectorTransformOptions;
+ LinalgVectorLoweringOptions &
+ setVectorTransformsOptions(vector::VectorTransformsOptions options) {
+ vectorTransformOptions = options;
+ return *this;
+ }
};
+//===----------------------------------------------------------------------===//
+// Transformations exposed as rewrite patterns.
+//===----------------------------------------------------------------------===//
/// Trait to check if T provides a `getOperationName` method.
template <typename T, typename... Args>
using has_get_operation_name = decltype(T::getOperationName());
diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.h b/mlir/include/mlir/Dialect/Vector/VectorOps.h
index dd56cd1ea1926..c6f4ba4bc0e59 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.h
@@ -40,76 +40,6 @@ namespace detail {
struct BitmaskEnumStorage;
} // namespace detail
-/// Enum to control the lowering of `vector.contract` operations.
-enum class VectorContractLowering {
- /// Progressively lower to finer grained `vector.contract` and dot-products.
- Dot = 0,
- /// Lower to `vector.matrix_multiply`, maps 1-1 to LLVM matrix intrinsics.
- Matmul = 1,
- /// Lower to `vector.outerproduct`.
- OuterProduct = 2,
-};
-/// Enum to control the lowering of `vector.multi_reduction` operations.
-enum class VectorMultiReductionLowering {
- /// Lower multi_reduction into outer-reduction and inner-parallel ops.
- InnerParallel = 0,
- /// Lower multi_reduction into outer-parallel and inner-reduction ops.
- InnerReduction = 1,
-};
-/// Enum to control the lowering of `vector.transpose` operations.
-enum class VectorTransposeLowering {
- /// Lower transpose into element-wise extract and inserts.
- EltWise = 0,
- /// Lower 2-D transpose to `vector.flat_transpose`, maps 1-1 to LLVM matrix
- /// intrinsics.
- Flat = 1,
-};
-/// Enum to control the splitting of `vector.transfer` operations into
-/// in-bounds and out-of-bounds variants.
-enum class VectorTransferSplit {
- /// Do not split vector transfer operations.
- None = 0,
- /// Split using in-bounds + out-of-bounds vector.transfer operations.
- VectorTransfer = 1,
- /// Split using an in-bounds vector.transfer + linalg.fill + linalg.copy
- /// operations.
- LinalgCopy = 2,
- /// Do not split vector transfer operation but instead mark it as "in-bounds".
- ForceInBounds = 3
-};
-/// Structure to control the behavior of vector transform patterns.
-struct VectorTransformsOptions {
- /// Option to control the lowering of vector.contract.
- VectorContractLowering vectorContractLowering = VectorContractLowering::Dot;
- VectorTransformsOptions &
- setVectorTransformsOptions(VectorContractLowering opt) {
- vectorContractLowering = opt;
- return *this;
- }
- /// Option to control the lowering of vector.multi_reduction.
- VectorMultiReductionLowering vectorMultiReductionLowering =
- VectorMultiReductionLowering::InnerParallel;
- VectorTransformsOptions &
- setVectorMultiReductionLowering(VectorMultiReductionLowering opt) {
- vectorMultiReductionLowering = opt;
- return *this;
- }
- /// Option to control the lowering of vector.transpose.
- VectorTransposeLowering vectorTransposeLowering =
- VectorTransposeLowering::EltWise;
- VectorTransformsOptions &
- setVectorTransposeLowering(VectorTransposeLowering opt) {
- vectorTransposeLowering = opt;
- return *this;
- }
- /// Option to control the splitting of vector transfers.
- VectorTransferSplit vectorTransferSplit = VectorTransferSplit::None;
- VectorTransformsOptions &setVectorTransferSplit(VectorTransferSplit opt) {
- vectorTransferSplit = opt;
- return *this;
- }
-};
-
/// Return whether `srcType` can be broadcast to `dstVectorType` under the
/// semantics of the `vector.broadcast` op.
enum class BroadcastableToResult {
@@ -161,33 +91,6 @@ void populateVectorTransferPermutationMapLoweringPatterns(
void populateVectorMaskMaterializationPatterns(RewritePatternSet &patterns,
bool enableIndexOptimizations);
-/// Collect a set of patterns to convert vector.multi_reduction op into
-/// a sequence of vector.reduction ops. The patterns comprise:
-/// - InnerOuterDimReductionConversion: rewrites vector.multi_reduction such
-/// that all reduction dimensions are either innermost or outermost, by adding
-/// the proper vector.transpose operations.
-/// - ReduceMultiDimReductionRank: once in innermost or outermost reduction
-/// form, rewrites n-D vector.multi_reduction into 2-D vector.multi_reduction,
-/// by introducing vector.shape_cast ops to collapse + multi-reduce + expand
-/// back.
-/// - TwoDimMultiReductionToElementWise: once in 2-D vector.multi_reduction
-/// form, with an **outermost** reduction dimension, unroll the outer dimension
-/// to obtain a sequence of 1-D vector ops. This also has an opportunity for
-/// tree-reduction (in the future).
-/// - TwoDimMultiReductionToReduction: once in 2-D vector.multi_reduction form,
-/// with an **innermost** reduction dimension, unroll the outer dimension to
-/// obtain a sequence of extract + vector.reduction + insert. This can further
-/// lower to horizontal reduction ops.
-/// - OneDimMultiReductionToTwoDim: for cases that reduce to 1-D vector<k>
-/// reduction (and are thus missing either a parallel or a reduction), we lift
-/// them back up to 2-D with a simple vector.shape_cast to vector<1xk> so that
-/// the other patterns can kick in, thus fully exiting out of the
-/// vector.multi_reduction abstraction.
-void populateVectorMultiReductionLoweringPatterns(
- RewritePatternSet &patterns,
- VectorMultiReductionLowering options =
- vector::VectorMultiReductionLowering::InnerParallel);
-
/// Collect a set of patterns to propagate insert_map/extract_map in the ssa
/// chain.
void populatePropagateVectorDistributionPatterns(RewritePatternSet &patterns);
@@ -212,12 +115,6 @@ class CombiningKindAttr
/// vectors to low-D vector ops.
void populateVectorBroadcastLoweringPatterns(RewritePatternSet &patterns);
-/// Collects patterns to progressively lower vector contraction ops on high-D
-/// into low-D reduction and product ops.
-void populateVectorContractLoweringPatterns(
- RewritePatternSet &patterns,
- VectorTransformsOptions options = VectorTransformsOptions());
-
/// Collects patterns to progressively lower vector mask ops into elementary
/// selection and insertion ops.
void populateVectorMaskOpLoweringPatterns(RewritePatternSet &patterns);
@@ -227,15 +124,6 @@ void populateVectorMaskOpLoweringPatterns(RewritePatternSet &patterns);
/// ops.
void populateVectorShapeCastLoweringPatterns(RewritePatternSet &patterns);
-/// Insert TransposeLowering patterns into extraction/insertion.
-void populateVectorTransposeLoweringPatterns(
- RewritePatternSet &patterns,
- VectorTransformsOptions options = VectorTransformsOptions());
-
-/// Collect patterns to convert reduction op to vector.contract and fold
-/// transpose/broadcast ops into the contract.
-void populateVetorReductionToContractPatterns(RewritePatternSet &patterns);
-
/// Returns the integer type required for subscripts in the vector dialect.
IntegerType getVectorSubscriptType(Builder &builder);
diff --git a/mlir/include/mlir/Dialect/Vector/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/VectorRewritePatterns.h
index 13b310713f7b5..47375c56673f8 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/VectorRewritePatterns.h
@@ -9,11 +9,173 @@
#ifndef DIALECT_VECTOR_VECTORREWRITEPATTERNS_H_
#define DIALECT_VECTOR_VECTORREWRITEPATTERNS_H_
+#include "mlir/Dialect/Vector/VectorOps.h"
+#include "mlir/Dialect/Vector/VectorUtils.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/PatternMatch.h"
+
namespace mlir {
class RewritePatternSet;
namespace vector {
+//===----------------------------------------------------------------------===//
+// Vector transformation options exposed as auxiliary structs.
+//===----------------------------------------------------------------------===//
+/// Enum to control the lowering of `vector.transpose` operations.
+enum class VectorTransposeLowering {
+ /// Lower transpose into element-wise extract and inserts.
+ EltWise = 0,
+ /// Lower 2-D transpose to `vector.flat_transpose`, maps 1-1 to LLVM matrix
+ /// intrinsics.
+ Flat = 1,
+};
+/// Enum to control the lowering of `vector.multi_reduction` operations.
+enum class VectorMultiReductionLowering {
+ /// Lower multi_reduction into outer-reduction and inner-parallel ops.
+ InnerParallel = 0,
+ /// Lower multi_reduction into outer-parallel and inner-reduction ops.
+ InnerReduction = 1,
+};
+/// Enum to control the lowering of `vector.contract` operations.
+enum class VectorContractLowering {
+ /// Progressively lower to finer grained `vector.contract` and dot-products.
+ Dot = 0,
+ /// Lower to `vector.matrix_multiply`, maps 1-1 to LLVM matrix intrinsics.
+ Matmul = 1,
+ /// Lower to `vector.outerproduct`.
+ OuterProduct = 2,
+};
+/// Enum to control the splitting of `vector.transfer` operations into
+/// in-bounds and out-of-bounds variants.
+enum class VectorTransferSplit {
+ /// Do not split vector transfer operations.
+ None = 0,
+ /// Split using in-bounds + out-of-bounds vector.transfer operations.
+ VectorTransfer = 1,
+ /// Split using an in-bounds vector.transfer + linalg.fill + linalg.copy
+ /// operations.
+ LinalgCopy = 2,
+ /// Do not split vector transfer operation but instead mark it as "in-bounds".
+ ForceInBounds = 3
+};
+/// Structure to control the behavior of vector transform patterns.
+struct VectorTransformsOptions {
+ /// Option to control the lowering of vector.contract.
+ VectorContractLowering vectorContractLowering = VectorContractLowering::Dot;
+ VectorTransformsOptions &
+ setVectorTransformsOptions(VectorContractLowering opt) {
+ vectorContractLowering = opt;
+ return *this;
+ }
+ /// Option to control the lowering of vector.multi_reduction.
+ VectorMultiReductionLowering vectorMultiReductionLowering =
+ VectorMultiReductionLowering::InnerParallel;
+ VectorTransformsOptions &
+ setVectorMultiReductionLowering(VectorMultiReductionLowering opt) {
+ vectorMultiReductionLowering = opt;
+ return *this;
+ }
+ /// Option to control the lowering of vector.transpose.
+ VectorTransposeLowering vectorTransposeLowering =
+ VectorTransposeLowering::EltWise;
+ VectorTransformsOptions &
+ setVectorTransposeLowering(VectorTransposeLowering opt) {
+ vectorTransposeLowering = opt;
+ return *this;
+ }
+ /// Option to control the splitting of vector transfers.
+ VectorTransferSplit vectorTransferSplit = VectorTransferSplit::None;
+ VectorTransformsOptions &setVectorTransferSplit(VectorTransferSplit opt) {
+ vectorTransferSplit = opt;
+ return *this;
+ }
+};
+
+/// Options that control the vector unrolling.
+struct UnrollVectorOptions {
+ using FilterConstraintFnType = std::function<LogicalResult(Operation *op)>;
+ /// Callback function that indicates whether vector unrolling should be
+ /// attempted on the operation.
+ FilterConstraintFnType filterConstraint = nullptr;
+ UnrollVectorOptions &setFilterConstraint(FilterConstraintFnType constraint) {
+ filterConstraint = constraint;
+ return *this;
+ }
+
+ using NativeShapeFnType =
+ std::function<Optional<SmallVector<int64_t, 4>>(Operation *op)>;
+ /// Function that returns the shape of the vector to unroll to for a given
+ /// operation. The unrolling is aborted if the function returns `llvm::None`.
+ NativeShapeFnType nativeShape = nullptr;
+ UnrollVectorOptions &setNativeShapeFn(NativeShapeFnType fn) {
+ nativeShape = fn;
+ return *this;
+ }
+
+ /// Set the native shape to use for unrolling.
+ UnrollVectorOptions &setNativeShape(ArrayRef<int64_t> shape) {
+ SmallVector<int64_t, 4> tsShape(shape.begin(), shape.end());
+ nativeShape = [=](Operation *) -> Optional<SmallVector<int64_t, 4>> {
+ return tsShape;
+ };
+ return *this;
+ }
+};
+
+//===----------------------------------------------------------------------===//
+// Vector transformation exposed as populate functions over rewrite patterns.
+//===----------------------------------------------------------------------===//
+
+/// Insert TransposeLowering patterns into extraction/insertion.
+void populateVectorTransposeLoweringPatterns(
+ RewritePatternSet &patterns,
+ VectorTransformsOptions options = VectorTransformsOptions());
+
+/// Collect a set of patterns to convert vector.multi_reduction op into
+/// a sequence of vector.reduction ops. The patterns comprise:
+/// - InnerOuterDimReductionConversion: rewrites vector.multi_reduction such
+/// that all reduction dimensions are either innermost or outermost, by adding
+/// the proper vector.transpose operations.
+/// - ReduceMultiDimReductionRank: once in innermost or outermost reduction
+/// form, rewrites n-D vector.multi_reduction into 2-D vector.multi_reduction,
+/// by introducing vector.shape_cast ops to collapse + multi-reduce + expand
+/// back.
+/// - TwoDimMultiReductionToElementWise: once in 2-D vector.multi_reduction
+/// form, with an **outermost** reduction dimension, unroll the outer dimension
+/// to obtain a sequence of 1-D vector ops. This also has an opportunity for
+/// tree-reduction (in the future).
+/// - TwoDimMultiReductionToReduction: once in 2-D vector.multi_reduction form,
+/// with an **innermost** reduction dimension, unroll the outer dimension to
+/// obtain a sequence of extract + vector.reduction + insert. This can further
+/// lower to horizontal reduction ops.
+/// - OneDimMultiReductionToTwoDim: for cases that reduce to 1-D vector<k>
+/// reduction (and are thus missing either a parallel or a reduction), we lift
+/// them back up to 2-D with a simple vector.shape_cast to vector<1xk> so that
+/// the other patterns can kick in, thus fully exiting out of the
+/// vector.multi_reduction abstraction.
+void populateVectorMultiReductionLoweringPatterns(
+ RewritePatternSet &patterns,
+ VectorMultiReductionLowering options =
+ VectorMultiReductionLowering::InnerParallel);
+
+/// Collects patterns to progressively lower vector contraction ops on high-D
+/// into low-D reduction and product ops.
+void populateVectorContractLoweringPatterns(
+ RewritePatternSet &patterns,
+ VectorTransformsOptions options = VectorTransformsOptions());
+
+/// Collect patterns to convert reduction op to vector.contract and fold
+/// transpose/broadcast ops into the contract.
+void populateVectorReductionToContractPatterns(RewritePatternSet &patterns);
+
+/// Collect a set of patterns to reduce the rank of the operands of vector
+/// transfer ops to operate on the largest contigious vector.
+/// These patterns are useful when lowering to dialects with 1d vector type
+/// such as llvm and it will result fewer memory reads.
+void populateVectorTransferCollapseInnerMostContiguousDimsPatterns(
+ RewritePatternSet &patterns);
+
/// Populate `patterns` with the following patterns.
///
/// [VectorInsertStridedSliceOpDifferentRankRewritePattern]
@@ -52,6 +214,235 @@ namespace vector {
void populateVectorInsertExtractStridedSliceTransforms(
RewritePatternSet &patterns);
+/// Collect a set of pattern to unroll vector operations to a smaller shapes.
+/// `options` structure controls which operations are unrolled and the target
+/// shape.
+/// `op` is unrolled to the `targetShape` as follows, for each of its operands:
+/// 1. the unrolled type `unrolledVectorType` and number of unrolled instances
+/// `numUnrolledInstances` are computed from the `targetShape`. For now it is
+/// assumed the unrolling factors divide the vector sizes.
+/// 2. ExtractStridedSlice are created to break-up the vector operands.
+/// 3. the original op is cloned `numUnrolledInstances` times, once for each
+/// result.
+/// 4. InsertStridedSlice are inserted to re-assemble the slices into the
+/// original vectore shape.
+///
+/// Example:
+///
+/// opA(operand0, operand1) // numUnrolledInstances = 3
+///
+/// operand0 operand1
+/// | |
+/// fork fork
+/// <----------gather all fork ops --------->
+/// /|\ /|\
+/// f00 f01 f02 f10 f11 f12
+/// <---------- clone op 3 times --------->
+/// opA0(f00, f10), opA1(f01, f11), opA2(f02, f12)
+/// \ | /
+/// <-------------------- join ------------------------->
+///
+/// Other local patterns then kick in iteratively (including DCE) and compose
+/// to combine the ExtractStridedSlice/InsertStridedSlice.
+void populateVectorUnrollPatterns(RewritePatternSet &patterns,
+ const UnrollVectorOptions &options);
+
+//===----------------------------------------------------------------------===//
+// Finer-grained patterns exposed for more control over individual lowerings.
+//===----------------------------------------------------------------------===//
+/// Apply `splitFullAndPartialTransfer` selectively via a pattern. This pattern
+/// may take an extra filter to perform selection at a finer granularity.
+struct VectorTransferFullPartialRewriter : public RewritePattern {
+ using FilterConstraintType =
+ std::function<LogicalResult(VectorTransferOpInterface op)>;
+
+ explicit VectorTransferFullPartialRewriter(
+ MLIRContext *context,
+ VectorTransformsOptions options = VectorTransformsOptions(),
+ FilterConstraintType filter =
+ [](VectorTransferOpInterface op) { return success(); },
+ PatternBenefit benefit = 1)
+ : RewritePattern(MatchAnyOpTypeTag(), benefit, context), options(options),
+ filter(filter) {}
+
+ /// Performs the rewrite.
+ LogicalResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const override;
+
+private:
+ VectorTransformsOptions options;
+ FilterConstraintType filter;
+};
+
+/// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
+/// semantics to:
+/// ```
+/// %flattened_a = vector.shape_cast %a
+/// %flattened_b = vector.shape_cast %b
+/// %flattened_d = vector.matmul %flattened_a, %flattened_b
+/// %d = vector.shape_cast %%flattened_d
+/// %e = add %c, %d
+/// ```
+/// `vector.matmul` later lowers to `llvm.matrix.multiply`.
+//
+/// This only kicks in when VectorTransformsOptions is set to OuterProduct and
+/// the vector.contract op is a row-major matrix multiply.
+class ContractionOpToMatmulOpLowering
+ : public OpRewritePattern<vector::ContractionOp> {
+public:
+ using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
+ using FilterConstraintType =
+ std::function<LogicalResult(vector::ContractionOp op)>;
+
+ static LogicalResult defaultFilter(vector::ContractionOp op) {
+ return success();
+ }
+
+ ContractionOpToMatmulOpLowering(
+ vector::VectorTransformsOptions vectorTransformOptions,
+ MLIRContext *context, FilterConstraintType constraint = defaultFilter)
+ : OpRewritePattern<vector::ContractionOp>(context),
+ vectorTransformOptions(vectorTransformOptions), filter(constraint) {}
+
+ LogicalResult matchAndRewrite(vector::ContractionOp op,
+ PatternRewriter &rewriter) const override;
+
+private:
+ /// Options to control the vector patterns.
+ vector::VectorTransformsOptions vectorTransformOptions;
+ FilterConstraintType filter;
+};
+
+/// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
+/// semantics to a reduction_size-unrolled sequence:
+/// ```
+/// %at = vector.transpose %a, [1, 0]
+/// %bRow0 = vector.extract %b[0]
+/// %atRow0 = vector.extract %at[0]
+/// %c0 = vector.outerproduct %atRow0, %bRow0, %c
+/// ...
+/// %bRowK = vector.extract %b[K]
+/// %atRowK = vector.extract %at[K]
+/// %cK = vector.outerproduct %atRowK, %bRowK, %cK-1
+/// ```
+///
+/// This only kicks in when VectorTransformsOptions is set to OuterProduct and
+/// the vector.contract op is a row-major matrix multiply.
+class ContractionOpToOuterProductOpLowering
+ : public OpRewritePattern<vector::ContractionOp> {
+public:
+ using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
+ using FilterConstraintType =
+ std::function<LogicalResult(vector::ContractionOp op)>;
+
+ static LogicalResult defaultFilter(vector::ContractionOp op) {
+ return success();
+ }
+
+ ContractionOpToOuterProductOpLowering(
+ vector::VectorTransformsOptions vectorTransformOptions,
+ MLIRContext *context, FilterConstraintType constraint = defaultFilter)
+ : OpRewritePattern<vector::ContractionOp>(context),
+ vectorTransformOptions(vectorTransformOptions), filter(constraint) {}
+
+ LogicalResult matchAndRewrite(vector::ContractionOp op,
+ PatternRewriter &rewriter) const override;
+
+private:
+ /// Options to control the vector patterns.
+ vector::VectorTransformsOptions vectorTransformOptions;
+ FilterConstraintType filter;
+};
+
+/// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
+/// semantics to an output-size-unrolled sequence:
+/// ```
+/// %out = arith.constant ... : vector<MxNxelt_type>
+/// %bt = vector.transpose %b, [1, 0]
+/// %aRow0 = vector.extract %a[0]
+/// %btRow0 = vector.extract %bt[0]
+/// %c00 = vector.reduce %atRow0, %bRow0
+/// %out00 = vector.insert %c00, %out[0, 0]
+/// ...
+/// %aRowLast = vector.extract %at[M-1]
+/// %btRowLast = vector.extract %b[N-1]
+/// %cLastLast = vector.reduce %atRowLast, %bRowLast
+/// %outcLastLast = vector.insert %cLastLast, %out[M-1, N-1]
+/// ```
+///
+/// This only kicks in when VectorTransformsOptions is set to Dot and
+/// the vector.contract op is a row-major matmul or matvec.
+class ContractionOpToDotLowering
+ : public OpRewritePattern<vector::ContractionOp> {
+public:
+ using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
+ using FilterConstraintType =
+ std::function<LogicalResult(vector::ContractionOp op)>;
+
+ static LogicalResult defaultFilter(vector::ContractionOp op) {
+ return success();
+ }
+
+ ContractionOpToDotLowering(
+ vector::VectorTransformsOptions vectorTransformOptions,
+ MLIRContext *context, FilterConstraintType constraint = defaultFilter)
+ : OpRewritePattern<vector::ContractionOp>(context),
+ vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {}
+
+ LogicalResult matchAndRewrite(vector::ContractionOp op,
+ PatternRewriter &rewriter) const override;
+
+private:
+ /// Options to control the vector patterns.
+ vector::VectorTransformsOptions vectorTransformOptions;
+ FilterConstraintType filter;
+};
+
+/// Progressive lowering of ContractionOp.
+///
+/// One:
+/// %x = vector.contract with at least one free/batch dimension
+/// is replaced by:
+/// %a = vector.contract with one less free/batch dimension
+/// %b = vector.contract with one less free/batch dimension
+/// ..
+/// %x = combine %a %b ..
+/// until a pure contraction is reached (no free/batch dimensions),
+/// which is replaced by a dot-product.
+///
+/// This only kicks in when either VectorTransformsOptions is set
+/// to Dot or when other contraction patterns fail.
+class ContractionOpLowering : public OpRewritePattern<vector::ContractionOp> {
+public:
+ using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
+ using FilterConstraintType =
+ std::function<LogicalResult(vector::ContractionOp op)>;
+
+ static LogicalResult defaultFilter(vector::ContractionOp op) {
+ return success();
+ }
+
+ ContractionOpLowering(vector::VectorTransformsOptions vectorTransformOptions,
+ MLIRContext *context,
+ FilterConstraintType constraint = defaultFilter)
+ : OpRewritePattern<vector::ContractionOp>(context),
+ vectorTransformOptions(vectorTransformOptions), filter(constraint) {}
+
+ LogicalResult matchAndRewrite(vector::ContractionOp op,
+ PatternRewriter &rewriter) const override;
+
+private:
+ /// Options to control the vector patterns.
+ vector::VectorTransformsOptions vectorTransformOptions;
+ FilterConstraintType filter;
+ // Lower one parallel dimension.
+ Value lowerParallel(vector::ContractionOp op, int64_t lhsIndex,
+ int64_t rhsIndex, PatternRewriter &rewriter) const;
+ // Lower one reduction dimension.
+ Value lowerReduction(vector::ContractionOp op,
+ PatternRewriter &rewriter) const;
+};
+
} // namespace vector
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h
index d26636c132ac4..811d72192910e 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h
+++ b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h
@@ -9,10 +9,8 @@
#ifndef DIALECT_VECTOR_VECTORTRANSFORMS_H_
#define DIALECT_VECTOR_VECTORTRANSFORMS_H_
-#include "mlir/Dialect/Vector/VectorOps.h"
+#include "mlir/Dialect/Vector/VectorRewritePatterns.h"
#include "mlir/Dialect/Vector/VectorUtils.h"
-#include "mlir/IR/BuiltinOps.h"
-#include "mlir/IR/PatternMatch.h"
namespace mlir {
class MLIRContext;
@@ -26,77 +24,9 @@ class IfOp;
namespace vector {
-/// Options that control the vector unrolling.
-struct UnrollVectorOptions {
- using FilterConstraintFnType = std::function<LogicalResult(Operation *op)>;
- /// Callback function that indicates whether vector unrolling should be
- /// attempted on the operation.
- FilterConstraintFnType filterConstraint = nullptr;
- UnrollVectorOptions &setFilterConstraint(FilterConstraintFnType constraint) {
- filterConstraint = constraint;
- return *this;
- }
-
- using NativeShapeFnType =
- std::function<Optional<SmallVector<int64_t, 4>>(Operation *op)>;
- /// Function that returns the shape of the vector to unroll to for a given
- /// operation. The unrolling is aborted if the function returns `llvm::None`.
- NativeShapeFnType nativeShape = nullptr;
- UnrollVectorOptions &setNativeShapeFn(NativeShapeFnType fn) {
- nativeShape = fn;
- return *this;
- }
-
- /// Set the native shape to use for unrolling.
- UnrollVectorOptions &setNativeShape(ArrayRef<int64_t> shape) {
- SmallVector<int64_t, 4> tsShape(shape.begin(), shape.end());
- nativeShape = [=](Operation *) -> Optional<SmallVector<int64_t, 4>> {
- return tsShape;
- };
- return *this;
- }
-};
-
-/// Collect a set of pattern to unroll vector operations to a smaller shapes.
-/// `options` structure controls which operations are unrolled and the target
-/// shape.
-/// `op` is unrolled to the `targetShape` as follows, for each of its operands:
-/// 1. the unrolled type `unrolledVectorType` and number of unrolled instances
-/// `numUnrolledInstances` are computed from the `targetShape`. For now it is
-/// assumed the unrolling factors divide the vector sizes.
-/// 2. ExtractStridedSlice are created to break-up the vector operands.
-/// 3. the original op is cloned `numUnrolledInstances` times, once for each
-/// result.
-/// 4. InsertStridedSlice are inserted to re-assemble the slices into the
-/// original vectore shape.
-///
-/// Example:
-///
-/// opA(operand0, operand1) // numUnrolledInstances = 3
-///
-/// operand0 operand1
-/// | |
-/// fork fork
-/// <----------gather all fork ops --------->
-/// /|\ /|\
-/// f00 f01 f02 f10 f11 f12
-/// <---------- clone op 3 times --------->
-/// opA0(f00, f10), opA1(f01, f11), opA2(f02, f12)
-/// \ | /
-/// <-------------------- join ------------------------->
-///
-/// Other local patterns then kick in iteratively (including DCE) and compose
-/// to combine the ExtractStridedSlice/InsertStridedSlice.
-void populateVectorUnrollPatterns(RewritePatternSet &patterns,
- const UnrollVectorOptions &options);
-
-/// Collect a set of patterns to reduce the rank of the operands of vector
-/// transfer ops to operate on the largest contigious vector.
-/// These patterns are useful when lowering to dialects with 1d vector type
-/// such as llvm and it will result fewer memory reads.
-void populateVectorTransferCollapseInnerMostContiguousDimsPatterns(
- RewritePatternSet &patterns);
-
+//===----------------------------------------------------------------------===//
+// Standalone transformations and helpers.
+//===----------------------------------------------------------------------===//
/// Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds
/// masking) fastpath and a slowpath.
/// If `ifOp` is not null and the result is `success, the `ifOp` points to the
@@ -130,37 +60,11 @@ void populateVectorTransferCollapseInnerMostContiguousDimsPatterns(
/// 2. the rank of the `xferOp.memref()` and the rank of the `xferOp.vector()`
/// must be equal. This will be relaxed in the future but requires
/// rank-reducing subviews.
-LogicalResult
-splitFullAndPartialTransferPrecondition(VectorTransferOpInterface xferOp);
LogicalResult splitFullAndPartialTransfer(
OpBuilder &b, VectorTransferOpInterface xferOp,
VectorTransformsOptions options = VectorTransformsOptions(),
scf::IfOp *ifOp = nullptr);
-/// Apply `splitFullAndPartialTransfer` selectively via a pattern. This pattern
-/// may take an extra filter to perform selection at a finer granularity.
-struct VectorTransferFullPartialRewriter : public RewritePattern {
- using FilterConstraintType =
- std::function<LogicalResult(VectorTransferOpInterface op)>;
-
- explicit VectorTransferFullPartialRewriter(
- MLIRContext *context,
- VectorTransformsOptions options = VectorTransformsOptions(),
- FilterConstraintType filter =
- [](VectorTransferOpInterface op) { return success(); },
- PatternBenefit benefit = 1)
- : RewritePattern(MatchAnyOpTypeTag(), benefit, context), options(options),
- filter(filter) {}
-
- /// Performs the rewrite.
- LogicalResult matchAndRewrite(Operation *op,
- PatternRewriter &rewriter) const override;
-
-private:
- VectorTransformsOptions options;
- FilterConstraintType filter;
-};
-
struct DistributeOps {
ExtractMapOp extract;
InsertMapOp insert;
@@ -188,180 +92,6 @@ distributPointwiseVectorOp(OpBuilder &builder, Operation *op,
void transferOpflowOpt(FuncOp func);
} // namespace vector
-
-//===----------------------------------------------------------------------===//
-// Finer-grained patterns exposed for more control over individual lowerings.
-//===----------------------------------------------------------------------===//
-
-/// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
-/// semantics to:
-/// ```
-/// %flattened_a = vector.shape_cast %a
-/// %flattened_b = vector.shape_cast %b
-/// %flattened_d = vector.matmul %flattened_a, %flattened_b
-/// %d = vector.shape_cast %%flattened_d
-/// %e = add %c, %d
-/// ```
-/// `vector.matmul` later lowers to `llvm.matrix.multiply`.
-//
-/// This only kicks in when VectorTransformsOptions is set to OuterProduct and
-/// the vector.contract op is a row-major matrix multiply.
-class ContractionOpToMatmulOpLowering
- : public OpRewritePattern<vector::ContractionOp> {
-public:
- using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
- using FilterConstraintType =
- std::function<LogicalResult(vector::ContractionOp op)>;
-
- static LogicalResult defaultFilter(vector::ContractionOp op) {
- return success();
- }
-
- ContractionOpToMatmulOpLowering(
- vector::VectorTransformsOptions vectorTransformOptions,
- MLIRContext *context, FilterConstraintType constraint = defaultFilter)
- : OpRewritePattern<vector::ContractionOp>(context),
- vectorTransformOptions(vectorTransformOptions), filter(constraint) {}
-
- LogicalResult matchAndRewrite(vector::ContractionOp op,
- PatternRewriter &rewriter) const override;
-
-private:
- /// Options to control the vector patterns.
- vector::VectorTransformsOptions vectorTransformOptions;
- FilterConstraintType filter;
-};
-
-/// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
-/// semantics to a reduction_size-unrolled sequence:
-/// ```
-/// %at = vector.transpose %a, [1, 0]
-/// %bRow0 = vector.extract %b[0]
-/// %atRow0 = vector.extract %at[0]
-/// %c0 = vector.outerproduct %atRow0, %bRow0, %c
-/// ...
-/// %bRowK = vector.extract %b[K]
-/// %atRowK = vector.extract %at[K]
-/// %cK = vector.outerproduct %atRowK, %bRowK, %cK-1
-/// ```
-///
-/// This only kicks in when VectorTransformsOptions is set to OuterProduct and
-/// the vector.contract op is a row-major matrix multiply.
-class ContractionOpToOuterProductOpLowering
- : public OpRewritePattern<vector::ContractionOp> {
-public:
- using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
- using FilterConstraintType =
- std::function<LogicalResult(vector::ContractionOp op)>;
-
- static LogicalResult defaultFilter(vector::ContractionOp op) {
- return success();
- }
-
- ContractionOpToOuterProductOpLowering(
- vector::VectorTransformsOptions vectorTransformOptions,
- MLIRContext *context, FilterConstraintType constraint = defaultFilter)
- : OpRewritePattern<vector::ContractionOp>(context),
- vectorTransformOptions(vectorTransformOptions), filter(constraint) {}
-
- LogicalResult matchAndRewrite(vector::ContractionOp op,
- PatternRewriter &rewriter) const override;
-
-private:
- /// Options to control the vector patterns.
- vector::VectorTransformsOptions vectorTransformOptions;
- FilterConstraintType filter;
-};
-
-/// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
-/// semantics to an output-size-unrolled sequence:
-/// ```
-/// %out = arith.constant ... : vector<MxNxelt_type>
-/// %bt = vector.transpose %b, [1, 0]
-/// %aRow0 = vector.extract %a[0]
-/// %btRow0 = vector.extract %bt[0]
-/// %c00 = vector.reduce %atRow0, %bRow0
-/// %out00 = vector.insert %c00, %out[0, 0]
-/// ...
-/// %aRowLast = vector.extract %at[M-1]
-/// %btRowLast = vector.extract %b[N-1]
-/// %cLastLast = vector.reduce %atRowLast, %bRowLast
-/// %outcLastLast = vector.insert %cLastLast, %out[M-1, N-1]
-/// ```
-///
-/// This only kicks in when VectorTransformsOptions is set to Dot and
-/// the vector.contract op is a row-major matmul or matvec.
-class ContractionOpToDotLowering
- : public OpRewritePattern<vector::ContractionOp> {
-public:
- using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
- using FilterConstraintType =
- std::function<LogicalResult(vector::ContractionOp op)>;
-
- static LogicalResult defaultFilter(vector::ContractionOp op) {
- return success();
- }
-
- ContractionOpToDotLowering(
- vector::VectorTransformsOptions vectorTransformOptions,
- MLIRContext *context, FilterConstraintType constraint = defaultFilter)
- : OpRewritePattern<vector::ContractionOp>(context),
- vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {}
-
- LogicalResult matchAndRewrite(vector::ContractionOp op,
- PatternRewriter &rewriter) const override;
-
-private:
- /// Options to control the vector patterns.
- vector::VectorTransformsOptions vectorTransformOptions;
- FilterConstraintType filter;
-};
-
-/// Progressive lowering of ContractionOp.
-///
-/// One:
-/// %x = vector.contract with at least one free/batch dimension
-/// is replaced by:
-/// %a = vector.contract with one less free/batch dimension
-/// %b = vector.contract with one less free/batch dimension
-/// ..
-/// %x = combine %a %b ..
-/// until a pure contraction is reached (no free/batch dimensions),
-/// which is replaced by a dot-product.
-///
-/// This only kicks in when either VectorTransformsOptions is set
-/// to Dot or when other contraction patterns fail.
-class ContractionOpLowering : public OpRewritePattern<vector::ContractionOp> {
-public:
- using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
- using FilterConstraintType =
- std::function<LogicalResult(vector::ContractionOp op)>;
-
- static LogicalResult defaultFilter(vector::ContractionOp op) {
- return success();
- }
-
- ContractionOpLowering(vector::VectorTransformsOptions vectorTransformOptions,
- MLIRContext *context,
- FilterConstraintType constraint = defaultFilter)
- : OpRewritePattern<vector::ContractionOp>(context),
- vectorTransformOptions(vectorTransformOptions), filter(constraint) {}
-
- LogicalResult matchAndRewrite(vector::ContractionOp op,
- PatternRewriter &rewriter) const override;
-
-private:
- /// Options to control the vector patterns.
- vector::VectorTransformsOptions vectorTransformOptions;
- FilterConstraintType filter;
- // Lower one parallel dimension.
- Value lowerParallel(vector::ContractionOp op, int64_t lhsIndex,
- int64_t rhsIndex, PatternRewriter &rewriter) const;
- // Lower one reduction dimension.
- Value lowerReduction(vector::ContractionOp op,
- PatternRewriter &rewriter) const;
-};
-
} // namespace mlir
#endif // DIALECT_VECTOR_VECTORTRANSFORMS_H_
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 77d2a46977172..1046c37588f69 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -14,8 +14,7 @@
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
-#include "mlir/Dialect/Vector/VectorOps.h"
-#include "mlir/Dialect/Vector/VectorRewritePatterns.h"
+#include "mlir/Dialect/Vector/VectorTransforms.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Support/MathExtras.h"
#include "mlir/Target/LLVMIR/TypeToLLVM.h"
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index 583ba4a13eb08..7d75f11d0e3d1 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -21,7 +21,7 @@
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
-#include "mlir/Dialect/Vector/VectorOps.h"
+#include "mlir/Dialect/Vector/VectorRewritePatterns.h"
#include "mlir/Dialect/X86Vector/Transforms.h"
#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp
index b7506eb91aa9a..97831866ffd08 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp
@@ -22,7 +22,6 @@
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/SCF/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
-#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/Dialect/Vector/VectorTransforms.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
@@ -32,6 +31,7 @@
#include "mlir/Transforms/Utils.h"
using namespace mlir;
+using namespace mlir::vector;
using namespace linalg;
namespace {
@@ -191,7 +191,7 @@ struct LinalgStrategyVectorizePass
}
vector::populateVectorTransferPermutationMapLoweringPatterns(
vectorizationPatterns);
- vector::populateVetorReductionToContractPatterns(vectorizationPatterns);
+ vector::populateVectorReductionToContractPatterns(vectorizationPatterns);
vectorizationPatterns.add<linalg::LinalgCopyVTRForwardingPattern,
linalg::LinalgCopyVTWForwardingPattern>(
funcOp.getContext(), /*benefit=*/2);
@@ -268,9 +268,14 @@ struct LinalgStrategyLowerVectorsPass
vector::populateVectorTransferLoweringPatterns(patterns,
options.maxTransferRank);
}
- if (options.transferPartialRewrite) {
- patterns.add<vector::VectorTransferFullPartialRewriter>(
- context, options.vectorTransformOptions);
+ if (options.transposeLowering) {
+ vector::populateVectorTransposeLoweringPatterns(
+ patterns, options.vectorTransformOptions);
+ }
+ if (options.multiReductionLowering) {
+ vector::populateVectorMultiReductionLoweringPatterns(
+ patterns,
+ options.vectorTransformOptions.vectorMultiReductionLowering);
}
if (options.contractionLowering) {
patterns.add<ContractionOpToOuterProductOpLowering,
@@ -278,15 +283,15 @@ struct LinalgStrategyLowerVectorsPass
options.vectorTransformOptions, context);
vector::populateVectorTransferPermutationMapLoweringPatterns(patterns);
}
- if (options.multiReductionLowering) {
- vector::populateVectorMultiReductionLoweringPatterns(
- patterns,
- options.vectorTransformOptions.vectorMultiReductionLowering);
+ if (options.transferPartialRewrite) {
+ patterns.add<vector::VectorTransferFullPartialRewriter>(
+ context, options.vectorTransformOptions);
}
if (options.transferToSCFConversion) {
populateVectorToSCFConversionPatterns(patterns,
options.vectorTransferToSCFOptions);
}
+ vector::populateVectorShapeCastLoweringPatterns(patterns);
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
}
diff --git a/mlir/lib/Dialect/Vector/VectorMultiDimReductionTransforms.cpp b/mlir/lib/Dialect/Vector/VectorMultiDimReductionTransforms.cpp
index 67d0db4d2cd45..637c8729f06f6 100644
--- a/mlir/lib/Dialect/Vector/VectorMultiDimReductionTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorMultiDimReductionTransforms.cpp
@@ -10,14 +10,9 @@
//
//===----------------------------------------------------------------------===//
-#include "mlir/Dialect/Vector/VectorOps.h"
-#include "mlir/Dialect/Vector/VectorTransforms.h"
+#include "mlir/Dialect/Vector/VectorRewritePatterns.h"
#include "mlir/Dialect/Vector/VectorUtils.h"
-#include "mlir/IR/AffineExpr.h"
-#include "mlir/IR/AffineMap.h"
-#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
-#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/TypeUtilities.h"
diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index d98fa705dbf62..4c7ef516fd927 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -21,21 +21,10 @@
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
-#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/Dialect/Vector/VectorTransforms.h"
-#include "mlir/Dialect/Vector/VectorUtils.h"
-#include "mlir/IR/AffineExpr.h"
-#include "mlir/IR/AffineMap.h"
-#include "mlir/IR/Attributes.h"
-#include "mlir/IR/Builders.h"
-#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
-#include "mlir/IR/Location.h"
#include "mlir/IR/Matchers.h"
-#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/PatternMatch.h"
-#include "mlir/IR/TypeUtilities.h"
-#include "mlir/IR/Types.h"
#include "mlir/Interfaces/VectorInterfaces.h"
#include "llvm/ADT/DenseSet.h"
@@ -48,6 +37,7 @@
#define DEBUG_TYPE "vector-to-vector"
using namespace mlir;
+using namespace mlir::vector;
// Helper to find an index in an affine map.
static Optional<int64_t> getResultIndex(AffineMap map, int64_t index) {
@@ -1978,9 +1968,41 @@ static Value createInBoundsCond(OpBuilder &b,
});
return inBoundsCond;
}
-
-LogicalResult mlir::vector::splitFullAndPartialTransferPrecondition(
- VectorTransferOpInterface xferOp) {
+/// Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds
+/// masking) fastpath and a slowpath.
+/// If `ifOp` is not null and the result is `success, the `ifOp` points to the
+/// newly created conditional upon function return.
+/// To accomodate for the fact that the original vector.transfer indexing may be
+/// arbitrary and the slow path indexes @[0...0] in the temporary buffer, the
+/// scf.if op returns a view and values of type index.
+/// At this time, only vector.transfer_read case is implemented.
+///
+/// Example (a 2-D vector.transfer_read):
+/// ```
+/// %1 = vector.transfer_read %0[...], %pad : memref<A...>, vector<...>
+/// ```
+/// is transformed into:
+/// ```
+/// %1:3 = scf.if (%inBounds) {
+/// // fastpath, direct cast
+/// memref.cast %A: memref<A...> to compatibleMemRefType
+/// scf.yield %view : compatibleMemRefType, index, index
+/// } else {
+/// // slowpath, not in-bounds vector.transfer or linalg.copy.
+/// memref.cast %alloc: memref<B...> to compatibleMemRefType
+/// scf.yield %4 : compatibleMemRefType, index, index
+// }
+/// %0 = vector.transfer_read %1#0[%1#1, %1#2] {in_bounds = [true ... true]}
+/// ```
+/// where `alloc` is a top of the function alloca'ed buffer of one vector.
+///
+/// Preconditions:
+/// 1. `xferOp.permutation_map()` must be a minor identity map
+/// 2. the rank of the `xferOp.memref()` and the rank of the `xferOp.vector()`
+/// must be equal. This will be relaxed in the future but requires
+/// rank-reducing subviews.
+static LogicalResult
+splitFullAndPartialTransferPrecondition(VectorTransferOpInterface xferOp) {
// TODO: expand support to these 2 cases.
if (!xferOp.permutation_map().isMinorIdentity())
return failure();
@@ -3863,7 +3885,7 @@ void mlir::vector::populateVectorTransposeLoweringPatterns(
patterns.add<TransposeOpLowering>(options, patterns.getContext());
}
-void mlir::vector::populateVetorReductionToContractPatterns(
+void mlir::vector::populateVectorReductionToContractPatterns(
RewritePatternSet &patterns) {
patterns.add<MultiReduceToContract, CombineContractBroadcast,
CombineContractTranspose>(patterns.getContext());
diff --git a/mlir/test/lib/Dialect/Linalg/TestConvVectorization.cpp b/mlir/test/lib/Dialect/Linalg/TestConvVectorization.cpp
index f73cd6c0a1eb0..9c8f138743dec 100644
--- a/mlir/test/lib/Dialect/Linalg/TestConvVectorization.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestConvVectorization.cpp
@@ -98,9 +98,8 @@ void TestConvVectorization::runOnOperation() {
VectorTransposeLowering::EltWise};
RewritePatternSet vectorTransferPatterns(context);
- // Pattern is not applied because rank-reducing vector transfer is not yet
- // supported as can be seen in splitFullAndPartialTransferPrecondition,
- // VectorTransforms.cpp
+ // Pattern is not applied: rank-reducing vector transfer is not yet supported
+ // (see: splitFullAndPartialTransferPrecondition in VectorTransforms.cpp).
vectorTransferPatterns.add<VectorTransferFullPartialRewriter>(
context, vectorTransformOptions);
(void)applyPatternsAndFoldGreedily(module, std::move(vectorTransferPatterns));
diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
index b1468083f52df..1e8620aee75ec 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
@@ -536,7 +536,7 @@ applyMatmulToVectorPatterns(FuncOp funcOp,
RewritePatternSet canonicalizationPatterns(funcOp.getContext());
vector::populateVectorTransferPermutationMapLoweringPatterns(
canonicalizationPatterns);
- vector::populateVetorReductionToContractPatterns(canonicalizationPatterns);
+ vector::populateVectorReductionToContractPatterns(canonicalizationPatterns);
stage1Patterns.push_back(std::move(canonicalizationPatterns));
}
SmallVector<FrozenRewritePatternSet, 4> frozenStage1Patterns;
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index b95c45d21633a..e7d520bcdb173 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -14,13 +14,13 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
-#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/Dialect/Vector/VectorTransforms.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
using namespace mlir;
using namespace mlir::vector;
+
namespace {
struct TestVectorToVectorConversion
@@ -511,7 +511,7 @@ struct TestVectorReduceToContractPatternsPatterns
}
void runOnFunction() override {
RewritePatternSet patterns(&getContext());
- populateVetorReductionToContractPatterns(patterns);
+ populateVectorReductionToContractPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
}
};
More information about the Mlir-commits
mailing list