[Mlir-commits] [mlir] [MLIR] [Vector] Added canonicalizer for folding from_elements + transpose (PR #161841)
Andrzej WarzyĆski
llvmlistbot at llvm.org
Thu Oct 16 06:49:48 PDT 2025
================
@@ -6723,6 +6723,61 @@ class FoldTransposeShapeCast final : public OpRewritePattern<TransposeOp> {
}
};
+/// Folds transpose(from_elements(...)) into a new from_elements with permuted
+/// operands matching the transposed shape.
+class FoldTransposeFromElements final : public OpRewritePattern<TransposeOp> {
+public:
+ using Base::Base;
+ LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
+ PatternRewriter &rewriter) const override {
+ auto fromElementsOp =
+ transposeOp.getVector().getDefiningOp<vector::FromElementsOp>();
+ if (!fromElementsOp)
+ return failure();
+
+ VectorType srcTy = fromElementsOp.getDest().getType();
+ VectorType dstTy = transposeOp.getType();
+
+ ArrayRef<int64_t> permutation = transposeOp.getPermutation();
+ int64_t rank = srcTy.getRank();
+
+ // Build inverse permutation to map destination indices back to source.
+ SmallVector<int64_t, 4> inversePerm(rank, 0);
+ for (int64_t i = 0; i < rank; ++i)
+ inversePerm[permutation[i]] = i;
+
+ ArrayRef<int64_t> srcShape = srcTy.getShape();
+ ArrayRef<int64_t> dstShape = dstTy.getShape();
+ SmallVector<int64_t, 4> srcIdx(rank, 0);
+ SmallVector<int64_t, 4> dstIdx(rank, 0);
+ SmallVector<int64_t, 4> srcStrides = computeStrides(srcShape);
+ SmallVector<int64_t, 4> dstStrides = computeStrides(dstShape);
+
+ auto elements = fromElementsOp.getElements();
+ SmallVector<Value> newElements;
+ int64_t dstNumElements = dstTy.getNumElements();
+ newElements.reserve(dstNumElements);
+
+ // For each element in destination row-major order, pick the corresponding
+ // source element.
+ for (int64_t lin = 0; lin < dstNumElements; ++lin) {
----------------
banach-space wrote:
I find `lin` a bit too enigmatic. Why not `linearIdx`?
https://github.com/llvm/llvm-project/pull/161841
More information about the Mlir-commits
mailing list