[Mlir-commits] [mlir] 8e4c806 - [mlir][Linalg] NFC - Add additional control to lower vector.shape_cast ops
Nicolas Vasilache
llvmlistbot at llvm.org
Wed Oct 27 01:14:20 PDT 2021
Author: Nicolas Vasilache
Date: 2021-10-27T08:12:57Z
New Revision: 8e4c806ed5a481e4d2163c8330f3c3c024d61a36
URL: https://github.com/llvm/llvm-project/commit/8e4c806ed5a481e4d2163c8330f3c3c024d61a36
DIFF: https://github.com/llvm/llvm-project/commit/8e4c806ed5a481e4d2163c8330f3c3c024d61a36.diff
LOG: [mlir][Linalg] NFC - Add additional control to lower vector.shape_cast ops
This also moves some code to a new patterns file.
Differential Revision: https://reviews.llvm.org/D112575
Added:
mlir/lib/Dialect/Vector/VectorTransferPermutationMapRewritePatterns.cpp
Modified:
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/include/mlir/Dialect/Vector/VectorOps.h
mlir/include/mlir/Dialect/Vector/VectorRewritePatterns.h
mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp
mlir/lib/Dialect/Vector/CMakeLists.txt
mlir/lib/Dialect/Vector/VectorTransforms.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index cfa38d71c2ba3..27688b5492530 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -877,49 +877,63 @@ struct LinalgEnablingOptions {
/// Vector lowering options control how ops are lowered down to 1-D and scf.for
/// form.
struct LinalgVectorLoweringOptions {
- /// Maximal transfer rank under which we do not lower further.
- int64_t maxTransferRank = 1;
- LinalgVectorLoweringOptions &setMaxTransferRank(int64_t val) {
- maxTransferRank = val;
- return *this;
- }
- /// Vector lowering operations may result in surprising behavior when
- /// composing multiple codegen strategies and must be enabled explicitly.
- bool transferLowering = true;
- LinalgVectorLoweringOptions &enableTransferLowering(bool val = true) {
- transferLowering = val;
- return *this;
- }
- /// Enable lowering of vector.transpose.
- bool transposeLowering = false;
- LinalgVectorLoweringOptions &enableVectorTransposeLowering(bool val = true) {
- transposeLowering = val;
+ /// Enable lowering of vector.contract.
+ /// In a progressive lowering of vectors, this would be the 1st step.
+ bool contractionLowering = false;
+ LinalgVectorLoweringOptions &enableContractionLowering(bool val = true) {
+ contractionLowering = val;
return *this;
}
/// Enable lowering of vector.multi_reduce.
+ /// In a progressive lowering of vectors, this would be the 2nd step.
bool multiReductionLowering = false;
LinalgVectorLoweringOptions &enableMultiReductionLowering(bool val = true) {
multiReductionLowering = val;
return *this;
}
- /// Enable lowering of vector.contract.
- bool contractionLowering = false;
- LinalgVectorLoweringOptions &enableContractionLowering(bool val = true) {
- contractionLowering = val;
- return *this;
- }
/// Trigger full / partial vector.transfer splits.
+ /// In a progressive lowering of vectors, this would be the 3rd step.
bool transferPartialRewrite = false;
LinalgVectorLoweringOptions &enableTransferPartialRewrite(bool val = true) {
transferPartialRewrite = val;
return *this;
}
/// Enable lowering of vector.transfer to scf.
+ /// In a progressive lowering of vectors, this would be the 4th step.
bool transferToSCFConversion = false;
LinalgVectorLoweringOptions &enableTransferToSCFConversion(bool val = true) {
transferToSCFConversion = val;
return *this;
}
+ /// Maximal transfer rank under which we do not lower further.
+ int64_t maxTransferRank = 1;
+ LinalgVectorLoweringOptions &setMaxTransferRank(int64_t val) {
+ maxTransferRank = val;
+ return *this;
+ }
+ /// Vector lowering operations may result in surprising behavior when
+ /// composing multiple codegen strategies and must be enabled explicitly.
+ /// In a progressive lowering of vectors, this would be the 5th step.
+ bool transferLowering = true;
+ LinalgVectorLoweringOptions &enableTransferLowering(bool val = true) {
+ transferLowering = val;
+ return *this;
+ }
+ /// Enable lowering of vector.shape_cast to insert/extract.
+ /// In a progressive lowering of vectors, this would be the 6th step.
+ bool shapeCastLowering = true;
+ LinalgVectorLoweringOptions &enableShapeCastLowering(bool val = true) {
+ shapeCastLowering = val;
+ return *this;
+ }
+ /// Enable lowering of vector.transpose.
+ /// In a progressive lowering of vectors, this would be the 7th step.
+ bool transposeLowering = false;
+ LinalgVectorLoweringOptions &enableVectorTransposeLowering(bool val = true) {
+ transposeLowering = val;
+ return *this;
+ }
+
/// Configure the post staged-patterns late vector.transfer to scf
/// conversion.
VectorTransferToSCFOptions vectorTransferToSCFOptions;
diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.h b/mlir/include/mlir/Dialect/Vector/VectorOps.h
index c6f4ba4bc0e59..a29683431344f 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.h
@@ -81,12 +81,6 @@ void populateVectorTransferLoweringPatterns(
RewritePatternSet &patterns,
llvm::Optional<unsigned> maxTransferRank = llvm::None);
-/// Collect a set of transfer read/write lowering patterns that simplify the
-/// permutation map (e.g., converting it to a minor identity map) by inserting
-/// broadcasts and transposes.
-void populateVectorTransferPermutationMapLoweringPatterns(
- RewritePatternSet &patterns);
-
/// These patterns materialize masks for various vector ops such as transfers.
void populateVectorMaskMaterializationPatterns(RewritePatternSet &patterns,
bool enableIndexOptimizations);
diff --git a/mlir/include/mlir/Dialect/Vector/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/VectorRewritePatterns.h
index 47375c56673f8..587f334bc0473 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/VectorRewritePatterns.h
@@ -169,6 +169,64 @@ void populateVectorContractLoweringPatterns(
/// transpose/broadcast ops into the contract.
void populateVectorReductionToContractPatterns(RewritePatternSet &patterns);
+//===----------------------------------------------------------------------===//
+// Vector.transfer patterns.
+//===----------------------------------------------------------------------===//
+/// Collect a set of transfer read/write lowering patterns that simplify the
+/// permutation map (e.g., converting it to a minor identity map) by inserting
+/// broadcasts and transposes. More specifically:
+///
+/// [TransferReadPermutationLowering]
+/// Lower transfer_read op with permutation into a transfer_read with a
+/// permutation map composed of leading zeros followed by a minor identity +
+/// vector.transpose op.
+/// Ex:
+/// vector.transfer_read ...
+/// permutation_map: (d0, d1, d2) -> (0, d1)
+/// into:
+/// %v = vector.transfer_read ...
+/// permutation_map: (d0, d1, d2) -> (d1, 0)
+/// vector.transpose %v, [1, 0]
+///
+/// vector.transfer_read ...
+/// permutation_map: (d0, d1, d2, d3) -> (0, 0, 0, d1, d3)
+/// into:
+/// %v = vector.transfer_read ...
+/// permutation_map: (d0, d1, d2, d3) -> (0, 0, d1, 0, d3)
+/// vector.transpose %v, [0, 1, 3, 2, 4]
+/// Note that an alternative is to transform it to linalg.transpose +
+/// vector.transfer_read to do the transpose in memory instead.
+///
+/// [TransferWritePermutationLowering]
+/// Lower transfer_write op with permutation into a transfer_write with a
+/// minor identity permutation map. (transfer_write ops cannot have broadcasts.)
+/// Ex:
+/// vector.transfer_write %v ...
+/// permutation_map: (d0, d1, d2) -> (d2, d0, d1)
+/// into:
+/// %tmp = vector.transpose %v, [2, 0, 1]
+/// vector.transfer_write %tmp ...
+/// permutation_map: (d0, d1, d2) -> (d0, d1, d2)
+///
+/// vector.transfer_write %v ...
+/// permutation_map: (d0, d1, d2, d3) -> (d3, d2)
+/// into:
+/// %tmp = vector.transpose %v, [1, 0]
+/// %v = vector.transfer_write %tmp ...
+/// permutation_map: (d0, d1, d2, d3) -> (d2, d3)
+///
+/// [TransferOpReduceRank]
+/// Lower transfer_read op with broadcast in the leading dimensions into
+/// transfer_read of lower rank + vector.broadcast.
+/// Ex: vector.transfer_read ...
+/// permutation_map: (d0, d1, d2, d3) -> (0, d1, 0, d3)
+/// into:
+/// %v = vector.transfer_read ...
+/// permutation_map: (d0, d1, d2, d3) -> (d1, 0, d3)
+/// vector.broadcast %v
+void populateVectorTransferPermutationMapLoweringPatterns(
+ RewritePatternSet &patterns);
+
/// Collect a set of patterns to reduce the rank of the operands of vector
/// transfer ops to operate on the largest contigious vector.
/// These patterns are useful when lowering to dialects with 1d vector type
diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index 1aacd4919c8a3..2fd4959884981 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -20,8 +20,7 @@
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/SCF.h"
-#include "mlir/Dialect/Vector/VectorOps.h"
-#include "mlir/Dialect/Vector/VectorUtils.h"
+#include "mlir/Dialect/Vector/VectorTransforms.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/Pass/Pass.h"
diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp
index 97831866ffd08..4462bbe2147a8 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp
@@ -264,34 +264,44 @@ struct LinalgStrategyLowerVectorsPass
MLIRContext *context = funcOp.getContext();
RewritePatternSet patterns(context);
vector::populateVectorToVectorCanonicalizationPatterns(patterns);
- if (options.transferLowering) {
- vector::populateVectorTransferLoweringPatterns(patterns,
- options.maxTransferRank);
- }
- if (options.transposeLowering) {
- vector::populateVectorTransposeLoweringPatterns(
- patterns, options.vectorTransformOptions);
- }
- if (options.multiReductionLowering) {
- vector::populateVectorMultiReductionLoweringPatterns(
- patterns,
- options.vectorTransformOptions.vectorMultiReductionLowering);
- }
+ // In a progressive lowering of vectors, this would be the 1st step.
if (options.contractionLowering) {
patterns.add<ContractionOpToOuterProductOpLowering,
ContractionOpToMatmulOpLowering, ContractionOpLowering>(
options.vectorTransformOptions, context);
vector::populateVectorTransferPermutationMapLoweringPatterns(patterns);
}
+ // In a progressive lowering of vectors, this would be the 2nd step.
+ if (options.multiReductionLowering) {
+ vector::populateVectorMultiReductionLoweringPatterns(
+ patterns,
+ options.vectorTransformOptions.vectorMultiReductionLowering);
+ }
+ // In a progressive lowering of vectors, this would be the 3rd step.
if (options.transferPartialRewrite) {
patterns.add<vector::VectorTransferFullPartialRewriter>(
context, options.vectorTransformOptions);
}
+ // In a progressive lowering of vectors, this would be the 4th step.
+ if (options.transferLowering) {
+ vector::populateVectorTransferLoweringPatterns(patterns,
+ options.maxTransferRank);
+ }
+ // In a progressive lowering of vectors, this would be the 5th step.
if (options.transferToSCFConversion) {
- populateVectorToSCFConversionPatterns(patterns,
- options.vectorTransferToSCFOptions);
+ populateVectorToSCFConversionPatterns(
+ patterns, options.vectorTransferToSCFOptions.setTargetRank(
+ options.maxTransferRank));
+ }
+ // In a progressive lowering of vectors, this would be the 6th step.
+ if (options.shapeCastLowering) {
+ vector::populateVectorShapeCastLoweringPatterns(patterns);
+ }
+ // In a progressive lowering of vectors, this would be the 7th step.
+ if (options.transposeLowering) {
+ vector::populateVectorTransposeLoweringPatterns(
+ patterns, options.vectorTransformOptions);
}
- vector::populateVectorShapeCastLoweringPatterns(patterns);
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
}
diff --git a/mlir/lib/Dialect/Vector/CMakeLists.txt b/mlir/lib/Dialect/Vector/CMakeLists.txt
index f620a370c8359..abd961626e599 100644
--- a/mlir/lib/Dialect/Vector/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/CMakeLists.txt
@@ -3,6 +3,7 @@ add_mlir_dialect_library(MLIRVector
VectorMultiDimReductionTransforms.cpp
VectorOps.cpp
VectorTransferOpTransforms.cpp
+ VectorTransferPermutationMapRewritePatterns.cpp
VectorTransforms.cpp
VectorUtils.cpp
diff --git a/mlir/lib/Dialect/Vector/VectorTransferPermutationMapRewritePatterns.cpp b/mlir/lib/Dialect/Vector/VectorTransferPermutationMapRewritePatterns.cpp
new file mode 100644
index 0000000000000..3f5c3127a286c
--- /dev/null
+++ b/mlir/lib/Dialect/Vector/VectorTransferPermutationMapRewritePatterns.cpp
@@ -0,0 +1,260 @@
+//===- VectorTransferPermutationMapRewritePatterns.cpp - Xfer map rewrite -===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements rewrite patterns for the permutation_map attribute of
+// vector.transfer operations.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Vector/VectorTransforms.h"
+#include "mlir/Interfaces/VectorInterfaces.h"
+
+using namespace mlir;
+using namespace mlir::vector;
+
+/// Transpose a vector transfer op's `in_bounds` attribute according to given
+/// indices.
+static ArrayAttr
+transposeInBoundsAttr(OpBuilder &builder, ArrayAttr attr,
+ const SmallVector<unsigned> &permutation) {
+ SmallVector<bool> newInBoundsValues;
+ for (unsigned pos : permutation)
+ newInBoundsValues.push_back(
+ attr.getValue()[pos].cast<BoolAttr>().getValue());
+ return builder.getBoolArrayAttr(newInBoundsValues);
+}
+/// Lower transfer_read op with permutation into a transfer_read with a
+/// permutation map composed of leading zeros followed by a minor identiy +
+/// vector.transpose op.
+/// Ex:
+/// vector.transfer_read ...
+/// permutation_map: (d0, d1, d2) -> (0, d1)
+/// into:
+/// %v = vector.transfer_read ...
+/// permutation_map: (d0, d1, d2) -> (d1, 0)
+/// vector.transpose %v, [1, 0]
+///
+/// vector.transfer_read ...
+/// permutation_map: (d0, d1, d2, d3) -> (0, 0, 0, d1, d3)
+/// into:
+/// %v = vector.transfer_read ...
+/// permutation_map: (d0, d1, d2, d3) -> (0, 0, d1, 0, d3)
+/// vector.transpose %v, [0, 1, 3, 2, 4]
+/// Note that an alternative is to transform it to linalg.transpose +
+/// vector.transfer_read to do the transpose in memory instead.
+struct TransferReadPermutationLowering
+ : public OpRewritePattern<vector::TransferReadOp> {
+ using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::TransferReadOp op,
+ PatternRewriter &rewriter) const override {
+ SmallVector<unsigned> permutation;
+ AffineMap map = op.permutation_map();
+ if (map.getNumResults() == 0)
+ return failure();
+ if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation))
+ return failure();
+ AffineMap permutationMap =
+ map.getPermutationMap(permutation, op.getContext());
+ if (permutationMap.isIdentity())
+ return failure();
+
+ permutationMap = map.getPermutationMap(permutation, op.getContext());
+ // Caluclate the map of the new read by applying the inverse permutation.
+ permutationMap = inversePermutation(permutationMap);
+ AffineMap newMap = permutationMap.compose(map);
+ // Apply the reverse transpose to deduce the type of the transfer_read.
+ ArrayRef<int64_t> originalShape = op.getVectorType().getShape();
+ SmallVector<int64_t> newVectorShape(originalShape.size());
+ for (auto pos : llvm::enumerate(permutation)) {
+ newVectorShape[pos.value()] = originalShape[pos.index()];
+ }
+
+ // Transpose mask operand.
+ Value newMask;
+ if (op.mask()) {
+ // Remove unused dims from the permutation map. E.g.:
+ // E.g.: (d0, d1, d2, d3, d4, d5) -> (d5, 0, d3, 0, d2)
+ // comp = (d0, d1, d2) -> (d2, 0, d1, 0 d0)
+ auto comp = compressUnusedDims(map);
+ // Get positions of remaining result dims.
+ // E.g.: (d0, d1, d2) -> (d2, 0, d1, 0 d0)
+ // maskTransposeIndices = [ 2, 1, 0]
+ SmallVector<int64_t> maskTransposeIndices;
+ for (unsigned i = 0; i < comp.getNumResults(); ++i) {
+ if (auto expr = comp.getResult(i).dyn_cast<AffineDimExpr>())
+ maskTransposeIndices.push_back(expr.getPosition());
+ }
+
+ newMask = rewriter.create<vector::TransposeOp>(op.getLoc(), op.mask(),
+ maskTransposeIndices);
+ }
+
+ // Transpose in_bounds attribute.
+ ArrayAttr newInBounds =
+ op.in_bounds() ? transposeInBoundsAttr(
+ rewriter, op.in_bounds().getValue(), permutation)
+ : ArrayAttr();
+
+ // Generate new transfer_read operation.
+ VectorType newReadType =
+ VectorType::get(newVectorShape, op.getVectorType().getElementType());
+ Value newRead = rewriter.create<vector::TransferReadOp>(
+ op.getLoc(), newReadType, op.source(), op.indices(), newMap,
+ op.padding(), newMask, newInBounds);
+
+ // Transpose result of transfer_read.
+ SmallVector<int64_t> transposePerm(permutation.begin(), permutation.end());
+ rewriter.replaceOpWithNewOp<vector::TransposeOp>(op, newRead,
+ transposePerm);
+ return success();
+ }
+};
+
+/// Lower transfer_write op with permutation into a transfer_write with a
+/// minor identity permutation map. (transfer_write ops cannot have broadcasts.)
+/// Ex:
+/// vector.transfer_write %v ...
+/// permutation_map: (d0, d1, d2) -> (d2, d0, d1)
+/// into:
+/// %tmp = vector.transpose %v, [2, 0, 1]
+/// vector.transfer_write %tmp ...
+/// permutation_map: (d0, d1, d2) -> (d0, d1, d2)
+///
+/// vector.transfer_write %v ...
+/// permutation_map: (d0, d1, d2, d3) -> (d3, d2)
+/// into:
+/// %tmp = vector.transpose %v, [1, 0]
+/// %v = vector.transfer_write %tmp ...
+/// permutation_map: (d0, d1, d2, d3) -> (d2, d3)
+struct TransferWritePermutationLowering
+ : public OpRewritePattern<vector::TransferWriteOp> {
+ using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::TransferWriteOp op,
+ PatternRewriter &rewriter) const override {
+ if (op.isZeroD())
+ return failure();
+
+ SmallVector<unsigned> permutation;
+ AffineMap map = op.permutation_map();
+ if (map.isMinorIdentity())
+ return failure();
+ if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation))
+ return failure();
+
+ // Remove unused dims from the permutation map. E.g.:
+ // E.g.: (d0, d1, d2, d3, d4, d5) -> (d5, d3, d4)
+ // comp = (d0, d1, d2) -> (d2, d0, d1)
+ auto comp = compressUnusedDims(map);
+ // Get positions of remaining result dims.
+ SmallVector<int64_t> indices;
+ llvm::transform(comp.getResults(), std::back_inserter(indices),
+ [](AffineExpr expr) {
+ return expr.dyn_cast<AffineDimExpr>().getPosition();
+ });
+
+ // Transpose mask operand.
+ Value newMask = op.mask() ? rewriter.create<vector::TransposeOp>(
+ op.getLoc(), op.mask(), indices)
+ : Value();
+
+ // Transpose in_bounds attribute.
+ ArrayAttr newInBounds =
+ op.in_bounds() ? transposeInBoundsAttr(
+ rewriter, op.in_bounds().getValue(), permutation)
+ : ArrayAttr();
+
+ // Generate new transfer_write operation.
+ Value newVec =
+ rewriter.create<vector::TransposeOp>(op.getLoc(), op.vector(), indices);
+ auto newMap = AffineMap::getMinorIdentityMap(
+ map.getNumDims(), map.getNumResults(), rewriter.getContext());
+ rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
+ op, Type(), newVec, op.source(), op.indices(), newMap, newMask,
+ newInBounds);
+
+ return success();
+ }
+};
+
+/// Lower transfer_read op with broadcast in the leading dimensions into
+/// transfer_read of lower rank + vector.broadcast.
+/// Ex: vector.transfer_read ...
+/// permutation_map: (d0, d1, d2, d3) -> (0, d1, 0, d3)
+/// into:
+/// %v = vector.transfer_read ...
+/// permutation_map: (d0, d1, d2, d3) -> (d1, 0, d3)
+/// vector.broadcast %v
+struct TransferOpReduceRank : public OpRewritePattern<vector::TransferReadOp> {
+ using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::TransferReadOp op,
+ PatternRewriter &rewriter) const override {
+ AffineMap map = op.permutation_map();
+ unsigned numLeadingBroadcast = 0;
+ for (auto expr : map.getResults()) {
+ auto dimExpr = expr.dyn_cast<AffineConstantExpr>();
+ if (!dimExpr || dimExpr.getValue() != 0)
+ break;
+ numLeadingBroadcast++;
+ }
+ // If there are no leading zeros in the map there is nothing to do.
+ if (numLeadingBroadcast == 0)
+ return failure();
+ VectorType originalVecType = op.getVectorType();
+ unsigned reducedShapeRank = originalVecType.getRank() - numLeadingBroadcast;
+ // Calculate new map, vector type and masks without the leading zeros.
+ AffineMap newMap = AffineMap::get(
+ map.getNumDims(), 0, map.getResults().take_back(reducedShapeRank),
+ op.getContext());
+ // Only remove the leading zeros if the rest of the map is a minor identity
+ // with broadasting. Otherwise we first want to permute the map.
+ if (!newMap.isMinorIdentityWithBroadcasting())
+ return failure();
+
+ // TODO: support zero-dimension vectors natively. See:
+ // https://llvm.discourse.group/t/should-we-have-0-d-vectors/3097.
+ // In the meantime, lower these to a scalar load when they pop up.
+ if (reducedShapeRank == 0) {
+ Value newRead = rewriter.create<memref::LoadOp>(
+ op.getLoc(), originalVecType.getElementType(), op.source(),
+ op.indices());
+ rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, originalVecType,
+ newRead);
+ return success();
+ }
+ SmallVector<int64_t> newShape = llvm::to_vector<4>(
+ originalVecType.getShape().take_back(reducedShapeRank));
+ // Vector rank cannot be zero. Handled by TransferReadToVectorLoadLowering.
+ if (newShape.empty())
+ return failure();
+ VectorType newReadType =
+ VectorType::get(newShape, originalVecType.getElementType());
+ ArrayAttr newInBounds =
+ op.in_bounds()
+ ? rewriter.getArrayAttr(
+ op.in_boundsAttr().getValue().take_back(reducedShapeRank))
+ : ArrayAttr();
+ Value newRead = rewriter.create<vector::TransferReadOp>(
+ op.getLoc(), newReadType, op.source(), op.indices(), newMap,
+ op.padding(), op.mask(), newInBounds);
+ rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, originalVecType,
+ newRead);
+ return success();
+ }
+};
+
+void mlir::vector::populateVectorTransferPermutationMapLoweringPatterns(
+ RewritePatternSet &patterns) {
+ patterns.add<TransferReadPermutationLowering,
+ TransferWritePermutationLowering, TransferOpReduceRank>(
+ patterns.getContext());
+}
diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index 4c7ef516fd927..efb22b9ab6be6 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -2882,240 +2882,6 @@ struct TransferWriteToVectorStoreLowering
llvm::Optional<unsigned> maxTransferRank;
};
-/// Transpose a vector transfer op's `in_bounds` attribute according to given
-/// indices.
-static ArrayAttr
-transposeInBoundsAttr(OpBuilder &builder, ArrayAttr attr,
- const SmallVector<unsigned> &permutation) {
- SmallVector<bool> newInBoundsValues;
- for (unsigned pos : permutation)
- newInBoundsValues.push_back(
- attr.getValue()[pos].cast<BoolAttr>().getValue());
- return builder.getBoolArrayAttr(newInBoundsValues);
-}
-
-/// Lower transfer_read op with permutation into a transfer_read with a
-/// permutation map composed of leading zeros followed by a minor identiy +
-/// vector.transpose op.
-/// Ex:
-/// vector.transfer_read ...
-/// permutation_map: (d0, d1, d2) -> (0, d1)
-/// into:
-/// %v = vector.transfer_read ...
-/// permutation_map: (d0, d1, d2) -> (d1, 0)
-/// vector.transpose %v, [1, 0]
-///
-/// vector.transfer_read ...
-/// permutation_map: (d0, d1, d2, d3) -> (0, 0, 0, d1, d3)
-/// into:
-/// %v = vector.transfer_read ...
-/// permutation_map: (d0, d1, d2, d3) -> (0, 0, d1, 0, d3)
-/// vector.transpose %v, [0, 1, 3, 2, 4]
-/// Note that an alternative is to transform it to linalg.transpose +
-/// vector.transfer_read to do the transpose in memory instead.
-struct TransferReadPermutationLowering
- : public OpRewritePattern<vector::TransferReadOp> {
- using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(vector::TransferReadOp op,
- PatternRewriter &rewriter) const override {
- SmallVector<unsigned> permutation;
- AffineMap map = op.permutation_map();
- if (map.getNumResults() == 0)
- return failure();
- if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation))
- return failure();
- AffineMap permutationMap =
- map.getPermutationMap(permutation, op.getContext());
- if (permutationMap.isIdentity())
- return failure();
-
- permutationMap = map.getPermutationMap(permutation, op.getContext());
- // Caluclate the map of the new read by applying the inverse permutation.
- permutationMap = inversePermutation(permutationMap);
- AffineMap newMap = permutationMap.compose(map);
- // Apply the reverse transpose to deduce the type of the transfer_read.
- ArrayRef<int64_t> originalShape = op.getVectorType().getShape();
- SmallVector<int64_t> newVectorShape(originalShape.size());
- for (auto pos : llvm::enumerate(permutation)) {
- newVectorShape[pos.value()] = originalShape[pos.index()];
- }
-
- // Transpose mask operand.
- Value newMask;
- if (op.mask()) {
- // Remove unused dims from the permutation map. E.g.:
- // E.g.: (d0, d1, d2, d3, d4, d5) -> (d5, 0, d3, 0, d2)
- // comp = (d0, d1, d2) -> (d2, 0, d1, 0 d0)
- auto comp = compressUnusedDims(map);
- // Get positions of remaining result dims.
- // E.g.: (d0, d1, d2) -> (d2, 0, d1, 0 d0)
- // maskTransposeIndices = [ 2, 1, 0]
- SmallVector<int64_t> maskTransposeIndices;
- for (unsigned i = 0; i < comp.getNumResults(); ++i) {
- if (auto expr = comp.getResult(i).dyn_cast<AffineDimExpr>())
- maskTransposeIndices.push_back(expr.getPosition());
- }
-
- newMask = rewriter.create<vector::TransposeOp>(op.getLoc(), op.mask(),
- maskTransposeIndices);
- }
-
- // Transpose in_bounds attribute.
- ArrayAttr newInBounds =
- op.in_bounds() ? transposeInBoundsAttr(
- rewriter, op.in_bounds().getValue(), permutation)
- : ArrayAttr();
-
- // Generate new transfer_read operation.
- VectorType newReadType =
- VectorType::get(newVectorShape, op.getVectorType().getElementType());
- Value newRead = rewriter.create<vector::TransferReadOp>(
- op.getLoc(), newReadType, op.source(), op.indices(), newMap,
- op.padding(), newMask, newInBounds);
-
- // Transpose result of transfer_read.
- SmallVector<int64_t> transposePerm(permutation.begin(), permutation.end());
- rewriter.replaceOpWithNewOp<vector::TransposeOp>(op, newRead,
- transposePerm);
- return success();
- }
-};
-
-/// Lower transfer_write op with permutation into a transfer_write with a
-/// minor identity permutation map. (transfer_write ops cannot have broadcasts.)
-/// Ex:
-/// vector.transfer_write %v ...
-/// permutation_map: (d0, d1, d2) -> (d2, d0, d1)
-/// into:
-/// %tmp = vector.transpose %v, [2, 0, 1]
-/// vector.transfer_write %tmp ...
-/// permutation_map: (d0, d1, d2) -> (d0, d1, d2)
-///
-/// vector.transfer_write %v ...
-/// permutation_map: (d0, d1, d2, d3) -> (d3, d2)
-/// into:
-/// %tmp = vector.transpose %v, [1, 0]
-/// %v = vector.transfer_write %tmp ...
-/// permutation_map: (d0, d1, d2, d3) -> (d2, d3)
-struct TransferWritePermutationLowering
- : public OpRewritePattern<vector::TransferWriteOp> {
- using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(vector::TransferWriteOp op,
- PatternRewriter &rewriter) const override {
- if (op.isZeroD())
- return failure();
-
- SmallVector<unsigned> permutation;
- AffineMap map = op.permutation_map();
- if (map.isMinorIdentity())
- return failure();
- if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation))
- return failure();
-
- // Remove unused dims from the permutation map. E.g.:
- // E.g.: (d0, d1, d2, d3, d4, d5) -> (d5, d3, d4)
- // comp = (d0, d1, d2) -> (d2, d0, d1)
- auto comp = compressUnusedDims(map);
- // Get positions of remaining result dims.
- SmallVector<int64_t> indices;
- llvm::transform(comp.getResults(), std::back_inserter(indices),
- [](AffineExpr expr) {
- return expr.dyn_cast<AffineDimExpr>().getPosition();
- });
-
- // Transpose mask operand.
- Value newMask = op.mask() ? rewriter.create<vector::TransposeOp>(
- op.getLoc(), op.mask(), indices)
- : Value();
-
- // Transpose in_bounds attribute.
- ArrayAttr newInBounds =
- op.in_bounds() ? transposeInBoundsAttr(
- rewriter, op.in_bounds().getValue(), permutation)
- : ArrayAttr();
-
- // Generate new transfer_write operation.
- Value newVec =
- rewriter.create<vector::TransposeOp>(op.getLoc(), op.vector(), indices);
- auto newMap = AffineMap::getMinorIdentityMap(
- map.getNumDims(), map.getNumResults(), rewriter.getContext());
- rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
- op, Type(), newVec, op.source(), op.indices(), newMap, newMask,
- newInBounds);
-
- return success();
- }
-};
-
-/// Lower transfer_read op with broadcast in the leading dimensions into
-/// transfer_read of lower rank + vector.broadcast.
-/// Ex: vector.transfer_read ...
-/// permutation_map: (d0, d1, d2, d3) -> (0, d1, 0, d3)
-/// into:
-/// %v = vector.transfer_read ...
-/// permutation_map: (d0, d1, d2, d3) -> (d1, 0, d3)
-/// vector.broadcast %v
-struct TransferOpReduceRank : public OpRewritePattern<vector::TransferReadOp> {
- using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(vector::TransferReadOp op,
- PatternRewriter &rewriter) const override {
- AffineMap map = op.permutation_map();
- unsigned numLeadingBroadcast = 0;
- for (auto expr : map.getResults()) {
- auto dimExpr = expr.dyn_cast<AffineConstantExpr>();
- if (!dimExpr || dimExpr.getValue() != 0)
- break;
- numLeadingBroadcast++;
- }
- // If there are no leading zeros in the map there is nothing to do.
- if (numLeadingBroadcast == 0)
- return failure();
- VectorType originalVecType = op.getVectorType();
- unsigned reducedShapeRank = originalVecType.getRank() - numLeadingBroadcast;
- // Calculate new map, vector type and masks without the leading zeros.
- AffineMap newMap = AffineMap::get(
- map.getNumDims(), 0, map.getResults().take_back(reducedShapeRank),
- op.getContext());
- // Only remove the leading zeros if the rest of the map is a minor identity
- // with broadasting. Otherwise we first want to permute the map.
- if (!newMap.isMinorIdentityWithBroadcasting())
- return failure();
-
- // TODO: support zero-dimension vectors natively. See:
- // https://llvm.discourse.group/t/should-we-have-0-d-vectors/3097.
- // In the meantime, lower these to a scalar load when they pop up.
- if (reducedShapeRank == 0) {
- Value newRead = rewriter.create<memref::LoadOp>(
- op.getLoc(), originalVecType.getElementType(), op.source(),
- op.indices());
- rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, originalVecType,
- newRead);
- return success();
- }
- SmallVector<int64_t> newShape = llvm::to_vector<4>(
- originalVecType.getShape().take_back(reducedShapeRank));
- // Vector rank cannot be zero. Handled by TransferReadToVectorLoadLowering.
- if (newShape.empty())
- return failure();
- VectorType newReadType =
- VectorType::get(newShape, originalVecType.getElementType());
- ArrayAttr newInBounds =
- op.in_bounds()
- ? rewriter.getArrayAttr(
- op.in_boundsAttr().getValue().take_back(reducedShapeRank))
- : ArrayAttr();
- Value newRead = rewriter.create<vector::TransferReadOp>(
- op.getLoc(), newReadType, op.source(), op.indices(), newMap,
- op.padding(), op.mask(), newInBounds);
- rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, originalVecType,
- newRead);
- return success();
- }
-};
-
// Trims leading one dimensions from `oldType` and returns the result type.
// Returns `vector<1xT>` if `oldType` only has one element.
static VectorType trimLeadingOneDims(VectorType oldType) {
@@ -3891,23 +3657,6 @@ void mlir::vector::populateVectorReductionToContractPatterns(
CombineContractTranspose>(patterns.getContext());
}
-void mlir::vector::populateVectorTransferPermutationMapLoweringPatterns(
- RewritePatternSet &patterns) {
- patterns.add<TransferReadPermutationLowering,
- TransferWritePermutationLowering, TransferOpReduceRank>(
- patterns.getContext());
-}
-
-void mlir::vector::populateVectorTransferLoweringPatterns(
- RewritePatternSet &patterns, llvm::Optional<unsigned> maxTransferRank) {
- patterns.add<TransferReadToVectorLoadLowering,
- TransferWriteToVectorStoreLowering>(patterns.getContext(),
- maxTransferRank);
- patterns
- .add<VectorLoadToMemrefLoadLowering, VectorStoreToMemrefStoreLowering>(
- patterns.getContext());
-}
-
void mlir::vector::populateVectorUnrollPatterns(
RewritePatternSet &patterns, const UnrollVectorOptions &options) {
patterns.add<UnrollTransferReadPattern, UnrollTransferWritePattern,
@@ -3920,3 +3669,13 @@ void mlir::vector::
RewritePatternSet &patterns) {
patterns.add<DropInnerMostUnitDims>(patterns.getContext());
}
+
+void mlir::vector::populateVectorTransferLoweringPatterns(
+ RewritePatternSet &patterns, llvm::Optional<unsigned> maxTransferRank) {
+ patterns.add<TransferReadToVectorLoadLowering,
+ TransferWriteToVectorStoreLowering>(patterns.getContext(),
+ maxTransferRank);
+ patterns
+ .add<VectorLoadToMemrefLoadLowering, VectorStoreToMemrefStoreLowering>(
+ patterns.getContext());
+}
More information about the Mlir-commits
mailing list