[Mlir-commits] [mlir] 00092f9 - [MLIR] [Vector] Added canonicalizer for folding from_elements + transpose (#161841)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Oct 21 01:40:02 PDT 2025


Author: Keshav Vinayak Jha
Date: 2025-10-21T14:09:57+05:30
New Revision: 00092f9bdd1d5037a5f4c8f3059e31e32aee6e8d

URL: https://github.com/llvm/llvm-project/commit/00092f9bdd1d5037a5f4c8f3059e31e32aee6e8d
DIFF: https://github.com/llvm/llvm-project/commit/00092f9bdd1d5037a5f4c8f3059e31e32aee6e8d.diff

LOG: [MLIR] [Vector] Added canonicalizer for folding from_elements + transpose (#161841)

## Description
Adds a new canonicalizer that folds
`vector.from_elements(vector.transpose))` => `vector.from_elements`.
This canonicalization reorders the input elements for
`vector.from_elements`, adjusts the output shape to match the effect of
the transpose op and eliminating its need.

## Testing
Added a 2D vector lit test that verifies the working of the rewrite.

---------

Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha at gmail.com>

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/test/Dialect/Vector/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 45c54c7587c69..ad8255a95cb4e 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -6835,6 +6835,73 @@ class FoldTransposeShapeCast final : public OpRewritePattern<TransposeOp> {
   }
 };
 
+/// Folds transpose(from_elements(...)) into a new from_elements with permuted
+/// operands matching the transposed shape.
+///
+/// Example:
+///
+///   %v = vector.from_elements %a00, %a01, %a02, %a10, %a11, %a12 :
+///   vector<2x3xi32> %t = vector.transpose %v, [1, 0] : vector<2x3xi32> to
+///   vector<3x2xi32>
+///
+/// becomes ->
+///
+///   %r = vector.from_elements %a00, %a10, %a01, %a11, %a02, %a12 :
+///   vector<3x2xi32>
+///
+class FoldTransposeFromElements final : public OpRewritePattern<TransposeOp> {
+public:
+  using Base::Base;
+  LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
+                                PatternRewriter &rewriter) const override {
+    auto fromElementsOp =
+        transposeOp.getVector().getDefiningOp<vector::FromElementsOp>();
+    if (!fromElementsOp)
+      return failure();
+
+    VectorType srcTy = fromElementsOp.getDest().getType();
+    VectorType dstTy = transposeOp.getType();
+
+    ArrayRef<int64_t> permutation = transposeOp.getPermutation();
+    int64_t rank = srcTy.getRank();
+
+    // Build inverse permutation to map destination indices back to source.
+    SmallVector<int64_t> inversePerm(rank, 0);
+    for (int64_t i = 0; i < rank; ++i)
+      inversePerm[permutation[i]] = i;
+
+    ArrayRef<int64_t> srcShape = srcTy.getShape();
+    ArrayRef<int64_t> dstShape = dstTy.getShape();
+    SmallVector<int64_t> srcIdx(rank, 0);
+    SmallVector<int64_t> dstIdx(rank, 0);
+    SmallVector<int64_t> srcStrides = computeStrides(srcShape);
+    SmallVector<int64_t> dstStrides = computeStrides(dstShape);
+
+    auto elementsOld = fromElementsOp.getElements();
+    SmallVector<Value> elementsNew;
+    int64_t dstNumElements = dstTy.getNumElements();
+    elementsNew.reserve(dstNumElements);
+
+    // For each element in destination row-major order, pick the corresponding
+    // source element.
+    for (int64_t linearIdx = 0; linearIdx < dstNumElements; ++linearIdx) {
+      // Pick the destination element index.
+      dstIdx = delinearize(linearIdx, dstStrides);
+      // Map the destination element index to the source element index.
+      for (int64_t j = 0; j < rank; ++j)
+        srcIdx[j] = dstIdx[inversePerm[j]];
+      // Linearize the source element index.
+      int64_t srcLin = linearize(srcIdx, srcStrides);
+      // Add the source element to the new elements.
+      elementsNew.push_back(elementsOld[srcLin]);
+    }
+
+    rewriter.replaceOpWithNewOp<FromElementsOp>(transposeOp, dstTy,
+                                                elementsNew);
+    return success();
+  }
+};
+
 /// Folds transpose(broadcast(x)) to broadcast(x) if the transpose is
 /// 'order preserving', where 'order preserving' means the flattened
 /// inputs and outputs of the transpose have identical (numerical) values.
@@ -6935,7 +7002,8 @@ class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
 void vector::TransposeOp::getCanonicalizationPatterns(
     RewritePatternSet &results, MLIRContext *context) {
   results.add<FoldTransposeCreateMask, FoldTransposeShapeCast, TransposeFolder,
-              FoldTransposeSplat, FoldTransposeBroadcast>(context);
+              FoldTransposeSplat, FoldTransposeFromElements,
+              FoldTransposeBroadcast>(context);
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 59774f92cac36..084f49fca212f 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -3530,6 +3530,62 @@ func.func @from_elements_index_to_i64_conversion() -> vector<3xi64> {
 
 // -----
 
+// +---------------------------------------------------------------------------
+// Tests for FoldTransposeFromElements
+// +---------------------------------------------------------------------------
+
+// CHECK-LABEL: transpose_from_elements_1d
+// CHECK-SAME:  %[[EL_0:.*]]: i32, %[[EL_1:.*]]: i32 
+func.func @transpose_from_elements_1d(%el_0: i32, %el_1: i32) -> vector<2xi32> {
+  %v = vector.from_elements %el_0, %el_1 : vector<2xi32>
+  %t = vector.transpose %v, [0] : vector<2xi32> to vector<2xi32>
+  return %t : vector<2xi32>
+  // CHECK: %[[R:.*]] = vector.from_elements %[[EL_0]], %[[EL_1]] : vector<2xi32>
+  // CHECK-NOT: vector.transpose
+  // CHECK: return %[[R]]
+}
+
+// CHECK-LABEL: transpose_from_elements_2d
+// CHECK-SAME:  %[[EL_0_0:.*]]: i32, %[[EL_0_1:.*]]: i32, %[[EL_0_2:.*]]: i32, %[[EL_1_0:.*]]: i32, %[[EL_1_1:.*]]: i32, %[[EL_1_2:.*]]: i32 
+func.func @transpose_from_elements_2d(
+  %el_0_0: i32, %el_0_1: i32, %el_0_2: i32,
+  %el_1_0: i32, %el_1_1: i32, %el_1_2: i32
+) -> vector<3x2xi32> {
+  %v = vector.from_elements %el_0_0, %el_0_1, %el_0_2, %el_1_0, %el_1_1, %el_1_2 : vector<2x3xi32>
+  %t = vector.transpose %v, [1, 0] : vector<2x3xi32> to vector<3x2xi32>
+  return %t : vector<3x2xi32>
+  // CHECK: %[[R:.*]] = vector.from_elements %[[EL_0_0:.*]], %[[EL_1_0:.*]], %[[EL_0_1:.*]], %[[EL_1_1:.*]], %[[EL_0_2:.*]], %[[EL_1_2:.*]] : vector<3x2xi32>
+  // CHECK-NOT: vector.transpose
+  // CHECK: return %[[R]]
+}
+
+// CHECK-LABEL: transpose_from_elements_3d
+// CHECK-SAME:  %[[EL_0_0_0:.*]]: i32, %[[EL_0_0_1:.*]]: i32, %[[EL_0_1_0:.*]]: i32, %[[EL_0_1_1:.*]]: i32, %[[EL_0_2_0:.*]]: i32, %[[EL_0_2_1:.*]]: i32, %[[EL_1_0_0:.*]]: i32, %[[EL_1_0_1:.*]]: i32, %[[EL_1_1_0:.*]]: i32, %[[EL_1_1_1:.*]]: i32, %[[EL_1_2_0:.*]]: i32, %[[EL_1_2_1:.*]]: i32 
+func.func @transpose_from_elements_3d(
+  %el_0_0_0: i32, %el_0_0_1: i32, %el_0_1_0: i32, %el_0_1_1: i32, %el_0_2_0: i32, %el_0_2_1: i32,
+  %el_1_0_0: i32, %el_1_0_1: i32, %el_1_1_0: i32, %el_1_1_1: i32, %el_1_2_0: i32, %el_1_2_1: i32
+) -> vector<2x2x3xi32> {
+  %v = vector.from_elements
+    %el_0_0_0, %el_0_0_1,
+    %el_0_1_0, %el_0_1_1,
+    %el_0_2_0, %el_0_2_1,
+    %el_1_0_0, %el_1_0_1,
+    %el_1_1_0, %el_1_1_1,
+    %el_1_2_0, %el_1_2_1
+    : vector<2x3x2xi32>
+  %t = vector.transpose %v, [0, 2, 1] : vector<2x3x2xi32> to vector<2x2x3xi32>
+  return %t : vector<2x2x3xi32>
+  // CHECK: %[[R:.*]] = vector.from_elements %[[EL_0_0_0:.*]], %[[EL_0_1_0:.*]], %[[EL_0_2_0:.*]], %[[EL_0_0_1:.*]], %[[EL_0_1_1:.*]], %[[EL_0_2_1:.*]], %[[EL_1_0_0:.*]], %[[EL_1_1_0:.*]], %[[EL_1_2_0:.*]], %[[EL_1_0_1:.*]], %[[EL_1_1_1:.*]], %[[EL_1_2_1:.*]] : vector<2x2x3xi32>
+  // CHECK-NOT: vector.transpose
+  // CHECK: return %[[R]]
+}
+
+// +---------------------------------------------------------------------------
+// End of  Tests for FoldTransposeFromElements
+// +---------------------------------------------------------------------------
+
+// -----
+
 // Not a DenseElementsAttr, don't fold.
 
 // CHECK-LABEL: func @negative_insert_llvm_undef(


        


More information about the Mlir-commits mailing list