[Mlir-commits] [mlir] [mlir][vector] Add extract(transpose(broadcast(x))) canonicalization (PR #72616)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Nov 16 23:32:55 PST 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-gpu
Author: Matthias Springer (matthias-springer)
<details>
<summary>Changes</summary>
Fold IR patterns where an extract is removing a dimension that was added by a broadcast, and where there is a transpose op between these two ops.
Depends on #<!-- -->72594. Review only the top commit.
---
Patch is 24.89 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/72616.diff
12 Files Affected:
- (modified) mlir/include/mlir/Dialect/Vector/IR/VectorOps.td (+17-9)
- (modified) mlir/include/mlir/IR/AffineMap.h (+2)
- (modified) mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp (+2-5)
- (modified) mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp (+1-6)
- (modified) mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp (+1-1)
- (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+96-41)
- (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp (+1-3)
- (modified) mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp (+9-11)
- (modified) mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp (+1-2)
- (modified) mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp (+2-5)
- (modified) mlir/lib/IR/AffineMap.cpp (+6)
- (modified) mlir/test/Dialect/Vector/canonicalize.mlir (+14)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 02e80a6446dfb24..49860cadcd12c26 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -386,10 +386,15 @@ def Vector_BroadcastOp :
return ::llvm::cast<VectorType>(getVector().getType());
}
- /// Return the dimensions of the result vector that were formerly ones in the
- /// source tensor and thus correspond to "dim-1" broadcasting.
+ /// Return the dimensions of the result vector that were formerly ones in
+ /// the source vector and thus correspond to "dim-1" broadcasting.
llvm::SetVector<int64_t> computeBroadcastedUnitDims();
+ /// Return the dimensions of the result vector that were newly added to the
+ /// source vector via rank extension. These are all the dimensions that were
+ /// not "dim-1" broadcasted.
+ llvm::SetVector<int64_t> computeRankExtendedDims();
+
/// Broadcast `value` to a vector of `dstShape`, knowing that exactly the
/// `broadcastedDims` dimensions in the dstShape are broadcasted.
/// This requires (and asserts) that the broadcast is free of dim-1
@@ -2436,14 +2441,13 @@ def Vector_TransposeOp :
Vector_Op<"transpose", [Pure,
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
PredOpTrait<"operand and result have same element type",
- TCresVTEtIsSameAsOpBase<0, 0>>]>,
- Arguments<(ins AnyVectorOfAnyRank:$vector, I64ArrayAttr:$transp)>,
- Results<(outs AnyVectorOfAnyRank:$result)> {
+ TCresVTEtIsSameAsOpBase<0, 0>>]> {
let summary = "vector transpose operation";
let description = [{
Takes a n-D vector and returns the transposed n-D vector defined by
the permutation of ranks in the n-sized integer array attribute (in case
of 0-D vectors the array attribute must be empty).
+
In the operation
```mlir
@@ -2452,7 +2456,7 @@ def Vector_TransposeOp :
to vector<d_trans[0] x .. x d_trans[n-1] x f32>
```
- the transp array [i_1, .., i_n] must be a permutation of [0, .., n-1].
+ the `permutation` array [i_1, .., i_n] must be a permutation of [0, .., n-1].
Example:
@@ -2464,8 +2468,13 @@ def Vector_TransposeOp :
[c, f] ]
```
}];
+
+ let arguments = (ins AnyVectorOfAnyRank:$vector,
+ DenseI64ArrayAttr:$permutation);
+ let results = (outs AnyVectorOfAnyRank:$result);
+
let builders = [
- OpBuilder<(ins "Value":$vector, "ArrayRef<int64_t>":$transp)>
+ OpBuilder<(ins "Value":$vector, "ArrayRef<int64_t>":$permutation)>
];
let extraClassDeclaration = [{
VectorType getSourceVectorType() {
@@ -2474,10 +2483,9 @@ def Vector_TransposeOp :
VectorType getResultVectorType() {
return ::llvm::cast<VectorType>(getResult().getType());
}
- void getTransp(SmallVectorImpl<int64_t> &results);
}];
let assemblyFormat = [{
- $vector `,` $transp attr-dict `:` type($vector) `to` type($result)
+ $vector `,` $permutation attr-dict `:` type($vector) `to` type($result)
}];
let hasCanonicalizer = 1;
let hasFolder = 1;
diff --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h
index 713aef767edf669..981f3d392cbc98c 100644
--- a/mlir/include/mlir/IR/AffineMap.h
+++ b/mlir/include/mlir/IR/AffineMap.h
@@ -103,6 +103,8 @@ class AffineMap {
/// (i.e. `[1,1,2]` is an invalid permutation).
static AffineMap getPermutationMap(ArrayRef<unsigned> permutation,
MLIRContext *context);
+ static AffineMap getPermutationMap(ArrayRef<int64_t> permutation,
+ MLIRContext *context);
/// Returns an affine map with `numDims` input dimensions and results
/// specified by `targets`.
diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index 953a465c18de69f..01c782676068d9a 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -436,12 +436,9 @@ struct TransposeOpToArmSMELowering
if (!tileType || !arm_sme::isValidSMETileVectorType(tileType))
return failure();
- SmallVector<int64_t> transp;
- for (auto attr : transposeOp.getTransp())
- transp.push_back(cast<IntegerAttr>(attr).getInt());
-
// Bail unless this is a true 2-D matrix transpose.
- if (transp[0] != 1 || transp[1] != 0)
+ ArrayRef<int64_t> permutation = transposeOp.getPermutation();
+ if (permutation[0] != 1 || permutation[1] != 0)
return failure();
OpBuilder::InsertionGuard g(rewriter);
diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
index 1126c2c20758c7a..429d1137b6f3781 100644
--- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
+++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
@@ -473,13 +473,8 @@ struct CombineTransferReadOpTranspose final
if (transferReadOp.getMask() || transferReadOp.hasOutOfBoundsDim())
return rewriter.notifyMatchFailure(op, "not inbounds transfer read");
- SmallVector<int64_t, 2> perm;
- op.getTransp(perm);
- SmallVector<unsigned, 2> permU;
- for (int64_t o : perm)
- permU.push_back(unsigned(o));
AffineMap permutationMap =
- AffineMap::getPermutationMap(permU, op.getContext());
+ AffineMap::getPermutationMap(op.getPermutation(), op.getContext());
AffineMap newMap =
permutationMap.compose(transferReadOp.getPermutationMap());
diff --git a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
index 1084fbc890053b9..79fabd6ed2e99a2 100644
--- a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
@@ -712,7 +712,7 @@ struct ExtensionOverTranspose final : NarrowingPattern<vector::TransposeOp> {
VectorType newTy =
origTy.cloneWith(origTy.getShape(), ext->getInElementType());
Value newTranspose = rewriter.create<vector::TransposeOp>(
- op.getLoc(), newTy, ext->getIn(), op.getTransp());
+ op.getLoc(), newTy, ext->getIn(), op.getPermutation());
ext->recreateAndReplace(rewriter, op, newTranspose);
return success();
}
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 95f49fa32bc0ae2..957143d6c13e9e4 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1456,9 +1456,8 @@ LogicalResult ExtractFromInsertTransposeChainState::handleTransposeOp() {
if (!nextTransposeOp)
return failure();
- auto permutation = extractVector<unsigned>(nextTransposeOp.getTransp());
- AffineMap m = inversePermutation(
- AffineMap::getPermutationMap(permutation, extractOp.getContext()));
+ AffineMap m = inversePermutation(AffineMap::getPermutationMap(
+ nextTransposeOp.getPermutation(), extractOp.getContext()));
extractPosition = applyPermutationMap(m, ArrayRef(extractPosition));
return success();
}
@@ -1898,6 +1897,62 @@ class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {
}
};
+/// Canonicalize extract(transpose(broadcast))) constructs, where the broadcast
+/// adds a new dimension and the extraction removes it again.
+class ExtractOpTransposedBroadcastDim final
+ : public OpRewritePattern<ExtractOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ExtractOp extractOp,
+ PatternRewriter &rewriter) const override {
+ // Skip vector.extract ops that do not remove any dimensions.
+ if (extractOp.getNumIndices() == 0)
+ return failure();
+ // Look for extract(transpose(broadcast(x))) pattern.
+ auto transposeOp =
+ extractOp.getVector().getDefiningOp<vector::TransposeOp>();
+ if (!transposeOp || transposeOp.getPermutation().empty())
+ return failure();
+ auto broadcastOp =
+ transposeOp.getVector().getDefiningOp<vector::BroadcastOp>();
+ if (!broadcastOp)
+ return failure();
+ // Check if the first dimension that is being removed by the vector.extract
+ // was added by the vector.broadcast.
+ int64_t removedDim = transposeOp.getPermutation()[0];
+ llvm::SetVector<int64_t> rankExtendedDims =
+ broadcastOp.computeRankExtendedDims();
+ if (!rankExtendedDims.contains(removedDim))
+ return failure();
+
+ // 1. Create new vector.broadcast without the removed dimension.
+ SmallVector<int64_t> newBroadcastShape(
+ broadcastOp.getResultVectorType().getShape());
+ newBroadcastShape.erase(newBroadcastShape.begin() + removedDim);
+ auto newBroadcast = rewriter.create<vector::BroadcastOp>(
+ broadcastOp.getLoc(),
+ VectorType::get(newBroadcastShape,
+ broadcastOp.getResultVectorType().getElementType()),
+ broadcastOp.getSource());
+
+ // 2. Create new vector.transpose.
+ SmallVector<int64_t> newPermutation;
+ for (int64_t dim : transposeOp.getPermutation().drop_front())
+ newPermutation.push_back(dim < transposeOp.getPermutation()[0] ? dim
+ : dim - 1);
+ auto newTranspose = rewriter.create<vector::TransposeOp>(
+ transposeOp.getLoc(), newBroadcast, newPermutation);
+
+ // 3. Create new vector.extract without the outermost dimension.
+ SmallVector<OpFoldResult> mixedPositions = extractOp.getMixedPosition();
+ rewriter.replaceOpWithNewOp<vector::ExtractOp>(
+ extractOp, newTranspose, ArrayRef(mixedPositions).drop_front());
+
+ return success();
+ }
+};
+
// Pattern to rewrite a ExtractOp(splat ConstantOp) -> ConstantOp.
class ExtractOpSplatConstantFolder final : public OpRewritePattern<ExtractOp> {
public:
@@ -2063,7 +2118,8 @@ LogicalResult foldExtractFromShapeCastToShapeCast(ExtractOp extractOp,
void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<ExtractOpSplatConstantFolder, ExtractOpNonSplatConstantFolder,
- ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
+ ExtractOpFromBroadcast, ExtractOpTransposedBroadcastDim,
+ ExtractOpFromCreateMask>(context);
results.add(foldExtractFromShapeCastToShapeCast);
}
@@ -2113,6 +2169,20 @@ llvm::SetVector<int64_t> BroadcastOp::computeBroadcastedUnitDims() {
getResultVectorType().getShape());
}
+llvm::SetVector<int64_t> BroadcastOp::computeRankExtendedDims() {
+ llvm::SetVector<int64_t> broadcastedUnitDims = computeBroadcastedUnitDims();
+ llvm::SetVector<int64_t> result;
+ auto vecSrcType = dyn_cast<VectorType>(getSourceType());
+ int64_t rankDiff =
+ vecSrcType ? getResultVectorType().getRank() - vecSrcType.getRank()
+ : getResultVectorType().getRank();
+ for (int64_t i = 0; i < rankDiff; ++i) {
+ if (!broadcastedUnitDims.contains(i))
+ result.insert(i);
+ }
+ return result;
+}
+
/// Broadcast `value` to a vector of `dstShape`, knowing that exactly the
/// `broadcastedDims` dimensions in the dstShape are broadcasted.
/// This requires (and asserts) that the broadcast is free of dim-1
@@ -5376,20 +5446,20 @@ LogicalResult TypeCastOp::verify() {
//===----------------------------------------------------------------------===//
void vector::TransposeOp::build(OpBuilder &builder, OperationState &result,
- Value vector, ArrayRef<int64_t> transp) {
+ Value vector, ArrayRef<int64_t> permutation) {
VectorType vt = llvm::cast<VectorType>(vector.getType());
SmallVector<int64_t, 4> transposedShape(vt.getRank());
SmallVector<bool, 4> transposedScalableDims(vt.getRank());
- for (unsigned i = 0; i < transp.size(); ++i) {
- transposedShape[i] = vt.getShape()[transp[i]];
- transposedScalableDims[i] = vt.getScalableDims()[transp[i]];
+ for (unsigned i = 0; i < permutation.size(); ++i) {
+ transposedShape[i] = vt.getShape()[permutation[i]];
+ transposedScalableDims[i] = vt.getScalableDims()[permutation[i]];
}
result.addOperands(vector);
result.addTypes(VectorType::get(transposedShape, vt.getElementType(),
transposedScalableDims));
- result.addAttribute(TransposeOp::getTranspAttrName(result.name),
- builder.getI64ArrayAttr(transp));
+ result.addAttribute(TransposeOp::getPermutationAttrName(result.name),
+ builder.getDenseI64ArrayAttr(permutation));
}
OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
@@ -5401,13 +5471,12 @@ OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
// Eliminate identity transpose ops. This happens when the dimensions of the
// input vector remain in their original order after the transpose operation.
- SmallVector<int64_t, 4> transp;
- getTransp(transp);
+ ArrayRef<int64_t> perm = getPermutation();
// Check if the permutation of the dimensions contains sequential values:
// {0, 1, 2, ...}.
- for (int64_t i = 0, e = transp.size(); i < e; i++) {
- if (transp[i] != i)
+ for (int64_t i = 0, e = perm.size(); i < e; i++) {
+ if (perm[i] != i)
return {};
}
@@ -5421,20 +5490,19 @@ LogicalResult vector::TransposeOp::verify() {
if (vectorType.getRank() != rank)
return emitOpError("vector result rank mismatch: ") << rank;
// Verify transposition array.
- auto transpAttr = getTransp().getValue();
- int64_t size = transpAttr.size();
+ ArrayRef<int64_t> perm = getPermutation();
+ int64_t size = perm.size();
if (rank != size)
return emitOpError("transposition length mismatch: ") << size;
SmallVector<bool, 8> seen(rank, false);
- for (const auto &ta : llvm::enumerate(transpAttr)) {
- int64_t i = llvm::cast<IntegerAttr>(ta.value()).getInt();
- if (i < 0 || i >= rank)
- return emitOpError("transposition index out of range: ") << i;
- if (seen[i])
- return emitOpError("duplicate position index: ") << i;
- seen[i] = true;
- if (resultType.getDimSize(ta.index()) != vectorType.getDimSize(i))
- return emitOpError("dimension size mismatch at: ") << i;
+ for (const auto &ta : llvm::enumerate(perm)) {
+ if (ta.value() < 0 || ta.value() >= rank)
+ return emitOpError("transposition index out of range: ") << ta.value();
+ if (seen[ta.value()])
+ return emitOpError("duplicate position index: ") << ta.value();
+ seen[ta.value()] = true;
+ if (resultType.getDimSize(ta.index()) != vectorType.getDimSize(ta.value()))
+ return emitOpError("dimension size mismatch at: ") << ta.value();
}
return success();
}
@@ -5452,13 +5520,6 @@ class TransposeFolder final : public OpRewritePattern<vector::TransposeOp> {
LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
PatternRewriter &rewriter) const override {
- // Wrapper around vector::TransposeOp::getTransp() for cleaner code.
- auto getPermutation = [](vector::TransposeOp transpose) {
- SmallVector<int64_t, 4> permutation;
- transpose.getTransp(permutation);
- return permutation;
- };
-
// Composes two permutations: result[i] = permutation1[permutation2[i]].
auto composePermutations = [](ArrayRef<int64_t> permutation1,
ArrayRef<int64_t> permutation2) {
@@ -5475,12 +5536,11 @@ class TransposeFolder final : public OpRewritePattern<vector::TransposeOp> {
return failure();
SmallVector<int64_t, 4> permutation = composePermutations(
- getPermutation(parentTransposeOp), getPermutation(transposeOp));
+ parentTransposeOp.getPermutation(), transposeOp.getPermutation());
// Replace 'transposeOp' with a new transpose operation.
rewriter.replaceOpWithNewOp<vector::TransposeOp>(
transposeOp, transposeOp.getResult().getType(),
- parentTransposeOp.getVector(),
- vector::getVectorSubscriptAttr(rewriter, permutation));
+ parentTransposeOp.getVector(), permutation);
return success();
}
};
@@ -5539,8 +5599,7 @@ class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> {
// Get the transpose permutation and apply it to the vector.create_mask or
// vector.constant_mask operands.
- SmallVector<int64_t> permutation;
- transpOp.getTransp(permutation);
+ ArrayRef<int64_t> permutation = transpOp.getPermutation();
if (createMaskOp) {
auto maskOperands = createMaskOp.getOperands();
@@ -5572,10 +5631,6 @@ void vector::TransposeOp::getCanonicalizationPatterns(
TransposeFolder, FoldTransposeSplat>(context);
}
-void vector::TransposeOp::getTransp(SmallVectorImpl<int64_t> &results) {
- populateFromInt64AttrArray(getTransp(), results);
-}
-
//===----------------------------------------------------------------------===//
// ConstantMaskOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
index dee786007c80630..97f6caca1b25ccc 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
@@ -327,9 +327,7 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
VectorType resType = op.getResultVectorType();
// Set up convenience transposition table.
- SmallVector<int64_t> transp;
- for (auto attr : op.getTransp())
- transp.push_back(cast<IntegerAttr>(attr).getInt());
+ ArrayRef<int64_t> transp = op.getPermutation();
if (isShuffleLike(vectorTransformOptions.vectorTransposeLowering) &&
succeeded(isTranspose2DSlice(op)))
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 713f9cb72c82cec..a20c8aeeb6f7108 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -212,8 +212,7 @@ struct CombineContractABTranspose final
if (!transposeOp)
continue;
AffineMap permutationMap = AffineMap::getPermutationMap(
- extractVector<unsigned>(transposeOp.getTransp()),
- contractOp.getContext());
+ transposeOp.getPermutation(), contractOp.getContext());
map = inversePermutation(permutationMap).compose(map);
*operand = transposeOp.getVector();
changed = true;
@@ -279,13 +278,13 @@ struct CombineContractResultTranspose final
// Accumulator transpose performs f(A) -> B. Contract performs g(C) -> B.
// To index into A in contract, we need revert(f)(g(C)) -> A.
- auto accTMap = AffineMap::getPermutationMap(
- extractVector<unsigned>(accTOp.getTransp()), context);
+ auto accTMap =
+ AffineMap::getPermutationMap(accTOp.getPermutation(), context);
// Contract performs g(C) -> D. Result transpose performs h(D) -> E.
// To index into E in contract, we need h(g(C)) -> E.
- auto resTMap = AffineMap::getPermutationMap(
- extractVector<unsigned>(resTOp.getTransp()), context);
+ auto resTMap =
+ AffineMap::getPermutationMap(resTOp.getPermutation(), context);
auto combinedResMap = resTMap.compose(contractMap);
// The accumulator and result share the same indexing map. So they should be
@@ -490,7 +489,7 @@ struct ReorderElementwiseOpsOnTranspose final
// Make sure all operands are transpose/constant ops and collect their
// transposition maps.
- SmallVector<ArrayAttr> transposeMaps;
+ SmallVector<ArrayRef<int64_t>> transposeMaps;
transposeMaps.reserve(op->getNumOperands());
// Record the initial type before transposition. We'll use its shape later.
// Any type will do here as we will check all transpose maps are the same.
@@ -498,7 +497,7 @@ struct ReorderElementwiseOpsOnTranspose final
for (Value operand : op->getOperands()) {
auto transposeOp = operand.getDefiningOp<vector::TransposeOp>();
...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/72616
More information about the Mlir-commits
mailing list