[Mlir-commits] [mlir] [mlir][vector]Enable DropUnitDimFromTransposeOp (PR #93007)
Han-Chung Wang
llvmlistbot at llvm.org
Thu May 30 15:03:14 PDT 2024
================
@@ -1695,6 +1695,77 @@ struct DropUnitDimFromElementwiseOps final
}
};
+/// Removes unit dimensions from a transpose op. Generates a vector.shape_cast
+/// on the operand and result to match types.
+///
+/// Ex:
+/// ```
+/// %tr = vector.transpose %arg0, [3, 1, 2, 0]: vector<1x4x1x2xf32> to
+/// vector<2x4x1x1xf32>
+/// ```
+///
+/// gets converted to:
+///
+/// ```
+/// %sc0 = vector.shape_cast %arg0 : vector<1x4x1x2xf32> to vector<4x2xf32>
+/// %tr = vector.transpose %sc0, [1, 0] : vector<4x2xf32> to vector<2x4xf32>
+/// %sc1 = vector.shape_cast %tr : vector<2x4xf32> to vector<2x4x1x1xf32>
+/// ```
+struct DropUnitDimFromTransposeOp final
+ : public OpRewritePattern<vector::TransposeOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
+ PatternRewriter &rewriter) const override {
+ auto sourceVectorType = transposeOp.getSourceVectorType();
+ if (sourceVectorType.getRank() < 2)
+ return failure();
+
+ VectorType newVType = sourceVectorType;
+ SmallVector<int64_t> newPerm =
+ llvm::to_vector(transposeOp.getPermutation());
+ unsigned removedDims = 0;
+ auto shape = sourceVectorType.getShape();
+ for (const auto &dim : llvm::enumerate(shape)) {
+ if (dim.value() == 1 &&
+ !sourceVectorType.getScalableDims()[dim.index()]) {
+ newVType =
+ VectorType::Builder(newVType).dropDim(dim.index() - removedDims);
+ for (unsigned permutationIdx = 0; permutationIdx < newPerm.size();
+ ++permutationIdx) {
+ // Erase from permutation map the dropped unary dimension.
+ if ((unsigned)newPerm[permutationIdx] == dim.index() - removedDims) {
+ newPerm.erase(newPerm.begin() + permutationIdx);
----------------
hanhanW wrote:
Same here, this is inefficient (e.g., O(N^3)).
https://github.com/llvm/llvm-project/pull/93007
More information about the Mlir-commits
mailing list