[Mlir-commits] [mlir] [mlir][vector] Better transfer_read(transfer_write) canonicalization (PR #72617)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Nov 16 23:38:29 PST 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-gpu
Author: Matthias Springer (matthias-springer)
<details>
<summary>Changes</summary>
This change improves the canonicalization of `transfer_read(transfer_write)` IR patterns where the two transfer ops access the same chunk of the shaped value (store-load forwarding). The existing rewrite pattern did not support cases where the two transfer ops operate on vectors of different rank (i.e., different rank-reduced/extended unit dims).
The previous pattern generated a combination of `vector.transpose` and `vector.broadcast`. The new pattern generates a combination of `vector.transpose`, `vector.broadcast` and `vector.extract`. In cases where no `vector.extract` is needed, other canonicalization patterns/foldings simplify the IR such the same IR as with the previous pattern is produced.
Depends on #<!-- -->72594 and #<!-- -->72616. Review only the top commit.
---
Patch is 37.96 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/72617.diff
12 Files Affected:
- (modified) mlir/include/mlir/Dialect/Vector/IR/VectorOps.td (+17-9)
- (modified) mlir/include/mlir/IR/AffineMap.h (+2)
- (modified) mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp (+2-5)
- (modified) mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp (+1-6)
- (modified) mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp (+1-1)
- (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+255-96)
- (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp (+1-3)
- (modified) mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp (+9-11)
- (modified) mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp (+1-2)
- (modified) mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp (+2-5)
- (modified) mlir/lib/IR/AffineMap.cpp (+6)
- (modified) mlir/test/Dialect/Vector/canonicalize.mlir (+19-5)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 02e80a6446dfb24..49860cadcd12c26 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -386,10 +386,15 @@ def Vector_BroadcastOp :
return ::llvm::cast<VectorType>(getVector().getType());
}
- /// Return the dimensions of the result vector that were formerly ones in the
- /// source tensor and thus correspond to "dim-1" broadcasting.
+ /// Return the dimensions of the result vector that were formerly ones in
+ /// the source vector and thus correspond to "dim-1" broadcasting.
llvm::SetVector<int64_t> computeBroadcastedUnitDims();
+ /// Return the dimensions of the result vector that were newly added to the
+ /// source vector via rank extension. These are all the dimensions that were
+ /// not "dim-1" broadcasted.
+ llvm::SetVector<int64_t> computeRankExtendedDims();
+
/// Broadcast `value` to a vector of `dstShape`, knowing that exactly the
/// `broadcastedDims` dimensions in the dstShape are broadcasted.
/// This requires (and asserts) that the broadcast is free of dim-1
@@ -2436,14 +2441,13 @@ def Vector_TransposeOp :
Vector_Op<"transpose", [Pure,
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
PredOpTrait<"operand and result have same element type",
- TCresVTEtIsSameAsOpBase<0, 0>>]>,
- Arguments<(ins AnyVectorOfAnyRank:$vector, I64ArrayAttr:$transp)>,
- Results<(outs AnyVectorOfAnyRank:$result)> {
+ TCresVTEtIsSameAsOpBase<0, 0>>]> {
let summary = "vector transpose operation";
let description = [{
Takes a n-D vector and returns the transposed n-D vector defined by
the permutation of ranks in the n-sized integer array attribute (in case
of 0-D vectors the array attribute must be empty).
+
In the operation
```mlir
@@ -2452,7 +2456,7 @@ def Vector_TransposeOp :
to vector<d_trans[0] x .. x d_trans[n-1] x f32>
```
- the transp array [i_1, .., i_n] must be a permutation of [0, .., n-1].
+ the `permutation` array [i_1, .., i_n] must be a permutation of [0, .., n-1].
Example:
@@ -2464,8 +2468,13 @@ def Vector_TransposeOp :
[c, f] ]
```
}];
+
+ let arguments = (ins AnyVectorOfAnyRank:$vector,
+ DenseI64ArrayAttr:$permutation);
+ let results = (outs AnyVectorOfAnyRank:$result);
+
let builders = [
- OpBuilder<(ins "Value":$vector, "ArrayRef<int64_t>":$transp)>
+ OpBuilder<(ins "Value":$vector, "ArrayRef<int64_t>":$permutation)>
];
let extraClassDeclaration = [{
VectorType getSourceVectorType() {
@@ -2474,10 +2483,9 @@ def Vector_TransposeOp :
VectorType getResultVectorType() {
return ::llvm::cast<VectorType>(getResult().getType());
}
- void getTransp(SmallVectorImpl<int64_t> &results);
}];
let assemblyFormat = [{
- $vector `,` $transp attr-dict `:` type($vector) `to` type($result)
+ $vector `,` $permutation attr-dict `:` type($vector) `to` type($result)
}];
let hasCanonicalizer = 1;
let hasFolder = 1;
diff --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h
index 713aef767edf669..981f3d392cbc98c 100644
--- a/mlir/include/mlir/IR/AffineMap.h
+++ b/mlir/include/mlir/IR/AffineMap.h
@@ -103,6 +103,8 @@ class AffineMap {
/// (i.e. `[1,1,2]` is an invalid permutation).
static AffineMap getPermutationMap(ArrayRef<unsigned> permutation,
MLIRContext *context);
+ static AffineMap getPermutationMap(ArrayRef<int64_t> permutation,
+ MLIRContext *context);
/// Returns an affine map with `numDims` input dimensions and results
/// specified by `targets`.
diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index 953a465c18de69f..01c782676068d9a 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -436,12 +436,9 @@ struct TransposeOpToArmSMELowering
if (!tileType || !arm_sme::isValidSMETileVectorType(tileType))
return failure();
- SmallVector<int64_t> transp;
- for (auto attr : transposeOp.getTransp())
- transp.push_back(cast<IntegerAttr>(attr).getInt());
-
// Bail unless this is a true 2-D matrix transpose.
- if (transp[0] != 1 || transp[1] != 0)
+ ArrayRef<int64_t> permutation = transposeOp.getPermutation();
+ if (permutation[0] != 1 || permutation[1] != 0)
return failure();
OpBuilder::InsertionGuard g(rewriter);
diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
index 1126c2c20758c7a..429d1137b6f3781 100644
--- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
+++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
@@ -473,13 +473,8 @@ struct CombineTransferReadOpTranspose final
if (transferReadOp.getMask() || transferReadOp.hasOutOfBoundsDim())
return rewriter.notifyMatchFailure(op, "not inbounds transfer read");
- SmallVector<int64_t, 2> perm;
- op.getTransp(perm);
- SmallVector<unsigned, 2> permU;
- for (int64_t o : perm)
- permU.push_back(unsigned(o));
AffineMap permutationMap =
- AffineMap::getPermutationMap(permU, op.getContext());
+ AffineMap::getPermutationMap(op.getPermutation(), op.getContext());
AffineMap newMap =
permutationMap.compose(transferReadOp.getPermutationMap());
diff --git a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
index 1084fbc890053b9..79fabd6ed2e99a2 100644
--- a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
@@ -712,7 +712,7 @@ struct ExtensionOverTranspose final : NarrowingPattern<vector::TransposeOp> {
VectorType newTy =
origTy.cloneWith(origTy.getShape(), ext->getInElementType());
Value newTranspose = rewriter.create<vector::TransposeOp>(
- op.getLoc(), newTy, ext->getIn(), op.getTransp());
+ op.getLoc(), newTy, ext->getIn(), op.getPermutation());
ext->recreateAndReplace(rewriter, op, newTranspose);
return success();
}
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 95f49fa32bc0ae2..cf7c3c6c1a395ac 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1456,9 +1456,8 @@ LogicalResult ExtractFromInsertTransposeChainState::handleTransposeOp() {
if (!nextTransposeOp)
return failure();
- auto permutation = extractVector<unsigned>(nextTransposeOp.getTransp());
- AffineMap m = inversePermutation(
- AffineMap::getPermutationMap(permutation, extractOp.getContext()));
+ AffineMap m = inversePermutation(AffineMap::getPermutationMap(
+ nextTransposeOp.getPermutation(), extractOp.getContext()));
extractPosition = applyPermutationMap(m, ArrayRef(extractPosition));
return success();
}
@@ -1898,6 +1897,62 @@ class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {
}
};
+/// Canonicalize extract(transpose(broadcast))) constructs, where the broadcast
+/// adds a new dimension and the extraction removes it again.
+class ExtractOpTransposedBroadcastDim final
+ : public OpRewritePattern<ExtractOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ExtractOp extractOp,
+ PatternRewriter &rewriter) const override {
+ // Skip vector.extract ops that do not remove any dimensions.
+ if (extractOp.getNumIndices() == 0)
+ return failure();
+ // Look for extract(transpose(broadcast(x))) pattern.
+ auto transposeOp =
+ extractOp.getVector().getDefiningOp<vector::TransposeOp>();
+ if (!transposeOp || transposeOp.getPermutation().empty())
+ return failure();
+ auto broadcastOp =
+ transposeOp.getVector().getDefiningOp<vector::BroadcastOp>();
+ if (!broadcastOp)
+ return failure();
+ // Check if the first dimension that is being removed by the vector.extract
+ // was added by the vector.broadcast.
+ int64_t removedDim = transposeOp.getPermutation()[0];
+ llvm::SetVector<int64_t> rankExtendedDims =
+ broadcastOp.computeRankExtendedDims();
+ if (!rankExtendedDims.contains(removedDim))
+ return failure();
+
+ // 1. Create new vector.broadcast without the removed dimension.
+ SmallVector<int64_t> newBroadcastShape(
+ broadcastOp.getResultVectorType().getShape());
+ newBroadcastShape.erase(newBroadcastShape.begin() + removedDim);
+ auto newBroadcast = rewriter.create<vector::BroadcastOp>(
+ broadcastOp.getLoc(),
+ VectorType::get(newBroadcastShape,
+ broadcastOp.getResultVectorType().getElementType()),
+ broadcastOp.getSource());
+
+ // 2. Create new vector.transpose.
+ SmallVector<int64_t> newPermutation;
+ for (int64_t dim : transposeOp.getPermutation().drop_front())
+ newPermutation.push_back(dim < transposeOp.getPermutation()[0] ? dim
+ : dim - 1);
+ auto newTranspose = rewriter.create<vector::TransposeOp>(
+ transposeOp.getLoc(), newBroadcast, newPermutation);
+
+ // 3. Create new vector.extract without the outermost dimension.
+ SmallVector<OpFoldResult> mixedPositions = extractOp.getMixedPosition();
+ rewriter.replaceOpWithNewOp<vector::ExtractOp>(
+ extractOp, newTranspose, ArrayRef(mixedPositions).drop_front());
+
+ return success();
+ }
+};
+
// Pattern to rewrite a ExtractOp(splat ConstantOp) -> ConstantOp.
class ExtractOpSplatConstantFolder final : public OpRewritePattern<ExtractOp> {
public:
@@ -2063,7 +2118,8 @@ LogicalResult foldExtractFromShapeCastToShapeCast(ExtractOp extractOp,
void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<ExtractOpSplatConstantFolder, ExtractOpNonSplatConstantFolder,
- ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
+ ExtractOpFromBroadcast, ExtractOpTransposedBroadcastDim,
+ ExtractOpFromCreateMask>(context);
results.add(foldExtractFromShapeCastToShapeCast);
}
@@ -2113,6 +2169,20 @@ llvm::SetVector<int64_t> BroadcastOp::computeBroadcastedUnitDims() {
getResultVectorType().getShape());
}
+llvm::SetVector<int64_t> BroadcastOp::computeRankExtendedDims() {
+ llvm::SetVector<int64_t> broadcastedUnitDims = computeBroadcastedUnitDims();
+ llvm::SetVector<int64_t> result;
+ auto vecSrcType = dyn_cast<VectorType>(getSourceType());
+ int64_t rankDiff =
+ vecSrcType ? getResultVectorType().getRank() - vecSrcType.getRank()
+ : getResultVectorType().getRank();
+ for (int64_t i = 0; i < rankDiff; ++i) {
+ if (!broadcastedUnitDims.contains(i))
+ result.insert(i);
+ }
+ return result;
+}
+
/// Broadcast `value` to a vector of `dstShape`, knowing that exactly the
/// `broadcastedDims` dimensions in the dstShape are broadcasted.
/// This requires (and asserts) that the broadcast is free of dim-1
@@ -4007,37 +4077,43 @@ void TransferReadOp::getEffects(
}
namespace {
-/// Store to load forwarding for transfer operations with permuation maps.
-/// Even if the permutation maps are different we can still propagate the store
-/// into the load if the size of the dimensions read and written match. Then we
-/// can replace the transfer_read + transfer_write by vector.broadcast and
-/// vector.transpose.
-/// Example:
-/// ```
-/// %w0 = vector.transfer_write %v0, %arg0[%c0, %c0, %c0]
-/// {in_bounds = [true, true],
-/// permutation_map = affine_map<(d0, d1, d2) -> (d2, d1)>} :
-/// vector<4x1xf32>, tensor<4x4x4xf32>
-/// %r = vector.transfer_read %w0[%c0, %c0, %c0], %cf0
-/// {in_bounds = [true, true, true, true],
-/// permutation_map = affine_map<(d0, d1, d2) -> (d1, 0, d2, 0)>} :
-/// tensor<4x4x4xf32>, vector<1x100x4x5xf32>
-/// ```
-/// To:
+/// Store to load forwarding for transfer operations with permutation maps.
+/// Even if the permutation maps and/or the rank of the read/written vectors are
+/// different, we can still propagate the store into the load if the accessed
+/// chunk of the shaped value matches.
+///
+/// The vector.transfer_read op is replaced by 3 ops:
+/// 1. A broadcast of the written vector with all broadcast dims of the reading
+/// op and unit dims for all shaped value dimensions that are not transfer
+/// dimensions of the writing op.
+/// 2. A transposition of the broadcasted value to account for differences
+/// in the permutation maps of the reading/writing op.
+/// 3. An extraction that removes shaped value dimensions that are not transfer
+/// dimensions of the reading op.
+///
+/// Running example:
/// ```
-/// %0 = vector.broadcast %arg1 : vector<4x1xf32> to vector<100x5x4x1xf32>
-/// %r = vector.transpose %0, [3, 0, 2, 1] :
-/// vector<100x5x4x1xf32> to vector<1x100x4x5xf32>
+/// %0 = vector.transfer_write %vec to %s[%a, %b, %c, %d, %e, %f]
+/// {permutation_map = affine_map<(d0, d1, d2, d3, d4, d5)
+/// -> (d2, d1, d4, d5)>}
+/// : vector<5x6x7x8xf32>, tensor<?x?x?x?x?x?xf32>
+/// %1 = vector.transfer_read %0[%a, %b, %c, %d, %e, %f]
+/// {permutation_map = affine_map<(d0, d1, d2, d3, d4, d5)
+/// -> (d1, d2, 0, d4, 0, d5, d0)>}
+/// : tensor<?x?x?x?x?x?xf32>, vector<6x5x100x7x200x8x1xf32>
/// ```
-struct TransferReadAfterWriteToBroadcast
- : public OpRewritePattern<TransferReadOp> {
+struct TransferReadAfterWrite : public OpRewritePattern<TransferReadOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(TransferReadOp readOp,
PatternRewriter &rewriter) const override {
+ Location loc = readOp.getLoc();
if (readOp.hasOutOfBoundsDim() ||
!llvm::isa<RankedTensorType>(readOp.getShapedType()))
return failure();
+ if (readOp.getShapedType().getElementType() !=
+ readOp.getVectorType().getElementType())
+ return failure();
auto defWrite = readOp.getSource().getDefiningOp<vector::TransferWriteOp>();
if (!defWrite)
return failure();
@@ -4046,42 +4122,140 @@ struct TransferReadAfterWriteToBroadcast
if (readOp.getTransferChunkAccessed() !=
defWrite.getTransferChunkAccessed())
return failure();
- // TODO: Support cases where a dim is explicitly written but implicitly
- // read (i.e., a unit dim that is rank reduced).
- if (getUnusedDimsBitVector({readOp.getPermutationMap()}) !=
- getUnusedDimsBitVector({defWrite.getPermutationMap()}))
- return failure();
- if (readOp.getIndices() != defWrite.getIndices() ||
- readOp.getMask() != defWrite.getMask())
+ if (readOp.getIndices() != defWrite.getIndices())
return failure();
- Value vec = defWrite.getVector();
- // TODO: loop through the chain of transfer_write if we can prove that they
- // don't overlap with the transfer_read. This requires improving
- // `isDisjointTransferIndices` helper.
- AffineMap readMap = compressUnusedDims(readOp.getPermutationMap());
- AffineMap writeMap = compressUnusedDims(defWrite.getPermutationMap());
- AffineMap map = readMap.compose(writeMap);
- if (map.getNumResults() == 0)
+ Type elementType = readOp.getVectorType().getElementType();
+ if (elementType != defWrite.getVectorType().getElementType())
return failure();
- // Calculate the permutation to apply to go from the vector stored to the
- // vector read.
- SmallVector<unsigned> permutation;
- if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation))
+ if (defWrite.getShapedType().getElementType() !=
+ defWrite.getVectorType().getElementType())
return failure();
- Location loc = readOp.getLoc();
- // Calculate the broadcast shape by applying the reverse permutation to the
- // final shape we want.
- ArrayRef<int64_t> destShape = readOp.getVectorType().getShape();
- SmallVector<int64_t> broadcastShape(destShape.size());
- for (const auto &pos : llvm::enumerate(permutation))
- broadcastShape[pos.value()] = destShape[pos.index()];
- VectorType broadcastedType = VectorType::get(
- broadcastShape, defWrite.getVectorType().getElementType());
- vec = rewriter.create<vector::BroadcastOp>(loc, broadcastedType, vec);
- SmallVector<int64_t> transposePerm(permutation.begin(), permutation.end());
- rewriter.replaceOpWithNewOp<vector::TransposeOp>(readOp, vec,
- transposePerm);
+ // 1. Add rank-reduced unit dimensions and broadcast dimension to input
+ // vector %vec. Broadcast dimensions are added at the beginning, followed
+ // by rank-reduced unit dims, followed by the dimensions of %vec.
+ //
+ // %bc = vector.broadcast %vec
+ // : vector<5x6x7x8xf32> to vector<100x200x1x1x5x6x7x8xf32>
+ // | | \|
+ // broadcast dims |
+ // |
+ // rank-reduced dims (corresponding to %a and %d)
+
+ // Gather broadcast dimensions of the transfer_read.
+ SmallVector<int64_t> broadcastedShape;
+ int64_t numBroadcastDims = 0;
+ for (int64_t i = 0, e = readOp.getTransferRank(); i < e; ++i) {
+ if (readOp.isBroadcastDim(i)) {
+ broadcastedShape.push_back(readOp.getVectorType().getDimSize(i));
+ ++numBroadcastDims;
+ }
+ }
+ // Append unit dims for rank-reduced (unused) dimensions in the
+ // transfer_write.
+ // Note: `getLeadingShapedRank` is a misnomer: the dimensions that do not
+ // participate in the transfer are not necessarily leading dimensions.
+ broadcastedShape.append(defWrite.getLeadingShapedRank(), 1);
+ // Append input vector (%vec) shape.
+ llvm::append_range(broadcastedShape, defWrite.getVectorType().getShape());
+ // Emit vector.broadcast op.
+ Value broadcasted = rewriter.create<vector::BroadcastOp>(
+ loc, VectorType::get(broadcastedShape, elementType),
+ defWrite.getVector());
+
+ // 2. Transpose the broadcasted vector. Dimensions that are not needed must
+ // be placed at the beginning (because vector.extract can remove only
+ // leading dimensions).
+
+ // Build a mapping (`shapedDimToVecDim`) from shaped value dims to dims of
+ // the broadcasted vector. This is essentially an inverted version of the
+ // transfer_write permutation map that takes into account the newly added
+ // unit dims.
+ // %b %f
+ // \ |
+ // Example: broadcasted vector type: vector<100x200x1x1x5x6x7x8xf32>
+ // / | \ \
+ // / %d | %e
+ // %a %c
+ // mapping = [2, 5, 4, 3, 6, 7]
+
+ // Initialize the mapping with -1.
+ SmallVector<int64_t> shapedDimToVecDim(defWrite.getShapedType().getRank(),
+ -1);
+ // Fill in the dimensions from the inverted transfer_write permutation map.
+ int64_t numUnitDims = defWrite.getLeadingShapedRank();
+ for (const auto &it :
+ llvm::enumerate(defWrite.getPermutationMap().getResults())) {
+ shapedDimToVecDim[cast<AffineDimExpr>(it.value()).getPosition()] =
+ it.index() + numUnitDims + numBroadcastDims;
+ }
+ // Fill in missing unused dims (of the transfer_write) with the broadcasted
+ // unit dims (which are placed right after the broadcast dims).
+ int64_t nextUnitDim = numBroadcastDims;
+ for (int64_t i = 0, e = shapedDimToVecDim.size(); i < e; ++i) {
+ if (shapedDimToVecDim[i] == -1)
+ shapedDimTo...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/72617
More information about the Mlir-commits
mailing list