[Mlir-commits] [mlir] [mlir][vector] Adds ToElementsToTargetShape pattern. (PR #166476)

Erick Ochoa Lopez llvmlistbot at llvm.org
Thu Nov 6 12:32:31 PST 2025


https://github.com/amd-eochoalo updated https://github.com/llvm/llvm-project/pull/166476

>From fbbf0e4113818f7ace97e4804679d579f8144a27 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Tue, 4 Nov 2025 16:22:58 -0500
Subject: [PATCH 01/11] [mlir][vector] Use getShapeForUnroll's default
 implementation.

---
 mlir/include/mlir/Dialect/Vector/IR/VectorOps.td | 2 +-
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp         | 4 ----
 2 files changed, 1 insertion(+), 5 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 43172ff2082df..ccea764cfc579 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -361,7 +361,7 @@ def Vector_MultiDimReductionOp :
 
 def Vector_BroadcastOp :
   Vector_Op<"broadcast", [Pure,
-     DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
+     DeclareOpInterfaceMethods<VectorUnrollOpInterface>,
      DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
      PredOpTrait<"source operand and result have same element type",
                  TCresVTEtIsSameAsOpBase<0, 0>>]>,
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index daef0ba02100a..3e125e5c1f37b 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2782,10 +2782,6 @@ void BroadcastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
   setResultRanges(getResult(), argRanges.front());
 }
 
-std::optional<SmallVector<int64_t, 4>> BroadcastOp::getShapeForUnroll() {
-  return llvm::to_vector<4>(getResultVectorType().getShape());
-}
-
 /// Return the dimensions of the result vector that were formerly ones in the
 /// source tensor and thus correspond to "dim-1" broadcasting.
 static llvm::SetVector<int64_t>

>From 1964d161457e71208189065fc3cf82f2341e26e7 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Tue, 4 Nov 2025 16:33:14 -0500
Subject: [PATCH 02/11] [mlir][vector] Use getShapeForUnroll's default
 implementation.

---
 mlir/include/mlir/Dialect/Vector/IR/VectorOps.td | 2 +-
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp         | 4 ----
 2 files changed, 1 insertion(+), 5 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index ccea764cfc579..1d3f70a9813f7 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2758,7 +2758,7 @@ def Vector_MaskOp : Vector_Op<"mask", [
 def Vector_TransposeOp :
   Vector_Op<"transpose", [Pure,
     DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
-    DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
+    DeclareOpInterfaceMethods<VectorUnrollOpInterface>,
     PredOpTrait<"operand and result have same element type",
                  TCresVTEtIsSameAsOpBase<0, 0>>]> {
   let summary = "vector transpose operation";
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 3e125e5c1f37b..2d5580ec0ff81 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -6716,10 +6716,6 @@ LogicalResult vector::TransposeOp::verify() {
   return success();
 }
 
-std::optional<SmallVector<int64_t, 4>> TransposeOp::getShapeForUnroll() {
-  return llvm::to_vector<4>(getResultVectorType().getShape());
-}
-
 void TransposeOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                                     SetIntRangeFn setResultRanges) {
   setResultRanges(getResult(), argRanges.front());

>From a0c6e4f90d38ab2609ebfce99fc1b28c623aeb11 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Tue, 4 Nov 2025 16:39:13 -0500
Subject: [PATCH 03/11] [mlir][vector] Use getShapeForUnroll's default
 implementation.

---
 mlir/include/mlir/Dialect/Vector/IR/VectorOps.td | 2 +-
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp         | 4 ----
 2 files changed, 1 insertion(+), 5 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 1d3f70a9813f7..fd6196a156d0f 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2057,7 +2057,7 @@ def Vector_GatherOp :
   Vector_Op<"gather", [
     DeclareOpInterfaceMethods<MaskableOpInterface>,
     DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
-    DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
+    DeclareOpInterfaceMethods<VectorUnrollOpInterface>
     DeclareOpInterfaceMethods<AlignmentAttrOpInterface>
   ]>,
     Arguments<(ins Arg<TensorOrMemRef<[AnyType]>, "", [MemRead]>:$base,
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 2d5580ec0ff81..cac8defb4d078 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5982,10 +5982,6 @@ Type GatherOp::getExpectedMaskType() {
                          vecType.getScalableDims());
 }
 
-std::optional<SmallVector<int64_t, 4>> GatherOp::getShapeForUnroll() {
-  return llvm::to_vector<4>(getVectorType().getShape());
-}
-
 /// Cheeck if `indexVec` is constant 1D vec of consecutive values [0, 1, 2, ...]
 static LogicalResult isZeroBasedContiguousSeq(Value indexVec) {
   auto vecType = dyn_cast<VectorType>(indexVec.getType());

>From a6cbe0b42db5de0609455d3b1b575c006f6d3e4d Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Tue, 4 Nov 2025 16:43:37 -0500
Subject: [PATCH 04/11] [mlir][vector] Use getShapeForUnroll's default
 implementation.

---
 mlir/include/mlir/Dialect/Vector/IR/VectorOps.td | 2 +-
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp         | 8 --------
 2 files changed, 1 insertion(+), 9 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index fd6196a156d0f..fa613a86ad793 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -732,7 +732,7 @@ def Vector_ExtractOp :
 def Vector_FMAOp :
   Op<Vector_Dialect, "fma", [
        Pure, AllTypesMatch<["lhs", "rhs", "acc", "result"]>,
-       DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>
+       DeclareOpInterfaceMethods<VectorUnrollOpInterface>
      ] # ElementwiseMappable.traits>,
     Arguments<(ins VectorOfAnyRankOf<[AnyFloat]>:$lhs,
                    VectorOfAnyRankOf<[AnyFloat]>:$rhs,
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index cac8defb4d078..b56e98dd6b595 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2374,14 +2374,6 @@ static void populateFromInt64AttrArray(ArrayAttr arrayAttr,
     results.push_back(llvm::cast<IntegerAttr>(attr).getInt());
 }
 
-//===----------------------------------------------------------------------===//
-// FmaOp
-//===----------------------------------------------------------------------===//
-
-std::optional<SmallVector<int64_t, 4>> FMAOp::getShapeForUnroll() {
-  return llvm::to_vector<4>(getVectorType().getShape());
-}
-
 //===----------------------------------------------------------------------===//
 // ToElementsOp
 //===----------------------------------------------------------------------===//

>From cd648dac74e3d607e4bf13c3e8bc7c65b0d5c698 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Tue, 4 Nov 2025 16:47:12 -0500
Subject: [PATCH 05/11] [mlir][vector] Use getShapeForUnroll's default
 implementation.

---
 mlir/include/mlir/Dialect/Vector/IR/VectorOps.td | 2 +-
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp         | 4 ----
 2 files changed, 1 insertion(+), 5 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index fa613a86ad793..a85ea2e128e1f 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -1245,7 +1245,7 @@ def Vector_ExtractStridedSliceOp :
 def Vector_TransferReadOp :
   Vector_Op<"transfer_read", [
       DeclareOpInterfaceMethods<VectorTransferOpInterface>,
-      DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
+      DeclareOpInterfaceMethods<VectorUnrollOpInterface>,
       DeclareOpInterfaceMethods<MaskableOpInterface>,
       DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
       DeclareOpInterfaceMethods<ConditionallySpeculatable>,
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index b56e98dd6b595..f126f8dd6c4dd 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5088,10 +5088,6 @@ OpFoldResult TransferReadOp::fold(FoldAdaptor) {
   return OpFoldResult();
 }
 
-std::optional<SmallVector<int64_t, 4>> TransferReadOp::getShapeForUnroll() {
-  return llvm::to_vector<4>(getVectorType().getShape());
-}
-
 void TransferReadOp::getEffects(
     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
         &effects) {

>From 5103187a4f7b4676bc2125297a632b1d8419f9be Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Tue, 4 Nov 2025 17:12:57 -0500
Subject: [PATCH 06/11] [mlir][vector] Use getShapeForUnroll's default
 implementation.

---
 mlir/include/mlir/Dialect/Vector/IR/VectorOps.td | 2 +-
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp         | 4 ----
 2 files changed, 1 insertion(+), 5 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index a85ea2e128e1f..acfa578a184b8 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -1653,7 +1653,7 @@ def Vector_TransferWriteOp :
 }
 
 def Vector_LoadOp : Vector_Op<"load", [
-    DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
+    DeclareOpInterfaceMethods<VectorUnrollOpInterface>,
     DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
     DeclareOpInterfaceMethods<AlignmentAttrOpInterface>
   ]> {
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index f126f8dd6c4dd..b030b060c6ba0 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5762,10 +5762,6 @@ OpFoldResult LoadOp::fold(FoldAdaptor) {
   return OpFoldResult();
 }
 
-std::optional<SmallVector<int64_t, 4>> LoadOp::getShapeForUnroll() {
-  return llvm::to_vector<4>(getVectorType().getShape());
-}
-
 FailureOr<std::optional<SmallVector<Value>>>
 LoadOp::bubbleDownCasts(OpBuilder &builder) {
   return mlir::detail::bubbleDownInPlaceMemorySpaceCastImpl(getBaseMutable(),

>From 71e53e7f294286f280b012367515f53a81b2cdb9 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Tue, 4 Nov 2025 17:21:35 -0500
Subject: [PATCH 07/11] Fix documentation

---
 mlir/include/mlir/Interfaces/VectorInterfaces.td | 5 ++---
 1 file changed, 2 insertions(+), 3 deletions(-)

diff --git a/mlir/include/mlir/Interfaces/VectorInterfaces.td b/mlir/include/mlir/Interfaces/VectorInterfaces.td
index 6838c16fdf0fe..1223f5c0704ab 100644
--- a/mlir/include/mlir/Interfaces/VectorInterfaces.td
+++ b/mlir/include/mlir/Interfaces/VectorInterfaces.td
@@ -24,9 +24,8 @@ def VectorUnrollOpInterface : OpInterface<"VectorUnrollOpInterface"> {
   let methods = [
     InterfaceMethod<
       /*desc=*/[{
-        Return the shape ratio of unrolling to the target vector shape
-        `targetShape`. Return `std::nullopt` if the op cannot be unrolled to the
-        target vector shape.
+        Return the shape of the vector of this operation, which may be used to decide unrolling factors.
+        Return std::nullopt if the op is not applicable for unrolling.
       }],
       /*retTy=*/"::std::optional<::llvm::SmallVector<int64_t, 4>>",
       /*methodName=*/"getShapeForUnroll",

>From 200773d78f4e57baf5d02b9531d97a289012399a Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Thu, 6 Nov 2025 15:23:07 -0500
Subject: [PATCH 08/11] Fix rebase

---
 mlir/include/mlir/Dialect/Vector/IR/VectorOps.td | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index acfa578a184b8..a1c5298629e58 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2057,7 +2057,7 @@ def Vector_GatherOp :
   Vector_Op<"gather", [
     DeclareOpInterfaceMethods<MaskableOpInterface>,
     DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
-    DeclareOpInterfaceMethods<VectorUnrollOpInterface>
+    DeclareOpInterfaceMethods<VectorUnrollOpInterface>,
     DeclareOpInterfaceMethods<AlignmentAttrOpInterface>
   ]>,
     Arguments<(ins Arg<TensorOrMemRef<[AnyType]>, "", [MemRead]>:$base,

>From aa4906a085fe94bc31d88ff9d0ac12131434ccae Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Tue, 4 Nov 2025 17:30:43 -0500
Subject: [PATCH 09/11] [mlir][vector] to_elements implements
 VectorUnrollOpInterface

---
 .../SPIRV/Transforms/SPIRVConversion.h        |  3 +
 .../mlir/Dialect/Vector/IR/VectorOps.td       |  8 ++
 .../SPIRV/Transforms/SPIRVConversion.cpp      | 11 ++-
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp      |  3 +
 .../Vector/Transforms/VectorUnroll.cpp        | 96 ++++++++++++++++++-
 5 files changed, 117 insertions(+), 4 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
index 03ae54a8ae30a..f202c0ea88bd0 100644
--- a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
+++ b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
@@ -198,6 +198,9 @@ Value getVulkanElementPtr(const SPIRVTypeConverter &typeConverter,
 // the target shape.
 int getComputeVectorSize(int64_t size);
 
+// GetNativeVectorShape implementation for to_elements ops.
+SmallVector<int64_t> getNativeVectorShapeImpl(vector::ToElementsOp op);
+
 // GetNativeVectorShape implementation for reduction ops.
 SmallVector<int64_t> getNativeVectorShapeImpl(vector::ReductionOp op);
 
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index a1c5298629e58..51e9a9b986315 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -762,6 +762,7 @@ def Vector_FMAOp :
 
 def Vector_ToElementsOp : Vector_Op<"to_elements", [
     InferTypeOpAdaptor, Pure,
+    DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
     ShapedTypeMatchesElementCountAndTypes<"source", "elements">]> {
   let summary = "operation that decomposes a vector into all its scalar elements";
   let description = [{
@@ -808,6 +809,13 @@ def Vector_ToElementsOp : Vector_Op<"to_elements", [
   let assemblyFormat = "$source attr-dict `:` type($source)";
   let hasFolder = 1;
   let hasCanonicalizer = 1;
+  let extraClassDeclaration = [{
+
+    VectorType getSourceVectorType() {
+      return ::llvm::cast<VectorType>(getSource().getType());
+    }
+
+  }];
 }
 
 def Vector_FromElementsOp : Vector_Op<"from_elements", [
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index cb9b7f6ec2fd2..22097f5f2cdc6 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -1435,6 +1435,15 @@ int mlir::spirv::getComputeVectorSize(int64_t size) {
   return 1;
 }
 
+SmallVector<int64_t>
+mlir::spirv::getNativeVectorShapeImpl(vector::ToElementsOp op) {
+  VectorType srcVectorType = op.getSourceVectorType();
+  assert(srcVectorType.getRank() == 1); // Guaranteed by semantics
+  int64_t vectorSize =
+      mlir::spirv::getComputeVectorSize(srcVectorType.getDimSize(0));
+  return {vectorSize};
+}
+
 SmallVector<int64_t>
 mlir::spirv::getNativeVectorShapeImpl(vector::ReductionOp op) {
   VectorType srcVectorType = op.getSourceVectorType();
@@ -1465,7 +1474,7 @@ mlir::spirv::getNativeVectorShape(Operation *op) {
   }
 
   return TypeSwitch<Operation *, std::optional<SmallVector<int64_t>>>(op)
-      .Case<vector::ReductionOp, vector::TransposeOp>(
+      .Case<vector::ReductionOp, vector::TransposeOp, vector::ToElementsOp>(
           [](auto typedOp) { return getNativeVectorShapeImpl(typedOp); })
       .Default(std::nullopt);
 }
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index b030b060c6ba0..4fe3b99f7fd6a 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2377,6 +2377,9 @@ static void populateFromInt64AttrArray(ArrayAttr arrayAttr,
 //===----------------------------------------------------------------------===//
 // ToElementsOp
 //===----------------------------------------------------------------------===//
+std::optional<SmallVector<int64_t, 4>> ToElementsOp::getShapeForUnroll() {
+  return llvm::to_vector<4>(getSourceVectorType().getShape());
+}
 
 /// Returns true if all the `operands` are defined by `defOp`.
 /// Otherwise, returns false.
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index fbae0989bed26..c49718e0902a5 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -834,11 +834,100 @@ struct UnrollBroadcastPattern : public OpRewritePattern<vector::BroadcastOp> {
   vector::UnrollVectorOptions options;
 };
 
+/// Takes a 1 dimensional `vector.to_element` op and attempts to change it to
+/// the target shape.
+///
+/// ```
+/// // In SPIR-V's default environment vector of size 8
+/// // are not allowed.
+/// %elements:8 = vector.to_elements %v : vector<8xf32>
+///
+/// ===>
+///
+/// %v_0_to_3 = vector.extract %v[0] : vector<4xf32> from vector<8xf32>
+/// %v_4_to_7 = vector.extract %v[4] : vector<4xf32> from vector<8xf32>
+/// %elements_0:4 = vector.to_elements %v_0_to_3 : vector<4xf32>
+/// %elements_1:4 = vector.to_elements %v_4_to_7 : vector<4xf32>
+/// ```
+///
+/// This pattern may fail if the rank is not divisible by to a native shape
+/// or if the rank is already in the target shape and therefore it may be
+/// skipped.
+struct ToElementsToTargetShape final
+    : public OpRewritePattern<vector::ToElementsOp> {
+  ToElementsToTargetShape(MLIRContext *context,
+                          const vector::UnrollVectorOptions &options,
+                          PatternBenefit benefit = 1)
+      : OpRewritePattern<vector::ToElementsOp>(context, benefit),
+        options(options) {}
+
+  LogicalResult matchAndRewrite(vector::ToElementsOp op,
+                                PatternRewriter &rewriter) const override {
+    auto targetShape = getTargetShape(options, op);
+    if (!targetShape)
+      return failure();
+
+    // We have
+    // source_rank = N * target_rank
+    int64_t source_rank = op.getSourceVectorType().getShape().front();
+    int64_t target_rank = targetShape->front();
+    int64_t N = source_rank / target_rank;
+
+    // Transformation where
+    // s = source_rank and
+    // t = target_rank
+    // ```
+    // %e:s = vector.to_elements %v : vector<sxf32>
+    //
+    // ===>
+    //
+    // // N vector.extract_strided_slice of size t
+    // %v0 = vector.extract_strided_slice %v
+    //   {offsets = [0*t], sizes = [t], strides = [1]}
+    //   : vector<txf32> from vector<sxf32>
+    // %v1 = vector.extract_strided_slice %v
+    //   {offsets = [1*t], sizes = [t], strides = [1]}
+    //   : vector<txf32> from vector<sxf32>
+    // ...
+    // %vNminus1 = vector.extract_strided_slice $v
+    //   {offsets = [(N-1)*t], sizes = [t], strides = [1]}
+    //   : vector<txf32> from vector<sxf32>
+    //
+    // // N vector.to_elements of size t vectors.
+    // %e0:t = vector.to_elements %v0 : vector<txf32>
+    // %e1:t = vector.to_elements %v1 : vector<txf32>
+    // ...
+    // %eNminus1:t = vector.to_elements %vNminus1 : vector<txf32>
+    // ```
+    SmallVector<Value> subVectors;
+    SmallVector<int64_t> strides(targetShape->size(), 1);
+    for (int64_t i = 0; i < N; i++) {
+      SmallVector<int64_t> elementOffsets = {i * target_rank};
+      Value subVector = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
+          op.getLoc(), op.getSource(), elementOffsets, *targetShape, strides);
+      subVectors.push_back(subVector);
+    }
+
+    SmallVector<Value> elements;
+    for (const Value subVector : subVectors) {
+      auto elementsOp =
+          vector::ToElementsOp::create(rewriter, op.getLoc(), subVector);
+      llvm::append_range(elements, elementsOp.getResults());
+    }
+
+    rewriter.replaceOp(op, elements);
+    return success();
+  }
+
+private:
+  vector::UnrollVectorOptions options;
+};
+
 /// Unrolls 2 or more dimensional `vector.to_elements` ops by unrolling the
 /// outermost dimension of the operand. For example:
 ///
 /// ```
-/// %0:4 = vector.to_elements %v : vector<2x2xf32>
+/// %0:8 = vector.to_elements %v : vector<2x2x2xf32>
 ///
 /// ==>
 ///
@@ -865,6 +954,7 @@ struct UnrollToElements final : public OpRewritePattern<vector::ToElementsOp> {
     FailureOr<SmallVector<Value>> result =
         vector::unrollVectorValue(source, rewriter);
     if (failed(result)) {
+      // Only fails if operand is 1-dimensional.
       return failure();
     }
     SmallVector<Value> vectors = *result;
@@ -1013,8 +1103,8 @@ void mlir::vector::populateVectorUnrollPatterns(
                UnrollReductionPattern, UnrollMultiReductionPattern,
                UnrollTransposePattern, UnrollGatherPattern, UnrollLoadPattern,
                UnrollStorePattern, UnrollBroadcastPattern, UnrollFromElements,
-               UnrollToElements, UnrollStepPattern>(patterns.getContext(),
-                                                    options, benefit);
+               UnrollToElements, UnrollStepPattern, ToElementsToTargetShape>(
+      patterns.getContext(), options, benefit);
 }
 
 void mlir::vector::populateVectorToElementsUnrollPatterns(

>From 228d0b142b14f95ef2dae0030fd39f48f14584b4 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Wed, 5 Nov 2025 15:14:09 -0500
Subject: [PATCH 10/11] [mlir] Test vector.to_elements to spirv conversion.

---
 .../ConvertToSPIRV/vector-sizes.mlir          | 67 +++++++++++++++++++
 1 file changed, 67 insertions(+)
 create mode 100644 mlir/test/Conversion/ConvertToSPIRV/vector-sizes.mlir

diff --git a/mlir/test/Conversion/ConvertToSPIRV/vector-sizes.mlir b/mlir/test/Conversion/ConvertToSPIRV/vector-sizes.mlir
new file mode 100644
index 0000000000000..402c539a77093
--- /dev/null
+++ b/mlir/test/Conversion/ConvertToSPIRV/vector-sizes.mlir
@@ -0,0 +1,67 @@
+// RUN: mlir-opt -test-convert-to-spirv="run-signature-conversion=false run-vector-unrolling=true" -split-input-file %s | FileCheck %s
+
+// COM: This file tests the current behaviour of the SignatureConversion
+// COM: and the unrolling of vector.to_elements to vectors of valid SPIR-V
+// COM: sizes.
+
+// COM: vector's of rank 1 and size 1 will be changed
+// COM: to scalars. Since vector.to_elements will also produce
+// COM: a scalar, we expect the vector.to_elements to be folded
+// COM: away. Please note that even if run-signature-conversion=false
+// COM: The pattern FuncOpConversion will still run and change parameters
+// COM: which fit this constraint.
+
+// CHECK-LABEL: spirv.func @vec_size_1
+// CHECK-SAME: (%[[ARG0:.+]]: f32)
+func.func @vec_size_1(%arg0: vector<1xf32>) -> (f32) {
+  // CHECK-NEXT: spirv.ReturnValue %[[ARG0]] : f32
+  %0:1 = vector.to_elements %arg0 : vector<1xf32>
+  return %0#0 : f32
+}
+
+// -----
+
+// COM: vector's of rank 2, 3, 4 are allowed by SPIR-V.
+// So they remain unchanged. FuncOpConversion will still
+// run, but the signature converter will not convert these vectors.
+
+// CHECK-LABEL: spirv.func @vec_size_2
+// CHECK-SAME: (%[[ARG0:.+]]: vector<2xf32>)
+func.func @vec_size_2(%arg0: vector<2xf32>) -> (f32) {
+  // COM: A single result type is enforced by the semantics
+
+  // CHECK-NEXT: %[[VAL:.+]] = spirv.CompositeExtract %[[ARG0]][0 : i32] : vector<2xf32>
+  %0:2 = vector.to_elements %arg0 : vector<2xf32>
+
+  // CHECK-NEXT: spirv.ReturnValue %[[VAL]]
+  return %0#0 : f32
+}
+
+// -----
+
+// COM: vector of rank 5 is the first one that doesn't fit
+// COM: into SPIR-V's vectors.
+
+// COM: run-signature-conversion=false means that
+// COM: this vector will not be unrolled.
+
+// CHECK-LABEL: func.func @vec_size_5
+// CHECK-SAME: (%[[ARG0:.+]]: vector<5xf32>)
+func.func @vec_size_5(%arg0: vector<5xf32>) -> (f32) {
+
+  // CHECK-NEXT: %[[VAL:.+]] = vector.extract_strided_slice %[[ARG0]] {offsets = [0], sizes = [1], strides = [1]} : vector<5xf32> to vector<1xf32>
+
+  // COM: We have the following comment in VectorConvertToElementOp
+  // COM:
+  // COM:     // Input vectors of size 1 are converted to scalars by the type converter.
+  // COM:     // We cannot use `spirv::CompositeExtractOp` directly in this case.
+  // COM:     // For a scalar source, the result is just the scalar itself.
+  // COM:
+  // COM: Which in this case means an unrealized conversion cast.
+
+  // CHECK-NEXT: %[[RETVAL:.+]] = builtin.unrealized_conversion_cast %[[VAL]] : vector<1xf32> to f32
+  %0:5 = vector.to_elements %arg0 : vector<5xf32>
+
+  // CHECK-NEXT: spirv.ReturnValue %[[RETVAL]] : f32
+  return %0#0 : f32
+}

>From 8fe386a4edfe8148e6bb57e7cbd84f2a82e02b78 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Wed, 5 Nov 2025 17:03:39 -0500
Subject: [PATCH 11/11] [mlir] Update unrollToElements tests

---
 .../Vector/Transforms/VectorUnroll.cpp        |  5 ++--
 .../ConvertToSPIRV/vector-unroll.mlir         | 16 +++++++++++++
 .../Vector/vector-to-elements-lowering.mlir   | 23 +++++++++++++++++++
 3 files changed, 42 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index c49718e0902a5..fd5a8f7c89d7d 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -1109,8 +1109,9 @@ void mlir::vector::populateVectorUnrollPatterns(
 
 void mlir::vector::populateVectorToElementsUnrollPatterns(
     RewritePatternSet &patterns, PatternBenefit benefit) {
-  patterns.add<UnrollToElements>(patterns.getContext(), UnrollVectorOptions(),
-                                 benefit);
+  auto options = UnrollVectorOptions().setNativeShape(SmallVector<int64_t>{4});
+  patterns.add<UnrollToElements, ToElementsToTargetShape>(patterns.getContext(),
+                                                          options, benefit);
 }
 
 void mlir::vector::populateVectorFromElementsUnrollPatterns(
diff --git a/mlir/test/Conversion/ConvertToSPIRV/vector-unroll.mlir b/mlir/test/Conversion/ConvertToSPIRV/vector-unroll.mlir
index 0957f67690b97..dcc55a7868978 100644
--- a/mlir/test/Conversion/ConvertToSPIRV/vector-unroll.mlir
+++ b/mlir/test/Conversion/ConvertToSPIRV/vector-unroll.mlir
@@ -120,6 +120,22 @@ func.func @unroll_to_elements_2d() -> (f32, f32, f32, f32) {
 
 // -----
 
+// CHECK-LABEL: @unroll_to_elements_8xf32
+func.func @unroll_to_elements_8xf32() -> (f32, f32) {
+
+  // CHECK: %[[VEC:.+]] = "test.op"
+  // CHECK: %[[V0:.+]] = vector.extract_strided_slice %[[VEC]] {offsets = [0]
+  // CHECK: %[[V1:.+]] = vector.extract_strided_slice %[[VEC]] {offsets = [4]
+  // CHECK: %[[ELEMS0:.+]]:4 = vector.to_elements %[[V0]]
+  // CHECK: %[[ELEMS1:.+]]:4 = vector.to_elements %[[V1]]
+  // CHECK: return %[[ELEMS0]]#3, %[[ELEMS1]]#0
+  %0 = "test.op"() : () -> (vector<8xf32>)
+  %1:8 = vector.to_elements %0 : vector<8xf32>
+  return %1#3, %1#4 : f32, f32
+}
+
+// -----
+
 // In order to verify that the pattern is applied,
 // we need to make sure that the the 2d vector is used
 // by an operation and that extracts are not folded away.
diff --git a/mlir/test/Dialect/Vector/vector-to-elements-lowering.mlir b/mlir/test/Dialect/Vector/vector-to-elements-lowering.mlir
index c521bf0138f98..d448377143249 100644
--- a/mlir/test/Dialect/Vector/vector-to-elements-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-to-elements-lowering.mlir
@@ -29,3 +29,26 @@ func.func @unroll_to_elements_2d(%arg0: vector<2x2xf32>) -> (f32, f32, f32, f32)
   %0:4 = vector.to_elements %arg0 : vector<2x2xf32>
   return %0#0, %0#1, %0#2, %0#3 : f32, f32, f32, f32
 }
+
+// -----
+
+// COM: Here we are testing the pattern ToElementsToTargetShape
+// COM: The pattern has a native shape of [4], which means
+// COM: that vectors multiples of 4 will be split. In this
+// COM: case, that will happen in the function's body, not the argument.
+
+// CHECK-LABEL: func.func @unroll_vector_8xf32
+// CHECK-SAME: (%[[ARG0:.+]]: vector<8xf32>)
+func.func @unroll_vector_8xf32(%arg0: vector<8xf32>) -> (f32, f32) {
+  %0:8 = vector.to_elements %arg0 : vector<8xf32>
+
+  // COM: We only return two elements, one from each of the
+  // COM: vectors.
+  return %0#3, %0#4: f32, f32
+
+  // CHECK: %[[V0:.+]] = vector.extract_strided_slice %[[ARG0]] {offsets = [0], sizes = [4], strides = [1]} : vector<8xf32> to vector<4xf32>
+  // CHECK-NEXT: %[[V1:.+]] = vector.extract_strided_slice %[[ARG0]] {offsets = [4], sizes = [4], strides = [1]} : vector<8xf32> to vector<4xf32>
+  // CHECK-NEXT: %[[ELEMS_0:.+]]:4 = vector.to_elements %[[V0]]
+  // CHECK-NEXT: %[[ELEMS_1:.+]]:4 = vector.to_elements %[[V1]]
+  // CHECK-NEXT: return %[[ELEMS_0]]#3, %[[ELEMS_1]]#0
+}



More information about the Mlir-commits mailing list