[Mlir-commits] [mlir] [MLIR] [Vector] Added canonicalizer for folding from_elements + transpose (PR #161841)

Keshav Vinayak Jha llvmlistbot at llvm.org
Tue Oct 21 01:20:02 PDT 2025


https://github.com/keshavvinayak01 updated https://github.com/llvm/llvm-project/pull/161841

>From c94bbb7846d0885a63d01b07ed7c8e362fd49689 Mon Sep 17 00:00:00 2001
From: Keshav Vinayak Jha <keshavvinayakjha at gmail.com>
Date: Fri, 3 Oct 2025 05:52:21 -0700
Subject: [PATCH 1/6] Added canonicalization (vector.from_elements +
 vector.transpose -> vector.transpose)

Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha at gmail.com>
---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp   | 61 +++++++++++++++++++++-
 mlir/test/Dialect/Vector/canonicalize.mlir | 12 +++++
 2 files changed, 72 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index b0132e889302f..31246f5da49b1 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2499,6 +2499,7 @@ static OpFoldResult foldFromElementsToConstant(FromElementsOp fromElementsOp,
   return DenseElementsAttr::get(destVecType, convertedElements);
 }
 
+
 OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) {
   if (auto res = foldFromElementsToElements(*this))
     return res;
@@ -6723,6 +6724,63 @@ class FoldTransposeShapeCast final : public OpRewritePattern<TransposeOp> {
   }
 };
 
+/// Folds transpose(from_elements(...)) into a new from_elements with permuted
+/// operands matching the transposed shape.
+class FoldTransposeFromElements final
+    : public OpRewritePattern<TransposeOp> {
+public:
+
+using Base::Base;
+  LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
+                                PatternRewriter &rewriter) const override {
+    auto fromElementsOp =
+        transposeOp.getVector().getDefiningOp<vector::FromElementsOp>();
+    if (!fromElementsOp)
+      return failure();
+
+    VectorType srcTy = fromElementsOp.getDest().getType();
+    VectorType dstTy = transposeOp.getType();
+
+    ArrayRef<int64_t> permutation = transposeOp.getPermutation();
+    int64_t rank = srcTy.getRank();
+
+    // Build inverse permutation to map destination indices back to source.
+    SmallVector<int64_t, 4> inversePerm(rank, 0);
+    for (int64_t i = 0; i < rank; ++i)
+      inversePerm[permutation[i]] = i;
+
+    ArrayRef<int64_t> srcShape = srcTy.getShape();
+    ArrayRef<int64_t> dstShape = dstTy.getShape();
+    SmallVector<int64_t, 4> srcIdx(rank, 0);
+    SmallVector<int64_t, 4> dstIdx(rank, 0);
+    SmallVector<int64_t, 4> srcStrides = computeStrides(srcShape);
+    SmallVector<int64_t, 4> dstStrides = computeStrides(dstShape);
+
+    auto elements = fromElementsOp.getElements();
+    SmallVector<Value> newElements;
+    int64_t dstNumElements = dstTy.getNumElements();
+    newElements.reserve(dstNumElements);
+
+    // For each element in destination row-major order, pick the corresponding
+    // source element.
+    for (int64_t lin = 0; lin < dstNumElements; ++lin) {
+      // Pick the destination element index.
+      dstIdx = delinearize(lin, dstStrides);
+      // Map the destination element index to the source element index.
+      for (int64_t j = 0; j < rank; ++j)
+        srcIdx[j] = dstIdx[inversePerm[j]];
+      // Linearize the source element index.
+      int64_t srcLin = linearize(srcIdx, srcStrides);
+      // Add the source element to the new elements.
+      newElements.push_back(elements[srcLin]);
+    }
+
+    rewriter.replaceOpWithNewOp<FromElementsOp>(transposeOp, dstTy,
+                                                        newElements);
+    return success();
+  }
+};
+
 /// Folds transpose(broadcast(x)) to broadcast(x) if the transpose is
 /// 'order preserving', where 'order preserving' means the flattened
 /// inputs and outputs of the transpose have identical (numerical) values.
@@ -6823,7 +6881,8 @@ class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
 void vector::TransposeOp::getCanonicalizationPatterns(
     RewritePatternSet &results, MLIRContext *context) {
   results.add<FoldTransposeCreateMask, FoldTransposeShapeCast, TransposeFolder,
-              FoldTransposeSplat, FoldTransposeBroadcast>(context);
+              FoldTransposeSplat, FoldTransposeFromElements,
+              FoldTransposeBroadcast>(context);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 5448976f84760..5f34d144cd472 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -308,6 +308,18 @@ func.func @constant_mask_transpose_to_transposed_constant_mask() -> (vector<2x3x
 
 // -----
 
+// CHECK-LABEL: transpose_from_elements_2d
+func.func @transpose_from_elements_2d(%a0: i32, %a1: i32, %a2: i32,
+                                      %a3: i32, %a4: i32, %a5: i32) -> vector<3x2xi32> {
+  %v = vector.from_elements %a0, %a1, %a2, %a3, %a4, %a5 : vector<2x3xi32>
+  %t = vector.transpose %v, [1, 0] : vector<2x3xi32> to vector<3x2xi32>
+  return %t : vector<3x2xi32>
+  // CHECK: %[[R:.*]] = vector.from_elements %arg0, %arg3, %arg1, %arg4, %arg2, %arg5 : vector<3x2xi32>
+  // CHECK-NOT: vector.transpose
+}
+
+// -----
+
 func.func @extract_strided_slice_of_constant_mask() -> (vector<2x2xi1>) {
   %0 = vector.constant_mask [2, 2] : vector<4x3xi1>
   %1 = vector.extract_strided_slice %0

>From 6bef6d259f8abf82d48092eae1404d6a2ebbfac7 Mon Sep 17 00:00:00 2001
From: Keshav Vinayak Jha <keshavvinayakjha at gmail.com>
Date: Fri, 3 Oct 2025 05:54:15 -0700
Subject: [PATCH 2/6] Formatted

Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha at gmail.com>
---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 9 +++------
 1 file changed, 3 insertions(+), 6 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 31246f5da49b1..7f6313c11ea18 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2499,7 +2499,6 @@ static OpFoldResult foldFromElementsToConstant(FromElementsOp fromElementsOp,
   return DenseElementsAttr::get(destVecType, convertedElements);
 }
 
-
 OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) {
   if (auto res = foldFromElementsToElements(*this))
     return res;
@@ -6726,11 +6725,9 @@ class FoldTransposeShapeCast final : public OpRewritePattern<TransposeOp> {
 
 /// Folds transpose(from_elements(...)) into a new from_elements with permuted
 /// operands matching the transposed shape.
-class FoldTransposeFromElements final
-    : public OpRewritePattern<TransposeOp> {
+class FoldTransposeFromElements final : public OpRewritePattern<TransposeOp> {
 public:
-
-using Base::Base;
+  using Base::Base;
   LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
                                 PatternRewriter &rewriter) const override {
     auto fromElementsOp =
@@ -6776,7 +6773,7 @@ using Base::Base;
     }
 
     rewriter.replaceOpWithNewOp<FromElementsOp>(transposeOp, dstTy,
-                                                        newElements);
+                                                newElements);
     return success();
   }
 };

>From 70d3d8f7ef66a595f8d4072af9cbfbafd2fe33eb Mon Sep 17 00:00:00 2001
From: Keshav Vinayak Jha <keshavvinayakjha at gmail.com>
Date: Thu, 16 Oct 2025 03:58:20 -0700
Subject: [PATCH 3/6] Addressed comments: 1. Minor nitpicks in code formatting.
 2. More lit tests, convering 1D, 2D, 3D cases.

Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha at gmail.com>
---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp   | 20 ++++++-------
 mlir/test/Dialect/Vector/canonicalize.mlir | 33 ++++++++++++++++++----
 2 files changed, 38 insertions(+), 15 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 7f6313c11ea18..75e3a79b22aa9 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -6742,21 +6742,21 @@ class FoldTransposeFromElements final : public OpRewritePattern<TransposeOp> {
     int64_t rank = srcTy.getRank();
 
     // Build inverse permutation to map destination indices back to source.
-    SmallVector<int64_t, 4> inversePerm(rank, 0);
+    SmallVector<int64_t> inversePerm(rank, 0);
     for (int64_t i = 0; i < rank; ++i)
       inversePerm[permutation[i]] = i;
 
     ArrayRef<int64_t> srcShape = srcTy.getShape();
     ArrayRef<int64_t> dstShape = dstTy.getShape();
-    SmallVector<int64_t, 4> srcIdx(rank, 0);
-    SmallVector<int64_t, 4> dstIdx(rank, 0);
-    SmallVector<int64_t, 4> srcStrides = computeStrides(srcShape);
-    SmallVector<int64_t, 4> dstStrides = computeStrides(dstShape);
+    SmallVector<int64_t> srcIdx(rank, 0);
+    SmallVector<int64_t> dstIdx(rank, 0);
+    SmallVector<int64_t> srcStrides = computeStrides(srcShape);
+    SmallVector<int64_t> dstStrides = computeStrides(dstShape);
 
-    auto elements = fromElementsOp.getElements();
-    SmallVector<Value> newElements;
+    auto elementsOld = fromElementsOp.getElements();
+    SmallVector<Value> elementsNew;
     int64_t dstNumElements = dstTy.getNumElements();
-    newElements.reserve(dstNumElements);
+    elementsNew.reserve(dstNumElements);
 
     // For each element in destination row-major order, pick the corresponding
     // source element.
@@ -6769,11 +6769,11 @@ class FoldTransposeFromElements final : public OpRewritePattern<TransposeOp> {
       // Linearize the source element index.
       int64_t srcLin = linearize(srcIdx, srcStrides);
       // Add the source element to the new elements.
-      newElements.push_back(elements[srcLin]);
+      elementsNew.push_back(elementsOld[srcLin]);
     }
 
     rewriter.replaceOpWithNewOp<FromElementsOp>(transposeOp, dstTy,
-                                                newElements);
+                                                elementsNew);
     return success();
   }
 };
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 5f34d144cd472..d3b92ffb8cc88 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -308,16 +308,39 @@ func.func @constant_mask_transpose_to_transposed_constant_mask() -> (vector<2x3x
 
 // -----
 
+// CHECK-LABEL: transpose_from_elements_1d
+func.func @transpose_from_elements_1d(%arg0: i32, %arg1: i32) -> vector<2xi32> {
+  %v = vector.from_elements %arg0, %arg1 : vector<2xi32>
+  %t = vector.transpose %v, [0] : vector<2xi32> to vector<2xi32>
+  return %t : vector<2xi32>
+  // CHECK: %[[R:.*]] = vector.from_elements %arg0, %arg1 : vector<2xi32>
+  // CHECK-NOT: vector.transpose
+}
+
 // CHECK-LABEL: transpose_from_elements_2d
-func.func @transpose_from_elements_2d(%a0: i32, %a1: i32, %a2: i32,
-                                      %a3: i32, %a4: i32, %a5: i32) -> vector<3x2xi32> {
-  %v = vector.from_elements %a0, %a1, %a2, %a3, %a4, %a5 : vector<2x3xi32>
-  %t = vector.transpose %v, [1, 0] : vector<2x3xi32> to vector<3x2xi32>
-  return %t : vector<3x2xi32>
+func.func @transpose_from_elements_2d(
+  %arg0: i32, %arg1: i32, %arg2: i32,
+  %arg3: i32, %arg4: i32, %arg5: i32
+) -> vector<3x2xi32> {
+  %arg6 = vector.from_elements %arg0, %arg1, %arg2, %arg3, %arg4, %arg5 : vector<2x3xi32>
+  %arg7 = vector.transpose %arg6, [1, 0] : vector<2x3xi32> to vector<3x2xi32>
+  return %arg7 : vector<3x2xi32>
   // CHECK: %[[R:.*]] = vector.from_elements %arg0, %arg3, %arg1, %arg4, %arg2, %arg5 : vector<3x2xi32>
   // CHECK-NOT: vector.transpose
 }
 
+// CHECK-LABEL: transpose_from_elements_3d
+func.func @transpose_from_elements_3d(
+  %arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32, %arg4: i32, %arg5: i32,
+  %arg6: i32, %arg7: i32, %arg8: i32, %arg9: i32, %arg10: i32, %arg11: i32
+) -> vector<2x2x3xi32> {
+  %arg12 = vector.from_elements %arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11 : vector<2x3x2xi32>
+  %arg13 = vector.transpose %arg12, [0, 2, 1] : vector<2x3x2xi32> to vector<2x2x3xi32>
+  return %arg13 : vector<2x2x3xi32>
+  // CHECK: %[[R:.*]] = vector.from_elements %arg0, %arg2, %arg4, %arg1, %arg3, %arg5, %arg6, %arg8, %arg10, %arg7, %arg9, %arg11 : vector<2x2x3xi32>
+  // CHECK-NOT: vector.transpose
+}
+
 // -----
 
 func.func @extract_strided_slice_of_constant_mask() -> (vector<2x2xi1>) {

>From 617267b40f96cd2f21064a1c56125a5afb7e5217 Mon Sep 17 00:00:00 2001
From: Keshav Vinayak Jha <keshavvinayakjha at gmail.com>
Date: Thu, 16 Oct 2025 08:06:55 -0700
Subject: [PATCH 4/6] Explainable arg names in lit test

Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha at gmail.com>
---
 mlir/test/Dialect/Vector/canonicalize.mlir | 37 +++++++++++++---------
 1 file changed, 22 insertions(+), 15 deletions(-)

diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index d3b92ffb8cc88..e51eeb9fabbb8 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -309,35 +309,42 @@ func.func @constant_mask_transpose_to_transposed_constant_mask() -> (vector<2x3x
 // -----
 
 // CHECK-LABEL: transpose_from_elements_1d
-func.func @transpose_from_elements_1d(%arg0: i32, %arg1: i32) -> vector<2xi32> {
-  %v = vector.from_elements %arg0, %arg1 : vector<2xi32>
+func.func @transpose_from_elements_1d(%el_0: i32, %el_1: i32) -> vector<2xi32> {
+  %v = vector.from_elements %el_0, %el_1 : vector<2xi32>
   %t = vector.transpose %v, [0] : vector<2xi32> to vector<2xi32>
   return %t : vector<2xi32>
-  // CHECK: %[[R:.*]] = vector.from_elements %arg0, %arg1 : vector<2xi32>
+  // CHECK: %[[R:.*]] = vector.from_elements %[[EL_0:.*]], %[[EL_1:.*]] : vector<2xi32>
   // CHECK-NOT: vector.transpose
 }
 
 // CHECK-LABEL: transpose_from_elements_2d
 func.func @transpose_from_elements_2d(
-  %arg0: i32, %arg1: i32, %arg2: i32,
-  %arg3: i32, %arg4: i32, %arg5: i32
+  %el_0_0: i32, %el_0_1: i32, %el_0_2: i32,
+  %el_1_0: i32, %el_1_1: i32, %el_1_2: i32
 ) -> vector<3x2xi32> {
-  %arg6 = vector.from_elements %arg0, %arg1, %arg2, %arg3, %arg4, %arg5 : vector<2x3xi32>
-  %arg7 = vector.transpose %arg6, [1, 0] : vector<2x3xi32> to vector<3x2xi32>
-  return %arg7 : vector<3x2xi32>
-  // CHECK: %[[R:.*]] = vector.from_elements %arg0, %arg3, %arg1, %arg4, %arg2, %arg5 : vector<3x2xi32>
+  %v = vector.from_elements %el_0_0, %el_0_1, %el_0_2, %el_1_0, %el_1_1, %el_1_2 : vector<2x3xi32>
+  %t = vector.transpose %v, [1, 0] : vector<2x3xi32> to vector<3x2xi32>
+  return %t : vector<3x2xi32>
+  // CHECK: %[[R:.*]] = vector.from_elements %[[EL_0_0:.*]], %[[EL_1_0:.*]], %[[EL_0_1:.*]], %[[EL_1_1:.*]], %[[EL_0_2:.*]], %[[EL_1_2:.*]] : vector<3x2xi32>
   // CHECK-NOT: vector.transpose
 }
 
 // CHECK-LABEL: transpose_from_elements_3d
 func.func @transpose_from_elements_3d(
-  %arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32, %arg4: i32, %arg5: i32,
-  %arg6: i32, %arg7: i32, %arg8: i32, %arg9: i32, %arg10: i32, %arg11: i32
+  %el_0_0_0: i32, %el_0_0_1: i32, %el_0_1_0: i32, %el_0_1_1: i32, %el_0_2_0: i32, %el_0_2_1: i32,
+  %el_1_0_0: i32, %el_1_0_1: i32, %el_1_1_0: i32, %el_1_1_1: i32, %el_1_2_0: i32, %el_1_2_1: i32
 ) -> vector<2x2x3xi32> {
-  %arg12 = vector.from_elements %arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11 : vector<2x3x2xi32>
-  %arg13 = vector.transpose %arg12, [0, 2, 1] : vector<2x3x2xi32> to vector<2x2x3xi32>
-  return %arg13 : vector<2x2x3xi32>
-  // CHECK: %[[R:.*]] = vector.from_elements %arg0, %arg2, %arg4, %arg1, %arg3, %arg5, %arg6, %arg8, %arg10, %arg7, %arg9, %arg11 : vector<2x2x3xi32>
+  %v = vector.from_elements
+    %el_0_0_0, %el_0_0_1,
+    %el_0_1_0, %el_0_1_1,
+    %el_0_2_0, %el_0_2_1,
+    %el_1_0_0, %el_1_0_1,
+    %el_1_1_0, %el_1_1_1,
+    %el_1_2_0, %el_1_2_1
+    : vector<2x3x2xi32>
+  %t = vector.transpose %v, [0, 2, 1] : vector<2x3x2xi32> to vector<2x2x3xi32>
+  return %t : vector<2x2x3xi32>
+  // CHECK: %[[R:.*]] = vector.from_elements %[[EL_0_0_0:.*]], %[[EL_0_1_0:.*]], %[[EL_0_2_0:.*]], %[[EL_0_0_1:.*]], %[[EL_0_1_1:.*]], %[[EL_0_2_1:.*]], %[[EL_1_0_0:.*]], %[[EL_1_1_0:.*]], %[[EL_1_2_0:.*]], %[[EL_1_0_1:.*]], %[[EL_1_1_1:.*]], %[[EL_1_2_1:.*]] : vector<2x2x3xi32>
   // CHECK-NOT: vector.transpose
 }
 

>From 2889f3d6795f562cd611b5f351ae4e8abc02c0fb Mon Sep 17 00:00:00 2001
From: Keshav Vinayak Jha <keshavvinayakjha at gmail.com>
Date: Fri, 17 Oct 2025 06:11:59 -0700
Subject: [PATCH 5/6] Addressed Comments: 1. Changed variable name of linearIdx
 iterator. 2. Moved canonicalizer lit tests to other vector.from_elements
 tests. 3. Added blocked comments signaling beginning, end, and name of the
 pattern.

Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha at gmail.com>
---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp   |  4 +-
 mlir/test/Dialect/Vector/canonicalize.mlir | 98 ++++++++++++----------
 2 files changed, 58 insertions(+), 44 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 75e3a79b22aa9..7c588a435aa1a 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -6760,9 +6760,9 @@ class FoldTransposeFromElements final : public OpRewritePattern<TransposeOp> {
 
     // For each element in destination row-major order, pick the corresponding
     // source element.
-    for (int64_t lin = 0; lin < dstNumElements; ++lin) {
+    for (int64_t linearIdx = 0; linearIdx < dstNumElements; ++linearIdx) {
       // Pick the destination element index.
-      dstIdx = delinearize(lin, dstStrides);
+      dstIdx = delinearize(linearIdx, dstStrides);
       // Map the destination element index to the source element index.
       for (int64_t j = 0; j < rank; ++j)
         srcIdx[j] = dstIdx[inversePerm[j]];
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index e51eeb9fabbb8..d5ae12f159a88 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -308,48 +308,6 @@ func.func @constant_mask_transpose_to_transposed_constant_mask() -> (vector<2x3x
 
 // -----
 
-// CHECK-LABEL: transpose_from_elements_1d
-func.func @transpose_from_elements_1d(%el_0: i32, %el_1: i32) -> vector<2xi32> {
-  %v = vector.from_elements %el_0, %el_1 : vector<2xi32>
-  %t = vector.transpose %v, [0] : vector<2xi32> to vector<2xi32>
-  return %t : vector<2xi32>
-  // CHECK: %[[R:.*]] = vector.from_elements %[[EL_0:.*]], %[[EL_1:.*]] : vector<2xi32>
-  // CHECK-NOT: vector.transpose
-}
-
-// CHECK-LABEL: transpose_from_elements_2d
-func.func @transpose_from_elements_2d(
-  %el_0_0: i32, %el_0_1: i32, %el_0_2: i32,
-  %el_1_0: i32, %el_1_1: i32, %el_1_2: i32
-) -> vector<3x2xi32> {
-  %v = vector.from_elements %el_0_0, %el_0_1, %el_0_2, %el_1_0, %el_1_1, %el_1_2 : vector<2x3xi32>
-  %t = vector.transpose %v, [1, 0] : vector<2x3xi32> to vector<3x2xi32>
-  return %t : vector<3x2xi32>
-  // CHECK: %[[R:.*]] = vector.from_elements %[[EL_0_0:.*]], %[[EL_1_0:.*]], %[[EL_0_1:.*]], %[[EL_1_1:.*]], %[[EL_0_2:.*]], %[[EL_1_2:.*]] : vector<3x2xi32>
-  // CHECK-NOT: vector.transpose
-}
-
-// CHECK-LABEL: transpose_from_elements_3d
-func.func @transpose_from_elements_3d(
-  %el_0_0_0: i32, %el_0_0_1: i32, %el_0_1_0: i32, %el_0_1_1: i32, %el_0_2_0: i32, %el_0_2_1: i32,
-  %el_1_0_0: i32, %el_1_0_1: i32, %el_1_1_0: i32, %el_1_1_1: i32, %el_1_2_0: i32, %el_1_2_1: i32
-) -> vector<2x2x3xi32> {
-  %v = vector.from_elements
-    %el_0_0_0, %el_0_0_1,
-    %el_0_1_0, %el_0_1_1,
-    %el_0_2_0, %el_0_2_1,
-    %el_1_0_0, %el_1_0_1,
-    %el_1_1_0, %el_1_1_1,
-    %el_1_2_0, %el_1_2_1
-    : vector<2x3x2xi32>
-  %t = vector.transpose %v, [0, 2, 1] : vector<2x3x2xi32> to vector<2x2x3xi32>
-  return %t : vector<2x2x3xi32>
-  // CHECK: %[[R:.*]] = vector.from_elements %[[EL_0_0_0:.*]], %[[EL_0_1_0:.*]], %[[EL_0_2_0:.*]], %[[EL_0_0_1:.*]], %[[EL_0_1_1:.*]], %[[EL_0_2_1:.*]], %[[EL_1_0_0:.*]], %[[EL_1_1_0:.*]], %[[EL_1_2_0:.*]], %[[EL_1_0_1:.*]], %[[EL_1_1_1:.*]], %[[EL_1_2_1:.*]] : vector<2x2x3xi32>
-  // CHECK-NOT: vector.transpose
-}
-
-// -----
-
 func.func @extract_strided_slice_of_constant_mask() -> (vector<2x2xi1>) {
   %0 = vector.constant_mask [2, 2] : vector<4x3xi1>
   %1 = vector.extract_strided_slice %0
@@ -3527,6 +3485,62 @@ func.func @from_elements_index_to_i64_conversion() -> vector<3xi64> {
 
 // -----
 
+// +---------------------------------------------------------------------------
+// Tests for FoldTransposeFromElements
+// +---------------------------------------------------------------------------
+
+// CHECK-LABEL: transpose_from_elements_1d
+// CHECK-SAME:  %[[EL_0:.*]]: i32, %[[EL_1:.*]]: i32 
+func.func @transpose_from_elements_1d(%el_0: i32, %el_1: i32) -> vector<2xi32> {
+  %v = vector.from_elements %el_0, %el_1 : vector<2xi32>
+  %t = vector.transpose %v, [0] : vector<2xi32> to vector<2xi32>
+  return %t : vector<2xi32>
+  // CHECK: %[[R:.*]] = vector.from_elements %[[EL_0]], %[[EL_1]] : vector<2xi32>
+  // CHECK-NOT: vector.transpose
+  // CHECK: return %[[R]]
+}
+
+// CHECK-LABEL: transpose_from_elements_2d
+// CHECK-SAME:  %[[EL_0_0:.*]]: i32, %[[EL_0_1:.*]]: i32, %[[EL_0_2:.*]]: i32, %[[EL_1_0:.*]]: i32, %[[EL_1_1:.*]]: i32, %[[EL_1_2:.*]]: i32 
+func.func @transpose_from_elements_2d(
+  %el_0_0: i32, %el_0_1: i32, %el_0_2: i32,
+  %el_1_0: i32, %el_1_1: i32, %el_1_2: i32
+) -> vector<3x2xi32> {
+  %v = vector.from_elements %el_0_0, %el_0_1, %el_0_2, %el_1_0, %el_1_1, %el_1_2 : vector<2x3xi32>
+  %t = vector.transpose %v, [1, 0] : vector<2x3xi32> to vector<3x2xi32>
+  return %t : vector<3x2xi32>
+  // CHECK: %[[R:.*]] = vector.from_elements %[[EL_0_0:.*]], %[[EL_1_0:.*]], %[[EL_0_1:.*]], %[[EL_1_1:.*]], %[[EL_0_2:.*]], %[[EL_1_2:.*]] : vector<3x2xi32>
+  // CHECK-NOT: vector.transpose
+  // CHECK: return %[[R]]
+}
+
+// CHECK-LABEL: transpose_from_elements_3d
+// CHECK-SAME:  %[[EL_0_0_0:.*]]: i32, %[[EL_0_0_1:.*]]: i32, %[[EL_0_1_0:.*]]: i32, %[[EL_0_1_1:.*]]: i32, %[[EL_0_2_0:.*]]: i32, %[[EL_0_2_1:.*]]: i32, %[[EL_1_0_0:.*]]: i32, %[[EL_1_0_1:.*]]: i32, %[[EL_1_1_0:.*]]: i32, %[[EL_1_1_1:.*]]: i32, %[[EL_1_2_0:.*]]: i32, %[[EL_1_2_1:.*]]: i32 
+func.func @transpose_from_elements_3d(
+  %el_0_0_0: i32, %el_0_0_1: i32, %el_0_1_0: i32, %el_0_1_1: i32, %el_0_2_0: i32, %el_0_2_1: i32,
+  %el_1_0_0: i32, %el_1_0_1: i32, %el_1_1_0: i32, %el_1_1_1: i32, %el_1_2_0: i32, %el_1_2_1: i32
+) -> vector<2x2x3xi32> {
+  %v = vector.from_elements
+    %el_0_0_0, %el_0_0_1,
+    %el_0_1_0, %el_0_1_1,
+    %el_0_2_0, %el_0_2_1,
+    %el_1_0_0, %el_1_0_1,
+    %el_1_1_0, %el_1_1_1,
+    %el_1_2_0, %el_1_2_1
+    : vector<2x3x2xi32>
+  %t = vector.transpose %v, [0, 2, 1] : vector<2x3x2xi32> to vector<2x2x3xi32>
+  return %t : vector<2x2x3xi32>
+  // CHECK: %[[R:.*]] = vector.from_elements %[[EL_0_0_0:.*]], %[[EL_0_1_0:.*]], %[[EL_0_2_0:.*]], %[[EL_0_0_1:.*]], %[[EL_0_1_1:.*]], %[[EL_0_2_1:.*]], %[[EL_1_0_0:.*]], %[[EL_1_1_0:.*]], %[[EL_1_2_0:.*]], %[[EL_1_0_1:.*]], %[[EL_1_1_1:.*]], %[[EL_1_2_1:.*]] : vector<2x2x3xi32>
+  // CHECK-NOT: vector.transpose
+  // CHECK: return %[[R]]
+}
+
+// +---------------------------------------------------------------------------
+// End of  Tests for FoldTransposeFromElements
+// +---------------------------------------------------------------------------
+
+// -----
+
 // Not a DenseElementsAttr, don't fold.
 
 // CHECK-LABEL: func @negative_insert_llvm_undef(

>From 08a08023b3194252ede656b7539b7037fee4b973 Mon Sep 17 00:00:00 2001
From: Keshav Vinayak Jha <keshavvinayakjha at gmail.com>
Date: Tue, 21 Oct 2025 01:18:49 -0700
Subject: [PATCH 6/6] Added example for folder

Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha at gmail.com>
---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 12 ++++++++++++
 1 file changed, 12 insertions(+)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 7c588a435aa1a..535192b4e10ad 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -6725,6 +6725,18 @@ class FoldTransposeShapeCast final : public OpRewritePattern<TransposeOp> {
 
 /// Folds transpose(from_elements(...)) into a new from_elements with permuted
 /// operands matching the transposed shape.
+///
+/// Example:
+///
+///   %v = vector.from_elements %a00, %a01, %a02, %a10, %a11, %a12 :
+///   vector<2x3xi32> %t = vector.transpose %v, [1, 0] : vector<2x3xi32> to
+///   vector<3x2xi32>
+///
+/// becomes ->
+///
+///   %r = vector.from_elements %a00, %a10, %a01, %a11, %a02, %a12 :
+///   vector<3x2xi32>
+///
 class FoldTransposeFromElements final : public OpRewritePattern<TransposeOp> {
 public:
   using Base::Base;



More information about the Mlir-commits mailing list