[Mlir-commits] [mlir] [Vector] Added canonicalizer for folding from_elements + transpose (PR #161841)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Oct 3 06:02:37 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-vector
Author: Keshav Vinayak Jha (keshavvinayak01)
<details>
<summary>Changes</summary>
## Description
Adds a new canonicalizer that folds `vector.from_elements(vector.broadcast))` => `vector.from_elements`. This canonicalization reorders the input elements for `vector.from_elements`, adjusts the output shape to match the effect of the broadcast op and eliminating its need.
## Testing
Added a 2D vector lit test that verifies the working of the rewrite.
---
Full diff: https://github.com/llvm/llvm-project/pull/161841.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+57-1)
- (modified) mlir/test/Dialect/Vector/canonicalize.mlir (+12)
``````````diff
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index b0132e889302f..7f6313c11ea18 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -6723,6 +6723,61 @@ class FoldTransposeShapeCast final : public OpRewritePattern<TransposeOp> {
}
};
+/// Folds transpose(from_elements(...)) into a new from_elements with permuted
+/// operands matching the transposed shape.
+class FoldTransposeFromElements final : public OpRewritePattern<TransposeOp> {
+public:
+ using Base::Base;
+ LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
+ PatternRewriter &rewriter) const override {
+ auto fromElementsOp =
+ transposeOp.getVector().getDefiningOp<vector::FromElementsOp>();
+ if (!fromElementsOp)
+ return failure();
+
+ VectorType srcTy = fromElementsOp.getDest().getType();
+ VectorType dstTy = transposeOp.getType();
+
+ ArrayRef<int64_t> permutation = transposeOp.getPermutation();
+ int64_t rank = srcTy.getRank();
+
+ // Build inverse permutation to map destination indices back to source.
+ SmallVector<int64_t, 4> inversePerm(rank, 0);
+ for (int64_t i = 0; i < rank; ++i)
+ inversePerm[permutation[i]] = i;
+
+ ArrayRef<int64_t> srcShape = srcTy.getShape();
+ ArrayRef<int64_t> dstShape = dstTy.getShape();
+ SmallVector<int64_t, 4> srcIdx(rank, 0);
+ SmallVector<int64_t, 4> dstIdx(rank, 0);
+ SmallVector<int64_t, 4> srcStrides = computeStrides(srcShape);
+ SmallVector<int64_t, 4> dstStrides = computeStrides(dstShape);
+
+ auto elements = fromElementsOp.getElements();
+ SmallVector<Value> newElements;
+ int64_t dstNumElements = dstTy.getNumElements();
+ newElements.reserve(dstNumElements);
+
+ // For each element in destination row-major order, pick the corresponding
+ // source element.
+ for (int64_t lin = 0; lin < dstNumElements; ++lin) {
+ // Pick the destination element index.
+ dstIdx = delinearize(lin, dstStrides);
+ // Map the destination element index to the source element index.
+ for (int64_t j = 0; j < rank; ++j)
+ srcIdx[j] = dstIdx[inversePerm[j]];
+ // Linearize the source element index.
+ int64_t srcLin = linearize(srcIdx, srcStrides);
+ // Add the source element to the new elements.
+ newElements.push_back(elements[srcLin]);
+ }
+
+ rewriter.replaceOpWithNewOp<FromElementsOp>(transposeOp, dstTy,
+ newElements);
+ return success();
+ }
+};
+
/// Folds transpose(broadcast(x)) to broadcast(x) if the transpose is
/// 'order preserving', where 'order preserving' means the flattened
/// inputs and outputs of the transpose have identical (numerical) values.
@@ -6823,7 +6878,8 @@ class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
void vector::TransposeOp::getCanonicalizationPatterns(
RewritePatternSet &results, MLIRContext *context) {
results.add<FoldTransposeCreateMask, FoldTransposeShapeCast, TransposeFolder,
- FoldTransposeSplat, FoldTransposeBroadcast>(context);
+ FoldTransposeSplat, FoldTransposeFromElements,
+ FoldTransposeBroadcast>(context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 5448976f84760..5f34d144cd472 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -308,6 +308,18 @@ func.func @constant_mask_transpose_to_transposed_constant_mask() -> (vector<2x3x
// -----
+// CHECK-LABEL: transpose_from_elements_2d
+func.func @transpose_from_elements_2d(%a0: i32, %a1: i32, %a2: i32,
+ %a3: i32, %a4: i32, %a5: i32) -> vector<3x2xi32> {
+ %v = vector.from_elements %a0, %a1, %a2, %a3, %a4, %a5 : vector<2x3xi32>
+ %t = vector.transpose %v, [1, 0] : vector<2x3xi32> to vector<3x2xi32>
+ return %t : vector<3x2xi32>
+ // CHECK: %[[R:.*]] = vector.from_elements %arg0, %arg3, %arg1, %arg4, %arg2, %arg5 : vector<3x2xi32>
+ // CHECK-NOT: vector.transpose
+}
+
+// -----
+
func.func @extract_strided_slice_of_constant_mask() -> (vector<2x2xi1>) {
%0 = vector.constant_mask [2, 2] : vector<4x3xi1>
%1 = vector.extract_strided_slice %0
``````````
</details>
https://github.com/llvm/llvm-project/pull/161841
More information about the Mlir-commits
mailing list