[Mlir-commits] [mlir] [mlir][vector] Better transfer_read(transfer_write) canonicalization (PR #72617)
Matthias Springer
llvmlistbot at llvm.org
Thu Nov 16 23:38:01 PST 2023
https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/72617
This change improves the canonicalization of `transfer_read(transfer_write)` IR patterns where the two transfer ops access the same chunk of the shaped value (store-load forwarding). The existing rewrite pattern did not support cases where the two transfer ops operate on vectors of different rank (i.e., different rank-reduced/extended unit dims).
The previous pattern generated a combination of `vector.transpose` and `vector.broadcast`. The new pattern generates a combination of `vector.transpose`, `vector.broadcast` and `vector.extract`. In cases where no `vector.extract` is needed, other canonicalization patterns/foldings simplify the IR such the same IR as with the previous pattern is produced.
Depends on #72594 and #72616. Review only the top commit.
>From a4b57fdb77f637f82bb556d74f146138bbf1fab8 Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Fri, 17 Nov 2023 11:10:50 +0900
Subject: [PATCH 1/3] [mlir][vector] Modernize `vector.transpose` op
* Declare arguments/results with `let` statements.
* Rename `transp` to `permutation`.
* Change type of `transp` from `I64ArrayAttr` to `DenseI64ArrayAttr`.
---
.../mlir/Dialect/Vector/IR/VectorOps.td | 17 +++--
mlir/include/mlir/IR/AffineMap.h | 2 +
.../VectorToArmSME/VectorToArmSME.cpp | 7 +-
.../Conversion/VectorToGPU/VectorToGPU.cpp | 7 +-
.../Dialect/Arith/Transforms/IntNarrowing.cpp | 2 +-
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 64 +++++++------------
.../Transforms/LowerVectorTranspose.cpp | 4 +-
.../Vector/Transforms/VectorTransforms.cpp | 20 +++---
.../Vector/Transforms/VectorUnroll.cpp | 3 +-
mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp | 7 +-
mlir/lib/IR/AffineMap.cpp | 6 ++
11 files changed, 59 insertions(+), 80 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 02e80a6446dfb24..1397d4caf1d9d61 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2436,14 +2436,13 @@ def Vector_TransposeOp :
Vector_Op<"transpose", [Pure,
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
PredOpTrait<"operand and result have same element type",
- TCresVTEtIsSameAsOpBase<0, 0>>]>,
- Arguments<(ins AnyVectorOfAnyRank:$vector, I64ArrayAttr:$transp)>,
- Results<(outs AnyVectorOfAnyRank:$result)> {
+ TCresVTEtIsSameAsOpBase<0, 0>>]> {
let summary = "vector transpose operation";
let description = [{
Takes a n-D vector and returns the transposed n-D vector defined by
the permutation of ranks in the n-sized integer array attribute (in case
of 0-D vectors the array attribute must be empty).
+
In the operation
```mlir
@@ -2452,7 +2451,7 @@ def Vector_TransposeOp :
to vector<d_trans[0] x .. x d_trans[n-1] x f32>
```
- the transp array [i_1, .., i_n] must be a permutation of [0, .., n-1].
+ the `permutation` array [i_1, .., i_n] must be a permutation of [0, .., n-1].
Example:
@@ -2464,8 +2463,13 @@ def Vector_TransposeOp :
[c, f] ]
```
}];
+
+ let arguments = (ins AnyVectorOfAnyRank:$vector,
+ DenseI64ArrayAttr:$permutation);
+ let results = (outs AnyVectorOfAnyRank:$result);
+
let builders = [
- OpBuilder<(ins "Value":$vector, "ArrayRef<int64_t>":$transp)>
+ OpBuilder<(ins "Value":$vector, "ArrayRef<int64_t>":$permutation)>
];
let extraClassDeclaration = [{
VectorType getSourceVectorType() {
@@ -2474,10 +2478,9 @@ def Vector_TransposeOp :
VectorType getResultVectorType() {
return ::llvm::cast<VectorType>(getResult().getType());
}
- void getTransp(SmallVectorImpl<int64_t> &results);
}];
let assemblyFormat = [{
- $vector `,` $transp attr-dict `:` type($vector) `to` type($result)
+ $vector `,` $permutation attr-dict `:` type($vector) `to` type($result)
}];
let hasCanonicalizer = 1;
let hasFolder = 1;
diff --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h
index 713aef767edf669..981f3d392cbc98c 100644
--- a/mlir/include/mlir/IR/AffineMap.h
+++ b/mlir/include/mlir/IR/AffineMap.h
@@ -103,6 +103,8 @@ class AffineMap {
/// (i.e. `[1,1,2]` is an invalid permutation).
static AffineMap getPermutationMap(ArrayRef<unsigned> permutation,
MLIRContext *context);
+ static AffineMap getPermutationMap(ArrayRef<int64_t> permutation,
+ MLIRContext *context);
/// Returns an affine map with `numDims` input dimensions and results
/// specified by `targets`.
diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index 953a465c18de69f..01c782676068d9a 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -436,12 +436,9 @@ struct TransposeOpToArmSMELowering
if (!tileType || !arm_sme::isValidSMETileVectorType(tileType))
return failure();
- SmallVector<int64_t> transp;
- for (auto attr : transposeOp.getTransp())
- transp.push_back(cast<IntegerAttr>(attr).getInt());
-
// Bail unless this is a true 2-D matrix transpose.
- if (transp[0] != 1 || transp[1] != 0)
+ ArrayRef<int64_t> permutation = transposeOp.getPermutation();
+ if (permutation[0] != 1 || permutation[1] != 0)
return failure();
OpBuilder::InsertionGuard g(rewriter);
diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
index 1126c2c20758c7a..429d1137b6f3781 100644
--- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
+++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
@@ -473,13 +473,8 @@ struct CombineTransferReadOpTranspose final
if (transferReadOp.getMask() || transferReadOp.hasOutOfBoundsDim())
return rewriter.notifyMatchFailure(op, "not inbounds transfer read");
- SmallVector<int64_t, 2> perm;
- op.getTransp(perm);
- SmallVector<unsigned, 2> permU;
- for (int64_t o : perm)
- permU.push_back(unsigned(o));
AffineMap permutationMap =
- AffineMap::getPermutationMap(permU, op.getContext());
+ AffineMap::getPermutationMap(op.getPermutation(), op.getContext());
AffineMap newMap =
permutationMap.compose(transferReadOp.getPermutationMap());
diff --git a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
index 1084fbc890053b9..79fabd6ed2e99a2 100644
--- a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
@@ -712,7 +712,7 @@ struct ExtensionOverTranspose final : NarrowingPattern<vector::TransposeOp> {
VectorType newTy =
origTy.cloneWith(origTy.getShape(), ext->getInElementType());
Value newTranspose = rewriter.create<vector::TransposeOp>(
- op.getLoc(), newTy, ext->getIn(), op.getTransp());
+ op.getLoc(), newTy, ext->getIn(), op.getPermutation());
ext->recreateAndReplace(rewriter, op, newTranspose);
return success();
}
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 95f49fa32bc0ae2..c7b74701fdbc8f2 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1456,9 +1456,8 @@ LogicalResult ExtractFromInsertTransposeChainState::handleTransposeOp() {
if (!nextTransposeOp)
return failure();
- auto permutation = extractVector<unsigned>(nextTransposeOp.getTransp());
- AffineMap m = inversePermutation(
- AffineMap::getPermutationMap(permutation, extractOp.getContext()));
+ AffineMap m = inversePermutation(AffineMap::getPermutationMap(
+ nextTransposeOp.getPermutation(), extractOp.getContext()));
extractPosition = applyPermutationMap(m, ArrayRef(extractPosition));
return success();
}
@@ -5376,20 +5375,20 @@ LogicalResult TypeCastOp::verify() {
//===----------------------------------------------------------------------===//
void vector::TransposeOp::build(OpBuilder &builder, OperationState &result,
- Value vector, ArrayRef<int64_t> transp) {
+ Value vector, ArrayRef<int64_t> permutation) {
VectorType vt = llvm::cast<VectorType>(vector.getType());
SmallVector<int64_t, 4> transposedShape(vt.getRank());
SmallVector<bool, 4> transposedScalableDims(vt.getRank());
- for (unsigned i = 0; i < transp.size(); ++i) {
- transposedShape[i] = vt.getShape()[transp[i]];
- transposedScalableDims[i] = vt.getScalableDims()[transp[i]];
+ for (unsigned i = 0; i < permutation.size(); ++i) {
+ transposedShape[i] = vt.getShape()[permutation[i]];
+ transposedScalableDims[i] = vt.getScalableDims()[permutation[i]];
}
result.addOperands(vector);
result.addTypes(VectorType::get(transposedShape, vt.getElementType(),
transposedScalableDims));
- result.addAttribute(TransposeOp::getTranspAttrName(result.name),
- builder.getI64ArrayAttr(transp));
+ result.addAttribute(TransposeOp::getPermutationAttrName(result.name),
+ builder.getDenseI64ArrayAttr(permutation));
}
OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
@@ -5401,13 +5400,12 @@ OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
// Eliminate identity transpose ops. This happens when the dimensions of the
// input vector remain in their original order after the transpose operation.
- SmallVector<int64_t, 4> transp;
- getTransp(transp);
+ ArrayRef<int64_t> perm = getPermutation();
// Check if the permutation of the dimensions contains sequential values:
// {0, 1, 2, ...}.
- for (int64_t i = 0, e = transp.size(); i < e; i++) {
- if (transp[i] != i)
+ for (int64_t i = 0, e = perm.size(); i < e; i++) {
+ if (perm[i] != i)
return {};
}
@@ -5421,20 +5419,19 @@ LogicalResult vector::TransposeOp::verify() {
if (vectorType.getRank() != rank)
return emitOpError("vector result rank mismatch: ") << rank;
// Verify transposition array.
- auto transpAttr = getTransp().getValue();
- int64_t size = transpAttr.size();
+ ArrayRef<int64_t> perm = getPermutation();
+ int64_t size = perm.size();
if (rank != size)
return emitOpError("transposition length mismatch: ") << size;
SmallVector<bool, 8> seen(rank, false);
- for (const auto &ta : llvm::enumerate(transpAttr)) {
- int64_t i = llvm::cast<IntegerAttr>(ta.value()).getInt();
- if (i < 0 || i >= rank)
- return emitOpError("transposition index out of range: ") << i;
- if (seen[i])
- return emitOpError("duplicate position index: ") << i;
- seen[i] = true;
- if (resultType.getDimSize(ta.index()) != vectorType.getDimSize(i))
- return emitOpError("dimension size mismatch at: ") << i;
+ for (const auto &ta : llvm::enumerate(perm)) {
+ if (ta.value() < 0 || ta.value() >= rank)
+ return emitOpError("transposition index out of range: ") << ta.value();
+ if (seen[ta.value()])
+ return emitOpError("duplicate position index: ") << ta.value();
+ seen[ta.value()] = true;
+ if (resultType.getDimSize(ta.index()) != vectorType.getDimSize(ta.value()))
+ return emitOpError("dimension size mismatch at: ") << ta.value();
}
return success();
}
@@ -5452,13 +5449,6 @@ class TransposeFolder final : public OpRewritePattern<vector::TransposeOp> {
LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
PatternRewriter &rewriter) const override {
- // Wrapper around vector::TransposeOp::getTransp() for cleaner code.
- auto getPermutation = [](vector::TransposeOp transpose) {
- SmallVector<int64_t, 4> permutation;
- transpose.getTransp(permutation);
- return permutation;
- };
-
// Composes two permutations: result[i] = permutation1[permutation2[i]].
auto composePermutations = [](ArrayRef<int64_t> permutation1,
ArrayRef<int64_t> permutation2) {
@@ -5475,12 +5465,11 @@ class TransposeFolder final : public OpRewritePattern<vector::TransposeOp> {
return failure();
SmallVector<int64_t, 4> permutation = composePermutations(
- getPermutation(parentTransposeOp), getPermutation(transposeOp));
+ parentTransposeOp.getPermutation(), transposeOp.getPermutation());
// Replace 'transposeOp' with a new transpose operation.
rewriter.replaceOpWithNewOp<vector::TransposeOp>(
transposeOp, transposeOp.getResult().getType(),
- parentTransposeOp.getVector(),
- vector::getVectorSubscriptAttr(rewriter, permutation));
+ parentTransposeOp.getVector(), permutation);
return success();
}
};
@@ -5539,8 +5528,7 @@ class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> {
// Get the transpose permutation and apply it to the vector.create_mask or
// vector.constant_mask operands.
- SmallVector<int64_t> permutation;
- transpOp.getTransp(permutation);
+ ArrayRef<int64_t> permutation = transpOp.getPermutation();
if (createMaskOp) {
auto maskOperands = createMaskOp.getOperands();
@@ -5572,10 +5560,6 @@ void vector::TransposeOp::getCanonicalizationPatterns(
TransposeFolder, FoldTransposeSplat>(context);
}
-void vector::TransposeOp::getTransp(SmallVectorImpl<int64_t> &results) {
- populateFromInt64AttrArray(getTransp(), results);
-}
-
//===----------------------------------------------------------------------===//
// ConstantMaskOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
index dee786007c80630..97f6caca1b25ccc 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
@@ -327,9 +327,7 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
VectorType resType = op.getResultVectorType();
// Set up convenience transposition table.
- SmallVector<int64_t> transp;
- for (auto attr : op.getTransp())
- transp.push_back(cast<IntegerAttr>(attr).getInt());
+ ArrayRef<int64_t> transp = op.getPermutation();
if (isShuffleLike(vectorTransformOptions.vectorTransposeLowering) &&
succeeded(isTranspose2DSlice(op)))
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 713f9cb72c82cec..a20c8aeeb6f7108 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -212,8 +212,7 @@ struct CombineContractABTranspose final
if (!transposeOp)
continue;
AffineMap permutationMap = AffineMap::getPermutationMap(
- extractVector<unsigned>(transposeOp.getTransp()),
- contractOp.getContext());
+ transposeOp.getPermutation(), contractOp.getContext());
map = inversePermutation(permutationMap).compose(map);
*operand = transposeOp.getVector();
changed = true;
@@ -279,13 +278,13 @@ struct CombineContractResultTranspose final
// Accumulator transpose performs f(A) -> B. Contract performs g(C) -> B.
// To index into A in contract, we need revert(f)(g(C)) -> A.
- auto accTMap = AffineMap::getPermutationMap(
- extractVector<unsigned>(accTOp.getTransp()), context);
+ auto accTMap =
+ AffineMap::getPermutationMap(accTOp.getPermutation(), context);
// Contract performs g(C) -> D. Result transpose performs h(D) -> E.
// To index into E in contract, we need h(g(C)) -> E.
- auto resTMap = AffineMap::getPermutationMap(
- extractVector<unsigned>(resTOp.getTransp()), context);
+ auto resTMap =
+ AffineMap::getPermutationMap(resTOp.getPermutation(), context);
auto combinedResMap = resTMap.compose(contractMap);
// The accumulator and result share the same indexing map. So they should be
@@ -490,7 +489,7 @@ struct ReorderElementwiseOpsOnTranspose final
// Make sure all operands are transpose/constant ops and collect their
// transposition maps.
- SmallVector<ArrayAttr> transposeMaps;
+ SmallVector<ArrayRef<int64_t>> transposeMaps;
transposeMaps.reserve(op->getNumOperands());
// Record the initial type before transposition. We'll use its shape later.
// Any type will do here as we will check all transpose maps are the same.
@@ -498,7 +497,7 @@ struct ReorderElementwiseOpsOnTranspose final
for (Value operand : op->getOperands()) {
auto transposeOp = operand.getDefiningOp<vector::TransposeOp>();
if (transposeOp) {
- transposeMaps.push_back(transposeOp.getTransp());
+ transposeMaps.push_back(transposeOp.getPermutation());
srcType = transposeOp.getSourceVectorType();
} else if (!matchPattern(operand, m_Constant())) {
return failure();
@@ -517,7 +516,7 @@ struct ReorderElementwiseOpsOnTranspose final
// If there are constant operands, we need to insert inverse transposes for
// them. Calculate the inverse order first.
- auto order = extractVector<unsigned>(transposeMaps.front());
+ auto order = transposeMaps.front();
SmallVector<int64_t> invOrder(order.size());
for (int i = 0, e = order.size(); i < e; ++i)
invOrder[order[i]] = i;
@@ -532,8 +531,7 @@ struct ReorderElementwiseOpsOnTranspose final
srcType.getShape(),
cast<VectorType>(operand.getType()).getElementType());
srcValues.push_back(rewriter.create<vector::TransposeOp>(
- operand.getLoc(), vectorType, operand,
- rewriter.getI64ArrayAttr(invOrder)));
+ operand.getLoc(), vectorType, operand, invOrder));
}
}
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index 4cfac7de29ee76f..78b041255443c30 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -537,8 +537,7 @@ struct UnrollTransposePattern : public OpRewritePattern<vector::TransposeOp> {
// Prepare the result vector;
Value result = rewriter.create<arith::ConstantOp>(
loc, originalVectorType, rewriter.getZeroAttr(originalVectorType));
- SmallVector<int64_t> permutation;
- transposeOp.getTransp(permutation);
+ ArrayRef<int64_t> permutation = transposeOp.getPermutation();
// Unroll the computation.
for (SmallVector<int64_t> elementOffsets :
diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
index 467a521f9eada96..48cd67ad86c63fb 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
@@ -87,14 +87,11 @@ mlir::vector::isTranspose2DSlice(vector::TransposeOp op) {
if (srcGtOneDims.size() != 2)
return failure();
- SmallVector<int64_t> transp;
- for (auto attr : op.getTransp())
- transp.push_back(cast<IntegerAttr>(attr).getInt());
-
// Check whether the two source vector dimensions that are greater than one
// must be transposed with each other so that we can apply one of the 2-D
// transpose pattens. Otherwise, these patterns are not applicable.
- if (!areDimsTransposedIn2DSlice(srcGtOneDims[0], srcGtOneDims[1], transp))
+ if (!areDimsTransposedIn2DSlice(srcGtOneDims[0], srcGtOneDims[1],
+ op.getPermutation()))
return failure();
return std::pair<int, int>(srcGtOneDims[0], srcGtOneDims[1]);
diff --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp
index 93a8d048e0a61d5..80a26a595edee0a 100644
--- a/mlir/lib/IR/AffineMap.cpp
+++ b/mlir/lib/IR/AffineMap.cpp
@@ -236,6 +236,12 @@ AffineMap AffineMap::getPermutationMap(ArrayRef<unsigned> permutation,
assert(permutationMap.isPermutation() && "Invalid permutation vector");
return permutationMap;
}
+AffineMap AffineMap::getPermutationMap(ArrayRef<int64_t> permutation,
+ MLIRContext *context) {
+ SmallVector<unsigned> perm = llvm::map_to_vector(
+ permutation, [](int64_t i) { return static_cast<unsigned>(i); });
+ return AffineMap::getPermutationMap(perm, context);
+}
AffineMap AffineMap::getMultiDimMapWithTargets(unsigned numDims,
ArrayRef<unsigned> targets,
>From 5b48e0b0f5eba3f2b0029617ad7313aa10a0c266 Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Fri, 17 Nov 2023 16:29:19 +0900
Subject: [PATCH 2/3] [mlir][vector] Add extract(transpose(broadcast(x)))
canonicalization
---
.../mlir/Dialect/Vector/IR/VectorOps.td | 9 ++-
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 73 ++++++++++++++++++-
mlir/test/Dialect/Vector/canonicalize.mlir | 14 ++++
3 files changed, 93 insertions(+), 3 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 1397d4caf1d9d61..49860cadcd12c26 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -386,10 +386,15 @@ def Vector_BroadcastOp :
return ::llvm::cast<VectorType>(getVector().getType());
}
- /// Return the dimensions of the result vector that were formerly ones in the
- /// source tensor and thus correspond to "dim-1" broadcasting.
+ /// Return the dimensions of the result vector that were formerly ones in
+ /// the source vector and thus correspond to "dim-1" broadcasting.
llvm::SetVector<int64_t> computeBroadcastedUnitDims();
+ /// Return the dimensions of the result vector that were newly added to the
+ /// source vector via rank extension. These are all the dimensions that were
+ /// not "dim-1" broadcasted.
+ llvm::SetVector<int64_t> computeRankExtendedDims();
+
/// Broadcast `value` to a vector of `dstShape`, knowing that exactly the
/// `broadcastedDims` dimensions in the dstShape are broadcasted.
/// This requires (and asserts) that the broadcast is free of dim-1
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index c7b74701fdbc8f2..957143d6c13e9e4 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1897,6 +1897,62 @@ class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {
}
};
+/// Canonicalize extract(transpose(broadcast))) constructs, where the broadcast
+/// adds a new dimension and the extraction removes it again.
+class ExtractOpTransposedBroadcastDim final
+ : public OpRewritePattern<ExtractOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ExtractOp extractOp,
+ PatternRewriter &rewriter) const override {
+ // Skip vector.extract ops that do not remove any dimensions.
+ if (extractOp.getNumIndices() == 0)
+ return failure();
+ // Look for extract(transpose(broadcast(x))) pattern.
+ auto transposeOp =
+ extractOp.getVector().getDefiningOp<vector::TransposeOp>();
+ if (!transposeOp || transposeOp.getPermutation().empty())
+ return failure();
+ auto broadcastOp =
+ transposeOp.getVector().getDefiningOp<vector::BroadcastOp>();
+ if (!broadcastOp)
+ return failure();
+ // Check if the first dimension that is being removed by the vector.extract
+ // was added by the vector.broadcast.
+ int64_t removedDim = transposeOp.getPermutation()[0];
+ llvm::SetVector<int64_t> rankExtendedDims =
+ broadcastOp.computeRankExtendedDims();
+ if (!rankExtendedDims.contains(removedDim))
+ return failure();
+
+ // 1. Create new vector.broadcast without the removed dimension.
+ SmallVector<int64_t> newBroadcastShape(
+ broadcastOp.getResultVectorType().getShape());
+ newBroadcastShape.erase(newBroadcastShape.begin() + removedDim);
+ auto newBroadcast = rewriter.create<vector::BroadcastOp>(
+ broadcastOp.getLoc(),
+ VectorType::get(newBroadcastShape,
+ broadcastOp.getResultVectorType().getElementType()),
+ broadcastOp.getSource());
+
+ // 2. Create new vector.transpose.
+ SmallVector<int64_t> newPermutation;
+ for (int64_t dim : transposeOp.getPermutation().drop_front())
+ newPermutation.push_back(dim < transposeOp.getPermutation()[0] ? dim
+ : dim - 1);
+ auto newTranspose = rewriter.create<vector::TransposeOp>(
+ transposeOp.getLoc(), newBroadcast, newPermutation);
+
+ // 3. Create new vector.extract without the outermost dimension.
+ SmallVector<OpFoldResult> mixedPositions = extractOp.getMixedPosition();
+ rewriter.replaceOpWithNewOp<vector::ExtractOp>(
+ extractOp, newTranspose, ArrayRef(mixedPositions).drop_front());
+
+ return success();
+ }
+};
+
// Pattern to rewrite a ExtractOp(splat ConstantOp) -> ConstantOp.
class ExtractOpSplatConstantFolder final : public OpRewritePattern<ExtractOp> {
public:
@@ -2062,7 +2118,8 @@ LogicalResult foldExtractFromShapeCastToShapeCast(ExtractOp extractOp,
void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<ExtractOpSplatConstantFolder, ExtractOpNonSplatConstantFolder,
- ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
+ ExtractOpFromBroadcast, ExtractOpTransposedBroadcastDim,
+ ExtractOpFromCreateMask>(context);
results.add(foldExtractFromShapeCastToShapeCast);
}
@@ -2112,6 +2169,20 @@ llvm::SetVector<int64_t> BroadcastOp::computeBroadcastedUnitDims() {
getResultVectorType().getShape());
}
+llvm::SetVector<int64_t> BroadcastOp::computeRankExtendedDims() {
+ llvm::SetVector<int64_t> broadcastedUnitDims = computeBroadcastedUnitDims();
+ llvm::SetVector<int64_t> result;
+ auto vecSrcType = dyn_cast<VectorType>(getSourceType());
+ int64_t rankDiff =
+ vecSrcType ? getResultVectorType().getRank() - vecSrcType.getRank()
+ : getResultVectorType().getRank();
+ for (int64_t i = 0; i < rankDiff; ++i) {
+ if (!broadcastedUnitDims.contains(i))
+ result.insert(i);
+ }
+ return result;
+}
+
/// Broadcast `value` to a vector of `dstShape`, knowing that exactly the
/// `broadcastedDims` dimensions in the dstShape are broadcasted.
/// This requires (and asserts) that the broadcast is free of dim-1
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 1021c73cc57d341..a6b4f7f2717da81 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -2524,3 +2524,17 @@ func.func @load_store_forwarding_rank_mismatch(%v0: vector<4x1x1xf32>, %arg0: te
tensor<4x4x4xf32>, vector<1x100x4x5xf32>
return %r : vector<1x100x4x5xf32>
}
+
+// -----
+
+// CHECK-LABEL: func @extract_of_transposed_broadcast_dim(
+// CHECK-SAME: %[[arg0:.*]]: vector<4x1xf32>
+// CHECK: %[[bc:.*]] = vector.broadcast %[[arg0]] : vector<4x1xf32> to vector<100x5x4x1xf32>
+// CHECK: %[[tp:.*]] = vector.transpose %[[bc]], [3, 0, 2, 1] : vector<100x5x4x1xf32> to vector<1x100x4x5xf32>
+// CHECK: return %[[tp]]
+func.func @extract_of_transposed_broadcast_dim(%arg0: vector<4x1xf32>) -> vector<1x100x4x5xf32> {
+ %0 = vector.broadcast %arg0 : vector<4x1xf32> to vector<100x5x1x4x1xf32>
+ %1 = vector.transpose %0, [2, 4, 0, 3, 1] : vector<100x5x1x4x1xf32> to vector<1x1x100x4x5xf32>
+ %2 = vector.extract %1[0] : vector<1x100x4x5xf32> from vector<1x1x100x4x5xf32>
+ return %2 : vector<1x100x4x5xf32>
+}
>From 1d58592e08d61d6875023088c2e4b006a34a895f Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Fri, 17 Nov 2023 16:31:58 +0900
Subject: [PATCH 3/3] [mlir][vector] Better `transfer_read(transfer_write)`
canonicalization
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 214 +++++++++++++++------
mlir/test/Dialect/Vector/canonicalize.mlir | 10 +-
2 files changed, 164 insertions(+), 60 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 957143d6c13e9e4..cf7c3c6c1a395ac 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -4077,37 +4077,43 @@ void TransferReadOp::getEffects(
}
namespace {
-/// Store to load forwarding for transfer operations with permuation maps.
-/// Even if the permutation maps are different we can still propagate the store
-/// into the load if the size of the dimensions read and written match. Then we
-/// can replace the transfer_read + transfer_write by vector.broadcast and
-/// vector.transpose.
-/// Example:
-/// ```
-/// %w0 = vector.transfer_write %v0, %arg0[%c0, %c0, %c0]
-/// {in_bounds = [true, true],
-/// permutation_map = affine_map<(d0, d1, d2) -> (d2, d1)>} :
-/// vector<4x1xf32>, tensor<4x4x4xf32>
-/// %r = vector.transfer_read %w0[%c0, %c0, %c0], %cf0
-/// {in_bounds = [true, true, true, true],
-/// permutation_map = affine_map<(d0, d1, d2) -> (d1, 0, d2, 0)>} :
-/// tensor<4x4x4xf32>, vector<1x100x4x5xf32>
-/// ```
-/// To:
+/// Store to load forwarding for transfer operations with permutation maps.
+/// Even if the permutation maps and/or the rank of the read/written vectors are
+/// different, we can still propagate the store into the load if the accessed
+/// chunk of the shaped value matches.
+///
+/// The vector.transfer_read op is replaced by 3 ops:
+/// 1. A broadcast of the written vector with all broadcast dims of the reading
+/// op and unit dims for all shaped value dimensions that are not transfer
+/// dimensions of the writing op.
+/// 2. A transposition of the broadcasted value to account for differences
+/// in the permutation maps of the reading/writing op.
+/// 3. An extraction that removes shaped value dimensions that are not transfer
+/// dimensions of the reading op.
+///
+/// Running example:
/// ```
-/// %0 = vector.broadcast %arg1 : vector<4x1xf32> to vector<100x5x4x1xf32>
-/// %r = vector.transpose %0, [3, 0, 2, 1] :
-/// vector<100x5x4x1xf32> to vector<1x100x4x5xf32>
+/// %0 = vector.transfer_write %vec to %s[%a, %b, %c, %d, %e, %f]
+/// {permutation_map = affine_map<(d0, d1, d2, d3, d4, d5)
+/// -> (d2, d1, d4, d5)>}
+/// : vector<5x6x7x8xf32>, tensor<?x?x?x?x?x?xf32>
+/// %1 = vector.transfer_read %0[%a, %b, %c, %d, %e, %f]
+/// {permutation_map = affine_map<(d0, d1, d2, d3, d4, d5)
+/// -> (d1, d2, 0, d4, 0, d5, d0)>}
+/// : tensor<?x?x?x?x?x?xf32>, vector<6x5x100x7x200x8x1xf32>
/// ```
-struct TransferReadAfterWriteToBroadcast
- : public OpRewritePattern<TransferReadOp> {
+struct TransferReadAfterWrite : public OpRewritePattern<TransferReadOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(TransferReadOp readOp,
PatternRewriter &rewriter) const override {
+ Location loc = readOp.getLoc();
if (readOp.hasOutOfBoundsDim() ||
!llvm::isa<RankedTensorType>(readOp.getShapedType()))
return failure();
+ if (readOp.getShapedType().getElementType() !=
+ readOp.getVectorType().getElementType())
+ return failure();
auto defWrite = readOp.getSource().getDefiningOp<vector::TransferWriteOp>();
if (!defWrite)
return failure();
@@ -4116,42 +4122,140 @@ struct TransferReadAfterWriteToBroadcast
if (readOp.getTransferChunkAccessed() !=
defWrite.getTransferChunkAccessed())
return failure();
- // TODO: Support cases where a dim is explicitly written but implicitly
- // read (i.e., a unit dim that is rank reduced).
- if (getUnusedDimsBitVector({readOp.getPermutationMap()}) !=
- getUnusedDimsBitVector({defWrite.getPermutationMap()}))
- return failure();
- if (readOp.getIndices() != defWrite.getIndices() ||
- readOp.getMask() != defWrite.getMask())
+ if (readOp.getIndices() != defWrite.getIndices())
return failure();
- Value vec = defWrite.getVector();
- // TODO: loop through the chain of transfer_write if we can prove that they
- // don't overlap with the transfer_read. This requires improving
- // `isDisjointTransferIndices` helper.
- AffineMap readMap = compressUnusedDims(readOp.getPermutationMap());
- AffineMap writeMap = compressUnusedDims(defWrite.getPermutationMap());
- AffineMap map = readMap.compose(writeMap);
- if (map.getNumResults() == 0)
+ Type elementType = readOp.getVectorType().getElementType();
+ if (elementType != defWrite.getVectorType().getElementType())
return failure();
- // Calculate the permutation to apply to go from the vector stored to the
- // vector read.
- SmallVector<unsigned> permutation;
- if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation))
+ if (defWrite.getShapedType().getElementType() !=
+ defWrite.getVectorType().getElementType())
return failure();
- Location loc = readOp.getLoc();
- // Calculate the broadcast shape by applying the reverse permutation to the
- // final shape we want.
- ArrayRef<int64_t> destShape = readOp.getVectorType().getShape();
- SmallVector<int64_t> broadcastShape(destShape.size());
- for (const auto &pos : llvm::enumerate(permutation))
- broadcastShape[pos.value()] = destShape[pos.index()];
- VectorType broadcastedType = VectorType::get(
- broadcastShape, defWrite.getVectorType().getElementType());
- vec = rewriter.create<vector::BroadcastOp>(loc, broadcastedType, vec);
- SmallVector<int64_t> transposePerm(permutation.begin(), permutation.end());
- rewriter.replaceOpWithNewOp<vector::TransposeOp>(readOp, vec,
- transposePerm);
+ // 1. Add rank-reduced unit dimensions and broadcast dimension to input
+ // vector %vec. Broadcast dimensions are added at the beginning, followed
+ // by rank-reduced unit dims, followed by the dimensions of %vec.
+ //
+ // %bc = vector.broadcast %vec
+ // : vector<5x6x7x8xf32> to vector<100x200x1x1x5x6x7x8xf32>
+ // | | \|
+ // broadcast dims |
+ // |
+ // rank-reduced dims (corresponding to %a and %d)
+
+ // Gather broadcast dimensions of the transfer_read.
+ SmallVector<int64_t> broadcastedShape;
+ int64_t numBroadcastDims = 0;
+ for (int64_t i = 0, e = readOp.getTransferRank(); i < e; ++i) {
+ if (readOp.isBroadcastDim(i)) {
+ broadcastedShape.push_back(readOp.getVectorType().getDimSize(i));
+ ++numBroadcastDims;
+ }
+ }
+ // Append unit dims for rank-reduced (unused) dimensions in the
+ // transfer_write.
+ // Note: `getLeadingShapedRank` is a misnomer: the dimensions that do not
+ // participate in the transfer are not necessarily leading dimensions.
+ broadcastedShape.append(defWrite.getLeadingShapedRank(), 1);
+ // Append input vector (%vec) shape.
+ llvm::append_range(broadcastedShape, defWrite.getVectorType().getShape());
+ // Emit vector.broadcast op.
+ Value broadcasted = rewriter.create<vector::BroadcastOp>(
+ loc, VectorType::get(broadcastedShape, elementType),
+ defWrite.getVector());
+
+ // 2. Transpose the broadcasted vector. Dimensions that are not needed must
+ // be placed at the beginning (because vector.extract can remove only
+ // leading dimensions).
+
+ // Build a mapping (`shapedDimToVecDim`) from shaped value dims to dims of
+ // the broadcasted vector. This is essentially an inverted version of the
+ // transfer_write permutation map that takes into account the newly added
+ // unit dims.
+ // %b %f
+ // \ |
+ // Example: broadcasted vector type: vector<100x200x1x1x5x6x7x8xf32>
+ // / | \ \
+ // / %d | %e
+ // %a %c
+ // mapping = [2, 5, 4, 3, 6, 7]
+
+ // Initialize the mapping with -1.
+ SmallVector<int64_t> shapedDimToVecDim(defWrite.getShapedType().getRank(),
+ -1);
+ // Fill in the dimensions from the inverted transfer_write permutation map.
+ int64_t numUnitDims = defWrite.getLeadingShapedRank();
+ for (const auto &it :
+ llvm::enumerate(defWrite.getPermutationMap().getResults())) {
+ shapedDimToVecDim[cast<AffineDimExpr>(it.value()).getPosition()] =
+ it.index() + numUnitDims + numBroadcastDims;
+ }
+ // Fill in missing unused dims (of the transfer_write) with the broadcasted
+ // unit dims (which are placed right after the broadcast dims).
+ int64_t nextUnitDim = numBroadcastDims;
+ for (int64_t i = 0, e = shapedDimToVecDim.size(); i < e; ++i) {
+ if (shapedDimToVecDim[i] == -1)
+ shapedDimToVecDim[i] = nextUnitDim++;
+ }
+ assert(nextUnitDim == numBroadcastDims + numUnitDims &&
+ "unexpected number of unit dims");
+
+ // Compute permutation. All dims that are not needed by the transfer_read
+ // are placed at the beginning.
+ SmallVector<int64_t> permutation(broadcastedShape.size(), -1);
+ // Helper data structure to keep track of dims that were not used yet.
+ SmallVector<int64_t> remainingDims =
+ llvm::to_vector(llvm::seq<int64_t>(0, broadcastedShape.size()));
+ int64_t numUnneededDims =
+ broadcastedShape.size() - readOp.getVectorType().getRank();
+ int64_t nextBroadcastDim = 0;
+ for (int64_t i = 0, e = readOp.getVectorType().getRank(); i < e; ++i) {
+ if (readOp.isBroadcastDim(i)) {
+ // This transfer_read result dim is a broadcast.
+ permutation[numUnneededDims + i] = nextBroadcastDim;
+ auto it = llvm::find(remainingDims, nextBroadcastDim);
+ assert(it != remainingDims.end() && "could not find broadcast dim");
+ remainingDims.erase(it);
+ nextBroadcastDim++;
+ continue;
+ }
+ // This transfer_read result dim is a dimension of the shape value. Look
+ // up its position in the broadcasted vector in the mapping.
+ int64_t shapedValueDim =
+ cast<AffineDimExpr>(readOp.getPermutationMap().getResult(i))
+ .getPosition();
+ permutation[numUnneededDims + i] = shapedDimToVecDim[shapedValueDim];
+ auto it = llvm::find(remainingDims, shapedDimToVecDim[shapedValueDim]);
+ assert(it != remainingDims.end() && "could not find regular dim");
+ remainingDims.erase(it);
+ }
+
+ // Fill up the dimensions at the beginning with all remaining dims.
+ assert(remainingDims.size() == numUnneededDims &&
+ "unexpected number of remaining dims");
+ for (int64_t i = 0; i < numUnneededDims; ++i) {
+ // All unneeded dims must be unit dimensions. Otherwise, the two transfer
+ // ops would be accessing different chunks.
+ assert(broadcastedShape[remainingDims[i]] == 1 && "expected unit dim");
+ permutation[i] = remainingDims[i];
+ }
+
+ // Build vector.transpose op.
+ //
+ // unneeded dim (%d) broadcast dims
+ // \ / \
+ // %tp = vector.transpose %bc, [3, 5, 4, 0, 6, 1, 7, 2]
+ // : vector<100x200x1x1x5x6x7x8xf32> to vector<1x6x5x100x7x200x8x1xf32>
+ Value transposed = rewriter.create<vector::TransposeOp>(
+ defWrite.getLoc(), broadcasted, permutation);
+
+ // 3. Remove unneeded dims.
+ //
+ // %1 = vector.extract %tp[0]
+ // : vector<6x5x100x7x200x8x1xf32> from vector<1x6x5x100x7x200x8x1xf32>
+ SmallVector<int64_t> extractPositions(numUnneededDims, 0);
+ rewriter.replaceOpWithNewOp<vector::ExtractOp>(readOp, transposed,
+ extractPositions);
+
return success();
}
};
@@ -4159,7 +4263,7 @@ struct TransferReadAfterWriteToBroadcast
void TransferReadOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<TransferReadAfterWriteToBroadcast>(context);
+ results.add<TransferReadAfterWrite>(context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index a6b4f7f2717da81..308e0602ee46295 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -2503,12 +2503,12 @@ func.func @fold_shape_cast_with_constant_mask() -> vector<4xi1>{
// -----
-// TODO: This IR could be canonicalized but the canonicalization pattern is not
-// smart enough. For now, just make sure that we do not crash.
-
// CHECK-LABEL: func.func @load_store_forwarding_rank_mismatch(
-// CHECK: vector.transfer_write
-// CHECK: vector.transfer_read
+// CHECK-SAME: %[[v0:.*]]: vector<4x1x1xf32>
+// CHECK: %[[bc:.*]] = vector.broadcast %[[v0]] : vector<4x1x1xf32> to vector<100x5x4x1x1xf32>
+// CHECK: %[[tp:.*]] = vector.transpose %[[bc]], [4, 3, 0, 2, 1] : vector<100x5x4x1x1xf32> to vector<1x1x100x4x5xf32>
+// CHECK: %[[extract:.*]] = vector.extract %[[tp]][0] : vector<1x100x4x5xf32> from vector<1x1x100x4x5xf32>
+// CHECK: return %[[extract]]
func.func @load_store_forwarding_rank_mismatch(%v0: vector<4x1x1xf32>, %arg0: tensor<4x4x4xf32>) -> (vector<1x100x4x5xf32>) {
%c0 = arith.constant 0 : index
%cf0 = arith.constant 0.0 : f32
More information about the Mlir-commits
mailing list