[Mlir-commits] [mlir] [mlir][vector] Modernize `vector.transpose` op (PR #72594)
Matthias Springer
llvmlistbot at llvm.org
Mon Nov 20 02:20:04 PST 2023
https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/72594
>From a0e42e7d0f2a457d158bccb9aed27e14d9dcd4a4 Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Mon, 20 Nov 2023 19:17:55 +0900
Subject: [PATCH] [mlir][vector] Modernize `vector.transpose` op
* Declare arguments/results with `let` statements.
* Rename `transp` to `permutation`.
* Change type of `transp` from `I64ArrayAttr` to `DenseI64ArrayAttr`.
BEGIN_PUBLIC
No public commit message needed for presubmit.
END_PUBLIC
---
.../mlir/Dialect/Vector/IR/VectorOps.td | 17 +++--
mlir/include/mlir/IR/AffineMap.h | 2 +
.../VectorToArmSME/VectorToArmSME.cpp | 7 +-
.../Conversion/VectorToGPU/VectorToGPU.cpp | 7 +-
.../Dialect/Arith/Transforms/IntNarrowing.cpp | 2 +-
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 68 +++++++------------
.../Transforms/LowerVectorTranspose.cpp | 4 +-
.../Vector/Transforms/VectorTransforms.cpp | 20 +++---
.../Vector/Transforms/VectorUnroll.cpp | 3 +-
mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp | 7 +-
mlir/lib/IR/AffineMap.cpp | 6 ++
11 files changed, 60 insertions(+), 83 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 02e80a6446dfb24..1397d4caf1d9d61 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2436,14 +2436,13 @@ def Vector_TransposeOp :
Vector_Op<"transpose", [Pure,
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
PredOpTrait<"operand and result have same element type",
- TCresVTEtIsSameAsOpBase<0, 0>>]>,
- Arguments<(ins AnyVectorOfAnyRank:$vector, I64ArrayAttr:$transp)>,
- Results<(outs AnyVectorOfAnyRank:$result)> {
+ TCresVTEtIsSameAsOpBase<0, 0>>]> {
let summary = "vector transpose operation";
let description = [{
Takes a n-D vector and returns the transposed n-D vector defined by
the permutation of ranks in the n-sized integer array attribute (in case
of 0-D vectors the array attribute must be empty).
+
In the operation
```mlir
@@ -2452,7 +2451,7 @@ def Vector_TransposeOp :
to vector<d_trans[0] x .. x d_trans[n-1] x f32>
```
- the transp array [i_1, .., i_n] must be a permutation of [0, .., n-1].
+ the `permutation` array [i_1, .., i_n] must be a permutation of [0, .., n-1].
Example:
@@ -2464,8 +2463,13 @@ def Vector_TransposeOp :
[c, f] ]
```
}];
+
+ let arguments = (ins AnyVectorOfAnyRank:$vector,
+ DenseI64ArrayAttr:$permutation);
+ let results = (outs AnyVectorOfAnyRank:$result);
+
let builders = [
- OpBuilder<(ins "Value":$vector, "ArrayRef<int64_t>":$transp)>
+ OpBuilder<(ins "Value":$vector, "ArrayRef<int64_t>":$permutation)>
];
let extraClassDeclaration = [{
VectorType getSourceVectorType() {
@@ -2474,10 +2478,9 @@ def Vector_TransposeOp :
VectorType getResultVectorType() {
return ::llvm::cast<VectorType>(getResult().getType());
}
- void getTransp(SmallVectorImpl<int64_t> &results);
}];
let assemblyFormat = [{
- $vector `,` $transp attr-dict `:` type($vector) `to` type($result)
+ $vector `,` $permutation attr-dict `:` type($vector) `to` type($result)
}];
let hasCanonicalizer = 1;
let hasFolder = 1;
diff --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h
index 0e4a8d363946432..640a52343307bda 100644
--- a/mlir/include/mlir/IR/AffineMap.h
+++ b/mlir/include/mlir/IR/AffineMap.h
@@ -103,6 +103,8 @@ class AffineMap {
/// (i.e. `[1,1,2]` is an invalid permutation).
static AffineMap getPermutationMap(ArrayRef<unsigned> permutation,
MLIRContext *context);
+ static AffineMap getPermutationMap(ArrayRef<int64_t> permutation,
+ MLIRContext *context);
/// Returns an affine map with `numDims` input dimensions and results
/// specified by `targets`.
diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index 953a465c18de69f..01c782676068d9a 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -436,12 +436,9 @@ struct TransposeOpToArmSMELowering
if (!tileType || !arm_sme::isValidSMETileVectorType(tileType))
return failure();
- SmallVector<int64_t> transp;
- for (auto attr : transposeOp.getTransp())
- transp.push_back(cast<IntegerAttr>(attr).getInt());
-
// Bail unless this is a true 2-D matrix transpose.
- if (transp[0] != 1 || transp[1] != 0)
+ ArrayRef<int64_t> permutation = transposeOp.getPermutation();
+ if (permutation[0] != 1 || permutation[1] != 0)
return failure();
OpBuilder::InsertionGuard g(rewriter);
diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
index 1126c2c20758c7a..429d1137b6f3781 100644
--- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
+++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
@@ -473,13 +473,8 @@ struct CombineTransferReadOpTranspose final
if (transferReadOp.getMask() || transferReadOp.hasOutOfBoundsDim())
return rewriter.notifyMatchFailure(op, "not inbounds transfer read");
- SmallVector<int64_t, 2> perm;
- op.getTransp(perm);
- SmallVector<unsigned, 2> permU;
- for (int64_t o : perm)
- permU.push_back(unsigned(o));
AffineMap permutationMap =
- AffineMap::getPermutationMap(permU, op.getContext());
+ AffineMap::getPermutationMap(op.getPermutation(), op.getContext());
AffineMap newMap =
permutationMap.compose(transferReadOp.getPermutationMap());
diff --git a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
index 1084fbc890053b9..79fabd6ed2e99a2 100644
--- a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
@@ -712,7 +712,7 @@ struct ExtensionOverTranspose final : NarrowingPattern<vector::TransposeOp> {
VectorType newTy =
origTy.cloneWith(origTy.getShape(), ext->getInElementType());
Value newTranspose = rewriter.create<vector::TransposeOp>(
- op.getLoc(), newTy, ext->getIn(), op.getTransp());
+ op.getLoc(), newTy, ext->getIn(), op.getPermutation());
ext->recreateAndReplace(rewriter, op, newTranspose);
return success();
}
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 432c11e3c449e0e..6793f902a1e59e5 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1456,9 +1456,8 @@ LogicalResult ExtractFromInsertTransposeChainState::handleTransposeOp() {
if (!nextTransposeOp)
return failure();
- auto permutation = extractVector<unsigned>(nextTransposeOp.getTransp());
- AffineMap m = inversePermutation(
- AffineMap::getPermutationMap(permutation, extractOp.getContext()));
+ AffineMap m = inversePermutation(AffineMap::getPermutationMap(
+ nextTransposeOp.getPermutation(), extractOp.getContext()));
extractPosition = applyPermutationMap(m, ArrayRef(extractPosition));
return success();
}
@@ -5376,20 +5375,20 @@ LogicalResult TypeCastOp::verify() {
//===----------------------------------------------------------------------===//
void vector::TransposeOp::build(OpBuilder &builder, OperationState &result,
- Value vector, ArrayRef<int64_t> transp) {
+ Value vector, ArrayRef<int64_t> permutation) {
VectorType vt = llvm::cast<VectorType>(vector.getType());
SmallVector<int64_t, 4> transposedShape(vt.getRank());
SmallVector<bool, 4> transposedScalableDims(vt.getRank());
- for (unsigned i = 0; i < transp.size(); ++i) {
- transposedShape[i] = vt.getShape()[transp[i]];
- transposedScalableDims[i] = vt.getScalableDims()[transp[i]];
+ for (unsigned i = 0; i < permutation.size(); ++i) {
+ transposedShape[i] = vt.getShape()[permutation[i]];
+ transposedScalableDims[i] = vt.getScalableDims()[permutation[i]];
}
result.addOperands(vector);
result.addTypes(VectorType::get(transposedShape, vt.getElementType(),
transposedScalableDims));
- result.addAttribute(TransposeOp::getTranspAttrName(result.name),
- builder.getI64ArrayAttr(transp));
+ result.addAttribute(TransposeOp::getPermutationAttrName(result.name),
+ builder.getDenseI64ArrayAttr(permutation));
}
OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
@@ -5401,13 +5400,12 @@ OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
// Eliminate identity transpose ops. This happens when the dimensions of the
// input vector remain in their original order after the transpose operation.
- SmallVector<int64_t, 4> transp;
- getTransp(transp);
+ ArrayRef<int64_t> perm = getPermutation();
// Check if the permutation of the dimensions contains sequential values:
// {0, 1, 2, ...}.
- for (int64_t i = 0, e = transp.size(); i < e; i++) {
- if (transp[i] != i)
+ for (int64_t i = 0, e = perm.size(); i < e; i++) {
+ if (perm[i] != i)
return {};
}
@@ -5421,20 +5419,19 @@ LogicalResult vector::TransposeOp::verify() {
if (vectorType.getRank() != rank)
return emitOpError("vector result rank mismatch: ") << rank;
// Verify transposition array.
- auto transpAttr = getTransp().getValue();
- int64_t size = transpAttr.size();
+ ArrayRef<int64_t> perm = getPermutation();
+ int64_t size = perm.size();
if (rank != size)
return emitOpError("transposition length mismatch: ") << size;
SmallVector<bool, 8> seen(rank, false);
- for (const auto &ta : llvm::enumerate(transpAttr)) {
- int64_t i = llvm::cast<IntegerAttr>(ta.value()).getInt();
- if (i < 0 || i >= rank)
- return emitOpError("transposition index out of range: ") << i;
- if (seen[i])
- return emitOpError("duplicate position index: ") << i;
- seen[i] = true;
- if (resultType.getDimSize(ta.index()) != vectorType.getDimSize(i))
- return emitOpError("dimension size mismatch at: ") << i;
+ for (const auto &ta : llvm::enumerate(perm)) {
+ if (ta.value() < 0 || ta.value() >= rank)
+ return emitOpError("transposition index out of range: ") << ta.value();
+ if (seen[ta.value()])
+ return emitOpError("duplicate position index: ") << ta.value();
+ seen[ta.value()] = true;
+ if (resultType.getDimSize(ta.index()) != vectorType.getDimSize(ta.value()))
+ return emitOpError("dimension size mismatch at: ") << ta.value();
}
return success();
}
@@ -5452,13 +5449,6 @@ class TransposeFolder final : public OpRewritePattern<vector::TransposeOp> {
LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
PatternRewriter &rewriter) const override {
- // Wrapper around vector::TransposeOp::getTransp() for cleaner code.
- auto getPermutation = [](vector::TransposeOp transpose) {
- SmallVector<int64_t, 4> permutation;
- transpose.getTransp(permutation);
- return permutation;
- };
-
// Composes two permutations: result[i] = permutation1[permutation2[i]].
auto composePermutations = [](ArrayRef<int64_t> permutation1,
ArrayRef<int64_t> permutation2) {
@@ -5475,12 +5465,11 @@ class TransposeFolder final : public OpRewritePattern<vector::TransposeOp> {
return failure();
SmallVector<int64_t, 4> permutation = composePermutations(
- getPermutation(parentTransposeOp), getPermutation(transposeOp));
+ parentTransposeOp.getPermutation(), transposeOp.getPermutation());
// Replace 'transposeOp' with a new transpose operation.
rewriter.replaceOpWithNewOp<vector::TransposeOp>(
transposeOp, transposeOp.getResult().getType(),
- parentTransposeOp.getVector(),
- vector::getVectorSubscriptAttr(rewriter, permutation));
+ parentTransposeOp.getVector(), permutation);
return success();
}
};
@@ -5539,8 +5528,7 @@ class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> {
// Get the transpose permutation and apply it to the vector.create_mask or
// vector.constant_mask operands.
- SmallVector<int64_t> permutation;
- transpOp.getTransp(permutation);
+ ArrayRef<int64_t> permutation = transpOp.getPermutation();
if (createMaskOp) {
auto maskOperands = createMaskOp.getOperands();
@@ -5583,9 +5571,7 @@ class FoldTransposeWithNonScalableUnitDimsToShapeCast final
PatternRewriter &rewriter) const override {
Value input = transpOp.getVector();
VectorType resType = transpOp.getResultVectorType();
-
- SmallVector<int64_t> permutation;
- transpOp.getTransp(permutation);
+ ArrayRef<int64_t> permutation = transpOp.getPermutation();
if (resType.getRank() == 2 &&
((resType.getShape().front() == 1 &&
@@ -5611,10 +5597,6 @@ void vector::TransposeOp::getCanonicalizationPatterns(
FoldTransposeWithNonScalableUnitDimsToShapeCast>(context);
}
-void vector::TransposeOp::getTransp(SmallVectorImpl<int64_t> &results) {
- populateFromInt64AttrArray(getTransp(), results);
-}
-
//===----------------------------------------------------------------------===//
// ConstantMaskOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
index 25a53b31163432e..9475d273c116260 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
@@ -327,9 +327,7 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
VectorType resType = op.getResultVectorType();
// Set up convenience transposition table.
- SmallVector<int64_t> transp;
- for (auto attr : op.getTransp())
- transp.push_back(cast<IntegerAttr>(attr).getInt());
+ ArrayRef<int64_t> transp = op.getPermutation();
if (isShuffleLike(vectorTransformOptions.vectorTransposeLowering) &&
succeeded(isTranspose2DSlice(op)))
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 713f9cb72c82cec..a20c8aeeb6f7108 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -212,8 +212,7 @@ struct CombineContractABTranspose final
if (!transposeOp)
continue;
AffineMap permutationMap = AffineMap::getPermutationMap(
- extractVector<unsigned>(transposeOp.getTransp()),
- contractOp.getContext());
+ transposeOp.getPermutation(), contractOp.getContext());
map = inversePermutation(permutationMap).compose(map);
*operand = transposeOp.getVector();
changed = true;
@@ -279,13 +278,13 @@ struct CombineContractResultTranspose final
// Accumulator transpose performs f(A) -> B. Contract performs g(C) -> B.
// To index into A in contract, we need revert(f)(g(C)) -> A.
- auto accTMap = AffineMap::getPermutationMap(
- extractVector<unsigned>(accTOp.getTransp()), context);
+ auto accTMap =
+ AffineMap::getPermutationMap(accTOp.getPermutation(), context);
// Contract performs g(C) -> D. Result transpose performs h(D) -> E.
// To index into E in contract, we need h(g(C)) -> E.
- auto resTMap = AffineMap::getPermutationMap(
- extractVector<unsigned>(resTOp.getTransp()), context);
+ auto resTMap =
+ AffineMap::getPermutationMap(resTOp.getPermutation(), context);
auto combinedResMap = resTMap.compose(contractMap);
// The accumulator and result share the same indexing map. So they should be
@@ -490,7 +489,7 @@ struct ReorderElementwiseOpsOnTranspose final
// Make sure all operands are transpose/constant ops and collect their
// transposition maps.
- SmallVector<ArrayAttr> transposeMaps;
+ SmallVector<ArrayRef<int64_t>> transposeMaps;
transposeMaps.reserve(op->getNumOperands());
// Record the initial type before transposition. We'll use its shape later.
// Any type will do here as we will check all transpose maps are the same.
@@ -498,7 +497,7 @@ struct ReorderElementwiseOpsOnTranspose final
for (Value operand : op->getOperands()) {
auto transposeOp = operand.getDefiningOp<vector::TransposeOp>();
if (transposeOp) {
- transposeMaps.push_back(transposeOp.getTransp());
+ transposeMaps.push_back(transposeOp.getPermutation());
srcType = transposeOp.getSourceVectorType();
} else if (!matchPattern(operand, m_Constant())) {
return failure();
@@ -517,7 +516,7 @@ struct ReorderElementwiseOpsOnTranspose final
// If there are constant operands, we need to insert inverse transposes for
// them. Calculate the inverse order first.
- auto order = extractVector<unsigned>(transposeMaps.front());
+ auto order = transposeMaps.front();
SmallVector<int64_t> invOrder(order.size());
for (int i = 0, e = order.size(); i < e; ++i)
invOrder[order[i]] = i;
@@ -532,8 +531,7 @@ struct ReorderElementwiseOpsOnTranspose final
srcType.getShape(),
cast<VectorType>(operand.getType()).getElementType());
srcValues.push_back(rewriter.create<vector::TransposeOp>(
- operand.getLoc(), vectorType, operand,
- rewriter.getI64ArrayAttr(invOrder)));
+ operand.getLoc(), vectorType, operand, invOrder));
}
}
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index 4cfac7de29ee76f..78b041255443c30 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -537,8 +537,7 @@ struct UnrollTransposePattern : public OpRewritePattern<vector::TransposeOp> {
// Prepare the result vector;
Value result = rewriter.create<arith::ConstantOp>(
loc, originalVectorType, rewriter.getZeroAttr(originalVectorType));
- SmallVector<int64_t> permutation;
- transposeOp.getTransp(permutation);
+ ArrayRef<int64_t> permutation = transposeOp.getPermutation();
// Unroll the computation.
for (SmallVector<int64_t> elementOffsets :
diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
index 467a521f9eada96..48cd67ad86c63fb 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
@@ -87,14 +87,11 @@ mlir::vector::isTranspose2DSlice(vector::TransposeOp op) {
if (srcGtOneDims.size() != 2)
return failure();
- SmallVector<int64_t> transp;
- for (auto attr : op.getTransp())
- transp.push_back(cast<IntegerAttr>(attr).getInt());
-
// Check whether the two source vector dimensions that are greater than one
// must be transposed with each other so that we can apply one of the 2-D
// transpose pattens. Otherwise, these patterns are not applicable.
- if (!areDimsTransposedIn2DSlice(srcGtOneDims[0], srcGtOneDims[1], transp))
+ if (!areDimsTransposedIn2DSlice(srcGtOneDims[0], srcGtOneDims[1],
+ op.getPermutation()))
return failure();
return std::pair<int, int>(srcGtOneDims[0], srcGtOneDims[1]);
diff --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp
index 3b3391882d04791..c2804626635947e 100644
--- a/mlir/lib/IR/AffineMap.cpp
+++ b/mlir/lib/IR/AffineMap.cpp
@@ -254,6 +254,12 @@ AffineMap AffineMap::getPermutationMap(ArrayRef<unsigned> permutation,
assert(permutationMap.isPermutation() && "Invalid permutation vector");
return permutationMap;
}
+AffineMap AffineMap::getPermutationMap(ArrayRef<int64_t> permutation,
+ MLIRContext *context) {
+ SmallVector<unsigned> perm = llvm::map_to_vector(
+ permutation, [](int64_t i) { return static_cast<unsigned>(i); });
+ return AffineMap::getPermutationMap(perm, context);
+}
AffineMap AffineMap::getMultiDimMapWithTargets(unsigned numDims,
ArrayRef<unsigned> targets,
More information about the Mlir-commits
mailing list