[Mlir-commits] [mlir] 2bc4c3e - [mlir][Vector] NFC - Reorganize vector patterns
Nicolas Vasilache
llvmlistbot at llvm.org
Thu Mar 23 11:30:33 PDT 2023
Author: Nicolas Vasilache
Date: 2023-03-23T11:30:25-07:00
New Revision: 2bc4c3e920ee078ef2879b00c40440e0867f0b9e
URL: https://github.com/llvm/llvm-project/commit/2bc4c3e920ee078ef2879b00c40440e0867f0b9e
DIFF: https://github.com/llvm/llvm-project/commit/2bc4c3e920ee078ef2879b00c40440e0867f0b9e.diff
LOG: [mlir][Vector] NFC - Reorganize vector patterns
Vector dialect patterns have grown enormously in the past year to a point where they are now impenetrable.
Start reorganizing them towards finer-grained control.
Differential Revision: https://reviews.llvm.org/D146736
Added:
mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp
mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
Modified:
mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
mlir/include/mlir/Dialect/Vector/Transforms/Passes.h
mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
mlir/include/mlir/Dialect/Vector/Transforms/VectorTransforms.h
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
mlir/lib/Dialect/Linalg/TransformOps/CMakeLists.txt
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp
mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
Removed:
mlir/lib/Dialect/Vector/Transforms/VectorMultiDimReductionTransforms.cpp
mlir/lib/Dialect/Vector/Transforms/VectorTransferPermutationMapRewritePatterns.cpp
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
index 56f8b4bf22d21..4763b6525b934 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
@@ -110,43 +110,11 @@ void populateFlattenVectorTransferPatterns(RewritePatternSet &patterns,
void populateBubbleVectorBitCastOpPatterns(RewritePatternSet &patterns,
PatternBenefit benefit = 1);
-/// Collect a set of transfer read/write lowering patterns.
-///
-/// These patterns lower transfer ops to simpler ops like `vector.load`,
-/// `vector.store` and `vector.broadcast`. Only transfers with a transfer rank
-/// of a most `maxTransferRank` are lowered. This is useful when combined with
-/// VectorToSCF, which reduces the rank of vector transfer ops.
-void populateVectorTransferLoweringPatterns(
- RewritePatternSet &patterns,
- std::optional<unsigned> maxTransferRank = std::nullopt,
- PatternBenefit benefit = 1);
-
/// These patterns materialize masks for various vector ops such as transfers.
void populateVectorMaskMaterializationPatterns(RewritePatternSet &patterns,
bool force32BitVectorIndices,
PatternBenefit benefit = 1);
-/// Collects patterns to progressively lower vector.broadcast ops on high-D
-/// vectors to low-D vector ops.
-void populateVectorBroadcastLoweringPatterns(RewritePatternSet &patterns,
- PatternBenefit benefit = 1);
-
-/// Collects patterns to progressively lower vector mask ops into elementary
-/// selection and insertion ops.
-void populateVectorMaskOpLoweringPatterns(RewritePatternSet &patterns,
- PatternBenefit benefit = 1);
-
-/// Collects patterns to progressively lower vector.shape_cast ops on high-D
-/// vectors into 1-D/2-D vector ops by generating data movement extract/insert
-/// ops.
-void populateVectorShapeCastLoweringPatterns(RewritePatternSet &patterns,
- PatternBenefit benefit = 1);
-
-/// Collects patterns that lower scalar vector transfer ops to memref loads and
-/// stores when beneficial.
-void populateScalarVectorTransferLoweringPatterns(RewritePatternSet &patterns,
- PatternBenefit benefit = 1);
-
/// Returns the integer type required for subscripts in the vector dialect.
IntegerType getVectorSubscriptType(Builder &builder);
@@ -214,8 +182,8 @@ void createMaskOpRegion(OpBuilder &builder, Operation *maskableOp);
/// Creates a vector.mask operation around a maskable operation. Returns the
/// vector.mask operation if the mask provided is valid. Otherwise, returns the
/// maskable operation itself.
-Operation *maskOperation(OpBuilder &builder, Operation *maskableOp,
- Value mask, Value passthru = Value());
+Operation *maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask,
+ Value passthru = Value());
/// Creates a vector select operation that picks values from `newValue` or
/// `passthru` for each result vector lane based on `mask`. This utility is used
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
new file mode 100644
index 0000000000000..dfadffba3883b
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -0,0 +1,248 @@
+//===- LoweringPatterns.h - Vector rewrite patterns --------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_VECTOR_TRANSFORMS_LOWERINGPATTERNS_H
+#define MLIR_DIALECT_VECTOR_TRANSFORMS_LOWERINGPATTERNS_H
+
+#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
+
+namespace mlir {
+class RewritePatternSet;
+
+namespace vector {
+
+//===----------------------------------------------------------------------===//
+// Lowering pattern populate functions
+//===----------------------------------------------------------------------===//
+
+/// Populate the pattern set with the following patterns:
+///
+/// [OuterProductOpLowering]
+/// Progressively lower a `vector.outerproduct` to linearized
+/// `vector.extract` + `vector.fma` + `vector.insert`.
+///
+/// [ContractionOpLowering]
+/// Progressive lowering of ContractionOp.
+/// One:
+/// %x = vector.contract with at least one free/batch dimension
+/// is replaced by:
+/// %a = vector.contract with one less free/batch dimension
+/// %b = vector.contract with one less free/batch dimension
+///
+/// [ContractionOpToMatmulOpLowering]
+/// Progressively lower a `vector.contract` with row-major matmul semantics to
+/// linearized `vector.shape_cast` + `vector.matmul` on the way to
+/// `llvm.matrix.multiply`.
+///
+/// [ContractionOpToDotLowering]
+/// Progressively lower a `vector.contract` with row-major matmul semantics to
+/// linearized `vector.extract` + `vector.reduce` + `vector.insert`.
+///
+/// [ContractionOpToOuterProductOpLowering]
+/// Progressively lower a `vector.contract` with row-major matmul semantics to
+/// linearized `vector.extract` + `vector.outerproduct` + `vector.insert`.
+void populateVectorContractLoweringPatterns(
+ RewritePatternSet &patterns, VectorTransformsOptions options,
+ PatternBenefit benefit = 1, bool disableOuterProductLowering = false);
+
+/// Collect a set of patterns to convert vector.multi_reduction op into
+/// a sequence of vector.reduction ops. The patterns comprise:
+///
+/// [InnerOuterDimReductionConversion]
+/// Rewrites vector.multi_reduction such that all reduction dimensions are
+/// either innermost or outermost, by adding the proper vector.transpose
+/// operations.
+///
+/// [ReduceMultiDimReductionRank]
+/// Once in innermost or outermost reduction
+/// form, rewrites n-D vector.multi_reduction into 2-D vector.multi_reduction,
+/// by introducing vector.shape_cast ops to collapse + multi-reduce + expand
+/// back.
+///
+/// [TwoDimMultiReductionToElementWise]
+/// Once in 2-D vector.multi_reduction form, with an **outermost** reduction
+/// dimension, unroll the outer dimension to obtain a sequence of 1-D vector
+/// ops. This also has an opportunity for tree-reduction (in the future).
+///
+/// [TwoDimMultiReductionToReduction]
+/// Once in 2-D vector.multi_reduction form, with an **innermost** reduction
+/// dimension, unroll the outer dimension to obtain a sequence of extract +
+/// vector.reduction + insert. This can further lower to horizontal reduction
+/// ops.
+///
+/// [OneDimMultiReductionToTwoDim]
+/// For cases that reduce to 1-D vector<k> reduction (and are thus missing
+/// either a parallel or a reduction), we lift them back up to 2-D with a simple
+/// vector.shape_cast to vector<1xk> so that the other patterns can kick in,
+/// thus fully exiting out of the vector.multi_reduction abstraction.
+void populateVectorMultiReductionLoweringPatterns(
+ RewritePatternSet &patterns, VectorMultiReductionLowering options,
+ PatternBenefit benefit = 1);
+
+/// Populate the pattern set with the following patterns:
+///
+/// [TransferReadToVectorLoadLowering]
+/// Progressive lowering of BroadcastOp to ExtractOp + InsertOp + lower-D
+/// BroadcastOp until dim 1.
+void populateVectorBroadcastLoweringPatterns(RewritePatternSet &patterns,
+ PatternBenefit benefit = 1);
+
+/// Populate the pattern set with the following patterns:
+///
+/// [CreateMaskOp]
+/// Progressive lowering of CreateMaskOp to lower-D CreateMaskOp until dim 1.
+///
+/// [ConstantMaskOp]
+/// Progressive lowering of ConstantMaskOp to lower-D ConstantMaskOp until
+/// dim 1.
+void populateVectorMaskOpLoweringPatterns(RewritePatternSet &patterns,
+ PatternBenefit benefit = 1);
+
+/// Collects patterns that lower scalar vector transfer ops to memref loads and
+/// stores when beneficial.
+void populateScalarVectorTransferLoweringPatterns(RewritePatternSet &patterns,
+ PatternBenefit benefit = 1);
+
+/// Populate the pattern set with the following patterns:
+///
+/// [ShapeCastOp2DDownCastRewritePattern]
+/// ShapeOp 2D -> 1D downcast serves the purpose of flattening 2-D to 1-D
+/// vectors progressively.
+///
+/// [ShapeCastOp2DUpCastRewritePattern]
+/// ShapeOp 1D -> 2D upcast serves the purpose of unflattening 2-D from 1-D
+/// vectors progressively.
+///
+/// [ShapeCastOpRewritePattern]
+/// Reference lowering to fully unrolled sequences of single element ExtractOp +
+/// InsertOp. Note that applying this pattern can almost always be considered a
+/// performance bug.
+void populateVectorShapeCastLoweringPatterns(RewritePatternSet &patterns,
+ PatternBenefit benefit = 1);
+
+/// Populate the pattern set with the following patterns:
+///
+/// [TransposeOpLowering]
+///
+/// [TransposeOp2DToShuffleLowering]
+///
+void populateVectorTransposeLoweringPatterns(RewritePatternSet &patterns,
+ VectorTransformsOptions options,
+ PatternBenefit benefit = 1);
+
+/// Populate the pattern set with the following patterns:
+///
+/// [TransferReadToVectorLoadLowering]
+/// Progressive lowering of transfer_read.This pattern supports lowering of
+/// `vector.transfer_read` to a combination of `vector.load` and
+/// `vector.broadcast`
+///
+/// [TransferWriteToVectorStoreLowering]
+/// Progressive lowering of transfer_write. This pattern supports lowering of
+/// `vector.transfer_write` to `vector.store`
+///
+/// [VectorLoadToMemrefLoadLowering]
+/// Replace a 0-d vector.load with a memref.load + vector.broadcast.
+///
+/// [VectorStoreToMemrefStoreLowering]
+/// Replace a 0-d vector.store with a vector.extractelement + memref.store.
+///
+/// These patterns lower transfer ops to simpler ops like `vector.load`,
+/// `vector.store` and `vector.broadcast`. Only transfers with a transfer rank
+/// of a most `maxTransferRank` are lowered. This is useful when combined with
+/// VectorToSCF, which reduces the rank of vector transfer ops.
+void populateVectorTransferLoweringPatterns(
+ RewritePatternSet &patterns,
+ std::optional<unsigned> maxTransferRank = std::nullopt,
+ PatternBenefit benefit = 1);
+
+/// Collect a set of transfer read/write lowering patterns that simplify the
+/// permutation map (e.g., converting it to a minor identity map) by inserting
+/// broadcasts and transposes. More specifically:
+///
+/// [TransferReadPermutationLowering]
+/// Lower transfer_read op with permutation into a transfer_read with a
+/// permutation map composed of leading zeros followed by a minor identity +
+/// vector.transpose op.
+/// Ex:
+/// vector.transfer_read ...
+/// permutation_map: (d0, d1, d2) -> (0, d1)
+/// into:
+/// %v = vector.transfer_read ...
+/// permutation_map: (d0, d1, d2) -> (d1, 0)
+/// vector.transpose %v, [1, 0]
+///
+/// vector.transfer_read ...
+/// permutation_map: (d0, d1, d2, d3) -> (0, 0, 0, d1, d3)
+/// into:
+/// %v = vector.transfer_read ...
+/// permutation_map: (d0, d1, d2, d3) -> (0, 0, d1, 0, d3)
+/// vector.transpose %v, [0, 1, 3, 2, 4]
+/// Note that an alternative is to transform it to linalg.transpose +
+/// vector.transfer_read to do the transpose in memory instead.
+///
+/// [TransferWritePermutationLowering]
+/// Lower transfer_write op with permutation into a transfer_write with a
+/// minor identity permutation map. (transfer_write ops cannot have broadcasts.)
+/// Ex:
+/// vector.transfer_write %v ...
+/// permutation_map: (d0, d1, d2) -> (d2, d0, d1)
+/// into:
+/// %tmp = vector.transpose %v, [2, 0, 1]
+/// vector.transfer_write %tmp ...
+/// permutation_map: (d0, d1, d2) -> (d0, d1, d2)
+///
+/// vector.transfer_write %v ...
+/// permutation_map: (d0, d1, d2, d3) -> (d3, d2)
+/// into:
+/// %tmp = vector.transpose %v, [1, 0]
+/// %v = vector.transfer_write %tmp ...
+/// permutation_map: (d0, d1, d2, d3) -> (d2, d3)
+///
+/// [TransferOpReduceRank]
+/// Lower transfer_read op with broadcast in the leading dimensions into
+/// transfer_read of lower rank + vector.broadcast.
+/// Ex: vector.transfer_read ...
+/// permutation_map: (d0, d1, d2, d3) -> (0, d1, 0, d3)
+/// into:
+/// %v = vector.transfer_read ...
+/// permutation_map: (d0, d1, d2, d3) -> (d1, 0, d3)
+/// vector.broadcast %v
+void populateVectorTransferPermutationMapLoweringPatterns(
+ RewritePatternSet &patterns, PatternBenefit benefit = 1);
+
+/// Populate the pattern set with the following patterns:
+///
+/// [ScanToArithOps]
+/// Convert vector.scan op into arith ops and vector.insert_strided_slice /
+/// vector.extract_strided_slice.
+void populateVectorScanLoweringPatterns(RewritePatternSet &patterns,
+ PatternBenefit benefit = 1);
+
+/// Populate the pattern set with the following patterns:
+///
+/// [FlattenGather]
+/// Flattens 2 or more dimensional `vector.gather` ops by unrolling the
+/// outermost dimension. For example:
+///
+/// [Gather1DToConditionalLoads]
+/// Turns 1-d `vector.gather` into a scalarized sequence of `vector.loads` or
+/// `tensor.extract`s. To avoid out-of-bounds memory accesses, these
+/// loads/extracts are made conditional using `scf.if` ops.
+void populateVectorGatherLoweringPatterns(RewritePatternSet &patterns,
+ PatternBenefit benefit = 1);
+
+/// Populates instances of `MaskOpRewritePattern` to lower masked operations
+/// with `vector.mask`. Patterns should rewrite the `vector.mask` operation and
+/// not its nested `MaskableOpInterface`.
+void populateVectorMaskLoweringPatternsForSideEffectingOps(
+ RewritePatternSet &patterns);
+
+} // namespace vector
+} // namespace mlir
+#endif // MLIR_DIALECT_VECTOR_TRANSFORMS_LOWERINGPATTERNS_H
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h
index d0c06f69930d2..bf89b01e2b60c 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h
@@ -22,12 +22,6 @@ std::unique_ptr<Pass> createVectorBufferizePass();
/// Creates an instance of the `vector.mask` lowering pass.
std::unique_ptr<Pass> createLowerVectorMaskPass();
-/// Populates instances of `MaskOpRewritePattern` to lower masked operations
-/// with `vector.mask`. Patterns should rewrite the `vector.mask` operation and
-/// not its nested `MaskableOpInterface`.
-void populateVectorMaskLoweringPatternsForSideEffectingOps(
- RewritePatternSet &patterns);
-
//===----------------------------------------------------------------------===//
// Registration
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index af68de7e0051e..a79bbd0be0975 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -9,8 +9,8 @@
#ifndef MLIR_DIALECT_VECTOR_TRANSFORMS_VECTORREWRITEPATTERNS_H
#define MLIR_DIALECT_VECTOR_TRANSFORMS_VECTORREWRITEPATTERNS_H
-#include <utility>
#include <optional>
+#include <utility>
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/VectorTransformsEnums.h.inc"
@@ -23,42 +23,7 @@ namespace mlir {
class RewritePatternSet;
namespace vector {
-
-//===----------------------------------------------------------------------===//
-// Vector transformation options exposed as auxiliary structs.
-//===----------------------------------------------------------------------===//
-/// Structure to control the behavior of vector transform patterns.
-struct VectorTransformsOptions {
- /// Option to control the lowering of vector.contract.
- VectorContractLowering vectorContractLowering = VectorContractLowering::Dot;
- VectorTransformsOptions &
- setVectorTransformsOptions(VectorContractLowering opt) {
- vectorContractLowering = opt;
- return *this;
- }
- /// Option to control the lowering of vector.multi_reduction.
- VectorMultiReductionLowering vectorMultiReductionLowering =
- VectorMultiReductionLowering::InnerParallel;
- VectorTransformsOptions &
- setVectorMultiReductionLowering(VectorMultiReductionLowering opt) {
- vectorMultiReductionLowering = opt;
- return *this;
- }
- /// Option to control the lowering of vector.transpose.
- VectorTransposeLowering vectorTransposeLowering =
- VectorTransposeLowering::EltWise;
- VectorTransformsOptions &
- setVectorTransposeLowering(VectorTransposeLowering opt) {
- vectorTransposeLowering = opt;
- return *this;
- }
- /// Option to control the splitting of vector transfers.
- VectorTransferSplit vectorTransferSplit = VectorTransferSplit::None;
- VectorTransformsOptions &setVectorTransferSplit(VectorTransferSplit opt) {
- vectorTransferSplit = opt;
- return *this;
- }
-};
+struct VectorTransformsOptions;
/// Options that control the vector unrolling.
struct UnrollVectorOptions {
@@ -109,45 +74,6 @@ struct UnrollVectorOptions {
// Vector transformation exposed as populate functions over rewrite patterns.
//===----------------------------------------------------------------------===//
-/// Insert TransposeLowering patterns into extraction/insertion.
-void populateVectorTransposeLoweringPatterns(
- RewritePatternSet &patterns,
- VectorTransformsOptions options = VectorTransformsOptions(),
- PatternBenefit benefit = 1);
-
-/// Collect a set of patterns to convert vector.multi_reduction op into
-/// a sequence of vector.reduction ops. The patterns comprise:
-/// - InnerOuterDimReductionConversion: rewrites vector.multi_reduction such
-/// that all reduction dimensions are either innermost or outermost, by adding
-/// the proper vector.transpose operations.
-/// - ReduceMultiDimReductionRank: once in innermost or outermost reduction
-/// form, rewrites n-D vector.multi_reduction into 2-D vector.multi_reduction,
-/// by introducing vector.shape_cast ops to collapse + multi-reduce + expand
-/// back.
-/// - TwoDimMultiReductionToElementWise: once in 2-D vector.multi_reduction
-/// form, with an **outermost** reduction dimension, unroll the outer dimension
-/// to obtain a sequence of 1-D vector ops. This also has an opportunity for
-/// tree-reduction (in the future).
-/// - TwoDimMultiReductionToReduction: once in 2-D vector.multi_reduction form,
-/// with an **innermost** reduction dimension, unroll the outer dimension to
-/// obtain a sequence of extract + vector.reduction + insert. This can further
-/// lower to horizontal reduction ops.
-/// - OneDimMultiReductionToTwoDim: for cases that reduce to 1-D vector<k>
-/// reduction (and are thus missing either a parallel or a reduction), we lift
-/// them back up to 2-D with a simple vector.shape_cast to vector<1xk> so that
-/// the other patterns can kick in, thus fully exiting out of the
-/// vector.multi_reduction abstraction.
-void populateVectorMultiReductionLoweringPatterns(
- RewritePatternSet &patterns, VectorMultiReductionLowering options,
- PatternBenefit benefit = 1);
-
-/// Collects patterns to progressively lower vector contraction ops on high-D
-/// into low-D reduction and product ops.
-void populateVectorContractLoweringPatterns(
- RewritePatternSet &patterns,
- VectorTransformsOptions options = VectorTransformsOptions(),
- PatternBenefit benefit = 1);
-
/// Canonicalization of a `vector.contraction %a, %b, %c` with row-major matmul
/// semantics to a contraction with MMT semantics (matrix matrix multiplication
/// with the RHS transposed). This specific form is meant to have the vector
@@ -174,67 +100,43 @@ void populateVectorContractCanonicalizeMatmulToMMT(
void populateVectorReductionToContractPatterns(RewritePatternSet &patterns,
PatternBenefit benefit = 1);
-/// Collect patterns to convert scan op
-void populateVectorScanLoweringPatterns(RewritePatternSet &patterns,
- PatternBenefit benefit = 1);
-
-//===----------------------------------------------------------------------===//
-// Vector.transfer patterns.
-//===----------------------------------------------------------------------===//
-/// Collect a set of transfer read/write lowering patterns that simplify the
-/// permutation map (e.g., converting it to a minor identity map) by inserting
-/// broadcasts and transposes. More specifically:
-///
-/// [TransferReadPermutationLowering]
-/// Lower transfer_read op with permutation into a transfer_read with a
-/// permutation map composed of leading zeros followed by a minor identity +
-/// vector.transpose op.
-/// Ex:
-/// vector.transfer_read ...
-/// permutation_map: (d0, d1, d2) -> (0, d1)
-/// into:
-/// %v = vector.transfer_read ...
-/// permutation_map: (d0, d1, d2) -> (d1, 0)
-/// vector.transpose %v, [1, 0]
+/// Populate `patterns` with the following patterns.
///
-/// vector.transfer_read ...
-/// permutation_map: (d0, d1, d2, d3) -> (0, 0, 0, d1, d3)
-/// into:
-/// %v = vector.transfer_read ...
-/// permutation_map: (d0, d1, d2, d3) -> (0, 0, d1, 0, d3)
-/// vector.transpose %v, [0, 1, 3, 2, 4]
-/// Note that an alternative is to transform it to linalg.transpose +
-/// vector.transfer_read to do the transpose in memory instead.
+/// - VectorTransferFullPartialRewriter
///
-/// [TransferWritePermutationLowering]
-/// Lower transfer_write op with permutation into a transfer_write with a
-/// minor identity permutation map. (transfer_write ops cannot have broadcasts.)
-/// Ex:
-/// vector.transfer_write %v ...
-/// permutation_map: (d0, d1, d2) -> (d2, d0, d1)
-/// into:
-/// %tmp = vector.transpose %v, [2, 0, 1]
-/// vector.transfer_write %tmp ...
-/// permutation_map: (d0, d1, d2) -> (d0, d1, d2)
+/// Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds
+/// masking) fast path and a slow path.
///
-/// vector.transfer_write %v ...
-/// permutation_map: (d0, d1, d2, d3) -> (d3, d2)
-/// into:
-/// %tmp = vector.transpose %v, [1, 0]
-/// %v = vector.transfer_write %tmp ...
-/// permutation_map: (d0, d1, d2, d3) -> (d2, d3)
+/// Example (a 2-D vector.transfer_read):
+/// ```
+/// %1 = vector.transfer_read %0[...], %pad : memref<A...>, vector<...>
+/// ```
+/// is transformed into:
+/// ```
+/// %1:3 = scf.if (%inBounds) {
+/// // fast path, direct cast
+/// memref.cast %A: memref<A...> to compatibleMemRefType
+/// scf.yield %view : compatibleMemRefType, index, index
+/// } else {
+/// // slow path, not in-bounds vector.transfer or linalg.copy.
+/// memref.cast %alloc: memref<B...> to compatibleMemRefType
+/// scf.yield %4 : compatibleMemRefType, index, index
+// }
+/// %0 = vector.transfer_read %1#0[%1#1, %1#2] {in_bounds = [true ... true]}
+/// ```
+/// where `alloc` is a top of the function alloca'ed buffer of one vector.
///
-/// [TransferOpReduceRank]
-/// Lower transfer_read op with broadcast in the leading dimensions into
-/// transfer_read of lower rank + vector.broadcast.
-/// Ex: vector.transfer_read ...
-/// permutation_map: (d0, d1, d2, d3) -> (0, d1, 0, d3)
-/// into:
-/// %v = vector.transfer_read ...
-/// permutation_map: (d0, d1, d2, d3) -> (d1, 0, d3)
-/// vector.broadcast %v
-void populateVectorTransferPermutationMapLoweringPatterns(
- RewritePatternSet &patterns, PatternBenefit benefit = 1);
+/// Preconditions:
+/// 1. `xferOp.permutation_map()` must be a minor identity map
+/// 2. the rank of the `xferOp.memref()` and the rank of the `xferOp.vector()`
+/// must be equal. This will be relaxed in the future but requires
+/// rank-reducing subviews.
+void populateVectorTransferFullPartialPatterns(
+ RewritePatternSet &patterns, const VectorTransformsOptions &options);
+
+//===----------------------------------------------------------------------===//
+// Vector.transfer patterns.
+//===----------------------------------------------------------------------===//
/// Collect a set of patterns to reduce the rank of the operands of vector
/// transfer ops to operate on the largest contigious vector.
@@ -334,220 +236,6 @@ void populateVectorUnrollPatterns(RewritePatternSet &patterns,
const UnrollVectorOptions &options,
PatternBenefit benefit = 1);
-/// Expands `vector.gather` ops into a series of conditional scalar loads
-/// (`vector.load` for memrefs or `tensor.extract` for tensors). These loads are
-/// conditional to avoid out-of-bounds memory accesses and guarded with `scf.if`
-/// ops. This lowering path is intended for targets that do not feature
-/// dedicated gather ops.
-void populateVectorGatherLoweringPatterns(RewritePatternSet &patterns,
- PatternBenefit benefit = 1);
-
-//===----------------------------------------------------------------------===//
-// Finer-grained patterns exposed for more control over individual lowerings.
-//===----------------------------------------------------------------------===//
-/// Apply `splitFullAndPartialTransfer` selectively via a pattern. This pattern
-/// may take an extra filter to perform selection at a finer granularity.
-struct VectorTransferFullPartialRewriter : public RewritePattern {
- using FilterConstraintType =
- std::function<LogicalResult(VectorTransferOpInterface op)>;
-
- explicit VectorTransferFullPartialRewriter(
- MLIRContext *context,
- VectorTransformsOptions options = VectorTransformsOptions(),
- FilterConstraintType filter =
- [](VectorTransferOpInterface op) { return success(); },
- PatternBenefit benefit = 1)
- : RewritePattern(MatchAnyOpTypeTag(), benefit, context), options(options),
- filter(std::move(filter)) {}
-
- /// Performs the rewrite.
- LogicalResult matchAndRewrite(Operation *op,
- PatternRewriter &rewriter) const override;
-
-private:
- VectorTransformsOptions options;
- FilterConstraintType filter;
-};
-
-/// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
-/// semantics to:
-/// ```
-/// %flattened_a = vector.shape_cast %a
-/// %flattened_b = vector.shape_cast %b
-/// %flattened_d = vector.matmul %flattened_a, %flattened_b
-/// %d = vector.shape_cast %%flattened_d
-/// %e = add %c, %d
-/// ```
-/// `vector.matmul` later lowers to `llvm.matrix.multiply`.
-//
-/// This only kicks in when VectorTransformsOptions is set to OuterProduct and
-/// the vector.contract op is a row-major matrix multiply.
-class ContractionOpToMatmulOpLowering
- : public OpRewritePattern<vector::ContractionOp> {
-public:
- using OpRewritePattern::OpRewritePattern;
-
- using FilterConstraintType =
- std::function<LogicalResult(vector::ContractionOp op)>;
-
- static LogicalResult defaultFilter(vector::ContractionOp op) {
- return success();
- }
-
- ContractionOpToMatmulOpLowering(
- vector::VectorTransformsOptions vectorTransformOptions,
- MLIRContext *context, PatternBenefit benefit = 1,
- FilterConstraintType constraint = defaultFilter)
- : OpRewritePattern<vector::ContractionOp>(context, benefit),
- vectorTransformOptions(vectorTransformOptions),
- filter(std::move(constraint)) {}
-
- LogicalResult matchAndRewrite(vector::ContractionOp op,
- PatternRewriter &rewriter) const override;
-
-private:
- /// Options to control the vector patterns.
- vector::VectorTransformsOptions vectorTransformOptions;
- FilterConstraintType filter;
-};
-
-/// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
-/// semantics to a reduction_size-unrolled sequence:
-/// ```
-/// %at = vector.transpose %a, [1, 0]
-/// %bRow0 = vector.extract %b[0]
-/// %atRow0 = vector.extract %at[0]
-/// %c0 = vector.outerproduct %atRow0, %bRow0, %c
-/// ...
-/// %bRowK = vector.extract %b[K]
-/// %atRowK = vector.extract %at[K]
-/// %cK = vector.outerproduct %atRowK, %bRowK, %cK-1
-/// ```
-///
-/// This only kicks in when VectorTransformsOptions is set to OuterProduct and
-/// the vector.contract op is a row-major matrix multiply.
-class ContractionOpToOuterProductOpLowering
- : public OpRewritePattern<vector::ContractionOp> {
-public:
- using OpRewritePattern::OpRewritePattern;
-
- using FilterConstraintType =
- std::function<LogicalResult(vector::ContractionOp op)>;
-
- static LogicalResult defaultFilter(vector::ContractionOp op) {
- return success();
- }
-
- ContractionOpToOuterProductOpLowering(
- vector::VectorTransformsOptions vectorTransformOptions,
- MLIRContext *context, PatternBenefit benefit = 1,
- FilterConstraintType constraint = defaultFilter)
- : OpRewritePattern<vector::ContractionOp>(context, benefit),
- vectorTransformOptions(vectorTransformOptions),
- filter(std::move(constraint)) {}
-
- LogicalResult matchAndRewrite(vector::ContractionOp op,
- PatternRewriter &rewriter) const override;
-
-private:
- /// Options to control the vector patterns.
- vector::VectorTransformsOptions vectorTransformOptions;
- FilterConstraintType filter;
-};
-
-/// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
-/// semantics to an output-size-unrolled sequence:
-/// ```
-/// %out = arith.constant ... : vector<MxNxelt_type>
-/// %bt = vector.transpose %b, [1, 0]
-/// %aRow0 = vector.extract %a[0]
-/// %btRow0 = vector.extract %bt[0]
-/// %c00 = vector.reduce %atRow0, %bRow0
-/// %out00 = vector.insert %c00, %out[0, 0]
-/// ...
-/// %aRowLast = vector.extract %at[M-1]
-/// %btRowLast = vector.extract %b[N-1]
-/// %cLastLast = vector.reduce %atRowLast, %bRowLast
-/// %outcLastLast = vector.insert %cLastLast, %out[M-1, N-1]
-/// ```
-///
-/// This only kicks in when VectorTransformsOptions is set to Dot and
-/// the vector.contract op is a row-major matmul or matvec.
-class ContractionOpToDotLowering
- : public OpRewritePattern<vector::ContractionOp> {
-public:
- using OpRewritePattern::OpRewritePattern;
-
- using FilterConstraintType =
- std::function<LogicalResult(vector::ContractionOp op)>;
-
- static LogicalResult defaultFilter(vector::ContractionOp op) {
- return success();
- }
-
- ContractionOpToDotLowering(
- vector::VectorTransformsOptions vectorTransformOptions,
- MLIRContext *context, PatternBenefit benefit = 1,
- const FilterConstraintType &constraint = defaultFilter)
- : OpRewritePattern<vector::ContractionOp>(context, benefit),
- vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {}
-
- LogicalResult matchAndRewrite(vector::ContractionOp op,
- PatternRewriter &rewriter) const override;
-
-private:
- /// Options to control the vector patterns.
- vector::VectorTransformsOptions vectorTransformOptions;
- FilterConstraintType filter;
-};
-
-/// Progressive lowering of ContractionOp.
-///
-/// One:
-/// %x = vector.contract with at least one free/batch dimension
-/// is replaced by:
-/// %a = vector.contract with one less free/batch dimension
-/// %b = vector.contract with one less free/batch dimension
-/// ..
-/// %x = combine %a %b ..
-/// until a pure contraction is reached (no free/batch dimensions),
-/// which is replaced by a dot-product.
-///
-/// This only kicks in when either VectorTransformsOptions is set
-/// to Dot or when other contraction patterns fail.
-class ContractionOpLowering : public OpRewritePattern<vector::ContractionOp> {
-public:
- using OpRewritePattern::OpRewritePattern;
- using FilterConstraintType =
- std::function<LogicalResult(vector::ContractionOp op)>;
-
- static LogicalResult defaultFilter(vector::ContractionOp op) {
- return success();
- }
-
- ContractionOpLowering(vector::VectorTransformsOptions vectorTransformOptions,
- MLIRContext *context, PatternBenefit benefit = 1,
- FilterConstraintType constraint = defaultFilter)
- : OpRewritePattern<vector::ContractionOp>(context, benefit),
- vectorTransformOptions(vectorTransformOptions),
- filter(std::move(constraint)) {}
-
- LogicalResult matchAndRewrite(vector::ContractionOp op,
- PatternRewriter &rewriter) const override;
-
-private:
- /// Options to control the vector patterns.
- vector::VectorTransformsOptions vectorTransformOptions;
- FilterConstraintType filter;
- // Lower one parallel dimension.
- FailureOr<Value> lowerParallel(PatternRewriter &rewriter,
- vector::ContractionOp op, int64_t lhsIndex,
- int64_t rhsIndex, Value mask) const;
- // Lower one reduction dimension.
- FailureOr<Value> lowerReduction(PatternRewriter &rewriter,
- vector::ContractionOp op, Value mask) const;
-};
-
} // namespace vector
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorTransforms.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorTransforms.h
index 947911f9a3841..52a4c9cc368d8 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorTransforms.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorTransforms.h
@@ -24,17 +24,53 @@ class IfOp;
namespace vector {
+//===----------------------------------------------------------------------===//
+// Vector transformation options exposed as auxiliary structs.
+//===----------------------------------------------------------------------===//
+/// Structure to control the behavior of vector transform patterns.
+struct VectorTransformsOptions {
+ /// Option to control the lowering of vector.contract.
+ VectorContractLowering vectorContractLowering = VectorContractLowering::Dot;
+ VectorTransformsOptions &
+ setVectorTransformsOptions(VectorContractLowering opt) {
+ vectorContractLowering = opt;
+ return *this;
+ }
+ /// Option to control the lowering of vector.multi_reduction.
+ VectorMultiReductionLowering vectorMultiReductionLowering =
+ VectorMultiReductionLowering::InnerParallel;
+ VectorTransformsOptions &
+ setVectorMultiReductionLowering(VectorMultiReductionLowering opt) {
+ vectorMultiReductionLowering = opt;
+ return *this;
+ }
+ /// Option to control the lowering of vector.transpose.
+ VectorTransposeLowering vectorTransposeLowering =
+ VectorTransposeLowering::EltWise;
+ VectorTransformsOptions &
+ setVectorTransposeLowering(VectorTransposeLowering opt) {
+ vectorTransposeLowering = opt;
+ return *this;
+ }
+ /// Option to control the splitting of vector transfers.
+ VectorTransferSplit vectorTransferSplit = VectorTransferSplit::None;
+ VectorTransformsOptions &setVectorTransferSplit(VectorTransferSplit opt) {
+ vectorTransferSplit = opt;
+ return *this;
+ }
+};
+
//===----------------------------------------------------------------------===//
// Standalone transformations and helpers.
//===----------------------------------------------------------------------===//
-/// Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds
-/// masking) fastpath and a slowpath.
-/// If `ifOp` is not null and the result is `success, the `ifOp` points to the
-/// newly created conditional upon function return.
-/// To accomodate for the fact that the original vector.transfer indexing may be
-/// arbitrary and the slow path indexes @[0...0] in the temporary buffer, the
-/// scf.if op returns a view and values of type index.
-/// At this time, only vector.transfer_read case is implemented.
+/// Split a vector.transfer operation into an in-bounds (i.e., no
+/// out-of-bounds masking) fastpath and a slowpath. If `ifOp` is not null and
+/// the result is `success, the `ifOp` points to the newly created conditional
+/// upon function return. To accomodate for the fact that the original
+/// vector.transfer indexing may be arbitrary and the slow path indexes
+/// @[0...0] in the temporary buffer, the scf.if op returns a view and values
+/// of type index. At this time, only vector.transfer_read case is
+/// implemented.
///
/// Example (a 2-D vector.transfer_read):
/// ```
@@ -51,15 +87,16 @@ namespace vector {
/// memref.cast %alloc: memref<B...> to compatibleMemRefType
/// scf.yield %4 : compatibleMemRefType, index, index
// }
-/// %0 = vector.transfer_read %1#0[%1#1, %1#2] {in_bounds = [true ... true]}
+/// %0 = vector.transfer_read %1#0[%1#1, %1#2] {in_bounds = [true ...
+/// true]}
/// ```
/// where `alloc` is a top of the function alloca'ed buffer of one vector.
///
/// Preconditions:
/// 1. `xferOp.permutation_map()` must be a minor identity map
-/// 2. the rank of the `xferOp.memref()` and the rank of the `xferOp.vector()`
-/// must be equal. This will be relaxed in the future but requires
-/// rank-reducing subviews.
+/// 2. the rank of the `xferOp.memref()` and the rank of the
+/// `xferOp.vector()` must be equal. This will be relaxed in the future but
+/// requires rank-reducing subviews.
LogicalResult splitFullAndPartialTransfer(
RewriterBase &b, VectorTransferOpInterface xferOp,
VectorTransformsOptions options = VectorTransformsOptions(),
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index c56d03f6f31d7..05def0f45d7fb 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -16,6 +16,7 @@
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/TypeUtilities.h"
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index fb544df18324b..3f1b107f6f8e0 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -19,6 +19,7 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/Dialect/X86Vector/Transforms.h"
#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
@@ -64,10 +65,11 @@ void LowerVectorToLLVMPass::runOnOperation() {
RewritePatternSet patterns(&getContext());
populateVectorToVectorCanonicalizationPatterns(patterns);
populateVectorBroadcastLoweringPatterns(patterns);
- populateVectorContractLoweringPatterns(patterns);
+ populateVectorContractLoweringPatterns(patterns, VectorTransformsOptions());
populateVectorMaskOpLoweringPatterns(patterns);
populateVectorShapeCastLoweringPatterns(patterns);
- populateVectorTransposeLoweringPatterns(patterns);
+ populateVectorTransposeLoweringPatterns(patterns,
+ VectorTransformsOptions());
// Vector transfer ops with rank > 1 should be lowered with VectorToSCF.
populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index d8070b34a761d..ec2e2aa4c0624 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -10,8 +10,8 @@
//
//===----------------------------------------------------------------------===//
-#include <type_traits>
#include <optional>
+#include <type_traits>
#include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
@@ -20,6 +20,7 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/Linalg/TransformOps/CMakeLists.txt
index eb97c6e168e5c..b7d9812ada0b1 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/TransformOps/CMakeLists.txt
@@ -20,5 +20,5 @@ add_mlir_dialect_library(MLIRLinalgTransformOps
MLIRSideEffectInterfaces
MLIRTransformDialect
MLIRTransformDialectUtils
- MLIRVectorDialect
+ MLIRVectorTransforms
)
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index d98eb3b781fc5..e3c1429ade54a 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -26,6 +26,7 @@
#include "mlir/Dialect/Transform/IR/TransformTypes.h"
#include "mlir/Dialect/Transform/Utils/Utils.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Matchers.h"
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index 60996b9add614..136d234742b8d 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -7,13 +7,14 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h"
-
#include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
#include "mlir/Dialect/PDL/IR/PDL.h"
#include "mlir/Dialect/PDL/IR/PDLTypes.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
+#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
#include "mlir/Dialect/X86Vector/Transforms.h"
#include "mlir/Parser/Parser.h"
@@ -82,10 +83,9 @@ DiagnosedSilenceableFailure transform::LowerVectorsOp::apply(
// In the future we may want to more finely select particular stages.
// Stage 1: contraction lowerings.
- patterns.add<mlir::vector::ContractionOpToOuterProductOpLowering,
- mlir::vector::ContractionOpToMatmulOpLowering,
- mlir::vector::ContractionOpLowering>(vectorTransformOptions,
- ctx);
+ populateVectorContractLoweringPatterns(
+ patterns, vectorTransformOptions, /*benefit=*/1,
+ /*disableOuterProductLowering*/ true);
vector::populateVectorTransferPermutationMapLoweringPatterns(patterns);
// Stage 2: multi-reduction lowerings.
@@ -93,8 +93,7 @@ DiagnosedSilenceableFailure transform::LowerVectorsOp::apply(
patterns, vectorTransformOptions.vectorMultiReductionLowering);
// Stage 3: Rewrite vector.transfer into full and partial parts.
- patterns.add<vector::VectorTransferFullPartialRewriter>(
- ctx, vectorTransformOptions);
+ populateVectorTransferFullPartialPatterns(patterns, vectorTransformOptions);
// Stage 4: Lower vector transfers.
vector::populateVectorTransferLoweringPatterns(patterns, maxTransferRank);
@@ -107,8 +106,8 @@ DiagnosedSilenceableFailure transform::LowerVectorsOp::apply(
vector::populateVectorShapeCastLoweringPatterns(patterns);
// Stage 7: Lower vector.transpose.
- vector::populateVectorTransposeLoweringPatterns(patterns,
- vectorTransformOptions);
+ vector::populateVectorTransposeLoweringPatterns(
+ patterns, vectorTransformOptions, /*benefit=*/1);
if (getTransposeAvx2Lowering())
x86vector::avx2::populateSpecializedTransposeLoweringPatterns(
patterns, avx2LoweringOptions, /*benefit=*/10);
diff --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
index 6fb1b8c18a122..f17208e193b3c 100644
--- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
@@ -1,14 +1,20 @@
add_mlir_dialect_library(MLIRVectorTransforms
BufferizableOpInterfaceImpl.cpp
Bufferize.cpp
+ LowerVectorBroadcast.cpp
+ LowerVectorContract.cpp
+ LowerVectorGather.cpp
LowerVectorMask.cpp
+ LowerVectorMultiReduction.cpp
+ LowerVectorScan.cpp
+ LowerVectorShapeCast.cpp
+ LowerVectorTransfer.cpp
+ LowerVectorTranspose.cpp
VectorDistribute.cpp
VectorDropLeadUnitDim.cpp
VectorInsertExtractStridedSliceRewritePatterns.cpp
- VectorMultiDimReductionTransforms.cpp
VectorTransferOpTransforms.cpp
VectorTransferSplitRewritePatterns.cpp
- VectorTransferPermutationMapRewritePatterns.cpp
VectorTransforms.cpp
VectorUnroll.cpp
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
new file mode 100644
index 0000000000000..ad538fe4a6828
--- /dev/null
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
@@ -0,0 +1,156 @@
+//===- LowerVectorBroadcast.cpp - Lower 'vector.broadcast' operation ------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements target-independent rewrites and utilities to lower the
+// 'vector.broadcast' operation.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
+#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
+#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
+#include "mlir/IR/BuiltinAttributeInterfaces.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/ImplicitLocOpBuilder.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeUtilities.h"
+#include "mlir/Interfaces/VectorInterfaces.h"
+#include "mlir/Support/LogicalResult.h"
+
+#define DEBUG_TYPE "vector-broadcast-lowering"
+
+using namespace mlir;
+using namespace mlir::vector;
+
+namespace {
+/// Progressive lowering of BroadcastOp.
+class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::BroadcastOp op,
+ PatternRewriter &rewriter) const override {
+ auto loc = op.getLoc();
+ VectorType dstType = op.getResultVectorType();
+ VectorType srcType = op.getSourceType().dyn_cast<VectorType>();
+ Type eltType = dstType.getElementType();
+
+ // Scalar to any vector can use splat.
+ if (!srcType) {
+ rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, op.getSource());
+ return success();
+ }
+
+ // Determine rank of source and destination.
+ int64_t srcRank = srcType.getRank();
+ int64_t dstRank = dstType.getRank();
+
+ // Stretching scalar inside vector (e.g. vector<1xf32>) can use splat.
+ if (srcRank <= 1 && dstRank == 1) {
+ Value ext;
+ if (srcRank == 0)
+ ext = rewriter.create<vector::ExtractElementOp>(loc, op.getSource());
+ else
+ ext = rewriter.create<vector::ExtractOp>(loc, op.getSource(), 0);
+ rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, ext);
+ return success();
+ }
+
+ // Duplicate this rank.
+ // For example:
+ // %x = broadcast %y : k-D to n-D, k < n
+ // becomes:
+ // %b = broadcast %y : k-D to (n-1)-D
+ // %x = [%b,%b,%b,%b] : n-D
+ // becomes:
+ // %b = [%y,%y] : (n-1)-D
+ // %x = [%b,%b,%b,%b] : n-D
+ if (srcRank < dstRank) {
+ // Duplication.
+ VectorType resType =
+ VectorType::get(dstType.getShape().drop_front(), eltType);
+ Value bcst =
+ rewriter.create<vector::BroadcastOp>(loc, resType, op.getSource());
+ Value result = rewriter.create<arith::ConstantOp>(
+ loc, dstType, rewriter.getZeroAttr(dstType));
+ for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d)
+ result = rewriter.create<vector::InsertOp>(loc, bcst, result, d);
+ rewriter.replaceOp(op, result);
+ return success();
+ }
+
+ // Find non-matching dimension, if any.
+ assert(srcRank == dstRank);
+ int64_t m = -1;
+ for (int64_t r = 0; r < dstRank; r++)
+ if (srcType.getDimSize(r) != dstType.getDimSize(r)) {
+ m = r;
+ break;
+ }
+
+ // All trailing dimensions are the same. Simply pass through.
+ if (m == -1) {
+ rewriter.replaceOp(op, op.getSource());
+ return success();
+ }
+
+ // Any non-matching dimension forces a stretch along this rank.
+ // For example:
+ // %x = broadcast %y : vector<4x1x2xf32> to vector<4x2x2xf32>
+ // becomes:
+ // %a = broadcast %y[0] : vector<1x2xf32> to vector<2x2xf32>
+ // %b = broadcast %y[1] : vector<1x2xf32> to vector<2x2xf32>
+ // %c = broadcast %y[2] : vector<1x2xf32> to vector<2x2xf32>
+ // %d = broadcast %y[3] : vector<1x2xf32> to vector<2x2xf32>
+ // %x = [%a,%b,%c,%d]
+ // becomes:
+ // %u = broadcast %y[0][0] : vector<2xf32> to vector <2x2xf32>
+ // %v = broadcast %y[1][0] : vector<2xf32> to vector <2x2xf32>
+ // %a = [%u, %v]
+ // ..
+ // %x = [%a,%b,%c,%d]
+ VectorType resType =
+ VectorType::get(dstType.getShape().drop_front(), eltType);
+ Value result = rewriter.create<arith::ConstantOp>(
+ loc, dstType, rewriter.getZeroAttr(dstType));
+ if (m == 0) {
+ // Stetch at start.
+ Value ext = rewriter.create<vector::ExtractOp>(loc, op.getSource(), 0);
+ Value bcst = rewriter.create<vector::BroadcastOp>(loc, resType, ext);
+ for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d)
+ result = rewriter.create<vector::InsertOp>(loc, bcst, result, d);
+ } else {
+ // Stetch not at start.
+ for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d) {
+ Value ext = rewriter.create<vector::ExtractOp>(loc, op.getSource(), d);
+ Value bcst = rewriter.create<vector::BroadcastOp>(loc, resType, ext);
+ result = rewriter.create<vector::InsertOp>(loc, bcst, result, d);
+ }
+ }
+ rewriter.replaceOp(op, result);
+ return success();
+ }
+};
+} // namespace
+
+void mlir::vector::populateVectorBroadcastLoweringPatterns(
+ RewritePatternSet &patterns, PatternBenefit benefit) {
+ patterns.add<BroadcastOpLowering>(patterns.getContext(), benefit);
+}
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
new file mode 100644
index 0000000000000..1280cfef0b645
--- /dev/null
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
@@ -0,0 +1,1329 @@
+//===- LowerVectorContract.cpp - Lower 'vector.contract' operation --------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements target-independent rewrites and utilities to lower the
+// 'vector.contract' operation.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
+#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
+#include "mlir/IR/BuiltinAttributeInterfaces.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/ImplicitLocOpBuilder.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeUtilities.h"
+#include "mlir/Interfaces/VectorInterfaces.h"
+#include "mlir/Support/LogicalResult.h"
+
+#define DEBUG_TYPE "vector-contract-lowering"
+
+using namespace mlir;
+using namespace mlir::vector;
+
+//===----------------------------------------------------------------------===//
+// Helper functions
+//===----------------------------------------------------------------------===//
+
+// Helper to find an index in an affine map.
+static std::optional<int64_t> getResultIndex(AffineMap map, int64_t index) {
+ for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
+ int64_t idx = map.getDimPosition(i);
+ if (idx == index)
+ return i;
+ }
+ return std::nullopt;
+}
+
+// Helper to construct iterator types with one index removed.
+static SmallVector<Attribute> adjustIter(ArrayAttr iteratorTypes,
+ int64_t index) {
+ SmallVector<Attribute> results;
+ for (const auto &it : llvm::enumerate(iteratorTypes)) {
+ int64_t idx = it.index();
+ if (idx == index)
+ continue;
+ results.push_back(it.value());
+ }
+ return results;
+}
+
+// Helper to construct an affine map with one index removed.
+static AffineMap adjustMap(AffineMap map, int64_t index,
+ PatternRewriter &rewriter) {
+ auto *ctx = rewriter.getContext();
+ SmallVector<AffineExpr> results;
+ for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
+ int64_t idx = map.getDimPosition(i);
+ if (idx == index)
+ continue;
+ // Re-insert remaining indices, but renamed when occurring
+ // after the removed index.
+ auto targetExpr = getAffineDimExpr(idx < index ? idx : idx - 1, ctx);
+ results.push_back(targetExpr);
+ }
+ return AffineMap::get(map.getNumDims() - 1, 0, results, ctx);
+}
+
+// Helper method to possibly drop a dimension in a load.
+// TODO
+static Value reshapeLoad(Location loc, Value val, VectorType type,
+ int64_t index, int64_t pos,
+ PatternRewriter &rewriter) {
+ if (index == -1)
+ return val;
+ Type lowType = VectorType::Builder(type).dropDim(0);
+ // At extraction dimension?
+ if (index == 0) {
+ auto posAttr = rewriter.getI64ArrayAttr(pos);
+ return rewriter.create<vector::ExtractOp>(loc, lowType, val, posAttr);
+ }
+ // Unroll leading dimensions.
+ VectorType vType = lowType.cast<VectorType>();
+ Type resType = VectorType::Builder(type).dropDim(index);
+ auto resVectorType = resType.cast<VectorType>();
+ Value result = rewriter.create<arith::ConstantOp>(
+ loc, resVectorType, rewriter.getZeroAttr(resVectorType));
+ for (int64_t d = 0, e = resVectorType.getDimSize(0); d < e; d++) {
+ auto posAttr = rewriter.getI64ArrayAttr(d);
+ Value ext = rewriter.create<vector::ExtractOp>(loc, vType, val, posAttr);
+ Value load = reshapeLoad(loc, ext, vType, index - 1, pos, rewriter);
+ result = rewriter.create<vector::InsertOp>(loc, resVectorType, load, result,
+ posAttr);
+ }
+ return result;
+}
+
+// Helper method to possibly drop a dimension in a store.
+// TODO
+static Value reshapeStore(Location loc, Value val, Value result,
+ VectorType type, int64_t index, int64_t pos,
+ PatternRewriter &rewriter) {
+ // Unmodified?
+ if (index == -1)
+ return val;
+ // At insertion dimension?
+ if (index == 0) {
+ auto posAttr = rewriter.getI64ArrayAttr(pos);
+ return rewriter.create<vector::InsertOp>(loc, type, val, result, posAttr);
+ }
+ // Unroll leading dimensions.
+ Type lowType = VectorType::Builder(type).dropDim(0);
+ VectorType vType = lowType.cast<VectorType>();
+ Type insType = VectorType::Builder(vType).dropDim(0);
+ for (int64_t d = 0, e = type.getDimSize(0); d < e; d++) {
+ auto posAttr = rewriter.getI64ArrayAttr(d);
+ Value ext = rewriter.create<vector::ExtractOp>(loc, vType, result, posAttr);
+ Value ins = rewriter.create<vector::ExtractOp>(loc, insType, val, posAttr);
+ Value sto = reshapeStore(loc, ins, ext, vType, index - 1, pos, rewriter);
+ result = rewriter.create<vector::InsertOp>(loc, type, sto, result, posAttr);
+ }
+ return result;
+}
+
+/// Helper to create arithmetic operation associated with a kind of contraction.
+static std::optional<Value>
+createContractArithOp(Location loc, Value x, Value y, Value acc,
+ vector::CombiningKind kind, PatternRewriter &rewriter,
+ bool isInt, Value mask = Value()) {
+ using vector::CombiningKind;
+ Value mul;
+
+ if (isInt) {
+ if (kind == CombiningKind::MINF || kind == CombiningKind::MAXF)
+ // Only valid for floating point types.
+ return std::nullopt;
+ mul = rewriter.create<arith::MulIOp>(loc, x, y);
+ } else {
+ // Float case.
+ if (kind == CombiningKind::AND || kind == CombiningKind::MINUI ||
+ kind == CombiningKind::MINSI || kind == CombiningKind::MAXUI ||
+ kind == CombiningKind::MAXSI || kind == CombiningKind::OR ||
+ kind == CombiningKind::XOR)
+ // Only valid for integer types.
+ return std::nullopt;
+ // Special case for fused multiply-add.
+ if (acc && acc.getType().isa<VectorType>() && kind == CombiningKind::ADD) {
+ Value fma = rewriter.create<vector::FMAOp>(loc, x, y, acc);
+ if (mask)
+ // The fma op doesn't need explicit masking. However, fma ops used in
+ // reductions must preserve previous 'acc' values for masked-out lanes.
+ fma = selectPassthru(rewriter, mask, fma, acc);
+ return fma;
+ }
+ mul = rewriter.create<arith::MulFOp>(loc, x, y);
+ }
+
+ if (!acc)
+ return std::optional<Value>(mul);
+
+ return makeArithReduction(rewriter, loc, kind, mul, acc, mask);
+}
+
+/// Return the positions of the reductions in the given map.
+static SmallVector<int64_t> getReductionIndex(AffineMap map,
+ ArrayAttr iteratorTypes) {
+ SmallVector<int64_t> dimsIdx;
+ for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
+ if (isReductionIterator(iteratorTypes[map.getDimPosition(i)]))
+ dimsIdx.push_back(i);
+ }
+ return dimsIdx;
+}
+
+/// Look for a given dimension in an affine map and return its position. Return
+/// std::nullopt if the dimension is not in the map results.
+static std::optional<unsigned> getDimPosition(AffineMap map, unsigned dim) {
+ for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
+ if (map.getDimPosition(i) == dim)
+ return i;
+ }
+ return std::nullopt;
+}
+
+/// Creates an AddIOp if `isInt` is true otherwise create an arith::AddFOp using
+/// operands `x` and `y`.
+static Value createAdd(Location loc, Value x, Value y, bool isInt,
+ PatternRewriter &rewriter) {
+ if (isInt)
+ return rewriter.create<arith::AddIOp>(loc, x, y);
+ return rewriter.create<arith::AddFOp>(loc, x, y);
+}
+
+/// Creates a MulIOp if `isInt` is true otherwise create an MulFOp using
+/// operands `x and `y`.
+static Value createMul(Location loc, Value x, Value y, bool isInt,
+ PatternRewriter &rewriter) {
+ if (isInt)
+ return rewriter.create<arith::MulIOp>(loc, x, y);
+ return rewriter.create<arith::MulFOp>(loc, x, y);
+}
+
+namespace {
+
+/// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
+/// semantics to:
+/// ```
+/// %flattened_a = vector.shape_cast %a
+/// %flattened_b = vector.shape_cast %b
+/// %flattened_d = vector.matmul %flattened_a, %flattened_b
+/// %d = vector.shape_cast %%flattened_d
+/// %e = add %c, %d
+/// ```
+/// `vector.matmul` later lowers to `llvm.matrix.multiply`.
+//
+/// This only kicks in when VectorTransformsOptions is set to OuterProduct and
+/// the vector.contract op is a row-major matrix multiply.
+class ContractionOpToMatmulOpLowering
+ : public OpRewritePattern<vector::ContractionOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+
+ using FilterConstraintType =
+ std::function<LogicalResult(vector::ContractionOp op)>;
+
+ static LogicalResult defaultFilter(vector::ContractionOp op) {
+ return success();
+ }
+
+ ContractionOpToMatmulOpLowering(
+ vector::VectorTransformsOptions vectorTransformOptions,
+ MLIRContext *context, PatternBenefit benefit = 1,
+ FilterConstraintType constraint = defaultFilter)
+ : OpRewritePattern<vector::ContractionOp>(context, benefit),
+ vectorTransformOptions(vectorTransformOptions),
+ filter(std::move(constraint)) {}
+
+ LogicalResult matchAndRewrite(vector::ContractionOp op,
+ PatternRewriter &rewriter) const override;
+
+private:
+ /// Options to control the vector patterns.
+ vector::VectorTransformsOptions vectorTransformOptions;
+ FilterConstraintType filter;
+};
+
+/// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
+/// semantics to a reduction_size-unrolled sequence:
+/// ```
+/// %at = vector.transpose %a, [1, 0]
+/// %bRow0 = vector.extract %b[0]
+/// %atRow0 = vector.extract %at[0]
+/// %c0 = vector.outerproduct %atRow0, %bRow0, %c
+/// ...
+/// %bRowK = vector.extract %b[K]
+/// %atRowK = vector.extract %at[K]
+/// %cK = vector.outerproduct %atRowK, %bRowK, %cK-1
+/// ```
+///
+/// This only kicks in when VectorTransformsOptions is set to OuterProduct and
+/// the vector.contract op is a row-major matrix multiply.
+class ContractionOpToOuterProductOpLowering
+ : public OpRewritePattern<vector::ContractionOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+
+ using FilterConstraintType =
+ std::function<LogicalResult(vector::ContractionOp op)>;
+
+ static LogicalResult defaultFilter(vector::ContractionOp op) {
+ return success();
+ }
+
+ ContractionOpToOuterProductOpLowering(
+ vector::VectorTransformsOptions vectorTransformOptions,
+ MLIRContext *context, PatternBenefit benefit = 1,
+ FilterConstraintType constraint = defaultFilter)
+ : OpRewritePattern<vector::ContractionOp>(context, benefit),
+ vectorTransformOptions(vectorTransformOptions),
+ filter(std::move(constraint)) {}
+
+ LogicalResult matchAndRewrite(vector::ContractionOp op,
+ PatternRewriter &rewriter) const override;
+
+private:
+ /// Options to control the vector patterns.
+ vector::VectorTransformsOptions vectorTransformOptions;
+ FilterConstraintType filter;
+};
+
+/// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
+/// semantics to an output-size-unrolled sequence:
+/// ```
+/// %out = arith.constant ... : vector<MxNxelt_type>
+/// %bt = vector.transpose %b, [1, 0]
+/// %aRow0 = vector.extract %a[0]
+/// %btRow0 = vector.extract %bt[0]
+/// %c00 = vector.reduce %atRow0, %bRow0
+/// %out00 = vector.insert %c00, %out[0, 0]
+/// ...
+/// %aRowLast = vector.extract %at[M-1]
+/// %btRowLast = vector.extract %b[N-1]
+/// %cLastLast = vector.reduce %atRowLast, %bRowLast
+/// %outcLastLast = vector.insert %cLastLast, %out[M-1, N-1]
+/// ```
+///
+/// This only kicks in when VectorTransformsOptions is set to Dot and
+/// the vector.contract op is a row-major matmul or matvec.
+class ContractionOpToDotLowering
+ : public OpRewritePattern<vector::ContractionOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+
+ using FilterConstraintType =
+ std::function<LogicalResult(vector::ContractionOp op)>;
+
+ static LogicalResult defaultFilter(vector::ContractionOp op) {
+ return success();
+ }
+
+ ContractionOpToDotLowering(
+ vector::VectorTransformsOptions vectorTransformOptions,
+ MLIRContext *context, PatternBenefit benefit = 1,
+ const FilterConstraintType &constraint = defaultFilter)
+ : OpRewritePattern<vector::ContractionOp>(context, benefit),
+ vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {}
+
+ LogicalResult matchAndRewrite(vector::ContractionOp op,
+ PatternRewriter &rewriter) const override;
+
+private:
+ /// Options to control the vector patterns.
+ vector::VectorTransformsOptions vectorTransformOptions;
+ FilterConstraintType filter;
+};
+
+/// Progressive lowering of ContractionOp.
+///
+/// One:
+/// %x = vector.contract with at least one free/batch dimension
+/// is replaced by:
+/// %a = vector.contract with one less free/batch dimension
+/// %b = vector.contract with one less free/batch dimension
+/// ..
+/// %x = combine %a %b ..
+/// until a pure contraction is reached (no free/batch dimensions),
+/// which is replaced by a dot-product.
+///
+/// This only kicks in when either VectorTransformsOptions is set
+/// to Dot or when other contraction patterns fail.
+class ContractionOpLowering : public OpRewritePattern<vector::ContractionOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+ using FilterConstraintType =
+ std::function<LogicalResult(vector::ContractionOp op)>;
+
+ static LogicalResult defaultFilter(vector::ContractionOp op) {
+ return success();
+ }
+
+ ContractionOpLowering(vector::VectorTransformsOptions vectorTransformOptions,
+ MLIRContext *context, PatternBenefit benefit = 1,
+ FilterConstraintType constraint = defaultFilter)
+ : OpRewritePattern<vector::ContractionOp>(context, benefit),
+ vectorTransformOptions(vectorTransformOptions),
+ filter(std::move(constraint)) {}
+
+ LogicalResult matchAndRewrite(vector::ContractionOp op,
+ PatternRewriter &rewriter) const override;
+
+private:
+ /// Options to control the vector patterns.
+ vector::VectorTransformsOptions vectorTransformOptions;
+ FilterConstraintType filter;
+ // Lower one parallel dimension.
+ FailureOr<Value> lowerParallel(PatternRewriter &rewriter,
+ vector::ContractionOp op, int64_t lhsIndex,
+ int64_t rhsIndex, Value mask) const;
+ // Lower one reduction dimension.
+ FailureOr<Value> lowerReduction(PatternRewriter &rewriter,
+ vector::ContractionOp op, Value mask) const;
+};
+
+/// Generate a vector implementation for matmat, matvec and tmatvec.
+/// This unrolls outer-products along the reduction dimension.
+struct UnrolledOuterProductGenerator
+ : public StructuredGenerator<vector::ContractionOp, vector::IteratorType> {
+ UnrolledOuterProductGenerator(RewriterBase &b, vector::ContractionOp op)
+ : StructuredGenerator<vector::ContractionOp, vector::IteratorType>(b, op),
+ kind(op.getKind()), lhs(op.getLhs()), rhs(op.getRhs()),
+ res(op.getAcc()), lhsType(op.getLhsType()) {
+ auto maskableOp = cast<MaskableOpInterface>(op.getOperation());
+ if (maskableOp.isMasked())
+ mask = maskableOp.getMaskingOp().getMask();
+ }
+
+ Value t(Value v, ArrayRef<int64_t> perm = {1, 0}) {
+ if (!v)
+ return v;
+ return rewriter.create<vector::TransposeOp>(loc, v, perm);
+ }
+
+ Value promote(Value v, Type dstElementType) {
+ Type elementType = v.getType();
+ auto vecType = elementType.dyn_cast<VectorType>();
+ if (vecType)
+ elementType = vecType.getElementType();
+ if (elementType == dstElementType)
+ return v;
+ Type promotedType = dstElementType;
+ if (vecType)
+ promotedType = VectorType::get(vecType.getShape(), promotedType);
+ if (dstElementType.isa<FloatType>())
+ return rewriter.create<arith::ExtFOp>(loc, promotedType, v);
+ return rewriter.create<arith::ExtSIOp>(loc, promotedType, v);
+ }
+
+ FailureOr<Value> outerProd(Value lhs, Value rhs, Value res, int reductionSize,
+ std::optional<Value> maybeMask = std::nullopt) {
+ assert(reductionSize > 0);
+ // Incremental support for masking.
+ if (mask && !maybeMask.has_value())
+ return failure();
+
+ Type resElementType = res.getType().cast<VectorType>().getElementType();
+ for (int64_t k = 0; k < reductionSize; ++k) {
+ Value extractA = rewriter.create<vector::ExtractOp>(loc, lhs, k);
+ Value extractB = rewriter.create<vector::ExtractOp>(loc, rhs, k);
+ extractA = promote(extractA, resElementType);
+ extractB = promote(extractB, resElementType);
+ Value extractMask;
+ if (maybeMask.has_value() && maybeMask.value())
+ extractMask =
+ rewriter.create<vector::ExtractOp>(loc, maybeMask.value(), k);
+
+ Operation *outerProdOp = rewriter.create<vector::OuterProductOp>(
+ loc, res.getType(), extractA, extractB, res, kind);
+ res = maskOperation(rewriter, outerProdOp, extractMask)->getResult(0);
+ }
+ return res;
+ }
+
+ /// Two outer parallel, one inner reduction (matmat flavor).
+ FailureOr<Value> matmat() {
+ if (!iters({Par(), Par(), Red()}))
+ return failure();
+ // Set up the parallel/reduction structure in the right form.
+ AffineExpr m, n, k;
+ bindDims(rewriter.getContext(), m, n, k);
+ // Classical row-major matmul: Just permute the lhs.
+ if (layout({{m, k}, {k, n}, {m, n}}))
+ return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1),
+ t(mask, {2, 0, 1}));
+ // TODO: may be better to fail and use some vector<k> -> scalar reduction.
+ if (layout({{m, k}, {n, k}, {m, n}})) {
+ Value tlhs = t(lhs);
+ return outerProd(tlhs, t(rhs), res, lhsType.getDimSize(1));
+ }
+ // No need to permute anything.
+ if (layout({{k, m}, {k, n}, {m, n}}))
+ return outerProd(lhs, rhs, res, lhsType.getDimSize(0));
+ // Just permute the rhs.
+ if (layout({{k, m}, {n, k}, {m, n}}))
+ return outerProd(lhs, t(rhs), res, lhsType.getDimSize(0));
+ // Transposed output: swap RHS and LHS.
+ // Classical row-major matmul: permute the lhs.
+ if (layout({{m, k}, {k, n}, {n, m}}))
+ return outerProd(rhs, t(lhs), res, lhsType.getDimSize(1));
+ // TODO: may be better to fail and use some vector<k> -> scalar reduction.
+ if (layout({{m, k}, {n, k}, {n, m}})) {
+ Value trhs = t(rhs);
+ return outerProd(trhs, t(lhs), res, lhsType.getDimSize(1));
+ }
+ if (layout({{k, m}, {k, n}, {n, m}}))
+ return outerProd(rhs, lhs, res, lhsType.getDimSize(0));
+ if (layout({{k, m}, {n, k}, {n, m}}))
+ return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0));
+ return failure();
+ }
+
+ /// One outer parallel, one inner reduction (matvec flavor)
+ FailureOr<Value> matvec() {
+ if (!iters({Par(), Red()}))
+ return failure();
+ AffineExpr m, k;
+ bindDims(rewriter.getContext(), m, k);
+
+ // Case mat-vec: transpose.
+ if (layout({{m, k}, {k}, {m}}))
+ return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1), t(mask));
+ // Case mat-trans-vec: ready to go.
+ if (layout({{k, m}, {k}, {m}}))
+ return outerProd(lhs, rhs, res, lhsType.getDimSize(0));
+ // Case vec-mat: swap and transpose.
+ if (layout({{k}, {m, k}, {m}}))
+ return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0));
+ // Case vec-mat-trans: swap and ready to go.
+ if (layout({{k}, {k, m}, {m}}))
+ return outerProd(rhs, lhs, res, lhsType.getDimSize(0));
+ return failure();
+ }
+
+ //
+ // One outer reduction, one inner parallel (tmatvec flavor)
+ //
+ FailureOr<Value> tmatvec() {
+ if (!iters({Red(), Par()}))
+ return failure();
+ AffineExpr k, m;
+ bindDims(rewriter.getContext(), k, m);
+
+ // Case mat-vec: transpose.
+ if (layout({{m, k}, {k}, {m}}))
+ return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1));
+ // Case mat-trans-vec: ready to go.
+ if (layout({{k, m}, {k}, {m}}))
+ return outerProd(lhs, rhs, res, lhsType.getDimSize(0));
+ // Case vec-mat: swap and transpose.
+ if (layout({{k}, {m, k}, {m}}))
+ return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0));
+ // Case vec-mat-trans: swap and ready to go.
+ if (layout({{k}, {k, m}, {m}}))
+ return outerProd(rhs, lhs, res, lhsType.getDimSize(0));
+ return failure();
+ }
+
+private:
+ vector::CombiningKind kind;
+ Value lhs, rhs, res, mask;
+ VectorType lhsType;
+};
+
+/// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul
+/// semantics to a reduction_size-unrolled sequence:
+/// ```
+/// %at = vector.transpose %a, [1, 0]
+/// %bRow0 = vector.extract %b[0]
+/// %atRow0 = vector.extract %at[0]
+/// %c0 = vector.outerproduct %atRow0, %bRow0, %c
+/// ...
+/// %bRowK = vector.extract %b[K]
+/// %atRowK = vector.extract %at[K]
+/// %cK = vector.outerproduct %atRowK, %bRowK, %cK-1
+/// ```
+///
+/// This only kicks in when VectorTransformsOptions is set to OuterProduct but
+/// otherwise supports any layout permutation of the matrix-multiply.
+LogicalResult ContractionOpToOuterProductOpLowering::matchAndRewrite(
+ vector::ContractionOp op, PatternRewriter &rewriter) const {
+ // TODO: Remove native masks from contraction op?
+ if (!op.getMasks().empty())
+ return failure();
+
+ if (vectorTransformOptions.vectorContractLowering !=
+ vector::VectorContractLowering::OuterProduct)
+ return failure();
+
+ if (failed(filter(op)))
+ return failure();
+
+ // Vector mask setup.
+ OpBuilder::InsertionGuard guard(rewriter);
+ auto maskableOp = cast<vector::MaskableOpInterface>(op.getOperation());
+ Operation *rootOp;
+ if (maskableOp.isMasked()) {
+ rewriter.setInsertionPoint(maskableOp.getMaskingOp());
+ rootOp = maskableOp.getMaskingOp();
+ } else {
+ rootOp = op;
+ }
+
+ UnrolledOuterProductGenerator e(rewriter, op);
+ FailureOr<Value> matmatRes = e.matmat();
+ if (succeeded(matmatRes)) {
+ rewriter.replaceOp(rootOp, *matmatRes);
+ return success();
+ }
+ FailureOr<Value> matvecRes = e.matvec();
+ if (succeeded(matvecRes)) {
+ rewriter.replaceOp(rootOp, *matvecRes);
+ return success();
+ }
+ FailureOr<Value> tmatvecRes = e.tmatvec();
+ if (succeeded(tmatvecRes)) {
+ rewriter.replaceOp(rootOp, *tmatvecRes);
+ return success();
+ }
+
+ return failure();
+}
+
+LogicalResult
+ContractionOpToDotLowering::matchAndRewrite(vector::ContractionOp op,
+ PatternRewriter &rewriter) const {
+ // TODO: Support vector.mask.
+ auto maskableOp = cast<MaskableOpInterface>(op.getOperation());
+ if (maskableOp.isMasked())
+ return failure();
+
+ // TODO: Remove native masks from contraction op?
+ if (!op.getMasks().empty())
+ return failure();
+
+ if (failed(filter(op)))
+ return failure();
+
+ if (vectorTransformOptions.vectorContractLowering !=
+ vector::VectorContractLowering::Dot)
+ return failure();
+
+ auto iteratorTypes = op.getIteratorTypes().getValue();
+ static constexpr std::array<int64_t, 2> perm = {1, 0};
+ Location loc = op.getLoc();
+ Value lhs = op.getLhs(), rhs = op.getRhs();
+
+ using MapList = ArrayRef<ArrayRef<AffineExpr>>;
+ auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
+ AffineExpr m, n, k;
+ bindDims(rewriter.getContext(), m, n, k);
+ SmallVector<AffineMap> maps = op.getIndexingMapsArray();
+ //
+ // In the following we wish to make the reduction dimension innermost so we
+ // can load vectors and just fmul + reduce into a scalar.
+ //
+ if (isParallelIterator(iteratorTypes[0]) &&
+ isParallelIterator(iteratorTypes[1]) &&
+ isReductionIterator(iteratorTypes[2])) {
+ //
+ // Two outer parallel, one inner reduction (matmat flavor).
+ //
+ if (maps == infer({{m, k}, {k, n}, {m, n}})) {
+ rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
+ } else if (maps == infer({{m, k}, {n, k}, {m, n}})) {
+ // No need to permute anything.
+ } else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
+ lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
+ rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
+ } else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
+ lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
+ } else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
+ // This is the classical row-major matmul. Just permute the lhs.
+ Value tmp = lhs;
+ lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
+ rhs = tmp;
+ } else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
+ std::swap(lhs, rhs);
+ } else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
+ Value tmp = lhs;
+ lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
+ rhs = rewriter.create<vector::TransposeOp>(loc, tmp, perm);
+ } else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
+ Value tmp = rhs;
+ rhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
+ lhs = tmp;
+ } else {
+ return failure();
+ }
+ } else if (isParallelIterator(iteratorTypes[0]) &&
+ isReductionIterator(iteratorTypes[1])) {
+ //
+ // One outer parallel, one inner reduction (matvec flavor)
+ //
+ if (maps == infer({{m, n}, {n}, {m}})) {
+ // No need to permute anything.
+ } else if (maps == infer({{n, m}, {n}, {m}})) {
+ lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
+ } else if (maps == infer({{n}, {m, n}, {m}})) {
+ std::swap(lhs, rhs);
+ } else if (maps == infer({{n}, {n, m}, {m}})) {
+ std::swap(lhs, rhs);
+ lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
+ } else {
+ return failure();
+ }
+ } else {
+ return failure();
+ }
+
+ VectorType dstType = op.getResultType().cast<VectorType>();
+ assert(dstType.getRank() >= 1 && dstType.getRank() <= 2 &&
+ "Expected dst type of rank 1 or 2");
+
+ unsigned rank = dstType.getRank();
+ unsigned dstRows = dstType.getShape()[0];
+ unsigned dstColumns = rank == 1 ? 1 : dstType.getShape()[1];
+
+ // ExtractOp does not allow dynamic indexing, we must unroll explicitly.
+ Value res = rewriter.create<arith::ConstantOp>(loc, dstType,
+ rewriter.getZeroAttr(dstType));
+ bool isInt = dstType.getElementType().isa<IntegerType>();
+ for (unsigned r = 0; r < dstRows; ++r) {
+ Value a = rewriter.create<vector::ExtractOp>(op.getLoc(), lhs, r);
+ for (unsigned c = 0; c < dstColumns; ++c) {
+ Value b = rank == 1
+ ? rhs
+ : rewriter.create<vector::ExtractOp>(op.getLoc(), rhs, c);
+ Value m = createMul(op.getLoc(), a, b, isInt, rewriter);
+ Value reduced = rewriter.create<vector::ReductionOp>(
+ op.getLoc(), vector::CombiningKind::ADD, m);
+
+ SmallVector<int64_t, 2> pos = rank == 1 ? SmallVector<int64_t, 2>{r}
+ : SmallVector<int64_t, 2>{r, c};
+ res = rewriter.create<vector::InsertOp>(op.getLoc(), reduced, res, pos);
+ }
+ }
+ if (auto acc = op.getAcc())
+ res = createAdd(op.getLoc(), res, acc, isInt, rewriter);
+ rewriter.replaceOp(op, res);
+ return success();
+}
+
+/// Lower vector.contract with all size one reduction dimensions to
+/// elementwise ops when possible.
+struct ContractOpToElementwise
+ : public OpRewritePattern<vector::ContractionOp> {
+ using OpRewritePattern::OpRewritePattern;
+ using FilterConstraintType =
+ std::function<LogicalResult(vector::ContractionOp op)>;
+ static LogicalResult defaultFilter(vector::ContractionOp op) {
+ return success();
+ }
+ ContractOpToElementwise(
+ vector::VectorTransformsOptions vectorTransformOptions,
+ MLIRContext *context, PatternBenefit benefit = 1,
+ const FilterConstraintType &constraint = defaultFilter)
+ : OpRewritePattern<vector::ContractionOp>(context, benefit),
+ vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {}
+
+ LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
+ PatternRewriter &rewriter) const override {
+ // TODO: Support vector.mask.
+ auto maskableOp = cast<MaskableOpInterface>(contractOp.getOperation());
+ if (maskableOp.isMasked())
+ return failure();
+
+ // TODO: Remove native masks from contraction op?
+ if (!contractOp.getMasks().empty())
+ return failure();
+
+ if (failed(filter(contractOp)))
+ return failure();
+
+ if (vectorTransformOptions.vectorContractLowering !=
+ vector::VectorContractLowering::ParallelArith)
+ return failure();
+
+ ArrayRef<int64_t> lhsShape = contractOp.getLhsType().getShape();
+ ArrayRef<int64_t> rhsShape = contractOp.getRhsType().getShape();
+ AffineMap lhsMap = contractOp.getIndexingMapsArray()[0];
+ AffineMap rhsMap = contractOp.getIndexingMapsArray()[1];
+ SmallVector<int64_t> lhsReductionDims =
+ getReductionIndex(lhsMap, contractOp.getIteratorTypes());
+ SmallVector<int64_t> rhsReductionDims =
+ getReductionIndex(rhsMap, contractOp.getIteratorTypes());
+ // All the reduction dimensions must be a size 1.
+ for (int64_t dim : lhsReductionDims) {
+ if (lhsShape[dim] != 1)
+ return failure();
+ }
+ for (int64_t dim : rhsReductionDims) {
+ if (rhsShape[dim] != 1)
+ return failure();
+ }
+ AffineMap accMap = contractOp.getIndexingMapsArray()[2];
+ unsigned numParallelDims = accMap.getNumResults();
+ unsigned numLhsDimToBroadcast =
+ numParallelDims - (lhsMap.getNumResults() - lhsReductionDims.size());
+ unsigned numRhsDimToBroadcast =
+ numParallelDims - (rhsMap.getNumResults() - rhsReductionDims.size());
+ SmallVector<int64_t> lhsDims;
+ SmallVector<int64_t> lhsTranspose;
+ SmallVector<int64_t> rhsDims;
+ SmallVector<int64_t> rhsTranspose;
+ for (int64_t dim : lhsReductionDims)
+ lhsTranspose.push_back(numLhsDimToBroadcast + dim);
+ for (int64_t dim : rhsReductionDims)
+ rhsTranspose.push_back(numRhsDimToBroadcast + dim);
+ // Loop through the parallel dimensions to calculate the dimensions to
+ // broadcast and to permute in order to extract only parallel dimensions.
+ for (unsigned i = 0; i < numParallelDims; i++) {
+ std::optional<unsigned> lhsDim =
+ getDimPosition(lhsMap, accMap.getDimPosition(i));
+ if (lhsDim) {
+ lhsTranspose.push_back(numLhsDimToBroadcast + *lhsDim);
+ } else {
+ // If the parallel dimension doesn't exist we will have to broadcast it.
+ lhsDims.push_back(
+ contractOp.getResultType().cast<VectorType>().getDimSize(i));
+ lhsTranspose.push_back(lhsDims.size() - 1);
+ }
+ std::optional<unsigned> rhsDim =
+ getDimPosition(rhsMap, accMap.getDimPosition(i));
+ if (rhsDim) {
+ rhsTranspose.push_back(numRhsDimToBroadcast + *rhsDim);
+ } else {
+ // If the parallel dimension doesn't exist we will have to broadcast it.
+ rhsDims.push_back(
+ contractOp.getResultType().cast<VectorType>().getDimSize(i));
+ rhsTranspose.push_back(rhsDims.size() - 1);
+ }
+ }
+ Value newLhs = contractOp.getLhs();
+ Value newRhs = contractOp.getRhs();
+ Location loc = contractOp.getLoc();
+ if (!lhsDims.empty()) {
+ lhsDims.append(lhsShape.begin(), lhsShape.end());
+ auto expandedType =
+ VectorType::get(lhsDims, contractOp.getLhsType().getElementType());
+ newLhs = rewriter.create<vector::BroadcastOp>(loc, expandedType, newLhs);
+ }
+ if (!rhsDims.empty()) {
+ rhsDims.append(rhsShape.begin(), rhsShape.end());
+ auto expandedType =
+ VectorType::get(rhsDims, contractOp.getRhsType().getElementType());
+ newRhs = rewriter.create<vector::BroadcastOp>(loc, expandedType, newRhs);
+ }
+ bool isInt = contractOp.getLhsType().getElementType().isIntOrIndex();
+ newLhs = rewriter.create<vector::TransposeOp>(loc, newLhs, lhsTranspose);
+ newRhs = rewriter.create<vector::TransposeOp>(loc, newRhs, rhsTranspose);
+ SmallVector<int64_t> lhsOffsets(lhsReductionDims.size(), 0);
+ SmallVector<int64_t> rhsOffsets(rhsReductionDims.size(), 0);
+ newLhs = rewriter.create<vector::ExtractOp>(
+ loc, newLhs, rewriter.getI64ArrayAttr(lhsOffsets));
+ newRhs = rewriter.create<vector::ExtractOp>(
+ loc, newRhs, rewriter.getI64ArrayAttr(rhsOffsets));
+ std::optional<Value> result =
+ createContractArithOp(loc, newLhs, newRhs, contractOp.getAcc(),
+ contractOp.getKind(), rewriter, isInt);
+ rewriter.replaceOp(contractOp, {*result});
+ return success();
+ }
+
+private:
+ /// Options to control the vector patterns.
+ vector::VectorTransformsOptions vectorTransformOptions;
+ FilterConstraintType filter;
+};
+
+/// Progressive lowering of ContractionOp.
+/// One:
+/// %x = vector.contract with at least one free/batch dimension
+/// is replaced by:
+/// %a = vector.contract with one less free/batch dimension
+/// %b = vector.contract with one less free/batch dimension
+/// ..
+/// %x = combine %a %b ..
+/// until a pure contraction is reached (no free/batch dimensions),
+/// which is replaced by a dot-product.
+///
+/// This only kicks in when either VectorTransformsOptions is set
+/// to DOT or when other contraction patterns fail.
+//
+// TODO: break down into transpose/reshape/cast ops
+// when they become available to avoid code dup
+// TODO: investigate lowering order impact on performance
+LogicalResult
+ContractionOpLowering::matchAndRewrite(vector::ContractionOp op,
+ PatternRewriter &rewriter) const {
+ // TODO: Remove native masks from contraction op?
+ if (!op.getMasks().empty())
+ return failure();
+
+ if (failed(filter(op)))
+ return failure();
+
+ // TODO: support mixed mode contract lowering.
+ if (op.getLhsType().getElementType() !=
+ getElementTypeOrSelf(op.getAccType()) ||
+ op.getRhsType().getElementType() != getElementTypeOrSelf(op.getAccType()))
+ return failure();
+
+ // TODO: the code below assumes the default contraction, make sure it supports
+ // other kinds before enabling this lowering.
+ if (op.getKind() != vector::CombiningKind::ADD) {
+ return rewriter.notifyMatchFailure(
+ op, "contractions other than 'add' not supported");
+ }
+
+ // TODO: implement benefits, cost models.
+ MLIRContext *ctx = op.getContext();
+ ContractionOpToMatmulOpLowering pat1(vectorTransformOptions, ctx);
+ if (succeeded(pat1.matchAndRewrite(op, rewriter)))
+ return success();
+ ContractionOpToOuterProductOpLowering pat2(vectorTransformOptions, ctx);
+ if (succeeded(pat2.matchAndRewrite(op, rewriter)))
+ return success();
+ ContractionOpToDotLowering pat3(vectorTransformOptions, ctx);
+ if (succeeded(pat3.matchAndRewrite(op, rewriter)))
+ return success();
+ ContractOpToElementwise pat4(vectorTransformOptions, ctx);
+ if (succeeded(pat4.matchAndRewrite(op, rewriter)))
+ return success();
+
+ // Vector mask setup.
+ OpBuilder::InsertionGuard guard(rewriter);
+ Operation *rootOp = op;
+ Value mask;
+ if (op.isMasked()) {
+ rewriter.setInsertionPoint(op.getMaskingOp());
+ rootOp = op.getMaskingOp();
+ mask = op.getMaskingOp().getMask();
+ }
+
+ // Find first batch dimension in LHS/RHS, and lower when found.
+ std::vector<std::pair<int64_t, int64_t>> batchDimMap = op.getBatchDimMap();
+ if (!batchDimMap.empty()) {
+ int64_t lhsIndex = batchDimMap[0].first;
+ int64_t rhsIndex = batchDimMap[0].second;
+ auto newOp = lowerParallel(rewriter, op, lhsIndex, rhsIndex, mask);
+ if (failed(newOp))
+ return failure();
+ rewriter.replaceOp(rootOp, *newOp);
+ return success();
+ }
+
+ // Collect contracting dimensions.
+ std::vector<std::pair<int64_t, int64_t>> contractingDimMap =
+ op.getContractingDimMap();
+ DenseSet<int64_t> lhsContractingDimSet;
+ DenseSet<int64_t> rhsContractingDimSet;
+ for (auto &dimPair : contractingDimMap) {
+ lhsContractingDimSet.insert(dimPair.first);
+ rhsContractingDimSet.insert(dimPair.second);
+ }
+
+ // Find first free dimension in LHS, and lower when found.
+ VectorType lhsType = op.getLhsType();
+ for (int64_t lhsIndex = 0, e = lhsType.getRank(); lhsIndex < e; ++lhsIndex) {
+ if (lhsContractingDimSet.count(lhsIndex) == 0) {
+ auto newOp = lowerParallel(rewriter, op, lhsIndex, /*rhsIndex=*/-1, mask);
+ if (failed(newOp))
+ return failure();
+ rewriter.replaceOp(rootOp, *newOp);
+ return success();
+ }
+ }
+
+ // Find first free dimension in RHS, and lower when found.
+ VectorType rhsType = op.getRhsType();
+ for (int64_t rhsIndex = 0, e = rhsType.getRank(); rhsIndex < e; ++rhsIndex) {
+ if (rhsContractingDimSet.count(rhsIndex) == 0) {
+ auto newOp = lowerParallel(rewriter, op, /*lhsIndex=*/-1, rhsIndex, mask);
+ if (failed(newOp))
+ return failure();
+ rewriter.replaceOp(rootOp, *newOp);
+ return success();
+ }
+ }
+
+ // Lower the first remaining reduction dimension.
+ if (!contractingDimMap.empty()) {
+ auto newOp = lowerReduction(rewriter, op, mask);
+ if (failed(newOp))
+ return failure();
+ rewriter.replaceOp(rootOp, *newOp);
+ return success();
+ }
+
+ return failure();
+}
+
+// Lower one parallel dimension.
+// Incidentally also tolerates unit-size (hence trivial) reduction dimensions.
+// TODO: consider reusing existing contract unrolling
+FailureOr<Value> ContractionOpLowering::lowerParallel(PatternRewriter &rewriter,
+ vector::ContractionOp op,
+ int64_t lhsIndex,
+ int64_t rhsIndex,
+ Value mask) const {
+ VectorType lhsType = op.getLhsType();
+ VectorType rhsType = op.getRhsType();
+ VectorType resType = op.getResultType().cast<VectorType>();
+ // Find the iterator type index and result index.
+ SmallVector<AffineMap> iMap = op.getIndexingMapsArray();
+ int64_t iterIndex = -1;
+ int64_t dimSize = -1;
+ if (lhsIndex >= 0) {
+ iterIndex = iMap[0].getDimPosition(lhsIndex);
+ if (rhsIndex >= 0 && iterIndex != iMap[1].getDimPosition(rhsIndex))
+ return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
+ diag << "expected lhsIndex=" << lhsIndex << " and rhsIndex=" << rhsIndex
+ << " to map to the same dimension";
+ });
+ dimSize = lhsType.getDimSize(lhsIndex);
+ } else if (rhsIndex >= 0) {
+ iterIndex = iMap[1].getDimPosition(rhsIndex);
+ dimSize = rhsType.getDimSize(rhsIndex);
+ }
+ if (iterIndex < 0)
+ return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
+ diag << "expected either lhsIndex=" << lhsIndex
+ << " or rhsIndex=" << rhsIndex << " to be nonnegative";
+ });
+ // value_or(-1) means that we tolerate a dimension not appearing
+ // in the result map. That can't happen for actual parallel iterators, but
+ // the caller ContractionOpLowering::matchAndRewrite is currently calling
+ // lowerParallel also for the case of unit-size reduction dims appearing only
+ // on one of LHS or RHS, not both. At the moment, such cases are created by
+ // CastAwayContractionLeadingOneDim, so we need to either support that or
+ // modify that pattern.
+ int64_t resIndex = getResultIndex(iMap[2], iterIndex).value_or(-1);
+ if (resIndex == -1 && dimSize != 1)
+ return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
+ diag << "expected the dimension for iterIndex=" << iterIndex
+ << " to either appear in the result map, or to be a unit dimension";
+ });
+
+ // Construct new iterator types and affine map array attribute.
+ std::array<AffineMap, 3> lowIndexingMaps = {
+ adjustMap(iMap[0], iterIndex, rewriter),
+ adjustMap(iMap[1], iterIndex, rewriter),
+ adjustMap(iMap[2], iterIndex, rewriter)};
+ auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps);
+ auto lowIter =
+ rewriter.getArrayAttr(adjustIter(op.getIteratorTypes(), iterIndex));
+ // Unroll into a series of lower dimensional vector.contract ops.
+ Location loc = op.getLoc();
+ Value result = rewriter.create<arith::ConstantOp>(
+ loc, resType, rewriter.getZeroAttr(resType));
+
+ for (int64_t d = 0; d < dimSize; ++d) {
+ auto lhs = reshapeLoad(loc, op.getLhs(), lhsType, lhsIndex, d, rewriter);
+ auto rhs = reshapeLoad(loc, op.getRhs(), rhsType, rhsIndex, d, rewriter);
+ auto acc = reshapeLoad(loc, op.getAcc(), resType, resIndex, d, rewriter);
+
+ Value lowMask;
+ if (mask)
+ lowMask = reshapeLoad(loc, mask, cast<VectorType>(mask.getType()),
+ iterIndex, d, rewriter);
+
+ Operation *lowContract = rewriter.create<vector::ContractionOp>(
+ loc, lhs, rhs, acc, lowAffine, lowIter);
+ lowContract = maskOperation(rewriter, lowContract, lowMask);
+ result = reshapeStore(loc, lowContract->getResult(0), result, resType,
+ resIndex, d, rewriter);
+ }
+ return result;
+}
+
+// Lower one reduction dimension.
+FailureOr<Value> ContractionOpLowering::lowerReduction(
+ PatternRewriter &rewriter, vector::ContractionOp op, Value mask) const {
+ auto loc = op.getLoc();
+ VectorType lhsType = op.getLhsType();
+ VectorType rhsType = op.getRhsType();
+ Type resType = op.getResultType();
+ if (resType.isa<VectorType>())
+ return rewriter.notifyMatchFailure(op,
+ "did not expect a VectorType result");
+ bool isInt = resType.isa<IntegerType>();
+ // Use iterator index 0.
+ int64_t iterIndex = 0;
+ SmallVector<AffineMap> iMap = op.getIndexingMapsArray();
+ std::optional<int64_t> lookupLhs = getResultIndex(iMap[0], iterIndex);
+ std::optional<int64_t> lookupRhs = getResultIndex(iMap[1], iterIndex);
+ if (!lookupLhs.has_value())
+ return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
+ diag << "expected iterIndex=" << iterIndex << "to map to a LHS dimension";
+ });
+ if (!lookupRhs.has_value())
+ return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
+ diag << "expected iterIndex=" << iterIndex << "to map to a RHS dimension";
+ });
+ int64_t lhsIndex = *lookupLhs;
+ int64_t rhsIndex = *lookupRhs;
+ int64_t dimSize = lhsType.getDimSize(lhsIndex);
+ if (dimSize != rhsType.getDimSize(rhsIndex))
+ return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
+ diag << "expect LHS dimension " << lhsIndex
+ << " to have the same size as RHS dimension " << rhsIndex;
+ });
+ // Base case.
+ if (lhsType.getRank() == 1) {
+ if (rhsType.getRank() != 1)
+ return rewriter.notifyMatchFailure(
+ op, "When LHS has rank 1, expected also RHS to have rank 1");
+ Value m = createMul(loc, op.getLhs(), op.getRhs(), isInt, rewriter);
+ auto kind = vector::CombiningKind::ADD;
+
+ Value acc = op.getAcc();
+ Operation *reductionOp =
+ acc ? rewriter.create<vector::ReductionOp>(loc, kind, m, acc)
+ : rewriter.create<vector::ReductionOp>(loc, kind, m);
+ return maskOperation(rewriter, reductionOp, mask)->getResult(0);
+ }
+ // Construct new iterator types and affine map array attribute.
+ std::array<AffineMap, 3> lowIndexingMaps = {
+ adjustMap(iMap[0], iterIndex, rewriter),
+ adjustMap(iMap[1], iterIndex, rewriter),
+ adjustMap(iMap[2], iterIndex, rewriter)};
+ auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps);
+ auto lowIter =
+ rewriter.getArrayAttr(adjustIter(op.getIteratorTypes(), iterIndex));
+ // Unroll into a series of lower dimensional vector.contract ops.
+ // By feeding the initial accumulator into the first contraction,
+ // and the result of each contraction into the next, eventually
+ // the sum of all reductions is computed.
+ Value result = op.getAcc();
+ for (int64_t d = 0; d < dimSize; ++d) {
+ auto lhs = reshapeLoad(loc, op.getLhs(), lhsType, lhsIndex, d, rewriter);
+ auto rhs = reshapeLoad(loc, op.getRhs(), rhsType, rhsIndex, d, rewriter);
+ Value newMask;
+ if (mask)
+ newMask = reshapeLoad(loc, mask, cast<VectorType>(mask.getType()),
+ iterIndex, d, rewriter);
+
+ Operation *newContract = rewriter.create<vector::ContractionOp>(
+ loc, lhs, rhs, result, lowAffine, lowIter);
+ result = maskOperation(rewriter, newContract, newMask)->getResult(0);
+ }
+ return result;
+}
+
+/// Progressive lowering of OuterProductOp.
+/// One:
+/// %x = vector.outerproduct %lhs, %rhs, %acc
+/// is replaced by:
+/// %z = zero-result
+/// %0 = vector.extract %lhs[0]
+/// %1 = vector.broadcast %0
+/// %2 = vector.extract %acc[0]
+/// %3 = vector.fma %1, %rhs, %2
+/// %4 = vector.insert %3, %z[0]
+/// ..
+/// %x = vector.insert %.., %..[N-1]
+///
+class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::OuterProductOp op,
+ PatternRewriter &rewriter) const override {
+ auto loc = op.getLoc();
+
+ VectorType lhsType = op.getOperandVectorTypeLHS();
+ VectorType rhsType = op.getOperandTypeRHS().dyn_cast<VectorType>();
+ VectorType resType = op.getResultVectorType();
+ Type eltType = resType.getElementType();
+ bool isInt = eltType.isa<IntegerType, IndexType>();
+ Value acc = (op.getAcc().empty()) ? nullptr : op.getAcc()[0];
+ vector::CombiningKind kind = op.getKind();
+
+ // Vector mask setup.
+ OpBuilder::InsertionGuard guard(rewriter);
+ auto maskableOp = cast<vector::MaskableOpInterface>(op.getOperation());
+ Operation *rootOp;
+ Value mask;
+ if (maskableOp.isMasked()) {
+ rewriter.setInsertionPoint(maskableOp.getMaskingOp());
+ rootOp = maskableOp.getMaskingOp();
+ mask = maskableOp.getMaskingOp().getMask();
+ } else {
+ rootOp = op;
+ }
+
+ if (!rhsType) {
+ // Special case: AXPY operation.
+ Value b = rewriter.create<vector::BroadcastOp>(loc, lhsType, op.getRhs());
+ std::optional<Value> mult = createContractArithOp(
+ loc, op.getLhs(), b, acc, kind, rewriter, isInt, mask);
+ if (!mult.has_value())
+ return failure();
+ rewriter.replaceOp(rootOp, *mult);
+ return success();
+ }
+
+ Value result = rewriter.create<arith::ConstantOp>(
+ loc, resType, rewriter.getZeroAttr(resType));
+ for (int64_t d = 0, e = resType.getDimSize(0); d < e; ++d) {
+ auto pos = rewriter.getI64ArrayAttr(d);
+ Value x = rewriter.create<vector::ExtractOp>(loc, op.getLhs(), pos);
+ Value a = rewriter.create<vector::BroadcastOp>(loc, rhsType, x);
+ Value r = nullptr;
+ if (acc)
+ r = rewriter.create<vector::ExtractOp>(loc, acc, pos);
+ Value extrMask;
+ if (mask)
+ extrMask = rewriter.create<vector::ExtractOp>(loc, mask, pos);
+
+ std::optional<Value> m = createContractArithOp(
+ loc, a, op.getRhs(), r, kind, rewriter, isInt, extrMask);
+ if (!m.has_value())
+ return failure();
+ result = rewriter.create<vector::InsertOp>(loc, resType, *m, result, pos);
+ }
+
+ rewriter.replaceOp(rootOp, result);
+ return success();
+ }
+};
+
+/// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul
+/// semantics to:
+/// ```
+/// %mta = maybe_transpose
+/// %mtb = maybe_transpose
+/// %flattened_a = vector.shape_cast %mta
+/// %flattened_b = vector.shape_cast %mtb
+/// %flattened_d = vector.matmul %flattened_a, %flattened_b
+/// %mtd = vector.shape_cast %flattened_d
+/// %d = maybe_untranspose %mtd
+/// %e = add %c, %d
+/// ```
+/// `vector.matmul` later lowers to `llvm.matrix.multiply`.
+//
+/// This only kicks in when VectorTransformsOptions is set to `Matmul`.
+/// vector.transpose operations are inserted if the vector.contract op is not a
+/// row-major matrix multiply.
+LogicalResult
+ContractionOpToMatmulOpLowering::matchAndRewrite(vector::ContractionOp op,
+ PatternRewriter &rew) const {
+ // TODO: Support vector.mask.
+ auto maskableOp = cast<MaskableOpInterface>(op.getOperation());
+ if (maskableOp.isMasked())
+ return failure();
+
+ // TODO: Remove native masks from contraction op?
+ if (!op.getMasks().empty())
+ return failure();
+ if (vectorTransformOptions.vectorContractLowering !=
+ vector::VectorContractLowering::Matmul)
+ return failure();
+ if (failed(filter(op)))
+ return failure();
+
+ auto iteratorTypes = op.getIteratorTypes().getValue();
+ if (!isParallelIterator(iteratorTypes[0]) ||
+ !isParallelIterator(iteratorTypes[1]) ||
+ !isReductionIterator(iteratorTypes[2]))
+ return failure();
+
+ Type elementType = op.getLhsType().getElementType();
+ if (!elementType.isIntOrFloat())
+ return failure();
+
+ Type dstElementType = op.getType();
+ if (auto vecType = dstElementType.dyn_cast<VectorType>())
+ dstElementType = vecType.getElementType();
+ if (elementType != dstElementType)
+ return failure();
+
+ // Perform lhs + rhs transpositions to conform to matmul row-major semantics.
+ // Bail out if the contraction cannot be put in this form.
+ MLIRContext *ctx = op.getContext();
+ Location loc = op.getLoc();
+ AffineExpr m, n, k;
+ bindDims(rew.getContext(), m, n, k);
+ // LHS must be A(m, k) or A(k, m).
+ Value lhs = op.getLhs();
+ auto lhsMap = op.getIndexingMapsArray()[0];
+ if (lhsMap == AffineMap::get(3, 0, {k, m}, ctx))
+ lhs = rew.create<vector::TransposeOp>(loc, lhs, ArrayRef<int64_t>{1, 0});
+ else if (lhsMap != AffineMap::get(3, 0, {m, k}, ctx))
+ return failure();
+
+ // RHS must be B(k, n) or B(n, k).
+ Value rhs = op.getRhs();
+ auto rhsMap = op.getIndexingMapsArray()[1];
+ if (rhsMap == AffineMap::get(3, 0, {n, k}, ctx))
+ rhs = rew.create<vector::TransposeOp>(loc, rhs, ArrayRef<int64_t>{1, 0});
+ else if (rhsMap != AffineMap::get(3, 0, {k, n}, ctx))
+ return failure();
+
+ // At this point lhs and rhs are in row-major.
+ VectorType lhsType = lhs.getType().cast<VectorType>();
+ VectorType rhsType = rhs.getType().cast<VectorType>();
+ int64_t lhsRows = lhsType.getDimSize(0);
+ int64_t lhsColumns = lhsType.getDimSize(1);
+ int64_t rhsColumns = rhsType.getDimSize(1);
+
+ Type flattenedLHSType =
+ VectorType::get(lhsType.getNumElements(), lhsType.getElementType());
+ lhs = rew.create<vector::ShapeCastOp>(loc, flattenedLHSType, lhs);
+
+ Type flattenedRHSType =
+ VectorType::get(rhsType.getNumElements(), rhsType.getElementType());
+ rhs = rew.create<vector::ShapeCastOp>(loc, flattenedRHSType, rhs);
+
+ Value mul = rew.create<vector::MatmulOp>(loc, lhs, rhs, lhsRows, lhsColumns,
+ rhsColumns);
+ mul = rew.create<vector::ShapeCastOp>(
+ loc,
+ VectorType::get({lhsRows, rhsColumns},
+ getElementTypeOrSelf(op.getAcc().getType())),
+ mul);
+
+ // ACC must be C(m, n) or C(n, m).
+ auto accMap = op.getIndexingMapsArray()[2];
+ if (accMap == AffineMap::get(3, 0, {n, m}, ctx))
+ mul = rew.create<vector::TransposeOp>(loc, mul, ArrayRef<int64_t>{1, 0});
+ else if (accMap != AffineMap::get(3, 0, {m, n}, ctx))
+ llvm_unreachable("invalid contraction semantics");
+
+ Value res =
+ elementType.isa<IntegerType>()
+ ? static_cast<Value>(rew.create<arith::AddIOp>(loc, op.getAcc(), mul))
+ : static_cast<Value>(
+ rew.create<arith::AddFOp>(loc, op.getAcc(), mul));
+
+ rew.replaceOp(op, res);
+ return success();
+}
+} // namespace
+
+void mlir::vector::populateVectorContractLoweringPatterns(
+ RewritePatternSet &patterns, VectorTransformsOptions options,
+ PatternBenefit benefit, bool disableOuterProductLowering) {
+ if (!disableOuterProductLowering)
+ patterns.add<OuterProductOpLowering>(patterns.getContext(), benefit);
+ patterns.add<ContractionOpLowering, ContractionOpToMatmulOpLowering,
+ ContractionOpToOuterProductOpLowering>(
+ options, patterns.getContext(), benefit);
+}
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
new file mode 100644
index 0000000000000..dc10cb6278cb8
--- /dev/null
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
@@ -0,0 +1,173 @@
+//===- LowerVectorScam.cpp - Lower 'vector.scan' operation ----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements target-independent rewrites and utilities to lower the
+// 'vector.scan' operation.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
+#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
+#include "mlir/IR/BuiltinAttributeInterfaces.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/ImplicitLocOpBuilder.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeUtilities.h"
+#include "mlir/Interfaces/VectorInterfaces.h"
+#include "mlir/Support/LogicalResult.h"
+
+#define DEBUG_TYPE "vector-broadcast-lowering"
+
+using namespace mlir;
+using namespace mlir::vector;
+
+namespace {
+/// Flattens 2 or more dimensional `vector.gather` ops by unrolling the
+/// outermost dimension. For example:
+/// ```
+/// %g = vector.gather %base[%c0][%v], %mask, %pass_thru :
+/// ... into vector<2x3xf32>
+///
+/// ==>
+///
+/// %0 = arith.constant dense<0.0> : vector<2x3xf32>
+/// %g0 = vector.gather %base[%c0][%v0], %mask0, %pass_thru0 : ...
+/// %1 = vector.insert %g0, %0 [0] : vector<3xf32> into vector<2x3xf32>
+/// %g1 = vector.gather %base[%c0][%v1], %mask1, %pass_thru1 : ...
+/// %g = vector.insert %g1, %1 [1] : vector<3xf32> into vector<2x3xf32>
+/// ```
+///
+/// When applied exhaustively, this will produce a sequence of 1-d gather ops.
+struct FlattenGather : OpRewritePattern<vector::GatherOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::GatherOp op,
+ PatternRewriter &rewriter) const override {
+ VectorType resultTy = op.getType();
+ if (resultTy.getRank() < 2)
+ return rewriter.notifyMatchFailure(op, "already flat");
+
+ Location loc = op.getLoc();
+ Value indexVec = op.getIndexVec();
+ Value maskVec = op.getMask();
+ Value passThruVec = op.getPassThru();
+
+ Value result = rewriter.create<arith::ConstantOp>(
+ loc, resultTy, rewriter.getZeroAttr(resultTy));
+
+ Type subTy = VectorType::get(resultTy.getShape().drop_front(),
+ resultTy.getElementType());
+
+ for (int64_t i = 0, e = resultTy.getShape().front(); i < e; ++i) {
+ int64_t thisIdx[1] = {i};
+
+ Value indexSubVec =
+ rewriter.create<vector::ExtractOp>(loc, indexVec, thisIdx);
+ Value maskSubVec =
+ rewriter.create<vector::ExtractOp>(loc, maskVec, thisIdx);
+ Value passThruSubVec =
+ rewriter.create<vector::ExtractOp>(loc, passThruVec, thisIdx);
+ Value subGather = rewriter.create<vector::GatherOp>(
+ loc, subTy, op.getBase(), op.getIndices(), indexSubVec, maskSubVec,
+ passThruSubVec);
+ result =
+ rewriter.create<vector::InsertOp>(loc, subGather, result, thisIdx);
+ }
+
+ rewriter.replaceOp(op, result);
+ return success();
+ }
+};
+
+/// Turns 1-d `vector.gather` into a scalarized sequence of `vector.loads` or
+/// `tensor.extract`s. To avoid out-of-bounds memory accesses, these
+/// loads/extracts are made conditional using `scf.if` ops.
+struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::GatherOp op,
+ PatternRewriter &rewriter) const override {
+ VectorType resultTy = op.getType();
+ if (resultTy.getRank() != 1)
+ return rewriter.notifyMatchFailure(op, "unsupported rank");
+
+ Location loc = op.getLoc();
+ Type elemTy = resultTy.getElementType();
+ // Vector type with a single element. Used to generate `vector.loads`.
+ VectorType elemVecTy = VectorType::get({1}, elemTy);
+
+ Value condMask = op.getMask();
+ Value base = op.getBase();
+ Value indexVec = rewriter.createOrFold<arith::IndexCastOp>(
+ loc, op.getIndexVectorType().clone(rewriter.getIndexType()),
+ op.getIndexVec());
+ auto baseOffsets = llvm::to_vector(op.getIndices());
+ Value lastBaseOffset = baseOffsets.back();
+
+ Value result = op.getPassThru();
+
+ // Emit a conditional access for each vector element.
+ for (int64_t i = 0, e = resultTy.getNumElements(); i < e; ++i) {
+ int64_t thisIdx[1] = {i};
+ Value condition =
+ rewriter.create<vector::ExtractOp>(loc, condMask, thisIdx);
+ Value index = rewriter.create<vector::ExtractOp>(loc, indexVec, thisIdx);
+ baseOffsets.back() =
+ rewriter.createOrFold<arith::AddIOp>(loc, lastBaseOffset, index);
+
+ auto loadBuilder = [&](OpBuilder &b, Location loc) {
+ Value extracted;
+ if (isa<MemRefType>(base.getType())) {
+ // `vector.load` does not support scalar result; emit a vector load
+ // and extract the single result instead.
+ Value load =
+ b.create<vector::LoadOp>(loc, elemVecTy, base, baseOffsets);
+ int64_t zeroIdx[1] = {0};
+ extracted = b.create<vector::ExtractOp>(loc, load, zeroIdx);
+ } else {
+ extracted = b.create<tensor::ExtractOp>(loc, base, baseOffsets);
+ }
+
+ Value newResult =
+ b.create<vector::InsertOp>(loc, extracted, result, thisIdx);
+ b.create<scf::YieldOp>(loc, newResult);
+ };
+ auto passThruBuilder = [result](OpBuilder &b, Location loc) {
+ b.create<scf::YieldOp>(loc, result);
+ };
+
+ result =
+ rewriter
+ .create<scf::IfOp>(loc, condition, /*thenBuilder=*/loadBuilder,
+ /*elseBuilder=*/passThruBuilder)
+ .getResult(0);
+ }
+
+ rewriter.replaceOp(op, result);
+ return success();
+ }
+};
+} // namespace
+
+void mlir::vector::populateVectorGatherLoweringPatterns(
+ RewritePatternSet &patterns, PatternBenefit benefit) {
+ patterns.add<FlattenGather, Gather1DToConditionalLoads>(patterns.getContext(),
+ benefit);
+}
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
index 7c66e65fdef8b..e318d4dc15915 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
@@ -6,7 +6,7 @@
//
//===----------------------------------------------------------------------===//
//
-// This file implements target-independent rewrites and utilitites to lower the
+// This file implements target-independent rewrites and utilities to lower the
// 'vector.mask' operation.
//
//===----------------------------------------------------------------------===//
@@ -14,6 +14,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
#include "mlir/Dialect/Vector/Transforms/Passes.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -30,6 +31,147 @@ namespace vector {
using namespace mlir;
using namespace mlir::vector;
+//===----------------------------------------------------------------------===//
+// populateVectorMaskOpLoweringPatterns
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// Progressive lowering of CreateMaskOp.
+/// One:
+/// %x = vector.create_mask %a, ... : vector<dx...>
+/// is replaced by:
+/// %l = vector.create_mask ... : vector<...> ; one lower rank
+/// %0 = arith.cmpi "slt", %ci, %a |
+/// %1 = select %0, %l, %zeroes |
+/// %r = vector.insert %1, %pr [i] | d-times
+/// %x = ....
+/// until a one-dimensional vector is reached.
+class CreateMaskOpLowering : public OpRewritePattern<vector::CreateMaskOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::CreateMaskOp op,
+ PatternRewriter &rewriter) const override {
+ auto dstType = op.getResult().getType().cast<VectorType>();
+ int64_t rank = dstType.getRank();
+ if (rank <= 1)
+ return rewriter.notifyMatchFailure(
+ op, "0-D and 1-D vectors are handled separately");
+
+ auto loc = op.getLoc();
+ auto eltType = dstType.getElementType();
+ int64_t dim = dstType.getDimSize(0);
+ Value idx = op.getOperand(0);
+
+ VectorType lowType =
+ VectorType::get(dstType.getShape().drop_front(), eltType);
+ Value trueVal = rewriter.create<vector::CreateMaskOp>(
+ loc, lowType, op.getOperands().drop_front());
+ Value falseVal = rewriter.create<arith::ConstantOp>(
+ loc, lowType, rewriter.getZeroAttr(lowType));
+ Value result = rewriter.create<arith::ConstantOp>(
+ loc, dstType, rewriter.getZeroAttr(dstType));
+ for (int64_t d = 0; d < dim; d++) {
+ Value bnd =
+ rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(d));
+ Value val = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
+ bnd, idx);
+ Value sel = rewriter.create<arith::SelectOp>(loc, val, trueVal, falseVal);
+ auto pos = rewriter.getI64ArrayAttr(d);
+ result =
+ rewriter.create<vector::InsertOp>(loc, dstType, sel, result, pos);
+ }
+ rewriter.replaceOp(op, result);
+ return success();
+ }
+};
+
+/// Progressive lowering of ConstantMaskOp.
+/// One:
+/// %x = vector.constant_mask [a,b]
+/// is replaced by:
+/// %z = zero-result
+/// %l = vector.constant_mask [b]
+/// %4 = vector.insert %l, %z[0]
+/// ..
+/// %x = vector.insert %l, %..[a-1]
+/// until a one-dimensional vector is reached. All these operations
+/// will be folded at LLVM IR level.
+class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::ConstantMaskOp op,
+ PatternRewriter &rewriter) const override {
+ auto loc = op.getLoc();
+ auto dstType = op.getType();
+ auto eltType = dstType.getElementType();
+ auto dimSizes = op.getMaskDimSizes();
+ int64_t rank = dstType.getRank();
+
+ if (rank == 0) {
+ assert(dimSizes.size() == 1 &&
+ "Expected exactly one dim size for a 0-D vector");
+ bool value = dimSizes[0].cast<IntegerAttr>().getInt() == 1;
+ rewriter.replaceOpWithNewOp<arith::ConstantOp>(
+ op, dstType,
+ DenseIntElementsAttr::get(
+ VectorType::get(ArrayRef<int64_t>{}, rewriter.getI1Type()),
+ ArrayRef<bool>{value}));
+ return success();
+ }
+
+ // Scalable constant masks can only be lowered for the "none set" case.
+ if (dstType.cast<VectorType>().isScalable()) {
+ rewriter.replaceOpWithNewOp<arith::ConstantOp>(
+ op, DenseElementsAttr::get(dstType, false));
+ return success();
+ }
+
+ int64_t trueDim = std::min(dstType.getDimSize(0),
+ dimSizes[0].cast<IntegerAttr>().getInt());
+
+ if (rank == 1) {
+ // Express constant 1-D case in explicit vector form:
+ // [T,..,T,F,..,F].
+ SmallVector<bool> values(dstType.getDimSize(0));
+ for (int64_t d = 0; d < trueDim; d++)
+ values[d] = true;
+ rewriter.replaceOpWithNewOp<arith::ConstantOp>(
+ op, dstType, rewriter.getBoolVectorAttr(values));
+ return success();
+ }
+
+ VectorType lowType =
+ VectorType::get(dstType.getShape().drop_front(), eltType);
+ SmallVector<int64_t> newDimSizes;
+ for (int64_t r = 1; r < rank; r++)
+ newDimSizes.push_back(dimSizes[r].cast<IntegerAttr>().getInt());
+ Value trueVal = rewriter.create<vector::ConstantMaskOp>(
+ loc, lowType, rewriter.getI64ArrayAttr(newDimSizes));
+ Value result = rewriter.create<arith::ConstantOp>(
+ loc, dstType, rewriter.getZeroAttr(dstType));
+ for (int64_t d = 0; d < trueDim; d++) {
+ auto pos = rewriter.getI64ArrayAttr(d);
+ result =
+ rewriter.create<vector::InsertOp>(loc, dstType, trueVal, result, pos);
+ }
+ rewriter.replaceOp(op, result);
+ return success();
+ }
+};
+} // namespace
+
+void mlir::vector::populateVectorMaskOpLoweringPatterns(
+ RewritePatternSet &patterns, PatternBenefit benefit) {
+ patterns.add<CreateMaskOpLowering, ConstantMaskOpLowering>(
+ patterns.getContext(), benefit);
+}
+
+//===----------------------------------------------------------------------===//
+// populateVectorMaskLoweringPatternsForSideEffectingOps
+//===----------------------------------------------------------------------===//
+
namespace {
/// The `MaskOpRewritePattern` implements a pattern that follows a two-fold
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorMultiDimReductionTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
similarity index 98%
rename from mlir/lib/Dialect/Vector/Transforms/VectorMultiDimReductionTransforms.cpp
rename to mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index b790d141415aa..1744c46db5886 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorMultiDimReductionTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -1,4 +1,4 @@
-//===- VectorMultiDimReductionTransforms.cpp - Multi-Reduction Transforms -===//
+//===- LowerVectorMultiReduction.cpp - Lower `vector.multi_reduction` op --===//
//
/// Part of the LLVM Project, under the Apache License v2.0 with LLVM
/// Exceptions. See https://llvm.org/LICENSE.txt for license information.
@@ -6,12 +6,13 @@
//
//===----------------------------------------------------------------------===//
//
-/// This file implements target-independent rewrites of MultiDimReductionOp.
+// This file implements target-independent rewrites and utilities to lower the
+// 'vector.multi_reduction' operation.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/TypeUtilities.h"
@@ -19,6 +20,7 @@
using namespace mlir;
+namespace {
/// This file implements the following transformations as composable atomic
/// patterns.
@@ -441,6 +443,7 @@ struct OneDimMultiReductionToTwoDim
return success();
}
};
+} // namespace
void mlir::vector::populateVectorMultiReductionLoweringPatterns(
RewritePatternSet &patterns, VectorMultiReductionLowering options,
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp
new file mode 100644
index 0000000000000..eb2deba7bc46b
--- /dev/null
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp
@@ -0,0 +1,251 @@
+//===- LowerVectorScam.cpp - Lower 'vector.scan' operation ----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements target-independent rewrites and utilities to lower the
+// 'vector.scan' operation.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
+#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
+#include "mlir/IR/BuiltinAttributeInterfaces.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/ImplicitLocOpBuilder.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeUtilities.h"
+#include "mlir/Interfaces/VectorInterfaces.h"
+#include "mlir/Support/LogicalResult.h"
+
+#define DEBUG_TYPE "vector-broadcast-lowering"
+
+using namespace mlir;
+using namespace mlir::vector;
+
+/// This function constructs the appropriate integer or float
+/// operation given the vector combining kind and operands. The
+/// supported int operations are : add, mul, min (signed/unsigned),
+/// max(signed/unsigned), and, or, xor. The supported float
+/// operations are : add, mul, min and max.
+static Value genOperator(Location loc, Value x, Value y,
+ vector::CombiningKind kind,
+ PatternRewriter &rewriter) {
+ using vector::CombiningKind;
+
+ auto elType = x.getType().cast<VectorType>().getElementType();
+ bool isInt = elType.isIntOrIndex();
+
+ Value combinedResult{nullptr};
+ switch (kind) {
+ case CombiningKind::ADD:
+ if (isInt)
+ combinedResult = rewriter.create<arith::AddIOp>(loc, x, y);
+ else
+ combinedResult = rewriter.create<arith::AddFOp>(loc, x, y);
+ break;
+ case CombiningKind::MUL:
+ if (isInt)
+ combinedResult = rewriter.create<arith::MulIOp>(loc, x, y);
+ else
+ combinedResult = rewriter.create<arith::MulFOp>(loc, x, y);
+ break;
+ case CombiningKind::MINUI:
+ combinedResult = rewriter.create<arith::MinUIOp>(loc, x, y);
+ break;
+ case CombiningKind::MINSI:
+ combinedResult = rewriter.create<arith::MinSIOp>(loc, x, y);
+ break;
+ case CombiningKind::MAXUI:
+ combinedResult = rewriter.create<arith::MaxUIOp>(loc, x, y);
+ break;
+ case CombiningKind::MAXSI:
+ combinedResult = rewriter.create<arith::MaxSIOp>(loc, x, y);
+ break;
+ case CombiningKind::AND:
+ combinedResult = rewriter.create<arith::AndIOp>(loc, x, y);
+ break;
+ case CombiningKind::OR:
+ combinedResult = rewriter.create<arith::OrIOp>(loc, x, y);
+ break;
+ case CombiningKind::XOR:
+ combinedResult = rewriter.create<arith::XOrIOp>(loc, x, y);
+ break;
+ case CombiningKind::MINF:
+ combinedResult = rewriter.create<arith::MinFOp>(loc, x, y);
+ break;
+ case CombiningKind::MAXF:
+ combinedResult = rewriter.create<arith::MaxFOp>(loc, x, y);
+ break;
+ }
+ return combinedResult;
+}
+
+/// This function checks to see if the vector combining kind
+/// is consistent with the integer or float element type.
+static bool isValidKind(bool isInt, vector::CombiningKind kind) {
+ using vector::CombiningKind;
+ enum class KindType { FLOAT, INT, INVALID };
+ KindType type{KindType::INVALID};
+ switch (kind) {
+ case CombiningKind::MINF:
+ case CombiningKind::MAXF:
+ type = KindType::FLOAT;
+ break;
+ case CombiningKind::MINUI:
+ case CombiningKind::MINSI:
+ case CombiningKind::MAXUI:
+ case CombiningKind::MAXSI:
+ case CombiningKind::AND:
+ case CombiningKind::OR:
+ case CombiningKind::XOR:
+ type = KindType::INT;
+ break;
+ case CombiningKind::ADD:
+ case CombiningKind::MUL:
+ type = isInt ? KindType::INT : KindType::FLOAT;
+ break;
+ }
+ bool isValidIntKind = (type == KindType::INT) && isInt;
+ bool isValidFloatKind = (type == KindType::FLOAT) && (!isInt);
+ return (isValidIntKind || isValidFloatKind);
+}
+
+namespace {
+/// Convert vector.scan op into arith ops and vector.insert_strided_slice /
+/// vector.extract_strided_slice.
+///
+/// Example:
+///
+/// ```
+/// %0:2 = vector.scan <add>, %arg0, %arg1
+/// {inclusive = true, reduction_dim = 1} :
+/// (vector<2x3xi32>, vector<2xi32>) to (vector<2x3xi32>, vector<2xi32>)
+/// ```
+///
+/// is converted to:
+///
+/// ```
+/// %cst = arith.constant dense<0> : vector<2x3xi32>
+/// %0 = vector.extract_strided_slice %arg0
+/// {offsets = [0, 0], sizes = [2, 1], strides = [1, 1]}
+/// : vector<2x3xi32> to vector<2x1xi32>
+/// %1 = vector.insert_strided_slice %0, %cst
+/// {offsets = [0, 0], strides = [1, 1]}
+/// : vector<2x1xi32> into vector<2x3xi32>
+/// %2 = vector.extract_strided_slice %arg0
+/// {offsets = [0, 1], sizes = [2, 1], strides = [1, 1]}
+/// : vector<2x3xi32> to vector<2x1xi32>
+/// %3 = arith.muli %0, %2 : vector<2x1xi32>
+/// %4 = vector.insert_strided_slice %3, %1
+/// {offsets = [0, 1], strides = [1, 1]}
+/// : vector<2x1xi32> into vector<2x3xi32>
+/// %5 = vector.extract_strided_slice %arg0
+/// {offsets = [0, 2], sizes = [2, 1], strides = [1, 1]}
+/// : vector<2x3xi32> to vector<2x1xi32>
+/// %6 = arith.muli %3, %5 : vector<2x1xi32>
+/// %7 = vector.insert_strided_slice %6, %4
+/// {offsets = [0, 2], strides = [1, 1]}
+/// : vector<2x1xi32> into vector<2x3xi32>
+/// %8 = vector.shape_cast %6 : vector<2x1xi32> to vector<2xi32>
+/// return %7, %8 : vector<2x3xi32>, vector<2xi32>
+/// ```
+struct ScanToArithOps : public OpRewritePattern<vector::ScanOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::ScanOp scanOp,
+ PatternRewriter &rewriter) const override {
+ auto loc = scanOp.getLoc();
+ VectorType destType = scanOp.getDestType();
+ ArrayRef<int64_t> destShape = destType.getShape();
+ auto elType = destType.getElementType();
+ bool isInt = elType.isIntOrIndex();
+ if (!isValidKind(isInt, scanOp.getKind()))
+ return failure();
+
+ VectorType resType = VectorType::get(destShape, elType);
+ Value result = rewriter.create<arith::ConstantOp>(
+ loc, resType, rewriter.getZeroAttr(resType));
+ int64_t reductionDim = scanOp.getReductionDim();
+ bool inclusive = scanOp.getInclusive();
+ int64_t destRank = destType.getRank();
+ VectorType initialValueType = scanOp.getInitialValueType();
+ int64_t initialValueRank = initialValueType.getRank();
+
+ SmallVector<int64_t> reductionShape(destShape.begin(), destShape.end());
+ reductionShape[reductionDim] = 1;
+ VectorType reductionType = VectorType::get(reductionShape, elType);
+ SmallVector<int64_t> offsets(destRank, 0);
+ SmallVector<int64_t> strides(destRank, 1);
+ SmallVector<int64_t> sizes(destShape.begin(), destShape.end());
+ sizes[reductionDim] = 1;
+ ArrayAttr scanSizes = rewriter.getI64ArrayAttr(sizes);
+ ArrayAttr scanStrides = rewriter.getI64ArrayAttr(strides);
+
+ Value lastOutput, lastInput;
+ for (int i = 0; i < destShape[reductionDim]; i++) {
+ offsets[reductionDim] = i;
+ ArrayAttr scanOffsets = rewriter.getI64ArrayAttr(offsets);
+ Value input = rewriter.create<vector::ExtractStridedSliceOp>(
+ loc, reductionType, scanOp.getSource(), scanOffsets, scanSizes,
+ scanStrides);
+ Value output;
+ if (i == 0) {
+ if (inclusive) {
+ output = input;
+ } else {
+ if (initialValueRank == 0) {
+ // ShapeCastOp cannot handle 0-D vectors
+ output = rewriter.create<vector::BroadcastOp>(
+ loc, input.getType(), scanOp.getInitialValue());
+ } else {
+ output = rewriter.create<vector::ShapeCastOp>(
+ loc, input.getType(), scanOp.getInitialValue());
+ }
+ }
+ } else {
+ Value y = inclusive ? input : lastInput;
+ output = genOperator(loc, lastOutput, y, scanOp.getKind(), rewriter);
+ assert(output != nullptr);
+ }
+ result = rewriter.create<vector::InsertStridedSliceOp>(
+ loc, output, result, offsets, strides);
+ lastOutput = output;
+ lastInput = input;
+ }
+
+ Value reduction;
+ if (initialValueRank == 0) {
+ Value v = rewriter.create<vector::ExtractOp>(loc, lastOutput, 0);
+ reduction =
+ rewriter.create<vector::BroadcastOp>(loc, initialValueType, v);
+ } else {
+ reduction = rewriter.create<vector::ShapeCastOp>(loc, initialValueType,
+ lastOutput);
+ }
+
+ rewriter.replaceOp(scanOp, {result, reduction});
+ return success();
+ }
+};
+} // namespace
+
+void mlir::vector::populateVectorScanLoweringPatterns(
+ RewritePatternSet &patterns, PatternBenefit benefit) {
+ patterns.add<ScanToArithOps>(patterns.getContext(), benefit);
+}
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
new file mode 100644
index 0000000000000..bd9716cbca94c
--- /dev/null
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
@@ -0,0 +1,177 @@
+//===- LowerVectorShapeCast.cpp - Lower 'vector.shape_cast' operation -----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements target-independent rewrites and utilities to lower the
+// 'vector.shape_cast' operation.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
+#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
+#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
+#include "mlir/IR/BuiltinAttributeInterfaces.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/ImplicitLocOpBuilder.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeUtilities.h"
+#include "mlir/Interfaces/VectorInterfaces.h"
+#include "mlir/Support/LogicalResult.h"
+
+#define DEBUG_TYPE "vector-shape-cast-lowering"
+
+using namespace mlir;
+using namespace mlir::vector;
+
+namespace {
+/// ShapeOp 2D -> 1D downcast serves the purpose of flattening 2-D to 1-D
+/// vectors progressively on the way to target llvm.matrix intrinsics.
+/// This iterates over the most major dimension of the 2-D vector and performs
+/// rewrites into:
+/// vector.extract from 2-D + vector.insert_strided_slice offset into 1-D
+class ShapeCastOp2DDownCastRewritePattern
+ : public OpRewritePattern<vector::ShapeCastOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::ShapeCastOp op,
+ PatternRewriter &rewriter) const override {
+ auto sourceVectorType = op.getSourceVectorType();
+ auto resultVectorType = op.getResultVectorType();
+ if (sourceVectorType.getRank() != 2 || resultVectorType.getRank() != 1)
+ return failure();
+
+ auto loc = op.getLoc();
+ Value desc = rewriter.create<arith::ConstantOp>(
+ loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
+ unsigned mostMinorVectorSize = sourceVectorType.getShape()[1];
+ for (int64_t i = 0, e = sourceVectorType.getShape().front(); i != e; ++i) {
+ Value vec = rewriter.create<vector::ExtractOp>(loc, op.getSource(), i);
+ desc = rewriter.create<vector::InsertStridedSliceOp>(
+ loc, vec, desc,
+ /*offsets=*/i * mostMinorVectorSize, /*strides=*/1);
+ }
+ rewriter.replaceOp(op, desc);
+ return success();
+ }
+};
+
+/// ShapeOp 1D -> 2D upcast serves the purpose of unflattening 2-D from 1-D
+/// vectors progressively.
+/// This iterates over the most major dimension of the 2-D vector and performs
+/// rewrites into:
+/// vector.extract_strided_slice from 1-D + vector.insert into 2-D
+/// Note that 1-D extract_strided_slice are lowered to efficient vector.shuffle.
+class ShapeCastOp2DUpCastRewritePattern
+ : public OpRewritePattern<vector::ShapeCastOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::ShapeCastOp op,
+ PatternRewriter &rewriter) const override {
+ auto sourceVectorType = op.getSourceVectorType();
+ auto resultVectorType = op.getResultVectorType();
+ if (sourceVectorType.getRank() != 1 || resultVectorType.getRank() != 2)
+ return failure();
+
+ auto loc = op.getLoc();
+ Value desc = rewriter.create<arith::ConstantOp>(
+ loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
+ unsigned mostMinorVectorSize = resultVectorType.getShape()[1];
+ for (int64_t i = 0, e = resultVectorType.getShape().front(); i != e; ++i) {
+ Value vec = rewriter.create<vector::ExtractStridedSliceOp>(
+ loc, op.getSource(), /*offsets=*/i * mostMinorVectorSize,
+ /*sizes=*/mostMinorVectorSize,
+ /*strides=*/1);
+ desc = rewriter.create<vector::InsertOp>(loc, vec, desc, i);
+ }
+ rewriter.replaceOp(op, desc);
+ return success();
+ }
+};
+
+// We typically should not lower general shape cast operations into data
+// movement instructions, since the assumption is that these casts are
+// optimized away during progressive lowering. For completeness, however,
+// we fall back to a reference implementation that moves all elements
+// into the right place if we get here.
+class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::ShapeCastOp op,
+ PatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ auto sourceVectorType = op.getSourceVectorType();
+ auto resultVectorType = op.getResultVectorType();
+
+ // Special case 2D / 1D lowerings with better implementations.
+ // TODO: make is ND / 1D to allow generic ND -> 1D -> MD.
+ int64_t srcRank = sourceVectorType.getRank();
+ int64_t resRank = resultVectorType.getRank();
+ if ((srcRank == 2 && resRank == 1) || (srcRank == 1 && resRank == 2))
+ return failure();
+
+ // Generic ShapeCast lowering path goes all the way down to unrolled scalar
+ // extract/insert chains.
+ // TODO: consider evolving the semantics to only allow 1D source or dest and
+ // drop this potentially very expensive lowering.
+ // Compute number of elements involved in the reshape.
+ int64_t numElts = 1;
+ for (int64_t r = 0; r < srcRank; r++)
+ numElts *= sourceVectorType.getDimSize(r);
+ // Replace with data movement operations:
+ // x[0,0,0] = y[0,0]
+ // x[0,0,1] = y[0,1]
+ // x[0,1,0] = y[0,2]
+ // etc., incrementing the two index vectors "row-major"
+ // within the source and result shape.
+ SmallVector<int64_t> srcIdx(srcRank);
+ SmallVector<int64_t> resIdx(resRank);
+ Value result = rewriter.create<arith::ConstantOp>(
+ loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
+ for (int64_t i = 0; i < numElts; i++) {
+ if (i != 0) {
+ incIdx(srcIdx, sourceVectorType, srcRank - 1);
+ incIdx(resIdx, resultVectorType, resRank - 1);
+ }
+ Value e = rewriter.create<vector::ExtractOp>(loc, op.getSource(), srcIdx);
+ result = rewriter.create<vector::InsertOp>(loc, e, result, resIdx);
+ }
+ rewriter.replaceOp(op, result);
+ return success();
+ }
+
+private:
+ static void incIdx(SmallVector<int64_t> &idx, VectorType tp, int64_t r) {
+ assert(0 <= r && r < tp.getRank());
+ if (++idx[r] == tp.getDimSize(r)) {
+ idx[r] = 0;
+ incIdx(idx, tp, r - 1);
+ }
+ }
+};
+} // namespace
+
+void mlir::vector::populateVectorShapeCastLoweringPatterns(
+ RewritePatternSet &patterns, PatternBenefit benefit) {
+ patterns.add<ShapeCastOp2DDownCastRewritePattern,
+ ShapeCastOp2DUpCastRewritePattern, ShapeCastOpRewritePattern>(
+ patterns.getContext(), benefit);
+}
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferPermutationMapRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
similarity index 57%
rename from mlir/lib/Dialect/Vector/Transforms/VectorTransferPermutationMapRewritePatterns.cpp
rename to mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
index 68d9a349478bf..c2ce9aa10a850 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferPermutationMapRewritePatterns.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
@@ -14,7 +14,7 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
-#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
#include "mlir/Interfaces/VectorInterfaces.h"
using namespace mlir;
@@ -46,6 +46,11 @@ static Value extendVectorRank(OpBuilder &builder, Location loc, Value vec,
return builder.create<vector::BroadcastOp>(loc, newVecType, vec);
}
+//===----------------------------------------------------------------------===//
+// populateVectorTransferPermutationMapLoweringPatterns
+//===----------------------------------------------------------------------===//
+
+namespace {
/// Lower transfer_read op with permutation into a transfer_read with a
/// permutation map composed of leading zeros followed by a minor identiy +
/// vector.transpose op.
@@ -332,6 +337,8 @@ struct TransferOpReduceRank : public OpRewritePattern<vector::TransferReadOp> {
}
};
+} // namespace
+
void mlir::vector::populateVectorTransferPermutationMapLoweringPatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
patterns
@@ -339,3 +346,239 @@ void mlir::vector::populateVectorTransferPermutationMapLoweringPatterns(
TransferOpReduceRank, TransferWriteNonPermutationLowering>(
patterns.getContext(), benefit);
}
+
+//===----------------------------------------------------------------------===//
+// populateVectorTransferLoweringPatterns
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// Progressive lowering of transfer_read. This pattern supports lowering of
+/// `vector.transfer_read` to a combination of `vector.load` and
+/// `vector.broadcast` if all of the following hold:
+/// - Stride of most minor memref dimension must be 1.
+/// - Out-of-bounds masking is not required.
+/// - If the memref's element type is a vector type then it coincides with the
+/// result type.
+/// - The permutation map doesn't perform permutation (broadcasting is allowed).
+struct TransferReadToVectorLoadLowering
+ : public OpRewritePattern<vector::TransferReadOp> {
+ TransferReadToVectorLoadLowering(MLIRContext *context,
+ std::optional<unsigned> maxRank,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<vector::TransferReadOp>(context, benefit),
+ maxTransferRank(maxRank) {}
+
+ LogicalResult matchAndRewrite(vector::TransferReadOp read,
+ PatternRewriter &rewriter) const override {
+ if (maxTransferRank && read.getVectorType().getRank() > *maxTransferRank)
+ return failure();
+
+ SmallVector<unsigned> broadcastedDims;
+ // Permutations are handled by VectorToSCF or
+ // populateVectorTransferPermutationMapLoweringPatterns.
+ // We let the 0-d corner case pass-through as it is supported.
+ if (!read.getPermutationMap().isMinorIdentityWithBroadcasting(
+ &broadcastedDims))
+ return failure();
+
+ auto memRefType = read.getShapedType().dyn_cast<MemRefType>();
+ if (!memRefType)
+ return failure();
+
+ // Non-unit strides are handled by VectorToSCF.
+ if (!vector::isLastMemrefDimUnitStride(memRefType))
+ return failure();
+
+ // If there is broadcasting involved then we first load the unbroadcasted
+ // vector, and then broadcast it with `vector.broadcast`.
+ ArrayRef<int64_t> vectorShape = read.getVectorType().getShape();
+ SmallVector<int64_t> unbroadcastedVectorShape(vectorShape.begin(),
+ vectorShape.end());
+ for (unsigned i : broadcastedDims)
+ unbroadcastedVectorShape[i] = 1;
+ VectorType unbroadcastedVectorType = VectorType::get(
+ unbroadcastedVectorShape, read.getVectorType().getElementType());
+
+ // `vector.load` supports vector types as memref's elements only when the
+ // resulting vector type is the same as the element type.
+ auto memrefElTy = memRefType.getElementType();
+ if (memrefElTy.isa<VectorType>() && memrefElTy != unbroadcastedVectorType)
+ return failure();
+
+ // Otherwise, element types of the memref and the vector must match.
+ if (!memrefElTy.isa<VectorType>() &&
+ memrefElTy != read.getVectorType().getElementType())
+ return failure();
+
+ // Out-of-bounds dims are handled by MaterializeTransferMask.
+ if (read.hasOutOfBoundsDim())
+ return failure();
+
+ // Create vector load op.
+ Operation *loadOp;
+ if (read.getMask()) {
+ Value fill = rewriter.create<vector::SplatOp>(
+ read.getLoc(), unbroadcastedVectorType, read.getPadding());
+ loadOp = rewriter.create<vector::MaskedLoadOp>(
+ read.getLoc(), unbroadcastedVectorType, read.getSource(),
+ read.getIndices(), read.getMask(), fill);
+ } else {
+ loadOp = rewriter.create<vector::LoadOp>(
+ read.getLoc(), unbroadcastedVectorType, read.getSource(),
+ read.getIndices());
+ }
+
+ // Insert a broadcasting op if required.
+ if (!broadcastedDims.empty()) {
+ rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
+ read, read.getVectorType(), loadOp->getResult(0));
+ } else {
+ rewriter.replaceOp(read, loadOp->getResult(0));
+ }
+
+ return success();
+ }
+
+ std::optional<unsigned> maxTransferRank;
+};
+
+/// Replace a 0-d vector.load with a memref.load + vector.broadcast.
+// TODO: we shouldn't cross the vector/scalar domains just for this
+// but atm we lack the infra to avoid it. Possible solutions include:
+// - go directly to LLVM + bitcast
+// - introduce a bitcast op and likely a new pointer dialect
+// - let memref.load/store additionally support the 0-d vector case
+// There are still deeper data layout issues lingering even in this
+// trivial case (for architectures for which this matters).
+struct VectorLoadToMemrefLoadLowering
+ : public OpRewritePattern<vector::LoadOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::LoadOp loadOp,
+ PatternRewriter &rewriter) const override {
+ auto vecType = loadOp.getVectorType();
+ if (vecType.getNumElements() != 1)
+ return failure();
+ auto memrefLoad = rewriter.create<memref::LoadOp>(
+ loadOp.getLoc(), loadOp.getBase(), loadOp.getIndices());
+ rewriter.replaceOpWithNewOp<vector::BroadcastOp>(loadOp, vecType,
+ memrefLoad);
+ return success();
+ }
+};
+
+/// Replace a 0-d vector.store with a vector.extractelement + memref.store.
+struct VectorStoreToMemrefStoreLowering
+ : public OpRewritePattern<vector::StoreOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::StoreOp storeOp,
+ PatternRewriter &rewriter) const override {
+ auto vecType = storeOp.getVectorType();
+ if (vecType.getNumElements() != 1)
+ return failure();
+ Value extracted;
+ if (vecType.getRank() == 0) {
+ // TODO: Unifiy once ExtractOp supports 0-d vectors.
+ extracted = rewriter.create<vector::ExtractElementOp>(
+ storeOp.getLoc(), storeOp.getValueToStore());
+ } else {
+ SmallVector<int64_t> indices(vecType.getRank(), 0);
+ extracted = rewriter.create<vector::ExtractOp>(
+ storeOp.getLoc(), storeOp.getValueToStore(), indices);
+ }
+
+ rewriter.replaceOpWithNewOp<memref::StoreOp>(
+ storeOp, extracted, storeOp.getBase(), storeOp.getIndices());
+ return success();
+ }
+};
+
+/// Progressive lowering of transfer_write. This pattern supports lowering of
+/// `vector.transfer_write` to `vector.store` if all of the following hold:
+/// - Stride of most minor memref dimension must be 1.
+/// - Out-of-bounds masking is not required.
+/// - If the memref's element type is a vector type then it coincides with the
+/// type of the written value.
+/// - The permutation map is the minor identity map (neither permutation nor
+/// broadcasting is allowed).
+struct TransferWriteToVectorStoreLowering
+ : public OpRewritePattern<vector::TransferWriteOp> {
+ TransferWriteToVectorStoreLowering(MLIRContext *context,
+ std::optional<unsigned> maxRank,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<vector::TransferWriteOp>(context, benefit),
+ maxTransferRank(maxRank) {}
+
+ LogicalResult matchAndRewrite(vector::TransferWriteOp write,
+ PatternRewriter &rewriter) const override {
+ if (maxTransferRank && write.getVectorType().getRank() > *maxTransferRank)
+ return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) {
+ diag << "rank exceeds maxTransferRank: " << write;
+ });
+
+ // Permutations are handled by VectorToSCF or
+ // populateVectorTransferPermutationMapLoweringPatterns.
+ if ( // pass-through for the 0-d corner case.
+ !write.getPermutationMap().isMinorIdentity())
+ return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) {
+ diag << "permutation map is not minor identity: " << write;
+ });
+
+ auto memRefType = write.getShapedType().dyn_cast<MemRefType>();
+ if (!memRefType)
+ return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) {
+ diag << "not a memref type: " << write;
+ });
+
+ // Non-unit strides are handled by VectorToSCF.
+ if (!vector::isLastMemrefDimUnitStride(memRefType))
+ return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) {
+ diag << "most minor stride is not 1: " << write;
+ });
+
+ // `vector.store` supports vector types as memref's elements only when the
+ // type of the vector value being written is the same as the element type.
+ auto memrefElTy = memRefType.getElementType();
+ if (memrefElTy.isa<VectorType>() && memrefElTy != write.getVectorType())
+ return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) {
+ diag << "elemental type mismatch: " << write;
+ });
+
+ // Otherwise, element types of the memref and the vector must match.
+ if (!memrefElTy.isa<VectorType>() &&
+ memrefElTy != write.getVectorType().getElementType())
+ return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) {
+ diag << "elemental type mismatch: " << write;
+ });
+
+ // Out-of-bounds dims are handled by MaterializeTransferMask.
+ if (write.hasOutOfBoundsDim())
+ return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) {
+ diag << "out of bounds dim: " << write;
+ });
+ if (write.getMask()) {
+ rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
+ write, write.getSource(), write.getIndices(), write.getMask(),
+ write.getVector());
+ } else {
+ rewriter.replaceOpWithNewOp<vector::StoreOp>(
+ write, write.getVector(), write.getSource(), write.getIndices());
+ }
+ return success();
+ }
+
+ std::optional<unsigned> maxTransferRank;
+};
+} // namespace
+
+void mlir::vector::populateVectorTransferLoweringPatterns(
+ RewritePatternSet &patterns, std::optional<unsigned> maxTransferRank,
+ PatternBenefit benefit) {
+ patterns.add<TransferReadToVectorLoadLowering,
+ TransferWriteToVectorStoreLowering>(patterns.getContext(),
+ maxTransferRank, benefit);
+ patterns
+ .add<VectorLoadToMemrefLoadLowering, VectorStoreToMemrefStoreLowering>(
+ patterns.getContext(), benefit);
+}
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
new file mode 100644
index 0000000000000..f6e8b0c445c99
--- /dev/null
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
@@ -0,0 +1,210 @@
+//===- LowerVectorTranspose.cpp - Lower 'vector.transpose' operation ------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements target-independent rewrites and utilities to lower the
+// 'vector.transpose' operation.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
+#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
+#include "mlir/IR/BuiltinAttributeInterfaces.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/ImplicitLocOpBuilder.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeUtilities.h"
+#include "mlir/Interfaces/VectorInterfaces.h"
+#include "mlir/Support/LogicalResult.h"
+
+#define DEBUG_TYPE "vector-shape-cast-lowering"
+
+using namespace mlir;
+using namespace mlir::vector;
+
+/// Given a 'transpose' pattern, prune the rightmost dimensions that are not
+/// transposed.
+static void pruneNonTransposedDims(ArrayRef<int64_t> transpose,
+ SmallVectorImpl<int64_t> &result) {
+ size_t numTransposedDims = transpose.size();
+ for (size_t transpDim : llvm::reverse(transpose)) {
+ if (transpDim != numTransposedDims - 1)
+ break;
+ numTransposedDims--;
+ }
+
+ result.append(transpose.begin(), transpose.begin() + numTransposedDims);
+}
+
+namespace {
+/// Progressive lowering of TransposeOp.
+/// One:
+/// %x = vector.transpose %y, [1, 0]
+/// is replaced by:
+/// %z = arith.constant dense<0.000000e+00>
+/// %0 = vector.extract %y[0, 0]
+/// %1 = vector.insert %0, %z [0, 0]
+/// ..
+/// %x = vector.insert .., .. [.., ..]
+class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+
+ TransposeOpLowering(vector::VectorTransformsOptions vectorTransformOptions,
+ MLIRContext *context, PatternBenefit benefit = 1)
+ : OpRewritePattern<vector::TransposeOp>(context, benefit),
+ vectorTransformOptions(vectorTransformOptions) {}
+
+ LogicalResult matchAndRewrite(vector::TransposeOp op,
+ PatternRewriter &rewriter) const override {
+ auto loc = op.getLoc();
+
+ Value input = op.getVector();
+ VectorType inputType = op.getSourceVectorType();
+ VectorType resType = op.getResultVectorType();
+
+ // Set up convenience transposition table.
+ SmallVector<int64_t> transp;
+ for (auto attr : op.getTransp())
+ transp.push_back(attr.cast<IntegerAttr>().getInt());
+
+ if (vectorTransformOptions.vectorTransposeLowering ==
+ vector::VectorTransposeLowering::Shuffle &&
+ resType.getRank() == 2 && transp[0] == 1 && transp[1] == 0)
+ return rewriter.notifyMatchFailure(
+ op, "Options specifies lowering to shuffle");
+
+ // Handle a true 2-D matrix transpose
diff erently when requested.
+ if (vectorTransformOptions.vectorTransposeLowering ==
+ vector::VectorTransposeLowering::Flat &&
+ resType.getRank() == 2 && transp[0] == 1 && transp[1] == 0) {
+ Type flattenedType =
+ VectorType::get(resType.getNumElements(), resType.getElementType());
+ auto matrix =
+ rewriter.create<vector::ShapeCastOp>(loc, flattenedType, input);
+ auto rows = rewriter.getI32IntegerAttr(resType.getShape()[0]);
+ auto columns = rewriter.getI32IntegerAttr(resType.getShape()[1]);
+ Value trans = rewriter.create<vector::FlatTransposeOp>(
+ loc, flattenedType, matrix, rows, columns);
+ rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, trans);
+ return success();
+ }
+
+ // Generate unrolled extract/insert ops. We do not unroll the rightmost
+ // (i.e., highest-order) dimensions that are not transposed and leave them
+ // in vector form to improve performance. Therefore, we prune those
+ // dimensions from the shape/transpose data structures used to generate the
+ // extract/insert ops.
+ SmallVector<int64_t> prunedTransp;
+ pruneNonTransposedDims(transp, prunedTransp);
+ size_t numPrunedDims = transp.size() - prunedTransp.size();
+ auto prunedInShape = inputType.getShape().drop_back(numPrunedDims);
+ auto prunedInStrides = computeStrides(prunedInShape);
+
+ // Generates the extract/insert operations for every scalar/vector element
+ // of the leftmost transposed dimensions. We traverse every transpose
+ // element using a linearized index that we delinearize to generate the
+ // appropriate indices for the extract/insert operations.
+ Value result = rewriter.create<arith::ConstantOp>(
+ loc, resType, rewriter.getZeroAttr(resType));
+ int64_t numTransposedElements = ShapedType::getNumElements(prunedInShape);
+
+ for (int64_t linearIdx = 0; linearIdx < numTransposedElements;
+ ++linearIdx) {
+ auto extractIdxs = delinearize(linearIdx, prunedInStrides);
+ SmallVector<int64_t> insertIdxs(extractIdxs);
+ applyPermutationToVector(insertIdxs, prunedTransp);
+ Value extractOp =
+ rewriter.create<vector::ExtractOp>(loc, input, extractIdxs);
+ result =
+ rewriter.create<vector::InsertOp>(loc, extractOp, result, insertIdxs);
+ }
+
+ rewriter.replaceOp(op, result);
+ return success();
+ }
+
+private:
+ /// Options to control the vector patterns.
+ vector::VectorTransformsOptions vectorTransformOptions;
+};
+
+/// Rewrite a 2-D vector.transpose as a sequence of:
+/// vector.shape_cast 2D -> 1D
+/// vector.shuffle
+/// vector.shape_cast 1D -> 2D
+class TransposeOp2DToShuffleLowering
+ : public OpRewritePattern<vector::TransposeOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+
+ TransposeOp2DToShuffleLowering(
+ vector::VectorTransformsOptions vectorTransformOptions,
+ MLIRContext *context, PatternBenefit benefit = 1)
+ : OpRewritePattern<vector::TransposeOp>(context, benefit),
+ vectorTransformOptions(vectorTransformOptions) {}
+
+ LogicalResult matchAndRewrite(vector::TransposeOp op,
+ PatternRewriter &rewriter) const override {
+ auto loc = op.getLoc();
+
+ VectorType srcType = op.getSourceVectorType();
+ if (srcType.getRank() != 2)
+ return rewriter.notifyMatchFailure(op, "Not a 2D transpose");
+
+ SmallVector<int64_t> transp;
+ for (auto attr : op.getTransp())
+ transp.push_back(attr.cast<IntegerAttr>().getInt());
+ if (transp[0] != 1 && transp[1] != 0)
+ return rewriter.notifyMatchFailure(op, "Not a 2D transpose permutation");
+
+ if (vectorTransformOptions.vectorTransposeLowering !=
+ VectorTransposeLowering::Shuffle)
+ return rewriter.notifyMatchFailure(op, "Options do not ask for Shuffle");
+
+ int64_t m = srcType.getShape().front(), n = srcType.getShape().back();
+ Value casted = rewriter.create<vector::ShapeCastOp>(
+ loc, VectorType::get({m * n}, srcType.getElementType()),
+ op.getVector());
+ SmallVector<int64_t> mask;
+ mask.reserve(m * n);
+ for (int64_t j = 0; j < n; ++j)
+ for (int64_t i = 0; i < m; ++i)
+ mask.push_back(i * n + j);
+
+ Value shuffled =
+ rewriter.create<vector::ShuffleOp>(loc, casted, casted, mask);
+ rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
+ op, op.getResultVectorType(), shuffled);
+
+ return success();
+ }
+
+private:
+ /// Options to control the vector patterns.
+ vector::VectorTransformsOptions vectorTransformOptions;
+};
+} // namespace
+
+void mlir::vector::populateVectorTransposeLoweringPatterns(
+ RewritePatternSet &patterns, VectorTransformsOptions options,
+ PatternBenefit benefit) {
+ patterns.add<TransposeOpLowering, TransposeOp2DToShuffleLowering>(
+ options, patterns.getContext(), benefit);
+}
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index 38062b9893f1a..b0690f63422d9 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -16,6 +16,7 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/IR/BuiltinOps.h"
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp
index ee23b5494f707..caf5822256bc6 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp
@@ -11,8 +11,8 @@
//
//===----------------------------------------------------------------------===//
-#include <type_traits>
#include <optional>
+#include <type_traits>
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
@@ -92,11 +92,11 @@ static Value createInBoundsCond(RewriterBase &b,
}
/// Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds
-/// masking) fastpath and a slowpath.
+/// masking) fast path and a slow path.
/// If `ifOp` is not null and the result is `success, the `ifOp` points to the
/// newly created conditional upon function return.
-/// To accomodate for the fact that the original vector.transfer indexing may be
-/// arbitrary and the slow path indexes @[0...0] in the temporary buffer, the
+/// To accommodate for the fact that the original vector.transfer indexing may
+/// be arbitrary and the slow path indexes @[0...0] in the temporary buffer, the
/// scf.if op returns a view and values of type index.
/// At this time, only vector.transfer_read case is implemented.
///
@@ -107,11 +107,11 @@ static Value createInBoundsCond(RewriterBase &b,
/// is transformed into:
/// ```
/// %1:3 = scf.if (%inBounds) {
-/// // fastpath, direct cast
+/// // fast path, direct cast
/// memref.cast %A: memref<A...> to compatibleMemRefType
/// scf.yield %view : compatibleMemRefType, index, index
/// } else {
-/// // slowpath, not in-bounds vector.transfer or linalg.copy.
+/// // slow path, not in-bounds vector.transfer or linalg.copy.
/// memref.cast %alloc: memref<B...> to compatibleMemRefType
/// scf.yield %4 : compatibleMemRefType, index, index
// }
@@ -172,12 +172,10 @@ static MemRefType getCastCompatibleMemRefType(MemRefType aT, MemRefType bT) {
for (int64_t idx = 0, e = aT.getRank(); idx < e; ++idx) {
resShape[idx] =
(aShape[idx] == bShape[idx]) ? aShape[idx] : ShapedType::kDynamic;
- resStrides[idx] = (aStrides[idx] == bStrides[idx])
- ? aStrides[idx]
- : ShapedType::kDynamic;
+ resStrides[idx] =
+ (aStrides[idx] == bStrides[idx]) ? aStrides[idx] : ShapedType::kDynamic;
}
- resOffset =
- (aOffset == bOffset) ? aOffset : ShapedType::kDynamic;
+ resOffset = (aOffset == bOffset) ? aOffset : ShapedType::kDynamic;
return MemRefType::get(
resShape, aT.getElementType(),
StridedLayoutAttr::get(aT.getContext(), resOffset, resStrides));
@@ -634,7 +632,34 @@ LogicalResult mlir::vector::splitFullAndPartialTransfer(
return success();
}
-LogicalResult mlir::vector::VectorTransferFullPartialRewriter::matchAndRewrite(
+namespace {
+/// Apply `splitFullAndPartialTransfer` selectively via a pattern. This pattern
+/// may take an extra filter to perform selection at a finer granularity.
+struct VectorTransferFullPartialRewriter : public RewritePattern {
+ using FilterConstraintType =
+ std::function<LogicalResult(VectorTransferOpInterface op)>;
+
+ explicit VectorTransferFullPartialRewriter(
+ MLIRContext *context,
+ VectorTransformsOptions options = VectorTransformsOptions(),
+ FilterConstraintType filter =
+ [](VectorTransferOpInterface op) { return success(); },
+ PatternBenefit benefit = 1)
+ : RewritePattern(MatchAnyOpTypeTag(), benefit, context), options(options),
+ filter(std::move(filter)) {}
+
+ /// Performs the rewrite.
+ LogicalResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const override;
+
+private:
+ VectorTransformsOptions options;
+ FilterConstraintType filter;
+};
+
+} // namespace
+
+LogicalResult VectorTransferFullPartialRewriter::matchAndRewrite(
Operation *op, PatternRewriter &rewriter) const {
auto xferOp = dyn_cast<VectorTransferOpInterface>(op);
if (!xferOp || failed(splitFullAndPartialTransferPrecondition(xferOp)) ||
@@ -642,3 +667,9 @@ LogicalResult mlir::vector::VectorTransferFullPartialRewriter::matchAndRewrite(
return failure();
return splitFullAndPartialTransfer(rewriter, xferOp, options);
}
+
+void mlir::vector::populateVectorTransferFullPartialPatterns(
+ RewritePatternSet &patterns, const VectorTransformsOptions &options) {
+ patterns.add<VectorTransferFullPartialRewriter>(patterns.getContext(),
+ options);
+}
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index fe59143ebd55f..20fc59e874ab6 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -51,102 +51,6 @@
using namespace mlir;
using namespace mlir::vector;
-// Helper to find an index in an affine map.
-static std::optional<int64_t> getResultIndex(AffineMap map, int64_t index) {
- for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
- int64_t idx = map.getDimPosition(i);
- if (idx == index)
- return i;
- }
- return std::nullopt;
-}
-
-// Helper to construct iterator types with one index removed.
-static SmallVector<Attribute> adjustIter(ArrayAttr iteratorTypes,
- int64_t index) {
- SmallVector<Attribute> results;
- for (const auto &it : llvm::enumerate(iteratorTypes)) {
- int64_t idx = it.index();
- if (idx == index)
- continue;
- results.push_back(it.value());
- }
- return results;
-}
-
-// Helper to construct an affine map with one index removed.
-static AffineMap adjustMap(AffineMap map, int64_t index,
- PatternRewriter &rewriter) {
- auto *ctx = rewriter.getContext();
- SmallVector<AffineExpr> results;
- for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
- int64_t idx = map.getDimPosition(i);
- if (idx == index)
- continue;
- // Re-insert remaining indices, but renamed when occurring
- // after the removed index.
- auto targetExpr = getAffineDimExpr(idx < index ? idx : idx - 1, ctx);
- results.push_back(targetExpr);
- }
- return AffineMap::get(map.getNumDims() - 1, 0, results, ctx);
-}
-
-// Helper method to possibly drop a dimension in a load.
-// TODO
-static Value reshapeLoad(Location loc, Value val, VectorType type,
- int64_t index, int64_t pos,
- PatternRewriter &rewriter) {
- if (index == -1)
- return val;
- Type lowType = VectorType::Builder(type).dropDim(0);
- // At extraction dimension?
- if (index == 0) {
- auto posAttr = rewriter.getI64ArrayAttr(pos);
- return rewriter.create<vector::ExtractOp>(loc, lowType, val, posAttr);
- }
- // Unroll leading dimensions.
- VectorType vType = lowType.cast<VectorType>();
- Type resType = VectorType::Builder(type).dropDim(index);
- auto resVectorType = resType.cast<VectorType>();
- Value result = rewriter.create<arith::ConstantOp>(
- loc, resVectorType, rewriter.getZeroAttr(resVectorType));
- for (int64_t d = 0, e = resVectorType.getDimSize(0); d < e; d++) {
- auto posAttr = rewriter.getI64ArrayAttr(d);
- Value ext = rewriter.create<vector::ExtractOp>(loc, vType, val, posAttr);
- Value load = reshapeLoad(loc, ext, vType, index - 1, pos, rewriter);
- result = rewriter.create<vector::InsertOp>(loc, resVectorType, load, result,
- posAttr);
- }
- return result;
-}
-
-// Helper method to possibly drop a dimension in a store.
-// TODO
-static Value reshapeStore(Location loc, Value val, Value result,
- VectorType type, int64_t index, int64_t pos,
- PatternRewriter &rewriter) {
- // Unmodified?
- if (index == -1)
- return val;
- // At insertion dimension?
- if (index == 0) {
- auto posAttr = rewriter.getI64ArrayAttr(pos);
- return rewriter.create<vector::InsertOp>(loc, type, val, result, posAttr);
- }
- // Unroll leading dimensions.
- Type lowType = VectorType::Builder(type).dropDim(0);
- VectorType vType = lowType.cast<VectorType>();
- Type insType = VectorType::Builder(vType).dropDim(0);
- for (int64_t d = 0, e = type.getDimSize(0); d < e; d++) {
- auto posAttr = rewriter.getI64ArrayAttr(d);
- Value ext = rewriter.create<vector::ExtractOp>(loc, vType, result, posAttr);
- Value ins = rewriter.create<vector::ExtractOp>(loc, insType, val, posAttr);
- Value sto = reshapeStore(loc, ins, ext, vType, index - 1, pos, rewriter);
- result = rewriter.create<vector::InsertOp>(loc, type, sto, result, posAttr);
- }
- return result;
-}
-
template <typename IntType>
static SmallVector<IntType> extractVector(ArrayAttr arrayAttr) {
return llvm::to_vector<4>(llvm::map_range(
@@ -154,61 +58,11 @@ static SmallVector<IntType> extractVector(ArrayAttr arrayAttr) {
[](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); }));
}
-/// Helper to create arithmetic operation associated with a kind of contraction.
-static std::optional<Value>
-createContractArithOp(Location loc, Value x, Value y, Value acc,
- vector::CombiningKind kind, PatternRewriter &rewriter,
- bool isInt, Value mask = Value()) {
- using vector::CombiningKind;
- Value mul;
-
- if (isInt) {
- if (kind == CombiningKind::MINF || kind == CombiningKind::MAXF)
- // Only valid for floating point types.
- return std::nullopt;
- mul = rewriter.create<arith::MulIOp>(loc, x, y);
- } else {
- // Float case.
- if (kind == CombiningKind::AND || kind == CombiningKind::MINUI ||
- kind == CombiningKind::MINSI || kind == CombiningKind::MAXUI ||
- kind == CombiningKind::MAXSI || kind == CombiningKind::OR ||
- kind == CombiningKind::XOR)
- // Only valid for integer types.
- return std::nullopt;
- // Special case for fused multiply-add.
- if (acc && acc.getType().isa<VectorType>() && kind == CombiningKind::ADD) {
- Value fma = rewriter.create<vector::FMAOp>(loc, x, y, acc);
- if (mask)
- // The fma op doesn't need explicit masking. However, fma ops used in
- // reductions must preserve previous 'acc' values for masked-out lanes.
- fma = selectPassthru(rewriter, mask, fma, acc);
- return fma;
- }
- mul = rewriter.create<arith::MulFOp>(loc, x, y);
- }
-
- if (!acc)
- return std::optional<Value>(mul);
-
- return makeArithReduction(rewriter, loc, kind, mul, acc, mask);
-}
-
-/// Return the positions of the reductions in the given map.
-static SmallVector<int64_t> getReductionIndex(AffineMap map,
- ArrayAttr iteratorTypes) {
- SmallVector<int64_t> dimsIdx;
- for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
- if (isReductionIterator(iteratorTypes[map.getDimPosition(i)]))
- dimsIdx.push_back(i);
- }
- return dimsIdx;
-}
-
-/// Look for a given dimension in an affine map and return its position. Return
-/// std::nullopt if the dimension is not in the map results.
-static std::optional<unsigned> getDimPosition(AffineMap map, unsigned dim) {
- for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
- if (map.getDimPosition(i) == dim)
+// Helper to find an index in an affine map.
+static std::optional<int64_t> getResultIndex(AffineMap map, int64_t index) {
+ for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
+ int64_t idx = map.getDimPosition(i);
+ if (idx == index)
return i;
}
return std::nullopt;
@@ -264,735 +118,6 @@ struct ShapeCastOpFolder : public OpRewritePattern<vector::ShapeCastOp> {
}
};
-/// Progressive lowering of BroadcastOp.
-class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
-public:
- using OpRewritePattern::OpRewritePattern;
-
- LogicalResult matchAndRewrite(vector::BroadcastOp op,
- PatternRewriter &rewriter) const override {
- auto loc = op.getLoc();
- VectorType dstType = op.getResultVectorType();
- VectorType srcType = op.getSourceType().dyn_cast<VectorType>();
- Type eltType = dstType.getElementType();
-
- // Scalar to any vector can use splat.
- if (!srcType) {
- rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, op.getSource());
- return success();
- }
-
- // Determine rank of source and destination.
- int64_t srcRank = srcType.getRank();
- int64_t dstRank = dstType.getRank();
-
- // Stretching scalar inside vector (e.g. vector<1xf32>) can use splat.
- if (srcRank <= 1 && dstRank == 1) {
- Value ext;
- if (srcRank == 0)
- ext = rewriter.create<vector::ExtractElementOp>(loc, op.getSource());
- else
- ext = rewriter.create<vector::ExtractOp>(loc, op.getSource(), 0);
- rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, ext);
- return success();
- }
-
- // Duplicate this rank.
- // For example:
- // %x = broadcast %y : k-D to n-D, k < n
- // becomes:
- // %b = broadcast %y : k-D to (n-1)-D
- // %x = [%b,%b,%b,%b] : n-D
- // becomes:
- // %b = [%y,%y] : (n-1)-D
- // %x = [%b,%b,%b,%b] : n-D
- if (srcRank < dstRank) {
- // Duplication.
- VectorType resType =
- VectorType::get(dstType.getShape().drop_front(), eltType);
- Value bcst =
- rewriter.create<vector::BroadcastOp>(loc, resType, op.getSource());
- Value result = rewriter.create<arith::ConstantOp>(
- loc, dstType, rewriter.getZeroAttr(dstType));
- for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d)
- result = rewriter.create<vector::InsertOp>(loc, bcst, result, d);
- rewriter.replaceOp(op, result);
- return success();
- }
-
- // Find non-matching dimension, if any.
- assert(srcRank == dstRank);
- int64_t m = -1;
- for (int64_t r = 0; r < dstRank; r++)
- if (srcType.getDimSize(r) != dstType.getDimSize(r)) {
- m = r;
- break;
- }
-
- // All trailing dimensions are the same. Simply pass through.
- if (m == -1) {
- rewriter.replaceOp(op, op.getSource());
- return success();
- }
-
- // Any non-matching dimension forces a stretch along this rank.
- // For example:
- // %x = broadcast %y : vector<4x1x2xf32> to vector<4x2x2xf32>
- // becomes:
- // %a = broadcast %y[0] : vector<1x2xf32> to vector<2x2xf32>
- // %b = broadcast %y[1] : vector<1x2xf32> to vector<2x2xf32>
- // %c = broadcast %y[2] : vector<1x2xf32> to vector<2x2xf32>
- // %d = broadcast %y[3] : vector<1x2xf32> to vector<2x2xf32>
- // %x = [%a,%b,%c,%d]
- // becomes:
- // %u = broadcast %y[0][0] : vector<2xf32> to vector <2x2xf32>
- // %v = broadcast %y[1][0] : vector<2xf32> to vector <2x2xf32>
- // %a = [%u, %v]
- // ..
- // %x = [%a,%b,%c,%d]
- VectorType resType =
- VectorType::get(dstType.getShape().drop_front(), eltType);
- Value result = rewriter.create<arith::ConstantOp>(
- loc, dstType, rewriter.getZeroAttr(dstType));
- if (m == 0) {
- // Stetch at start.
- Value ext = rewriter.create<vector::ExtractOp>(loc, op.getSource(), 0);
- Value bcst = rewriter.create<vector::BroadcastOp>(loc, resType, ext);
- for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d)
- result = rewriter.create<vector::InsertOp>(loc, bcst, result, d);
- } else {
- // Stetch not at start.
- for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d) {
- Value ext = rewriter.create<vector::ExtractOp>(loc, op.getSource(), d);
- Value bcst = rewriter.create<vector::BroadcastOp>(loc, resType, ext);
- result = rewriter.create<vector::InsertOp>(loc, bcst, result, d);
- }
- }
- rewriter.replaceOp(op, result);
- return success();
- }
-};
-
-/// Given a 'transpose' pattern, prune the rightmost dimensions that are not
-/// transposed.
-void pruneNonTransposedDims(ArrayRef<int64_t> transpose,
- SmallVectorImpl<int64_t> &result) {
- size_t numTransposedDims = transpose.size();
- for (size_t transpDim : llvm::reverse(transpose)) {
- if (transpDim != numTransposedDims - 1)
- break;
- numTransposedDims--;
- }
-
- result.append(transpose.begin(), transpose.begin() + numTransposedDims);
-}
-
-/// Progressive lowering of TransposeOp.
-/// One:
-/// %x = vector.transpose %y, [1, 0]
-/// is replaced by:
-/// %z = arith.constant dense<0.000000e+00>
-/// %0 = vector.extract %y[0, 0]
-/// %1 = vector.insert %0, %z [0, 0]
-/// ..
-/// %x = vector.insert .., .. [.., ..]
-class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
-public:
- using OpRewritePattern::OpRewritePattern;
-
- TransposeOpLowering(vector::VectorTransformsOptions vectorTransformOptions,
- MLIRContext *context, PatternBenefit benefit = 1)
- : OpRewritePattern<vector::TransposeOp>(context, benefit),
- vectorTransformOptions(vectorTransformOptions) {}
-
- LogicalResult matchAndRewrite(vector::TransposeOp op,
- PatternRewriter &rewriter) const override {
- auto loc = op.getLoc();
-
- Value input = op.getVector();
- VectorType inputType = op.getSourceVectorType();
- VectorType resType = op.getResultVectorType();
-
- // Set up convenience transposition table.
- SmallVector<int64_t> transp;
- for (auto attr : op.getTransp())
- transp.push_back(attr.cast<IntegerAttr>().getInt());
-
- if (vectorTransformOptions.vectorTransposeLowering ==
- vector::VectorTransposeLowering::Shuffle &&
- resType.getRank() == 2 && transp[0] == 1 && transp[1] == 0)
- return rewriter.notifyMatchFailure(
- op, "Options specifies lowering to shuffle");
-
- // Handle a true 2-D matrix transpose
diff erently when requested.
- if (vectorTransformOptions.vectorTransposeLowering ==
- vector::VectorTransposeLowering::Flat &&
- resType.getRank() == 2 && transp[0] == 1 && transp[1] == 0) {
- Type flattenedType =
- VectorType::get(resType.getNumElements(), resType.getElementType());
- auto matrix =
- rewriter.create<vector::ShapeCastOp>(loc, flattenedType, input);
- auto rows = rewriter.getI32IntegerAttr(resType.getShape()[0]);
- auto columns = rewriter.getI32IntegerAttr(resType.getShape()[1]);
- Value trans = rewriter.create<vector::FlatTransposeOp>(
- loc, flattenedType, matrix, rows, columns);
- rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, trans);
- return success();
- }
-
- // Generate unrolled extract/insert ops. We do not unroll the rightmost
- // (i.e., highest-order) dimensions that are not transposed and leave them
- // in vector form to improve performance. Therefore, we prune those
- // dimensions from the shape/transpose data structures used to generate the
- // extract/insert ops.
- SmallVector<int64_t> prunedTransp;
- pruneNonTransposedDims(transp, prunedTransp);
- size_t numPrunedDims = transp.size() - prunedTransp.size();
- auto prunedInShape = inputType.getShape().drop_back(numPrunedDims);
- auto prunedInStrides = computeStrides(prunedInShape);
-
- // Generates the extract/insert operations for every scalar/vector element
- // of the leftmost transposed dimensions. We traverse every transpose
- // element using a linearized index that we delinearize to generate the
- // appropriate indices for the extract/insert operations.
- Value result = rewriter.create<arith::ConstantOp>(
- loc, resType, rewriter.getZeroAttr(resType));
- int64_t numTransposedElements = ShapedType::getNumElements(prunedInShape);
-
- for (int64_t linearIdx = 0; linearIdx < numTransposedElements;
- ++linearIdx) {
- auto extractIdxs = delinearize(linearIdx, prunedInStrides);
- SmallVector<int64_t> insertIdxs(extractIdxs);
- applyPermutationToVector(insertIdxs, prunedTransp);
- Value extractOp =
- rewriter.create<vector::ExtractOp>(loc, input, extractIdxs);
- result =
- rewriter.create<vector::InsertOp>(loc, extractOp, result, insertIdxs);
- }
-
- rewriter.replaceOp(op, result);
- return success();
- }
-
-private:
- /// Options to control the vector patterns.
- vector::VectorTransformsOptions vectorTransformOptions;
-};
-
-/// Rewrite a 2-D vector.transpose as a sequence of:
-/// vector.shape_cast 2D -> 1D
-/// vector.shuffle
-/// vector.shape_cast 1D -> 2D
-class TransposeOp2DToShuffleLowering
- : public OpRewritePattern<vector::TransposeOp> {
-public:
- using OpRewritePattern::OpRewritePattern;
-
- TransposeOp2DToShuffleLowering(
- vector::VectorTransformsOptions vectorTransformOptions,
- MLIRContext *context, PatternBenefit benefit = 1)
- : OpRewritePattern<vector::TransposeOp>(context, benefit),
- vectorTransformOptions(vectorTransformOptions) {}
-
- LogicalResult matchAndRewrite(vector::TransposeOp op,
- PatternRewriter &rewriter) const override {
- auto loc = op.getLoc();
-
- VectorType srcType = op.getSourceVectorType();
- if (srcType.getRank() != 2)
- return rewriter.notifyMatchFailure(op, "Not a 2D transpose");
-
- SmallVector<int64_t> transp;
- for (auto attr : op.getTransp())
- transp.push_back(attr.cast<IntegerAttr>().getInt());
- if (transp[0] != 1 && transp[1] != 0)
- return rewriter.notifyMatchFailure(op, "Not a 2D transpose permutation");
-
- if (vectorTransformOptions.vectorTransposeLowering !=
- VectorTransposeLowering::Shuffle)
- return rewriter.notifyMatchFailure(op, "Options do not ask for Shuffle");
-
- int64_t m = srcType.getShape().front(), n = srcType.getShape().back();
- Value casted = rewriter.create<vector::ShapeCastOp>(
- loc, VectorType::get({m * n}, srcType.getElementType()),
- op.getVector());
- SmallVector<int64_t> mask;
- mask.reserve(m * n);
- for (int64_t j = 0; j < n; ++j)
- for (int64_t i = 0; i < m; ++i)
- mask.push_back(i * n + j);
-
- Value shuffled =
- rewriter.create<vector::ShuffleOp>(loc, casted, casted, mask);
- rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
- op, op.getResultVectorType(), shuffled);
-
- return success();
- }
-
-private:
- /// Options to control the vector patterns.
- vector::VectorTransformsOptions vectorTransformOptions;
-};
-
-/// Progressive lowering of OuterProductOp.
-/// One:
-/// %x = vector.outerproduct %lhs, %rhs, %acc
-/// is replaced by:
-/// %z = zero-result
-/// %0 = vector.extract %lhs[0]
-/// %1 = vector.broadcast %0
-/// %2 = vector.extract %acc[0]
-/// %3 = vector.fma %1, %rhs, %2
-/// %4 = vector.insert %3, %z[0]
-/// ..
-/// %x = vector.insert %.., %..[N-1]
-///
-class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
-public:
- using OpRewritePattern::OpRewritePattern;
-
- LogicalResult matchAndRewrite(vector::OuterProductOp op,
- PatternRewriter &rewriter) const override {
- auto loc = op.getLoc();
-
- VectorType lhsType = op.getOperandVectorTypeLHS();
- VectorType rhsType = op.getOperandTypeRHS().dyn_cast<VectorType>();
- VectorType resType = op.getResultVectorType();
- Type eltType = resType.getElementType();
- bool isInt = eltType.isa<IntegerType, IndexType>();
- Value acc = (op.getAcc().empty()) ? nullptr : op.getAcc()[0];
- vector::CombiningKind kind = op.getKind();
-
- // Vector mask setup.
- OpBuilder::InsertionGuard guard(rewriter);
- auto maskableOp = cast<vector::MaskableOpInterface>(op.getOperation());
- Operation *rootOp;
- Value mask;
- if (maskableOp.isMasked()) {
- rewriter.setInsertionPoint(maskableOp.getMaskingOp());
- rootOp = maskableOp.getMaskingOp();
- mask = maskableOp.getMaskingOp().getMask();
- } else {
- rootOp = op;
- }
-
- if (!rhsType) {
- // Special case: AXPY operation.
- Value b = rewriter.create<vector::BroadcastOp>(loc, lhsType, op.getRhs());
- std::optional<Value> mult = createContractArithOp(
- loc, op.getLhs(), b, acc, kind, rewriter, isInt, mask);
- if (!mult.has_value())
- return failure();
- rewriter.replaceOp(rootOp, *mult);
- return success();
- }
-
- Value result = rewriter.create<arith::ConstantOp>(
- loc, resType, rewriter.getZeroAttr(resType));
- for (int64_t d = 0, e = resType.getDimSize(0); d < e; ++d) {
- auto pos = rewriter.getI64ArrayAttr(d);
- Value x = rewriter.create<vector::ExtractOp>(loc, op.getLhs(), pos);
- Value a = rewriter.create<vector::BroadcastOp>(loc, rhsType, x);
- Value r = nullptr;
- if (acc)
- r = rewriter.create<vector::ExtractOp>(loc, acc, pos);
- Value extrMask;
- if (mask)
- extrMask = rewriter.create<vector::ExtractOp>(loc, mask, pos);
-
- std::optional<Value> m = createContractArithOp(
- loc, a, op.getRhs(), r, kind, rewriter, isInt, extrMask);
- if (!m.has_value())
- return failure();
- result = rewriter.create<vector::InsertOp>(loc, resType, *m, result, pos);
- }
-
- rewriter.replaceOp(rootOp, result);
- return success();
- }
-};
-
-/// Lower vector.contract with all size one reduction dimensions to
-/// elementwise ops when possible.
-struct ContractOpToElementwise
- : public OpRewritePattern<vector::ContractionOp> {
- using OpRewritePattern::OpRewritePattern;
- using FilterConstraintType =
- std::function<LogicalResult(vector::ContractionOp op)>;
- static LogicalResult defaultFilter(vector::ContractionOp op) {
- return success();
- }
- ContractOpToElementwise(
- vector::VectorTransformsOptions vectorTransformOptions,
- MLIRContext *context, PatternBenefit benefit = 1,
- const FilterConstraintType &constraint = defaultFilter)
- : OpRewritePattern<vector::ContractionOp>(context, benefit),
- vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {}
-
- LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
- PatternRewriter &rewriter) const override {
- // TODO: Support vector.mask.
- auto maskableOp = cast<MaskableOpInterface>(contractOp.getOperation());
- if (maskableOp.isMasked())
- return failure();
-
- // TODO: Remove native masks from contraction op?
- if (!contractOp.getMasks().empty())
- return failure();
-
- if (failed(filter(contractOp)))
- return failure();
-
- if (vectorTransformOptions.vectorContractLowering !=
- vector::VectorContractLowering::ParallelArith)
- return failure();
-
- ArrayRef<int64_t> lhsShape = contractOp.getLhsType().getShape();
- ArrayRef<int64_t> rhsShape = contractOp.getRhsType().getShape();
- AffineMap lhsMap = contractOp.getIndexingMapsArray()[0];
- AffineMap rhsMap = contractOp.getIndexingMapsArray()[1];
- SmallVector<int64_t> lhsReductionDims =
- getReductionIndex(lhsMap, contractOp.getIteratorTypes());
- SmallVector<int64_t> rhsReductionDims =
- getReductionIndex(rhsMap, contractOp.getIteratorTypes());
- // All the reduction dimensions must be a size 1.
- for (int64_t dim : lhsReductionDims) {
- if (lhsShape[dim] != 1)
- return failure();
- }
- for (int64_t dim : rhsReductionDims) {
- if (rhsShape[dim] != 1)
- return failure();
- }
- AffineMap accMap = contractOp.getIndexingMapsArray()[2];
- unsigned numParallelDims = accMap.getNumResults();
- unsigned numLhsDimToBroadcast =
- numParallelDims - (lhsMap.getNumResults() - lhsReductionDims.size());
- unsigned numRhsDimToBroadcast =
- numParallelDims - (rhsMap.getNumResults() - rhsReductionDims.size());
- SmallVector<int64_t> lhsDims;
- SmallVector<int64_t> lhsTranspose;
- SmallVector<int64_t> rhsDims;
- SmallVector<int64_t> rhsTranspose;
- for (int64_t dim : lhsReductionDims)
- lhsTranspose.push_back(numLhsDimToBroadcast + dim);
- for (int64_t dim : rhsReductionDims)
- rhsTranspose.push_back(numRhsDimToBroadcast + dim);
- // Loop through the parallel dimensions to calculate the dimensions to
- // broadcast and to permute in order to extract only parallel dimensions.
- for (unsigned i = 0; i < numParallelDims; i++) {
- std::optional<unsigned> lhsDim =
- getDimPosition(lhsMap, accMap.getDimPosition(i));
- if (lhsDim) {
- lhsTranspose.push_back(numLhsDimToBroadcast + *lhsDim);
- } else {
- // If the parallel dimension doesn't exist we will have to broadcast it.
- lhsDims.push_back(
- contractOp.getResultType().cast<VectorType>().getDimSize(i));
- lhsTranspose.push_back(lhsDims.size() - 1);
- }
- std::optional<unsigned> rhsDim =
- getDimPosition(rhsMap, accMap.getDimPosition(i));
- if (rhsDim) {
- rhsTranspose.push_back(numRhsDimToBroadcast + *rhsDim);
- } else {
- // If the parallel dimension doesn't exist we will have to broadcast it.
- rhsDims.push_back(
- contractOp.getResultType().cast<VectorType>().getDimSize(i));
- rhsTranspose.push_back(rhsDims.size() - 1);
- }
- }
- Value newLhs = contractOp.getLhs();
- Value newRhs = contractOp.getRhs();
- Location loc = contractOp.getLoc();
- if (!lhsDims.empty()) {
- lhsDims.append(lhsShape.begin(), lhsShape.end());
- auto expandedType =
- VectorType::get(lhsDims, contractOp.getLhsType().getElementType());
- newLhs = rewriter.create<vector::BroadcastOp>(loc, expandedType, newLhs);
- }
- if (!rhsDims.empty()) {
- rhsDims.append(rhsShape.begin(), rhsShape.end());
- auto expandedType =
- VectorType::get(rhsDims, contractOp.getRhsType().getElementType());
- newRhs = rewriter.create<vector::BroadcastOp>(loc, expandedType, newRhs);
- }
- bool isInt = contractOp.getLhsType().getElementType().isIntOrIndex();
- newLhs = rewriter.create<vector::TransposeOp>(loc, newLhs, lhsTranspose);
- newRhs = rewriter.create<vector::TransposeOp>(loc, newRhs, rhsTranspose);
- SmallVector<int64_t> lhsOffsets(lhsReductionDims.size(), 0);
- SmallVector<int64_t> rhsOffsets(rhsReductionDims.size(), 0);
- newLhs = rewriter.create<vector::ExtractOp>(
- loc, newLhs, rewriter.getI64ArrayAttr(lhsOffsets));
- newRhs = rewriter.create<vector::ExtractOp>(
- loc, newRhs, rewriter.getI64ArrayAttr(rhsOffsets));
- std::optional<Value> result =
- createContractArithOp(loc, newLhs, newRhs, contractOp.getAcc(),
- contractOp.getKind(), rewriter, isInt);
- rewriter.replaceOp(contractOp, {*result});
- return success();
- }
-
-private:
- /// Options to control the vector patterns.
- vector::VectorTransformsOptions vectorTransformOptions;
- FilterConstraintType filter;
-};
-
-/// Progressive lowering of ConstantMaskOp.
-/// One:
-/// %x = vector.constant_mask [a,b]
-/// is replaced by:
-/// %z = zero-result
-/// %l = vector.constant_mask [b]
-/// %4 = vector.insert %l, %z[0]
-/// ..
-/// %x = vector.insert %l, %..[a-1]
-/// until a one-dimensional vector is reached. All these operations
-/// will be folded at LLVM IR level.
-class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> {
-public:
- using OpRewritePattern::OpRewritePattern;
-
- LogicalResult matchAndRewrite(vector::ConstantMaskOp op,
- PatternRewriter &rewriter) const override {
- auto loc = op.getLoc();
- auto dstType = op.getType();
- auto eltType = dstType.getElementType();
- auto dimSizes = op.getMaskDimSizes();
- int64_t rank = dstType.getRank();
-
- if (rank == 0) {
- assert(dimSizes.size() == 1 &&
- "Expected exactly one dim size for a 0-D vector");
- bool value = dimSizes[0].cast<IntegerAttr>().getInt() == 1;
- rewriter.replaceOpWithNewOp<arith::ConstantOp>(
- op, dstType,
- DenseIntElementsAttr::get(
- VectorType::get(ArrayRef<int64_t>{}, rewriter.getI1Type()),
- ArrayRef<bool>{value}));
- return success();
- }
-
- // Scalable constant masks can only be lowered for the "none set" case.
- if (dstType.cast<VectorType>().isScalable()) {
- rewriter.replaceOpWithNewOp<arith::ConstantOp>(
- op, DenseElementsAttr::get(dstType, false));
- return success();
- }
-
- int64_t trueDim = std::min(dstType.getDimSize(0),
- dimSizes[0].cast<IntegerAttr>().getInt());
-
- if (rank == 1) {
- // Express constant 1-D case in explicit vector form:
- // [T,..,T,F,..,F].
- SmallVector<bool> values(dstType.getDimSize(0));
- for (int64_t d = 0; d < trueDim; d++)
- values[d] = true;
- rewriter.replaceOpWithNewOp<arith::ConstantOp>(
- op, dstType, rewriter.getBoolVectorAttr(values));
- return success();
- }
-
- VectorType lowType =
- VectorType::get(dstType.getShape().drop_front(), eltType);
- SmallVector<int64_t> newDimSizes;
- for (int64_t r = 1; r < rank; r++)
- newDimSizes.push_back(dimSizes[r].cast<IntegerAttr>().getInt());
- Value trueVal = rewriter.create<vector::ConstantMaskOp>(
- loc, lowType, rewriter.getI64ArrayAttr(newDimSizes));
- Value result = rewriter.create<arith::ConstantOp>(
- loc, dstType, rewriter.getZeroAttr(dstType));
- for (int64_t d = 0; d < trueDim; d++) {
- auto pos = rewriter.getI64ArrayAttr(d);
- result =
- rewriter.create<vector::InsertOp>(loc, dstType, trueVal, result, pos);
- }
- rewriter.replaceOp(op, result);
- return success();
- }
-};
-
-/// Progressive lowering of CreateMaskOp.
-/// One:
-/// %x = vector.create_mask %a, ... : vector<dx...>
-/// is replaced by:
-/// %l = vector.create_mask ... : vector<...> ; one lower rank
-/// %0 = arith.cmpi "slt", %ci, %a |
-/// %1 = select %0, %l, %zeroes |
-/// %r = vector.insert %1, %pr [i] | d-times
-/// %x = ....
-/// until a one-dimensional vector is reached.
-class CreateMaskOpLowering : public OpRewritePattern<vector::CreateMaskOp> {
-public:
- using OpRewritePattern::OpRewritePattern;
-
- LogicalResult matchAndRewrite(vector::CreateMaskOp op,
- PatternRewriter &rewriter) const override {
- auto dstType = op.getResult().getType().cast<VectorType>();
- int64_t rank = dstType.getRank();
- if (rank <= 1)
- return rewriter.notifyMatchFailure(
- op, "0-D and 1-D vectors are handled separately");
-
- auto loc = op.getLoc();
- auto eltType = dstType.getElementType();
- int64_t dim = dstType.getDimSize(0);
- Value idx = op.getOperand(0);
-
- VectorType lowType =
- VectorType::get(dstType.getShape().drop_front(), eltType);
- Value trueVal = rewriter.create<vector::CreateMaskOp>(
- loc, lowType, op.getOperands().drop_front());
- Value falseVal = rewriter.create<arith::ConstantOp>(
- loc, lowType, rewriter.getZeroAttr(lowType));
- Value result = rewriter.create<arith::ConstantOp>(
- loc, dstType, rewriter.getZeroAttr(dstType));
- for (int64_t d = 0; d < dim; d++) {
- Value bnd =
- rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(d));
- Value val = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
- bnd, idx);
- Value sel = rewriter.create<arith::SelectOp>(loc, val, trueVal, falseVal);
- auto pos = rewriter.getI64ArrayAttr(d);
- result =
- rewriter.create<vector::InsertOp>(loc, dstType, sel, result, pos);
- }
- rewriter.replaceOp(op, result);
- return success();
- }
-};
-
-/// ShapeOp 2D -> 1D downcast serves the purpose of flattening 2-D to 1-D
-/// vectors progressively on the way to target llvm.matrix intrinsics.
-/// This iterates over the most major dimension of the 2-D vector and performs
-/// rewrites into:
-/// vector.extract from 2-D + vector.insert_strided_slice offset into 1-D
-class ShapeCastOp2DDownCastRewritePattern
- : public OpRewritePattern<vector::ShapeCastOp> {
-public:
- using OpRewritePattern::OpRewritePattern;
-
- LogicalResult matchAndRewrite(vector::ShapeCastOp op,
- PatternRewriter &rewriter) const override {
- auto sourceVectorType = op.getSourceVectorType();
- auto resultVectorType = op.getResultVectorType();
- if (sourceVectorType.getRank() != 2 || resultVectorType.getRank() != 1)
- return failure();
-
- auto loc = op.getLoc();
- Value desc = rewriter.create<arith::ConstantOp>(
- loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
- unsigned mostMinorVectorSize = sourceVectorType.getShape()[1];
- for (int64_t i = 0, e = sourceVectorType.getShape().front(); i != e; ++i) {
- Value vec = rewriter.create<vector::ExtractOp>(loc, op.getSource(), i);
- desc = rewriter.create<vector::InsertStridedSliceOp>(
- loc, vec, desc,
- /*offsets=*/i * mostMinorVectorSize, /*strides=*/1);
- }
- rewriter.replaceOp(op, desc);
- return success();
- }
-};
-
-/// ShapeOp 1D -> 2D upcast serves the purpose of unflattening 2-D from 1-D
-/// vectors progressively.
-/// This iterates over the most major dimension of the 2-D vector and performs
-/// rewrites into:
-/// vector.extract_strided_slice from 1-D + vector.insert into 2-D
-/// Note that 1-D extract_strided_slice are lowered to efficient vector.shuffle.
-class ShapeCastOp2DUpCastRewritePattern
- : public OpRewritePattern<vector::ShapeCastOp> {
-public:
- using OpRewritePattern::OpRewritePattern;
-
- LogicalResult matchAndRewrite(vector::ShapeCastOp op,
- PatternRewriter &rewriter) const override {
- auto sourceVectorType = op.getSourceVectorType();
- auto resultVectorType = op.getResultVectorType();
- if (sourceVectorType.getRank() != 1 || resultVectorType.getRank() != 2)
- return failure();
-
- auto loc = op.getLoc();
- Value desc = rewriter.create<arith::ConstantOp>(
- loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
- unsigned mostMinorVectorSize = resultVectorType.getShape()[1];
- for (int64_t i = 0, e = resultVectorType.getShape().front(); i != e; ++i) {
- Value vec = rewriter.create<vector::ExtractStridedSliceOp>(
- loc, op.getSource(), /*offsets=*/i * mostMinorVectorSize,
- /*sizes=*/mostMinorVectorSize,
- /*strides=*/1);
- desc = rewriter.create<vector::InsertOp>(loc, vec, desc, i);
- }
- rewriter.replaceOp(op, desc);
- return success();
- }
-};
-
-// We typically should not lower general shape cast operations into data
-// movement instructions, since the assumption is that these casts are
-// optimized away during progressive lowering. For completeness, however,
-// we fall back to a reference implementation that moves all elements
-// into the right place if we get here.
-class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
-public:
- using OpRewritePattern::OpRewritePattern;
-
- LogicalResult matchAndRewrite(vector::ShapeCastOp op,
- PatternRewriter &rewriter) const override {
- Location loc = op.getLoc();
- auto sourceVectorType = op.getSourceVectorType();
- auto resultVectorType = op.getResultVectorType();
-
- // Special case 2D/1D lowerings with better implementations.
- // TODO: make is ND/1D to allow generic ND->1D->MD.
- int64_t srcRank = sourceVectorType.getRank();
- int64_t resRank = resultVectorType.getRank();
- if ((srcRank == 2 && resRank == 1) || (srcRank == 1 && resRank == 2))
- return failure();
-
- // Generic ShapeCast lowering path goes all the way down to unrolled scalar
- // extract/insert chains.
- // TODO: consider evolving the semantics to only allow 1D source or dest and
- // drop this potentially very expensive lowering.
- // Compute number of elements involved in the reshape.
- int64_t numElts = 1;
- for (int64_t r = 0; r < srcRank; r++)
- numElts *= sourceVectorType.getDimSize(r);
- // Replace with data movement operations:
- // x[0,0,0] = y[0,0]
- // x[0,0,1] = y[0,1]
- // x[0,1,0] = y[0,2]
- // etc., incrementing the two index vectors "row-major"
- // within the source and result shape.
- SmallVector<int64_t> srcIdx(srcRank);
- SmallVector<int64_t> resIdx(resRank);
- Value result = rewriter.create<arith::ConstantOp>(
- loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
- for (int64_t i = 0; i < numElts; i++) {
- if (i != 0) {
- incIdx(srcIdx, sourceVectorType, srcRank - 1);
- incIdx(resIdx, resultVectorType, resRank - 1);
- }
- Value e = rewriter.create<vector::ExtractOp>(loc, op.getSource(), srcIdx);
- result = rewriter.create<vector::InsertOp>(loc, e, result, resIdx);
- }
- rewriter.replaceOp(op, result);
- return success();
- }
-
-private:
- static void incIdx(SmallVector<int64_t> &idx, VectorType tp, int64_t r) {
- assert(0 <= r && r < tp.getRank());
- if (++idx[r] == tp.getDimSize(r)) {
- idx[r] = 0;
- incIdx(idx, tp, r - 1);
- }
- }
-};
-
/// Convert MulIOp/MulFOp + MultiDimReductionOp<add> into ContractionOp.
/// Ex:
/// ```
@@ -1425,967 +550,6 @@ struct ReorderElementwiseOpsOnTranspose final
}
};
-} // namespace
-
-/// Creates an AddIOp if `isInt` is true otherwise create an arith::AddFOp using
-/// operands `x` and `y`.
-static Value createAdd(Location loc, Value x, Value y, bool isInt,
- PatternRewriter &rewriter) {
- if (isInt)
- return rewriter.create<arith::AddIOp>(loc, x, y);
- return rewriter.create<arith::AddFOp>(loc, x, y);
-}
-
-/// Creates a MulIOp if `isInt` is true otherwise create an MulFOp using
-/// operands `x and `y`.
-static Value createMul(Location loc, Value x, Value y, bool isInt,
- PatternRewriter &rewriter) {
- if (isInt)
- return rewriter.create<arith::MulIOp>(loc, x, y);
- return rewriter.create<arith::MulFOp>(loc, x, y);
-}
-
-namespace mlir {
-
-/// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul
-/// semantics to:
-/// ```
-/// %mta = maybe_transpose
-/// %mtb = maybe_transpose
-/// %flattened_a = vector.shape_cast %mta
-/// %flattened_b = vector.shape_cast %mtb
-/// %flattened_d = vector.matmul %flattened_a, %flattened_b
-/// %mtd = vector.shape_cast %flattened_d
-/// %d = maybe_untranspose %mtd
-/// %e = add %c, %d
-/// ```
-/// `vector.matmul` later lowers to `llvm.matrix.multiply`.
-//
-/// This only kicks in when VectorTransformsOptions is set to `Matmul`.
-/// vector.transpose operations are inserted if the vector.contract op is not a
-/// row-major matrix multiply.
-LogicalResult
-ContractionOpToMatmulOpLowering::matchAndRewrite(vector::ContractionOp op,
- PatternRewriter &rew) const {
- // TODO: Support vector.mask.
- auto maskableOp = cast<MaskableOpInterface>(op.getOperation());
- if (maskableOp.isMasked())
- return failure();
-
- // TODO: Remove native masks from contraction op?
- if (!op.getMasks().empty())
- return failure();
- if (vectorTransformOptions.vectorContractLowering !=
- vector::VectorContractLowering::Matmul)
- return failure();
- if (failed(filter(op)))
- return failure();
-
- auto iteratorTypes = op.getIteratorTypes().getValue();
- if (!isParallelIterator(iteratorTypes[0]) ||
- !isParallelIterator(iteratorTypes[1]) ||
- !isReductionIterator(iteratorTypes[2]))
- return failure();
-
- Type elementType = op.getLhsType().getElementType();
- if (!elementType.isIntOrFloat())
- return failure();
-
- Type dstElementType = op.getType();
- if (auto vecType = dstElementType.dyn_cast<VectorType>())
- dstElementType = vecType.getElementType();
- if (elementType != dstElementType)
- return failure();
-
- // Perform lhs + rhs transpositions to conform to matmul row-major semantics.
- // Bail out if the contraction cannot be put in this form.
- MLIRContext *ctx = op.getContext();
- Location loc = op.getLoc();
- AffineExpr m, n, k;
- bindDims(rew.getContext(), m, n, k);
- // LHS must be A(m, k) or A(k, m).
- Value lhs = op.getLhs();
- auto lhsMap = op.getIndexingMapsArray()[0];
- if (lhsMap == AffineMap::get(3, 0, {k, m}, ctx))
- lhs = rew.create<vector::TransposeOp>(loc, lhs, ArrayRef<int64_t>{1, 0});
- else if (lhsMap != AffineMap::get(3, 0, {m, k}, ctx))
- return failure();
-
- // RHS must be B(k, n) or B(n, k).
- Value rhs = op.getRhs();
- auto rhsMap = op.getIndexingMapsArray()[1];
- if (rhsMap == AffineMap::get(3, 0, {n, k}, ctx))
- rhs = rew.create<vector::TransposeOp>(loc, rhs, ArrayRef<int64_t>{1, 0});
- else if (rhsMap != AffineMap::get(3, 0, {k, n}, ctx))
- return failure();
-
- // At this point lhs and rhs are in row-major.
- VectorType lhsType = lhs.getType().cast<VectorType>();
- VectorType rhsType = rhs.getType().cast<VectorType>();
- int64_t lhsRows = lhsType.getDimSize(0);
- int64_t lhsColumns = lhsType.getDimSize(1);
- int64_t rhsColumns = rhsType.getDimSize(1);
-
- Type flattenedLHSType =
- VectorType::get(lhsType.getNumElements(), lhsType.getElementType());
- lhs = rew.create<vector::ShapeCastOp>(loc, flattenedLHSType, lhs);
-
- Type flattenedRHSType =
- VectorType::get(rhsType.getNumElements(), rhsType.getElementType());
- rhs = rew.create<vector::ShapeCastOp>(loc, flattenedRHSType, rhs);
-
- Value mul = rew.create<vector::MatmulOp>(loc, lhs, rhs, lhsRows, lhsColumns,
- rhsColumns);
- mul = rew.create<vector::ShapeCastOp>(
- loc,
- VectorType::get({lhsRows, rhsColumns},
- getElementTypeOrSelf(op.getAcc().getType())),
- mul);
-
- // ACC must be C(m, n) or C(n, m).
- auto accMap = op.getIndexingMapsArray()[2];
- if (accMap == AffineMap::get(3, 0, {n, m}, ctx))
- mul = rew.create<vector::TransposeOp>(loc, mul, ArrayRef<int64_t>{1, 0});
- else if (accMap != AffineMap::get(3, 0, {m, n}, ctx))
- llvm_unreachable("invalid contraction semantics");
-
- Value res =
- elementType.isa<IntegerType>()
- ? static_cast<Value>(rew.create<arith::AddIOp>(loc, op.getAcc(), mul))
- : static_cast<Value>(
- rew.create<arith::AddFOp>(loc, op.getAcc(), mul));
-
- rew.replaceOp(op, res);
- return success();
-}
-
-namespace {
-
-/// Generate a vector implementation for matmat, matvec and tmatvec.
-/// This unrolls outer-products along the reduction dimension.
-struct UnrolledOuterProductGenerator
- : public StructuredGenerator<vector::ContractionOp, vector::IteratorType> {
- UnrolledOuterProductGenerator(RewriterBase &b, vector::ContractionOp op)
- : StructuredGenerator<vector::ContractionOp, vector::IteratorType>(b, op),
- kind(op.getKind()), lhs(op.getLhs()), rhs(op.getRhs()),
- res(op.getAcc()), lhsType(op.getLhsType()) {
- auto maskableOp = cast<MaskableOpInterface>(op.getOperation());
- if (maskableOp.isMasked())
- mask = maskableOp.getMaskingOp().getMask();
- }
-
- Value t(Value v, ArrayRef<int64_t> perm = {1, 0}) {
- if (!v)
- return v;
- return rewriter.create<vector::TransposeOp>(loc, v, perm);
- }
-
- Value promote(Value v, Type dstElementType) {
- Type elementType = v.getType();
- auto vecType = elementType.dyn_cast<VectorType>();
- if (vecType)
- elementType = vecType.getElementType();
- if (elementType == dstElementType)
- return v;
- Type promotedType = dstElementType;
- if (vecType)
- promotedType = VectorType::get(vecType.getShape(), promotedType);
- if (dstElementType.isa<FloatType>())
- return rewriter.create<arith::ExtFOp>(loc, promotedType, v);
- return rewriter.create<arith::ExtSIOp>(loc, promotedType, v);
- }
-
- FailureOr<Value> outerProd(Value lhs, Value rhs, Value res, int reductionSize,
- std::optional<Value> maybeMask = std::nullopt) {
- assert(reductionSize > 0);
- // Incremental support for masking.
- if (mask && !maybeMask.has_value())
- return failure();
-
- Type resElementType = res.getType().cast<VectorType>().getElementType();
- for (int64_t k = 0; k < reductionSize; ++k) {
- Value extractA = rewriter.create<vector::ExtractOp>(loc, lhs, k);
- Value extractB = rewriter.create<vector::ExtractOp>(loc, rhs, k);
- extractA = promote(extractA, resElementType);
- extractB = promote(extractB, resElementType);
- Value extractMask;
- if (maybeMask.has_value() && maybeMask.value())
- extractMask =
- rewriter.create<vector::ExtractOp>(loc, maybeMask.value(), k);
-
- Operation *outerProdOp = rewriter.create<vector::OuterProductOp>(
- loc, res.getType(), extractA, extractB, res, kind);
- res = maskOperation(rewriter, outerProdOp, extractMask)->getResult(0);
- }
- return res;
- }
-
- /// Two outer parallel, one inner reduction (matmat flavor).
- FailureOr<Value> matmat() {
- if (!iters({Par(), Par(), Red()}))
- return failure();
- // Set up the parallel/reduction structure in the right form.
- AffineExpr m, n, k;
- bindDims(rewriter.getContext(), m, n, k);
- // Classical row-major matmul: Just permute the lhs.
- if (layout({{m, k}, {k, n}, {m, n}}))
- return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1),
- t(mask, {2, 0, 1}));
- // TODO: may be better to fail and use some vector<k> -> scalar reduction.
- if (layout({{m, k}, {n, k}, {m, n}})) {
- Value tlhs = t(lhs);
- return outerProd(tlhs, t(rhs), res, lhsType.getDimSize(1));
- }
- // No need to permute anything.
- if (layout({{k, m}, {k, n}, {m, n}}))
- return outerProd(lhs, rhs, res, lhsType.getDimSize(0));
- // Just permute the rhs.
- if (layout({{k, m}, {n, k}, {m, n}}))
- return outerProd(lhs, t(rhs), res, lhsType.getDimSize(0));
- // Transposed output: swap RHS and LHS.
- // Classical row-major matmul: permute the lhs.
- if (layout({{m, k}, {k, n}, {n, m}}))
- return outerProd(rhs, t(lhs), res, lhsType.getDimSize(1));
- // TODO: may be better to fail and use some vector<k> -> scalar reduction.
- if (layout({{m, k}, {n, k}, {n, m}})) {
- Value trhs = t(rhs);
- return outerProd(trhs, t(lhs), res, lhsType.getDimSize(1));
- }
- if (layout({{k, m}, {k, n}, {n, m}}))
- return outerProd(rhs, lhs, res, lhsType.getDimSize(0));
- if (layout({{k, m}, {n, k}, {n, m}}))
- return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0));
- return failure();
- }
-
- /// One outer parallel, one inner reduction (matvec flavor)
- FailureOr<Value> matvec() {
- if (!iters({Par(), Red()}))
- return failure();
- AffineExpr m, k;
- bindDims(rewriter.getContext(), m, k);
-
- // Case mat-vec: transpose.
- if (layout({{m, k}, {k}, {m}}))
- return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1), t(mask));
- // Case mat-trans-vec: ready to go.
- if (layout({{k, m}, {k}, {m}}))
- return outerProd(lhs, rhs, res, lhsType.getDimSize(0));
- // Case vec-mat: swap and transpose.
- if (layout({{k}, {m, k}, {m}}))
- return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0));
- // Case vec-mat-trans: swap and ready to go.
- if (layout({{k}, {k, m}, {m}}))
- return outerProd(rhs, lhs, res, lhsType.getDimSize(0));
- return failure();
- }
-
- //
- // One outer reduction, one inner parallel (tmatvec flavor)
- //
- FailureOr<Value> tmatvec() {
- if (!iters({Red(), Par()}))
- return failure();
- AffineExpr k, m;
- bindDims(rewriter.getContext(), k, m);
-
- // Case mat-vec: transpose.
- if (layout({{m, k}, {k}, {m}}))
- return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1));
- // Case mat-trans-vec: ready to go.
- if (layout({{k, m}, {k}, {m}}))
- return outerProd(lhs, rhs, res, lhsType.getDimSize(0));
- // Case vec-mat: swap and transpose.
- if (layout({{k}, {m, k}, {m}}))
- return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0));
- // Case vec-mat-trans: swap and ready to go.
- if (layout({{k}, {k, m}, {m}}))
- return outerProd(rhs, lhs, res, lhsType.getDimSize(0));
- return failure();
- }
-
-private:
- vector::CombiningKind kind;
- Value lhs, rhs, res, mask;
- VectorType lhsType;
-};
-} // namespace
-
-/// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul
-/// semantics to a reduction_size-unrolled sequence:
-/// ```
-/// %at = vector.transpose %a, [1, 0]
-/// %bRow0 = vector.extract %b[0]
-/// %atRow0 = vector.extract %at[0]
-/// %c0 = vector.outerproduct %atRow0, %bRow0, %c
-/// ...
-/// %bRowK = vector.extract %b[K]
-/// %atRowK = vector.extract %at[K]
-/// %cK = vector.outerproduct %atRowK, %bRowK, %cK-1
-/// ```
-///
-/// This only kicks in when VectorTransformsOptions is set to OuterProduct but
-/// otherwise supports any layout permutation of the matrix-multiply.
-LogicalResult ContractionOpToOuterProductOpLowering::matchAndRewrite(
- vector::ContractionOp op, PatternRewriter &rewriter) const {
- // TODO: Remove native masks from contraction op?
- if (!op.getMasks().empty())
- return failure();
-
- if (vectorTransformOptions.vectorContractLowering !=
- vector::VectorContractLowering::OuterProduct)
- return failure();
-
- if (failed(filter(op)))
- return failure();
-
- // Vector mask setup.
- OpBuilder::InsertionGuard guard(rewriter);
- auto maskableOp = cast<vector::MaskableOpInterface>(op.getOperation());
- Operation *rootOp;
- if (maskableOp.isMasked()) {
- rewriter.setInsertionPoint(maskableOp.getMaskingOp());
- rootOp = maskableOp.getMaskingOp();
- } else {
- rootOp = op;
- }
-
- UnrolledOuterProductGenerator e(rewriter, op);
- FailureOr<Value> matmatRes = e.matmat();
- if (succeeded(matmatRes)) {
- rewriter.replaceOp(rootOp, *matmatRes);
- return success();
- }
- FailureOr<Value> matvecRes = e.matvec();
- if (succeeded(matvecRes)) {
- rewriter.replaceOp(rootOp, *matvecRes);
- return success();
- }
- FailureOr<Value> tmatvecRes = e.tmatvec();
- if (succeeded(tmatvecRes)) {
- rewriter.replaceOp(rootOp, *tmatvecRes);
- return success();
- }
-
- return failure();
-}
-
-LogicalResult
-ContractionOpToDotLowering::matchAndRewrite(vector::ContractionOp op,
- PatternRewriter &rewriter) const {
- // TODO: Support vector.mask.
- auto maskableOp = cast<MaskableOpInterface>(op.getOperation());
- if (maskableOp.isMasked())
- return failure();
-
- // TODO: Remove native masks from contraction op?
- if (!op.getMasks().empty())
- return failure();
-
- if (failed(filter(op)))
- return failure();
-
- if (vectorTransformOptions.vectorContractLowering !=
- vector::VectorContractLowering::Dot)
- return failure();
-
- auto iteratorTypes = op.getIteratorTypes().getValue();
- static constexpr std::array<int64_t, 2> perm = {1, 0};
- Location loc = op.getLoc();
- Value lhs = op.getLhs(), rhs = op.getRhs();
-
- using MapList = ArrayRef<ArrayRef<AffineExpr>>;
- auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
- AffineExpr m, n, k;
- bindDims(rewriter.getContext(), m, n, k);
- SmallVector<AffineMap> maps = op.getIndexingMapsArray();
- //
- // In the following we wish to make the reduction dimension innermost so we
- // can load vectors and just fmul + reduce into a scalar.
- //
- if (isParallelIterator(iteratorTypes[0]) &&
- isParallelIterator(iteratorTypes[1]) &&
- isReductionIterator(iteratorTypes[2])) {
- //
- // Two outer parallel, one inner reduction (matmat flavor).
- //
- if (maps == infer({{m, k}, {k, n}, {m, n}})) {
- rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
- } else if (maps == infer({{m, k}, {n, k}, {m, n}})) {
- // No need to permute anything.
- } else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
- lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
- rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
- } else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
- lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
- } else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
- // This is the classical row-major matmul. Just permute the lhs.
- Value tmp = lhs;
- lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
- rhs = tmp;
- } else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
- std::swap(lhs, rhs);
- } else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
- Value tmp = lhs;
- lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
- rhs = rewriter.create<vector::TransposeOp>(loc, tmp, perm);
- } else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
- Value tmp = rhs;
- rhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
- lhs = tmp;
- } else {
- return failure();
- }
- } else if (isParallelIterator(iteratorTypes[0]) &&
- isReductionIterator(iteratorTypes[1])) {
- //
- // One outer parallel, one inner reduction (matvec flavor)
- //
- if (maps == infer({{m, n}, {n}, {m}})) {
- // No need to permute anything.
- } else if (maps == infer({{n, m}, {n}, {m}})) {
- lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
- } else if (maps == infer({{n}, {m, n}, {m}})) {
- std::swap(lhs, rhs);
- } else if (maps == infer({{n}, {n, m}, {m}})) {
- std::swap(lhs, rhs);
- lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
- } else {
- return failure();
- }
- } else {
- return failure();
- }
-
- VectorType dstType = op.getResultType().cast<VectorType>();
- assert(dstType.getRank() >= 1 && dstType.getRank() <= 2 &&
- "Expected dst type of rank 1 or 2");
-
- unsigned rank = dstType.getRank();
- unsigned dstRows = dstType.getShape()[0];
- unsigned dstColumns = rank == 1 ? 1 : dstType.getShape()[1];
-
- // ExtractOp does not allow dynamic indexing, we must unroll explicitly.
- Value res = rewriter.create<arith::ConstantOp>(loc, dstType,
- rewriter.getZeroAttr(dstType));
- bool isInt = dstType.getElementType().isa<IntegerType>();
- for (unsigned r = 0; r < dstRows; ++r) {
- Value a = rewriter.create<vector::ExtractOp>(op.getLoc(), lhs, r);
- for (unsigned c = 0; c < dstColumns; ++c) {
- Value b = rank == 1
- ? rhs
- : rewriter.create<vector::ExtractOp>(op.getLoc(), rhs, c);
- Value m = createMul(op.getLoc(), a, b, isInt, rewriter);
- Value reduced = rewriter.create<vector::ReductionOp>(
- op.getLoc(), vector::CombiningKind::ADD, m);
-
- SmallVector<int64_t, 2> pos = rank == 1 ? SmallVector<int64_t, 2>{r}
- : SmallVector<int64_t, 2>{r, c};
- res = rewriter.create<vector::InsertOp>(op.getLoc(), reduced, res, pos);
- }
- }
- if (auto acc = op.getAcc())
- res = createAdd(op.getLoc(), res, acc, isInt, rewriter);
- rewriter.replaceOp(op, res);
- return success();
-}
-
-/// Progressive lowering of ContractionOp.
-/// One:
-/// %x = vector.contract with at least one free/batch dimension
-/// is replaced by:
-/// %a = vector.contract with one less free/batch dimension
-/// %b = vector.contract with one less free/batch dimension
-/// ..
-/// %x = combine %a %b ..
-/// until a pure contraction is reached (no free/batch dimensions),
-/// which is replaced by a dot-product.
-///
-/// This only kicks in when either VectorTransformsOptions is set
-/// to DOT or when other contraction patterns fail.
-//
-// TODO: break down into transpose/reshape/cast ops
-// when they become available to avoid code dup
-// TODO: investigate lowering order impact on performance
-LogicalResult
-ContractionOpLowering::matchAndRewrite(vector::ContractionOp op,
- PatternRewriter &rewriter) const {
- // TODO: Remove native masks from contraction op?
- if (!op.getMasks().empty())
- return failure();
-
- if (failed(filter(op)))
- return failure();
-
- // TODO: support mixed mode contract lowering.
- if (op.getLhsType().getElementType() !=
- getElementTypeOrSelf(op.getAccType()) ||
- op.getRhsType().getElementType() != getElementTypeOrSelf(op.getAccType()))
- return failure();
-
- // TODO: the code below assumes the default contraction, make sure it supports
- // other kinds before enabling this lowering.
- if (op.getKind() != vector::CombiningKind::ADD) {
- return rewriter.notifyMatchFailure(
- op, "contractions other than 'add' not supported");
- }
-
- // TODO: implement benefits, cost models.
- MLIRContext *ctx = op.getContext();
- ContractionOpToMatmulOpLowering pat1(vectorTransformOptions, ctx);
- if (succeeded(pat1.matchAndRewrite(op, rewriter)))
- return success();
- ContractionOpToOuterProductOpLowering pat2(vectorTransformOptions, ctx);
- if (succeeded(pat2.matchAndRewrite(op, rewriter)))
- return success();
- ContractionOpToDotLowering pat3(vectorTransformOptions, ctx);
- if (succeeded(pat3.matchAndRewrite(op, rewriter)))
- return success();
- ContractOpToElementwise pat4(vectorTransformOptions, ctx);
- if (succeeded(pat4.matchAndRewrite(op, rewriter)))
- return success();
-
- // Vector mask setup.
- OpBuilder::InsertionGuard guard(rewriter);
- Operation *rootOp = op;
- Value mask;
- if (op.isMasked()) {
- rewriter.setInsertionPoint(op.getMaskingOp());
- rootOp = op.getMaskingOp();
- mask = op.getMaskingOp().getMask();
- }
-
- // Find first batch dimension in LHS/RHS, and lower when found.
- std::vector<std::pair<int64_t, int64_t>> batchDimMap = op.getBatchDimMap();
- if (!batchDimMap.empty()) {
- int64_t lhsIndex = batchDimMap[0].first;
- int64_t rhsIndex = batchDimMap[0].second;
- auto newOp = lowerParallel(rewriter, op, lhsIndex, rhsIndex, mask);
- if (failed(newOp))
- return failure();
- rewriter.replaceOp(rootOp, *newOp);
- return success();
- }
-
- // Collect contracting dimensions.
- std::vector<std::pair<int64_t, int64_t>> contractingDimMap =
- op.getContractingDimMap();
- DenseSet<int64_t> lhsContractingDimSet;
- DenseSet<int64_t> rhsContractingDimSet;
- for (auto &dimPair : contractingDimMap) {
- lhsContractingDimSet.insert(dimPair.first);
- rhsContractingDimSet.insert(dimPair.second);
- }
-
- // Find first free dimension in LHS, and lower when found.
- VectorType lhsType = op.getLhsType();
- for (int64_t lhsIndex = 0, e = lhsType.getRank(); lhsIndex < e; ++lhsIndex) {
- if (lhsContractingDimSet.count(lhsIndex) == 0) {
- auto newOp = lowerParallel(rewriter, op, lhsIndex, /*rhsIndex=*/-1, mask);
- if (failed(newOp))
- return failure();
- rewriter.replaceOp(rootOp, *newOp);
- return success();
- }
- }
-
- // Find first free dimension in RHS, and lower when found.
- VectorType rhsType = op.getRhsType();
- for (int64_t rhsIndex = 0, e = rhsType.getRank(); rhsIndex < e; ++rhsIndex) {
- if (rhsContractingDimSet.count(rhsIndex) == 0) {
- auto newOp = lowerParallel(rewriter, op, /*lhsIndex=*/-1, rhsIndex, mask);
- if (failed(newOp))
- return failure();
- rewriter.replaceOp(rootOp, *newOp);
- return success();
- }
- }
-
- // Lower the first remaining reduction dimension.
- if (!contractingDimMap.empty()) {
- auto newOp = lowerReduction(rewriter, op, mask);
- if (failed(newOp))
- return failure();
- rewriter.replaceOp(rootOp, *newOp);
- return success();
- }
-
- return failure();
-}
-
-// Lower one parallel dimension.
-// Incidentally also tolerates unit-size (hence trivial) reduction dimensions.
-// TODO: consider reusing existing contract unrolling
-FailureOr<Value> ContractionOpLowering::lowerParallel(PatternRewriter &rewriter,
- vector::ContractionOp op,
- int64_t lhsIndex,
- int64_t rhsIndex,
- Value mask) const {
- VectorType lhsType = op.getLhsType();
- VectorType rhsType = op.getRhsType();
- VectorType resType = op.getResultType().cast<VectorType>();
- // Find the iterator type index and result index.
- SmallVector<AffineMap> iMap = op.getIndexingMapsArray();
- int64_t iterIndex = -1;
- int64_t dimSize = -1;
- if (lhsIndex >= 0) {
- iterIndex = iMap[0].getDimPosition(lhsIndex);
- if (rhsIndex >= 0 && iterIndex != iMap[1].getDimPosition(rhsIndex))
- return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
- diag << "expected lhsIndex=" << lhsIndex << " and rhsIndex=" << rhsIndex
- << " to map to the same dimension";
- });
- dimSize = lhsType.getDimSize(lhsIndex);
- } else if (rhsIndex >= 0) {
- iterIndex = iMap[1].getDimPosition(rhsIndex);
- dimSize = rhsType.getDimSize(rhsIndex);
- }
- if (iterIndex < 0)
- return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
- diag << "expected either lhsIndex=" << lhsIndex
- << " or rhsIndex=" << rhsIndex << " to be nonnegative";
- });
- // value_or(-1) means that we tolerate a dimension not appearing
- // in the result map. That can't happen for actual parallel iterators, but
- // the caller ContractionOpLowering::matchAndRewrite is currently calling
- // lowerParallel also for the case of unit-size reduction dims appearing only
- // on one of LHS or RHS, not both. At the moment, such cases are created by
- // CastAwayContractionLeadingOneDim, so we need to either support that or
- // modify that pattern.
- int64_t resIndex = getResultIndex(iMap[2], iterIndex).value_or(-1);
- if (resIndex == -1 && dimSize != 1)
- return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
- diag << "expected the dimension for iterIndex=" << iterIndex
- << " to either appear in the result map, or to be a unit dimension";
- });
-
- // Construct new iterator types and affine map array attribute.
- std::array<AffineMap, 3> lowIndexingMaps = {
- adjustMap(iMap[0], iterIndex, rewriter),
- adjustMap(iMap[1], iterIndex, rewriter),
- adjustMap(iMap[2], iterIndex, rewriter)};
- auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps);
- auto lowIter =
- rewriter.getArrayAttr(adjustIter(op.getIteratorTypes(), iterIndex));
- // Unroll into a series of lower dimensional vector.contract ops.
- Location loc = op.getLoc();
- Value result = rewriter.create<arith::ConstantOp>(
- loc, resType, rewriter.getZeroAttr(resType));
-
- for (int64_t d = 0; d < dimSize; ++d) {
- auto lhs = reshapeLoad(loc, op.getLhs(), lhsType, lhsIndex, d, rewriter);
- auto rhs = reshapeLoad(loc, op.getRhs(), rhsType, rhsIndex, d, rewriter);
- auto acc = reshapeLoad(loc, op.getAcc(), resType, resIndex, d, rewriter);
-
- Value lowMask;
- if (mask)
- lowMask = reshapeLoad(loc, mask, cast<VectorType>(mask.getType()),
- iterIndex, d, rewriter);
-
- Operation *lowContract = rewriter.create<vector::ContractionOp>(
- loc, lhs, rhs, acc, lowAffine, lowIter);
- lowContract = maskOperation(rewriter, lowContract, lowMask);
- result = reshapeStore(loc, lowContract->getResult(0), result, resType,
- resIndex, d, rewriter);
- }
- return result;
-}
-
-// Lower one reduction dimension.
-FailureOr<Value> ContractionOpLowering::lowerReduction(
- PatternRewriter &rewriter, vector::ContractionOp op, Value mask) const {
- auto loc = op.getLoc();
- VectorType lhsType = op.getLhsType();
- VectorType rhsType = op.getRhsType();
- Type resType = op.getResultType();
- if (resType.isa<VectorType>())
- return rewriter.notifyMatchFailure(op,
- "did not expect a VectorType result");
- bool isInt = resType.isa<IntegerType>();
- // Use iterator index 0.
- int64_t iterIndex = 0;
- SmallVector<AffineMap> iMap = op.getIndexingMapsArray();
- std::optional<int64_t> lookupLhs = getResultIndex(iMap[0], iterIndex);
- std::optional<int64_t> lookupRhs = getResultIndex(iMap[1], iterIndex);
- if (!lookupLhs.has_value())
- return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
- diag << "expected iterIndex=" << iterIndex << "to map to a LHS dimension";
- });
- if (!lookupRhs.has_value())
- return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
- diag << "expected iterIndex=" << iterIndex << "to map to a RHS dimension";
- });
- int64_t lhsIndex = *lookupLhs;
- int64_t rhsIndex = *lookupRhs;
- int64_t dimSize = lhsType.getDimSize(lhsIndex);
- if (dimSize != rhsType.getDimSize(rhsIndex))
- return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
- diag << "expect LHS dimension " << lhsIndex
- << " to have the same size as RHS dimension " << rhsIndex;
- });
- // Base case.
- if (lhsType.getRank() == 1) {
- if (rhsType.getRank() != 1)
- return rewriter.notifyMatchFailure(
- op, "When LHS has rank 1, expected also RHS to have rank 1");
- Value m = createMul(loc, op.getLhs(), op.getRhs(), isInt, rewriter);
- auto kind = vector::CombiningKind::ADD;
-
- Value acc = op.getAcc();
- Operation *reductionOp =
- acc ? rewriter.create<vector::ReductionOp>(loc, kind, m, acc)
- : rewriter.create<vector::ReductionOp>(loc, kind, m);
- return maskOperation(rewriter, reductionOp, mask)->getResult(0);
- }
- // Construct new iterator types and affine map array attribute.
- std::array<AffineMap, 3> lowIndexingMaps = {
- adjustMap(iMap[0], iterIndex, rewriter),
- adjustMap(iMap[1], iterIndex, rewriter),
- adjustMap(iMap[2], iterIndex, rewriter)};
- auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps);
- auto lowIter =
- rewriter.getArrayAttr(adjustIter(op.getIteratorTypes(), iterIndex));
- // Unroll into a series of lower dimensional vector.contract ops.
- // By feeding the initial accumulator into the first contraction,
- // and the result of each contraction into the next, eventually
- // the sum of all reductions is computed.
- Value result = op.getAcc();
- for (int64_t d = 0; d < dimSize; ++d) {
- auto lhs = reshapeLoad(loc, op.getLhs(), lhsType, lhsIndex, d, rewriter);
- auto rhs = reshapeLoad(loc, op.getRhs(), rhsType, rhsIndex, d, rewriter);
- Value newMask;
- if (mask)
- newMask = reshapeLoad(loc, mask, cast<VectorType>(mask.getType()),
- iterIndex, d, rewriter);
-
- Operation *newContract = rewriter.create<vector::ContractionOp>(
- loc, lhs, rhs, result, lowAffine, lowIter);
- result = maskOperation(rewriter, newContract, newMask)->getResult(0);
- }
- return result;
-}
-
-} // namespace mlir
-
-/// Progressive lowering of transfer_read. This pattern supports lowering of
-/// `vector.transfer_read` to a combination of `vector.load` and
-/// `vector.broadcast` if all of the following hold:
-/// - Stride of most minor memref dimension must be 1.
-/// - Out-of-bounds masking is not required.
-/// - If the memref's element type is a vector type then it coincides with the
-/// result type.
-/// - The permutation map doesn't perform permutation (broadcasting is allowed).
-struct TransferReadToVectorLoadLowering
- : public OpRewritePattern<vector::TransferReadOp> {
- TransferReadToVectorLoadLowering(MLIRContext *context,
- std::optional<unsigned> maxRank,
- PatternBenefit benefit = 1)
- : OpRewritePattern<vector::TransferReadOp>(context, benefit),
- maxTransferRank(maxRank) {}
-
- LogicalResult matchAndRewrite(vector::TransferReadOp read,
- PatternRewriter &rewriter) const override {
- if (maxTransferRank && read.getVectorType().getRank() > *maxTransferRank)
- return failure();
-
- SmallVector<unsigned> broadcastedDims;
- // Permutations are handled by VectorToSCF or
- // populateVectorTransferPermutationMapLoweringPatterns.
- // We let the 0-d corner case pass-through as it is supported.
- if (!read.getPermutationMap().isMinorIdentityWithBroadcasting(
- &broadcastedDims))
- return failure();
-
- auto memRefType = read.getShapedType().dyn_cast<MemRefType>();
- if (!memRefType)
- return failure();
-
- // Non-unit strides are handled by VectorToSCF.
- if (!vector::isLastMemrefDimUnitStride(memRefType))
- return failure();
-
- // If there is broadcasting involved then we first load the unbroadcasted
- // vector, and then broadcast it with `vector.broadcast`.
- ArrayRef<int64_t> vectorShape = read.getVectorType().getShape();
- SmallVector<int64_t> unbroadcastedVectorShape(vectorShape.begin(),
- vectorShape.end());
- for (unsigned i : broadcastedDims)
- unbroadcastedVectorShape[i] = 1;
- VectorType unbroadcastedVectorType = VectorType::get(
- unbroadcastedVectorShape, read.getVectorType().getElementType());
-
- // `vector.load` supports vector types as memref's elements only when the
- // resulting vector type is the same as the element type.
- auto memrefElTy = memRefType.getElementType();
- if (memrefElTy.isa<VectorType>() && memrefElTy != unbroadcastedVectorType)
- return failure();
-
- // Otherwise, element types of the memref and the vector must match.
- if (!memrefElTy.isa<VectorType>() &&
- memrefElTy != read.getVectorType().getElementType())
- return failure();
-
- // Out-of-bounds dims are handled by MaterializeTransferMask.
- if (read.hasOutOfBoundsDim())
- return failure();
-
- // Create vector load op.
- Operation *loadOp;
- if (read.getMask()) {
- Value fill = rewriter.create<vector::SplatOp>(
- read.getLoc(), unbroadcastedVectorType, read.getPadding());
- loadOp = rewriter.create<vector::MaskedLoadOp>(
- read.getLoc(), unbroadcastedVectorType, read.getSource(),
- read.getIndices(), read.getMask(), fill);
- } else {
- loadOp = rewriter.create<vector::LoadOp>(
- read.getLoc(), unbroadcastedVectorType, read.getSource(),
- read.getIndices());
- }
-
- // Insert a broadcasting op if required.
- if (!broadcastedDims.empty()) {
- rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
- read, read.getVectorType(), loadOp->getResult(0));
- } else {
- rewriter.replaceOp(read, loadOp->getResult(0));
- }
-
- return success();
- }
-
- std::optional<unsigned> maxTransferRank;
-};
-
-/// Replace a 0-d vector.load with a memref.load + vector.broadcast.
-// TODO: we shouldn't cross the vector/scalar domains just for this
-// but atm we lack the infra to avoid it. Possible solutions include:
-// - go directly to LLVM + bitcast
-// - introduce a bitcast op and likely a new pointer dialect
-// - let memref.load/store additionally support the 0-d vector case
-// There are still deeper data layout issues lingering even in this
-// trivial case (for architectures for which this matters).
-struct VectorLoadToMemrefLoadLowering
- : public OpRewritePattern<vector::LoadOp> {
- using OpRewritePattern::OpRewritePattern;
-
- LogicalResult matchAndRewrite(vector::LoadOp loadOp,
- PatternRewriter &rewriter) const override {
- auto vecType = loadOp.getVectorType();
- if (vecType.getNumElements() != 1)
- return failure();
- auto memrefLoad = rewriter.create<memref::LoadOp>(
- loadOp.getLoc(), loadOp.getBase(), loadOp.getIndices());
- rewriter.replaceOpWithNewOp<vector::BroadcastOp>(loadOp, vecType,
- memrefLoad);
- return success();
- }
-};
-
-/// Replace a 0-d vector.store with a vector.extractelement + memref.store.
-struct VectorStoreToMemrefStoreLowering
- : public OpRewritePattern<vector::StoreOp> {
- using OpRewritePattern::OpRewritePattern;
-
- LogicalResult matchAndRewrite(vector::StoreOp storeOp,
- PatternRewriter &rewriter) const override {
- auto vecType = storeOp.getVectorType();
- if (vecType.getNumElements() != 1)
- return failure();
- Value extracted;
- if (vecType.getRank() == 0) {
- // TODO: Unifiy once ExtractOp supports 0-d vectors.
- extracted = rewriter.create<vector::ExtractElementOp>(
- storeOp.getLoc(), storeOp.getValueToStore());
- } else {
- SmallVector<int64_t> indices(vecType.getRank(), 0);
- extracted = rewriter.create<vector::ExtractOp>(
- storeOp.getLoc(), storeOp.getValueToStore(), indices);
- }
-
- rewriter.replaceOpWithNewOp<memref::StoreOp>(
- storeOp, extracted, storeOp.getBase(), storeOp.getIndices());
- return success();
- }
-};
-
-/// Progressive lowering of transfer_write. This pattern supports lowering of
-/// `vector.transfer_write` to `vector.store` if all of the following hold:
-/// - Stride of most minor memref dimension must be 1.
-/// - Out-of-bounds masking is not required.
-/// - If the memref's element type is a vector type then it coincides with the
-/// type of the written value.
-/// - The permutation map is the minor identity map (neither permutation nor
-/// broadcasting is allowed).
-struct TransferWriteToVectorStoreLowering
- : public OpRewritePattern<vector::TransferWriteOp> {
- TransferWriteToVectorStoreLowering(MLIRContext *context,
- std::optional<unsigned> maxRank,
- PatternBenefit benefit = 1)
- : OpRewritePattern<vector::TransferWriteOp>(context, benefit),
- maxTransferRank(maxRank) {}
-
- LogicalResult matchAndRewrite(vector::TransferWriteOp write,
- PatternRewriter &rewriter) const override {
- if (maxTransferRank && write.getVectorType().getRank() > *maxTransferRank)
- return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) {
- diag << "rank exceeds maxTransferRank: " << write;
- });
-
- // Permutations are handled by VectorToSCF or
- // populateVectorTransferPermutationMapLoweringPatterns.
- if ( // pass-through for the 0-d corner case.
- !write.getPermutationMap().isMinorIdentity())
- return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) {
- diag << "permutation map is not minor identity: " << write;
- });
-
- auto memRefType = write.getShapedType().dyn_cast<MemRefType>();
- if (!memRefType)
- return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) {
- diag << "not a memref type: " << write;
- });
-
- // Non-unit strides are handled by VectorToSCF.
- if (!vector::isLastMemrefDimUnitStride(memRefType))
- return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) {
- diag << "most minor stride is not 1: " << write;
- });
-
- // `vector.store` supports vector types as memref's elements only when the
- // type of the vector value being written is the same as the element type.
- auto memrefElTy = memRefType.getElementType();
- if (memrefElTy.isa<VectorType>() && memrefElTy != write.getVectorType())
- return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) {
- diag << "elemental type mismatch: " << write;
- });
-
- // Otherwise, element types of the memref and the vector must match.
- if (!memrefElTy.isa<VectorType>() &&
- memrefElTy != write.getVectorType().getElementType())
- return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) {
- diag << "elemental type mismatch: " << write;
- });
-
- // Out-of-bounds dims are handled by MaterializeTransferMask.
- if (write.hasOutOfBoundsDim())
- return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) {
- diag << "out of bounds dim: " << write;
- });
- if (write.getMask()) {
- rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
- write, write.getSource(), write.getIndices(), write.getMask(),
- write.getVector());
- } else {
- rewriter.replaceOpWithNewOp<vector::StoreOp>(
- write, write.getVector(), write.getSource(), write.getIndices());
- }
- return success();
- }
-
- std::optional<unsigned> maxTransferRank;
-};
-
// Returns the values in `arrayAttr` as an integer vector.
static SmallVector<int64_t> getIntValueVector(ArrayAttr arrayAttr) {
return llvm::to_vector<4>(
@@ -2863,202 +1027,6 @@ class DropInnerMostUnitDims : public OpRewritePattern<vector::TransferReadOp> {
}
};
-namespace {
-
-/// This function checks to see if the vector combining kind
-/// is consistent with the integer or float element type.
-static bool isValidKind(bool isInt, vector::CombiningKind kind) {
- using vector::CombiningKind;
- enum class KindType { FLOAT, INT, INVALID };
- KindType type{KindType::INVALID};
- switch (kind) {
- case CombiningKind::MINF:
- case CombiningKind::MAXF:
- type = KindType::FLOAT;
- break;
- case CombiningKind::MINUI:
- case CombiningKind::MINSI:
- case CombiningKind::MAXUI:
- case CombiningKind::MAXSI:
- case CombiningKind::AND:
- case CombiningKind::OR:
- case CombiningKind::XOR:
- type = KindType::INT;
- break;
- case CombiningKind::ADD:
- case CombiningKind::MUL:
- type = isInt ? KindType::INT : KindType::FLOAT;
- break;
- }
- bool isValidIntKind = (type == KindType::INT) && isInt;
- bool isValidFloatKind = (type == KindType::FLOAT) && (!isInt);
- return (isValidIntKind || isValidFloatKind);
-}
-
-/// This function constructs the appropriate integer or float
-/// operation given the vector combining kind and operands. The
-/// supported int operations are : add, mul, min (signed/unsigned),
-/// max(signed/unsigned), and, or, xor. The supported float
-/// operations are : add, mul, min and max.
-static Value genOperator(Location loc, Value x, Value y,
- vector::CombiningKind kind,
- PatternRewriter &rewriter) {
- using vector::CombiningKind;
-
- auto elType = x.getType().cast<VectorType>().getElementType();
- bool isInt = elType.isIntOrIndex();
-
- Value combinedResult{nullptr};
- switch (kind) {
- case CombiningKind::ADD:
- if (isInt)
- combinedResult = rewriter.create<arith::AddIOp>(loc, x, y);
- else
- combinedResult = rewriter.create<arith::AddFOp>(loc, x, y);
- break;
- case CombiningKind::MUL:
- if (isInt)
- combinedResult = rewriter.create<arith::MulIOp>(loc, x, y);
- else
- combinedResult = rewriter.create<arith::MulFOp>(loc, x, y);
- break;
- case CombiningKind::MINUI:
- combinedResult = rewriter.create<arith::MinUIOp>(loc, x, y);
- break;
- case CombiningKind::MINSI:
- combinedResult = rewriter.create<arith::MinSIOp>(loc, x, y);
- break;
- case CombiningKind::MAXUI:
- combinedResult = rewriter.create<arith::MaxUIOp>(loc, x, y);
- break;
- case CombiningKind::MAXSI:
- combinedResult = rewriter.create<arith::MaxSIOp>(loc, x, y);
- break;
- case CombiningKind::AND:
- combinedResult = rewriter.create<arith::AndIOp>(loc, x, y);
- break;
- case CombiningKind::OR:
- combinedResult = rewriter.create<arith::OrIOp>(loc, x, y);
- break;
- case CombiningKind::XOR:
- combinedResult = rewriter.create<arith::XOrIOp>(loc, x, y);
- break;
- case CombiningKind::MINF:
- combinedResult = rewriter.create<arith::MinFOp>(loc, x, y);
- break;
- case CombiningKind::MAXF:
- combinedResult = rewriter.create<arith::MaxFOp>(loc, x, y);
- break;
- }
- return combinedResult;
-}
-
-/// Convert vector.scan op into arith ops and
-/// vector.insert_strided_slice/extract_strided_slice
-///
-/// Ex:
-/// ```
-/// %0:2 = vector.scan <add>, %arg0, %arg1 {inclusive = true, reduction_dim =
-/// 1} :
-/// (vector<2x3xi32>, vector<2xi32>) to (vector<2x3xi32>, vector<2xi32>)
-/// ```
-/// Gets converted to:
-/// ```
-/// %cst = arith.constant dense<0> : vector<2x3xi32>
-/// %0 = vector.extract_strided_slice %arg0 {offsets = [0, 0], sizes = [2, 1],
-/// strides = [1, 1]} : vector<2x3xi32> to vector<2x1xi32> %1 =
-/// vector.insert_strided_slice %0, %cst {offsets = [0, 0], strides = [1, 1]}
-/// : vector<2x1xi32> into vector<2x3xi32> %2 = vector.extract_strided_slice
-/// %arg0 {offsets = [0, 1], sizes = [2, 1], strides = [1, 1]} :
-/// vector<2x3xi32> to vector<2x1xi32> %3 = arith.muli %0, %2 :
-/// vector<2x1xi32> %4 = vector.insert_strided_slice %3, %1 {offsets = [0, 1],
-/// strides = [1, 1]} : vector<2x1xi32> into vector<2x3xi32> %5 =
-/// vector.extract_strided_slice %arg0 {offsets = [0, 2], sizes = [2, 1],
-/// strides = [1, 1]} : vector<2x3xi32> to vector<2x1xi32> %6 = arith.muli %3,
-/// %5 : vector<2x1xi32> %7 = vector.insert_strided_slice %6, %4 {offsets =
-/// [0, 2], strides = [1, 1]} : vector<2x1xi32> into vector<2x3xi32> %8 =
-/// vector.shape_cast %6 : vector<2x1xi32> to vector<2xi32> return %7, %8 :
-/// vector<2x3xi32>, vector<2xi32>
-/// ```
-struct ScanToArithOps : public OpRewritePattern<vector::ScanOp> {
- using OpRewritePattern::OpRewritePattern;
-
- LogicalResult matchAndRewrite(vector::ScanOp scanOp,
- PatternRewriter &rewriter) const override {
- auto loc = scanOp.getLoc();
- VectorType destType = scanOp.getDestType();
- ArrayRef<int64_t> destShape = destType.getShape();
- auto elType = destType.getElementType();
- bool isInt = elType.isIntOrIndex();
- if (!isValidKind(isInt, scanOp.getKind()))
- return failure();
-
- VectorType resType = VectorType::get(destShape, elType);
- Value result = rewriter.create<arith::ConstantOp>(
- loc, resType, rewriter.getZeroAttr(resType));
- int64_t reductionDim = scanOp.getReductionDim();
- bool inclusive = scanOp.getInclusive();
- int64_t destRank = destType.getRank();
- VectorType initialValueType = scanOp.getInitialValueType();
- int64_t initialValueRank = initialValueType.getRank();
-
- SmallVector<int64_t> reductionShape(destShape.begin(), destShape.end());
- reductionShape[reductionDim] = 1;
- VectorType reductionType = VectorType::get(reductionShape, elType);
- SmallVector<int64_t> offsets(destRank, 0);
- SmallVector<int64_t> strides(destRank, 1);
- SmallVector<int64_t> sizes(destShape.begin(), destShape.end());
- sizes[reductionDim] = 1;
- ArrayAttr scanSizes = rewriter.getI64ArrayAttr(sizes);
- ArrayAttr scanStrides = rewriter.getI64ArrayAttr(strides);
-
- Value lastOutput, lastInput;
- for (int i = 0; i < destShape[reductionDim]; i++) {
- offsets[reductionDim] = i;
- ArrayAttr scanOffsets = rewriter.getI64ArrayAttr(offsets);
- Value input = rewriter.create<vector::ExtractStridedSliceOp>(
- loc, reductionType, scanOp.getSource(), scanOffsets, scanSizes,
- scanStrides);
- Value output;
- if (i == 0) {
- if (inclusive) {
- output = input;
- } else {
- if (initialValueRank == 0) {
- // ShapeCastOp cannot handle 0-D vectors
- output = rewriter.create<vector::BroadcastOp>(
- loc, input.getType(), scanOp.getInitialValue());
- } else {
- output = rewriter.create<vector::ShapeCastOp>(
- loc, input.getType(), scanOp.getInitialValue());
- }
- }
- } else {
- Value y = inclusive ? input : lastInput;
- output = genOperator(loc, lastOutput, y, scanOp.getKind(), rewriter);
- assert(output != nullptr);
- }
- result = rewriter.create<vector::InsertStridedSliceOp>(
- loc, output, result, offsets, strides);
- lastOutput = output;
- lastInput = input;
- }
-
- Value reduction;
- if (initialValueRank == 0) {
- Value v = rewriter.create<vector::ExtractOp>(loc, lastOutput, 0);
- reduction =
- rewriter.create<vector::BroadcastOp>(loc, initialValueType, v);
- } else {
- reduction = rewriter.create<vector::ShapeCastOp>(loc, initialValueType,
- lastOutput);
- }
-
- rewriter.replaceOp(scanOp, {result, reduction});
- return success();
- }
-};
-
/// Canonicalization of a `vector.contraction %a, %b, %c` with row-major matmul
/// semantics to a contraction suitable for MMT (matrix matrix multiplication
/// with the RHS transposed) lowering.
@@ -3157,132 +1125,6 @@ struct CanonicalizeContractMatmulToMMT final
FilterConstraintType filter;
};
-/// Flattens 2 or more dimensional `vector.gather` ops by unrolling the
-/// outermost dimension. For example:
-/// ```
-/// %g = vector.gather %base[%c0][%v], %mask, %pass_thru :
-/// ... into vector<2x3xf32>
-///
-/// ==>
-///
-/// %0 = arith.constant dense<0.0> : vector<2x3xf32>
-/// %g0 = vector.gather %base[%c0][%v0], %mask0, %pass_thru0 : ...
-/// %1 = vector.insert %g0, %0 [0] : vector<3xf32> into vector<2x3xf32>
-/// %g1 = vector.gather %base[%c0][%v1], %mask1, %pass_thru1 : ...
-/// %g = vector.insert %g1, %1 [1] : vector<3xf32> into vector<2x3xf32>
-/// ```
-///
-/// When applied exhaustively, this will produce a sequence of 1-d gather ops.
-struct FlattenGather : OpRewritePattern<vector::GatherOp> {
- using OpRewritePattern::OpRewritePattern;
-
- LogicalResult matchAndRewrite(vector::GatherOp op,
- PatternRewriter &rewriter) const override {
- VectorType resultTy = op.getType();
- if (resultTy.getRank() < 2)
- return rewriter.notifyMatchFailure(op, "already flat");
-
- Location loc = op.getLoc();
- Value indexVec = op.getIndexVec();
- Value maskVec = op.getMask();
- Value passThruVec = op.getPassThru();
-
- Value result = rewriter.create<arith::ConstantOp>(
- loc, resultTy, rewriter.getZeroAttr(resultTy));
-
- Type subTy = VectorType::get(resultTy.getShape().drop_front(),
- resultTy.getElementType());
-
- for (int64_t i = 0, e = resultTy.getShape().front(); i < e; ++i) {
- int64_t thisIdx[1] = {i};
-
- Value indexSubVec =
- rewriter.create<vector::ExtractOp>(loc, indexVec, thisIdx);
- Value maskSubVec =
- rewriter.create<vector::ExtractOp>(loc, maskVec, thisIdx);
- Value passThruSubVec =
- rewriter.create<vector::ExtractOp>(loc, passThruVec, thisIdx);
- Value subGather = rewriter.create<vector::GatherOp>(
- loc, subTy, op.getBase(), op.getIndices(), indexSubVec, maskSubVec,
- passThruSubVec);
- result =
- rewriter.create<vector::InsertOp>(loc, subGather, result, thisIdx);
- }
-
- rewriter.replaceOp(op, result);
- return success();
- }
-};
-
-/// Turns 1-d `vector.gather` into a scalarized sequence of `vector.loads` or
-/// `tensor.extract`s. To avoid out-of-bounds memory accesses, these
-/// loads/extracts are made conditional using `scf.if` ops.
-struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> {
- using OpRewritePattern::OpRewritePattern;
-
- LogicalResult matchAndRewrite(vector::GatherOp op,
- PatternRewriter &rewriter) const override {
- VectorType resultTy = op.getType();
- if (resultTy.getRank() != 1)
- return rewriter.notifyMatchFailure(op, "unsupported rank");
-
- Location loc = op.getLoc();
- Type elemTy = resultTy.getElementType();
- // Vector type with a single element. Used to generate `vector.loads`.
- VectorType elemVecTy = VectorType::get({1}, elemTy);
-
- Value condMask = op.getMask();
- Value base = op.getBase();
- Value indexVec = rewriter.createOrFold<arith::IndexCastOp>(
- loc, op.getIndexVectorType().clone(rewriter.getIndexType()),
- op.getIndexVec());
- auto baseOffsets = llvm::to_vector(op.getIndices());
- Value lastBaseOffset = baseOffsets.back();
-
- Value result = op.getPassThru();
-
- // Emit a conditional access for each vector element.
- for (int64_t i = 0, e = resultTy.getNumElements(); i < e; ++i) {
- int64_t thisIdx[1] = {i};
- Value condition =
- rewriter.create<vector::ExtractOp>(loc, condMask, thisIdx);
- Value index = rewriter.create<vector::ExtractOp>(loc, indexVec, thisIdx);
- baseOffsets.back() =
- rewriter.createOrFold<arith::AddIOp>(loc, lastBaseOffset, index);
-
- auto loadBuilder = [&](OpBuilder &b, Location loc) {
- Value extracted;
- if (isa<MemRefType>(base.getType())) {
- // `vector.load` does not support scalar result; emit a vector load
- // and extract the single result instead.
- Value load =
- b.create<vector::LoadOp>(loc, elemVecTy, base, baseOffsets);
- int64_t zeroIdx[1] = {0};
- extracted = b.create<vector::ExtractOp>(loc, load, zeroIdx);
- } else {
- extracted = b.create<tensor::ExtractOp>(loc, base, baseOffsets);
- }
-
- Value newResult =
- b.create<vector::InsertOp>(loc, extracted, result, thisIdx);
- b.create<scf::YieldOp>(loc, newResult);
- };
- auto passThruBuilder = [result](OpBuilder &b, Location loc) {
- b.create<scf::YieldOp>(loc, result);
- };
-
- result =
- rewriter
- .create<scf::IfOp>(loc, condition, /*thenBuilder=*/loadBuilder,
- /*elseBuilder=*/passThruBuilder)
- .getResult(0);
- }
-
- rewriter.replaceOp(op, result);
- return success();
- }
-};
-
} // namespace
void mlir::vector::populateVectorMaskMaterializationPatterns(
@@ -3307,33 +1149,6 @@ void mlir::vector::populateBubbleVectorBitCastOpPatterns(
benefit);
}
-void mlir::vector::populateVectorBroadcastLoweringPatterns(
- RewritePatternSet &patterns, PatternBenefit benefit) {
- patterns.add<BroadcastOpLowering>(patterns.getContext(), benefit);
-}
-
-void mlir::vector::populateVectorMaskOpLoweringPatterns(
- RewritePatternSet &patterns, PatternBenefit benefit) {
- patterns.add<CreateMaskOpLowering, ConstantMaskOpLowering>(
- patterns.getContext(), benefit);
-}
-
-void mlir::vector::populateVectorShapeCastLoweringPatterns(
- RewritePatternSet &patterns, PatternBenefit benefit) {
- patterns.add<ShapeCastOp2DDownCastRewritePattern,
- ShapeCastOp2DUpCastRewritePattern, ShapeCastOpRewritePattern>(
- patterns.getContext(), benefit);
-}
-
-void mlir::vector::populateVectorContractLoweringPatterns(
- RewritePatternSet &patterns, VectorTransformsOptions options,
- PatternBenefit benefit) {
- patterns.add<OuterProductOpLowering>(patterns.getContext(), benefit);
- patterns.add<ContractionOpLowering, ContractionOpToMatmulOpLowering,
- ContractionOpToOuterProductOpLowering>(
- options, patterns.getContext(), benefit);
-}
-
void mlir::vector::populateVectorContractCanonicalizeMatmulToMMT(
RewritePatternSet &patterns,
std::function<LogicalResult(vector::ContractionOp)> constraint,
@@ -3342,13 +1157,6 @@ void mlir::vector::populateVectorContractCanonicalizeMatmulToMMT(
std::move(constraint));
}
-void mlir::vector::populateVectorTransposeLoweringPatterns(
- RewritePatternSet &patterns, VectorTransformsOptions options,
- PatternBenefit benefit) {
- patterns.add<TransposeOpLowering, TransposeOp2DToShuffleLowering>(
- options, patterns.getContext(), benefit);
-}
-
void mlir::vector::populateVectorReductionToContractPatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
patterns.add<MultiReduceToContract, CombineContractBroadcast,
@@ -3363,28 +1171,6 @@ void mlir::vector::
patterns.add<DropInnerMostUnitDims>(patterns.getContext(), benefit);
}
-void mlir::vector::populateVectorTransferLoweringPatterns(
- RewritePatternSet &patterns, std::optional<unsigned> maxTransferRank,
- PatternBenefit benefit) {
- patterns.add<TransferReadToVectorLoadLowering,
- TransferWriteToVectorStoreLowering>(patterns.getContext(),
- maxTransferRank, benefit);
- patterns
- .add<VectorLoadToMemrefLoadLowering, VectorStoreToMemrefStoreLowering>(
- patterns.getContext(), benefit);
-}
-
-void mlir::vector::populateVectorScanLoweringPatterns(
- RewritePatternSet &patterns, PatternBenefit benefit) {
- patterns.add<ScanToArithOps>(patterns.getContext(), benefit);
-}
-
-void mlir::vector::populateVectorGatherLoweringPatterns(
- RewritePatternSet &patterns, PatternBenefit benefit) {
- patterns.add<FlattenGather, Gather1DToConditionalLoads>(patterns.getContext(),
- benefit);
-}
-
//===----------------------------------------------------------------------===//
// TableGen'd enum attribute definitions
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index f79ca2259fa38..7a4f9cf5e5101 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -22,6 +22,7 @@
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
#include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
@@ -148,8 +149,9 @@ struct TestVectorContractionLowering
if (lowerToOuterProduct) {
VectorContractLowering lowering = VectorContractLowering::OuterProduct;
VectorTransformsOptions options{lowering};
- patterns.add<ContractionOpToOuterProductOpLowering>(options,
- &getContext());
+ populateVectorContractLoweringPatterns(
+ patterns, options, /*benefit=*/1,
+ /*disableOuterProductlowering=*/true);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
return;
}
@@ -469,7 +471,7 @@ struct TestVectorTransferFullPartialSplitPatterns
options.setVectorTransferSplit(VectorTransferSplit::LinalgCopy);
else
options.setVectorTransferSplit(VectorTransferSplit::VectorTransfer);
- patterns.add<VectorTransferFullPartialRewriter>(ctx, options);
+ populateVectorTransferFullPartialPatterns(patterns, options);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
};
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 8538c3db59dcd..f565030d63d9f 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -8539,6 +8539,7 @@ cc_library(
":TransformDialect",
":TransformDialectUtils",
":TransformUtils",
+ ":VectorTransforms",
"//llvm:Support",
],
)
More information about the Mlir-commits
mailing list