[Mlir-commits] [mlir] [mlir][vector] Adds ToElementsToTargetShape pattern. (PR #166476)
Erick Ochoa Lopez
llvmlistbot at llvm.org
Thu Nov 6 12:32:31 PST 2025
https://github.com/amd-eochoalo updated https://github.com/llvm/llvm-project/pull/166476
>From fbbf0e4113818f7ace97e4804679d579f8144a27 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Tue, 4 Nov 2025 16:22:58 -0500
Subject: [PATCH 01/11] [mlir][vector] Use getShapeForUnroll's default
implementation.
---
mlir/include/mlir/Dialect/Vector/IR/VectorOps.td | 2 +-
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 4 ----
2 files changed, 1 insertion(+), 5 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 43172ff2082df..ccea764cfc579 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -361,7 +361,7 @@ def Vector_MultiDimReductionOp :
def Vector_BroadcastOp :
Vector_Op<"broadcast", [Pure,
- DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
+ DeclareOpInterfaceMethods<VectorUnrollOpInterface>,
DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
PredOpTrait<"source operand and result have same element type",
TCresVTEtIsSameAsOpBase<0, 0>>]>,
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index daef0ba02100a..3e125e5c1f37b 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2782,10 +2782,6 @@ void BroadcastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
setResultRanges(getResult(), argRanges.front());
}
-std::optional<SmallVector<int64_t, 4>> BroadcastOp::getShapeForUnroll() {
- return llvm::to_vector<4>(getResultVectorType().getShape());
-}
-
/// Return the dimensions of the result vector that were formerly ones in the
/// source tensor and thus correspond to "dim-1" broadcasting.
static llvm::SetVector<int64_t>
>From 1964d161457e71208189065fc3cf82f2341e26e7 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Tue, 4 Nov 2025 16:33:14 -0500
Subject: [PATCH 02/11] [mlir][vector] Use getShapeForUnroll's default
implementation.
---
mlir/include/mlir/Dialect/Vector/IR/VectorOps.td | 2 +-
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 4 ----
2 files changed, 1 insertion(+), 5 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index ccea764cfc579..1d3f70a9813f7 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2758,7 +2758,7 @@ def Vector_MaskOp : Vector_Op<"mask", [
def Vector_TransposeOp :
Vector_Op<"transpose", [Pure,
DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
- DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
+ DeclareOpInterfaceMethods<VectorUnrollOpInterface>,
PredOpTrait<"operand and result have same element type",
TCresVTEtIsSameAsOpBase<0, 0>>]> {
let summary = "vector transpose operation";
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 3e125e5c1f37b..2d5580ec0ff81 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -6716,10 +6716,6 @@ LogicalResult vector::TransposeOp::verify() {
return success();
}
-std::optional<SmallVector<int64_t, 4>> TransposeOp::getShapeForUnroll() {
- return llvm::to_vector<4>(getResultVectorType().getShape());
-}
-
void TransposeOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRanges) {
setResultRanges(getResult(), argRanges.front());
>From a0c6e4f90d38ab2609ebfce99fc1b28c623aeb11 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Tue, 4 Nov 2025 16:39:13 -0500
Subject: [PATCH 03/11] [mlir][vector] Use getShapeForUnroll's default
implementation.
---
mlir/include/mlir/Dialect/Vector/IR/VectorOps.td | 2 +-
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 4 ----
2 files changed, 1 insertion(+), 5 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 1d3f70a9813f7..fd6196a156d0f 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2057,7 +2057,7 @@ def Vector_GatherOp :
Vector_Op<"gather", [
DeclareOpInterfaceMethods<MaskableOpInterface>,
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
- DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
+ DeclareOpInterfaceMethods<VectorUnrollOpInterface>
DeclareOpInterfaceMethods<AlignmentAttrOpInterface>
]>,
Arguments<(ins Arg<TensorOrMemRef<[AnyType]>, "", [MemRead]>:$base,
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 2d5580ec0ff81..cac8defb4d078 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5982,10 +5982,6 @@ Type GatherOp::getExpectedMaskType() {
vecType.getScalableDims());
}
-std::optional<SmallVector<int64_t, 4>> GatherOp::getShapeForUnroll() {
- return llvm::to_vector<4>(getVectorType().getShape());
-}
-
/// Cheeck if `indexVec` is constant 1D vec of consecutive values [0, 1, 2, ...]
static LogicalResult isZeroBasedContiguousSeq(Value indexVec) {
auto vecType = dyn_cast<VectorType>(indexVec.getType());
>From a6cbe0b42db5de0609455d3b1b575c006f6d3e4d Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Tue, 4 Nov 2025 16:43:37 -0500
Subject: [PATCH 04/11] [mlir][vector] Use getShapeForUnroll's default
implementation.
---
mlir/include/mlir/Dialect/Vector/IR/VectorOps.td | 2 +-
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 8 --------
2 files changed, 1 insertion(+), 9 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index fd6196a156d0f..fa613a86ad793 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -732,7 +732,7 @@ def Vector_ExtractOp :
def Vector_FMAOp :
Op<Vector_Dialect, "fma", [
Pure, AllTypesMatch<["lhs", "rhs", "acc", "result"]>,
- DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>
+ DeclareOpInterfaceMethods<VectorUnrollOpInterface>
] # ElementwiseMappable.traits>,
Arguments<(ins VectorOfAnyRankOf<[AnyFloat]>:$lhs,
VectorOfAnyRankOf<[AnyFloat]>:$rhs,
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index cac8defb4d078..b56e98dd6b595 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2374,14 +2374,6 @@ static void populateFromInt64AttrArray(ArrayAttr arrayAttr,
results.push_back(llvm::cast<IntegerAttr>(attr).getInt());
}
-//===----------------------------------------------------------------------===//
-// FmaOp
-//===----------------------------------------------------------------------===//
-
-std::optional<SmallVector<int64_t, 4>> FMAOp::getShapeForUnroll() {
- return llvm::to_vector<4>(getVectorType().getShape());
-}
-
//===----------------------------------------------------------------------===//
// ToElementsOp
//===----------------------------------------------------------------------===//
>From cd648dac74e3d607e4bf13c3e8bc7c65b0d5c698 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Tue, 4 Nov 2025 16:47:12 -0500
Subject: [PATCH 05/11] [mlir][vector] Use getShapeForUnroll's default
implementation.
---
mlir/include/mlir/Dialect/Vector/IR/VectorOps.td | 2 +-
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 4 ----
2 files changed, 1 insertion(+), 5 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index fa613a86ad793..a85ea2e128e1f 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -1245,7 +1245,7 @@ def Vector_ExtractStridedSliceOp :
def Vector_TransferReadOp :
Vector_Op<"transfer_read", [
DeclareOpInterfaceMethods<VectorTransferOpInterface>,
- DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
+ DeclareOpInterfaceMethods<VectorUnrollOpInterface>,
DeclareOpInterfaceMethods<MaskableOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
DeclareOpInterfaceMethods<ConditionallySpeculatable>,
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index b56e98dd6b595..f126f8dd6c4dd 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5088,10 +5088,6 @@ OpFoldResult TransferReadOp::fold(FoldAdaptor) {
return OpFoldResult();
}
-std::optional<SmallVector<int64_t, 4>> TransferReadOp::getShapeForUnroll() {
- return llvm::to_vector<4>(getVectorType().getShape());
-}
-
void TransferReadOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
>From 5103187a4f7b4676bc2125297a632b1d8419f9be Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Tue, 4 Nov 2025 17:12:57 -0500
Subject: [PATCH 06/11] [mlir][vector] Use getShapeForUnroll's default
implementation.
---
mlir/include/mlir/Dialect/Vector/IR/VectorOps.td | 2 +-
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 4 ----
2 files changed, 1 insertion(+), 5 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index a85ea2e128e1f..acfa578a184b8 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -1653,7 +1653,7 @@ def Vector_TransferWriteOp :
}
def Vector_LoadOp : Vector_Op<"load", [
- DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
+ DeclareOpInterfaceMethods<VectorUnrollOpInterface>,
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
DeclareOpInterfaceMethods<AlignmentAttrOpInterface>
]> {
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index f126f8dd6c4dd..b030b060c6ba0 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5762,10 +5762,6 @@ OpFoldResult LoadOp::fold(FoldAdaptor) {
return OpFoldResult();
}
-std::optional<SmallVector<int64_t, 4>> LoadOp::getShapeForUnroll() {
- return llvm::to_vector<4>(getVectorType().getShape());
-}
-
FailureOr<std::optional<SmallVector<Value>>>
LoadOp::bubbleDownCasts(OpBuilder &builder) {
return mlir::detail::bubbleDownInPlaceMemorySpaceCastImpl(getBaseMutable(),
>From 71e53e7f294286f280b012367515f53a81b2cdb9 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Tue, 4 Nov 2025 17:21:35 -0500
Subject: [PATCH 07/11] Fix documentation
---
mlir/include/mlir/Interfaces/VectorInterfaces.td | 5 ++---
1 file changed, 2 insertions(+), 3 deletions(-)
diff --git a/mlir/include/mlir/Interfaces/VectorInterfaces.td b/mlir/include/mlir/Interfaces/VectorInterfaces.td
index 6838c16fdf0fe..1223f5c0704ab 100644
--- a/mlir/include/mlir/Interfaces/VectorInterfaces.td
+++ b/mlir/include/mlir/Interfaces/VectorInterfaces.td
@@ -24,9 +24,8 @@ def VectorUnrollOpInterface : OpInterface<"VectorUnrollOpInterface"> {
let methods = [
InterfaceMethod<
/*desc=*/[{
- Return the shape ratio of unrolling to the target vector shape
- `targetShape`. Return `std::nullopt` if the op cannot be unrolled to the
- target vector shape.
+ Return the shape of the vector of this operation, which may be used to decide unrolling factors.
+ Return std::nullopt if the op is not applicable for unrolling.
}],
/*retTy=*/"::std::optional<::llvm::SmallVector<int64_t, 4>>",
/*methodName=*/"getShapeForUnroll",
>From 200773d78f4e57baf5d02b9531d97a289012399a Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Thu, 6 Nov 2025 15:23:07 -0500
Subject: [PATCH 08/11] Fix rebase
---
mlir/include/mlir/Dialect/Vector/IR/VectorOps.td | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index acfa578a184b8..a1c5298629e58 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2057,7 +2057,7 @@ def Vector_GatherOp :
Vector_Op<"gather", [
DeclareOpInterfaceMethods<MaskableOpInterface>,
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
- DeclareOpInterfaceMethods<VectorUnrollOpInterface>
+ DeclareOpInterfaceMethods<VectorUnrollOpInterface>,
DeclareOpInterfaceMethods<AlignmentAttrOpInterface>
]>,
Arguments<(ins Arg<TensorOrMemRef<[AnyType]>, "", [MemRead]>:$base,
>From aa4906a085fe94bc31d88ff9d0ac12131434ccae Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Tue, 4 Nov 2025 17:30:43 -0500
Subject: [PATCH 09/11] [mlir][vector] to_elements implements
VectorUnrollOpInterface
---
.../SPIRV/Transforms/SPIRVConversion.h | 3 +
.../mlir/Dialect/Vector/IR/VectorOps.td | 8 ++
.../SPIRV/Transforms/SPIRVConversion.cpp | 11 ++-
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 3 +
.../Vector/Transforms/VectorUnroll.cpp | 96 ++++++++++++++++++-
5 files changed, 117 insertions(+), 4 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
index 03ae54a8ae30a..f202c0ea88bd0 100644
--- a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
+++ b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
@@ -198,6 +198,9 @@ Value getVulkanElementPtr(const SPIRVTypeConverter &typeConverter,
// the target shape.
int getComputeVectorSize(int64_t size);
+// GetNativeVectorShape implementation for to_elements ops.
+SmallVector<int64_t> getNativeVectorShapeImpl(vector::ToElementsOp op);
+
// GetNativeVectorShape implementation for reduction ops.
SmallVector<int64_t> getNativeVectorShapeImpl(vector::ReductionOp op);
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index a1c5298629e58..51e9a9b986315 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -762,6 +762,7 @@ def Vector_FMAOp :
def Vector_ToElementsOp : Vector_Op<"to_elements", [
InferTypeOpAdaptor, Pure,
+ DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
ShapedTypeMatchesElementCountAndTypes<"source", "elements">]> {
let summary = "operation that decomposes a vector into all its scalar elements";
let description = [{
@@ -808,6 +809,13 @@ def Vector_ToElementsOp : Vector_Op<"to_elements", [
let assemblyFormat = "$source attr-dict `:` type($source)";
let hasFolder = 1;
let hasCanonicalizer = 1;
+ let extraClassDeclaration = [{
+
+ VectorType getSourceVectorType() {
+ return ::llvm::cast<VectorType>(getSource().getType());
+ }
+
+ }];
}
def Vector_FromElementsOp : Vector_Op<"from_elements", [
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index cb9b7f6ec2fd2..22097f5f2cdc6 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -1435,6 +1435,15 @@ int mlir::spirv::getComputeVectorSize(int64_t size) {
return 1;
}
+SmallVector<int64_t>
+mlir::spirv::getNativeVectorShapeImpl(vector::ToElementsOp op) {
+ VectorType srcVectorType = op.getSourceVectorType();
+ assert(srcVectorType.getRank() == 1); // Guaranteed by semantics
+ int64_t vectorSize =
+ mlir::spirv::getComputeVectorSize(srcVectorType.getDimSize(0));
+ return {vectorSize};
+}
+
SmallVector<int64_t>
mlir::spirv::getNativeVectorShapeImpl(vector::ReductionOp op) {
VectorType srcVectorType = op.getSourceVectorType();
@@ -1465,7 +1474,7 @@ mlir::spirv::getNativeVectorShape(Operation *op) {
}
return TypeSwitch<Operation *, std::optional<SmallVector<int64_t>>>(op)
- .Case<vector::ReductionOp, vector::TransposeOp>(
+ .Case<vector::ReductionOp, vector::TransposeOp, vector::ToElementsOp>(
[](auto typedOp) { return getNativeVectorShapeImpl(typedOp); })
.Default(std::nullopt);
}
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index b030b060c6ba0..4fe3b99f7fd6a 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2377,6 +2377,9 @@ static void populateFromInt64AttrArray(ArrayAttr arrayAttr,
//===----------------------------------------------------------------------===//
// ToElementsOp
//===----------------------------------------------------------------------===//
+std::optional<SmallVector<int64_t, 4>> ToElementsOp::getShapeForUnroll() {
+ return llvm::to_vector<4>(getSourceVectorType().getShape());
+}
/// Returns true if all the `operands` are defined by `defOp`.
/// Otherwise, returns false.
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index fbae0989bed26..c49718e0902a5 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -834,11 +834,100 @@ struct UnrollBroadcastPattern : public OpRewritePattern<vector::BroadcastOp> {
vector::UnrollVectorOptions options;
};
+/// Takes a 1 dimensional `vector.to_element` op and attempts to change it to
+/// the target shape.
+///
+/// ```
+/// // In SPIR-V's default environment vector of size 8
+/// // are not allowed.
+/// %elements:8 = vector.to_elements %v : vector<8xf32>
+///
+/// ===>
+///
+/// %v_0_to_3 = vector.extract %v[0] : vector<4xf32> from vector<8xf32>
+/// %v_4_to_7 = vector.extract %v[4] : vector<4xf32> from vector<8xf32>
+/// %elements_0:4 = vector.to_elements %v_0_to_3 : vector<4xf32>
+/// %elements_1:4 = vector.to_elements %v_4_to_7 : vector<4xf32>
+/// ```
+///
+/// This pattern may fail if the rank is not divisible by to a native shape
+/// or if the rank is already in the target shape and therefore it may be
+/// skipped.
+struct ToElementsToTargetShape final
+ : public OpRewritePattern<vector::ToElementsOp> {
+ ToElementsToTargetShape(MLIRContext *context,
+ const vector::UnrollVectorOptions &options,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<vector::ToElementsOp>(context, benefit),
+ options(options) {}
+
+ LogicalResult matchAndRewrite(vector::ToElementsOp op,
+ PatternRewriter &rewriter) const override {
+ auto targetShape = getTargetShape(options, op);
+ if (!targetShape)
+ return failure();
+
+ // We have
+ // source_rank = N * target_rank
+ int64_t source_rank = op.getSourceVectorType().getShape().front();
+ int64_t target_rank = targetShape->front();
+ int64_t N = source_rank / target_rank;
+
+ // Transformation where
+ // s = source_rank and
+ // t = target_rank
+ // ```
+ // %e:s = vector.to_elements %v : vector<sxf32>
+ //
+ // ===>
+ //
+ // // N vector.extract_strided_slice of size t
+ // %v0 = vector.extract_strided_slice %v
+ // {offsets = [0*t], sizes = [t], strides = [1]}
+ // : vector<txf32> from vector<sxf32>
+ // %v1 = vector.extract_strided_slice %v
+ // {offsets = [1*t], sizes = [t], strides = [1]}
+ // : vector<txf32> from vector<sxf32>
+ // ...
+ // %vNminus1 = vector.extract_strided_slice $v
+ // {offsets = [(N-1)*t], sizes = [t], strides = [1]}
+ // : vector<txf32> from vector<sxf32>
+ //
+ // // N vector.to_elements of size t vectors.
+ // %e0:t = vector.to_elements %v0 : vector<txf32>
+ // %e1:t = vector.to_elements %v1 : vector<txf32>
+ // ...
+ // %eNminus1:t = vector.to_elements %vNminus1 : vector<txf32>
+ // ```
+ SmallVector<Value> subVectors;
+ SmallVector<int64_t> strides(targetShape->size(), 1);
+ for (int64_t i = 0; i < N; i++) {
+ SmallVector<int64_t> elementOffsets = {i * target_rank};
+ Value subVector = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
+ op.getLoc(), op.getSource(), elementOffsets, *targetShape, strides);
+ subVectors.push_back(subVector);
+ }
+
+ SmallVector<Value> elements;
+ for (const Value subVector : subVectors) {
+ auto elementsOp =
+ vector::ToElementsOp::create(rewriter, op.getLoc(), subVector);
+ llvm::append_range(elements, elementsOp.getResults());
+ }
+
+ rewriter.replaceOp(op, elements);
+ return success();
+ }
+
+private:
+ vector::UnrollVectorOptions options;
+};
+
/// Unrolls 2 or more dimensional `vector.to_elements` ops by unrolling the
/// outermost dimension of the operand. For example:
///
/// ```
-/// %0:4 = vector.to_elements %v : vector<2x2xf32>
+/// %0:8 = vector.to_elements %v : vector<2x2x2xf32>
///
/// ==>
///
@@ -865,6 +954,7 @@ struct UnrollToElements final : public OpRewritePattern<vector::ToElementsOp> {
FailureOr<SmallVector<Value>> result =
vector::unrollVectorValue(source, rewriter);
if (failed(result)) {
+ // Only fails if operand is 1-dimensional.
return failure();
}
SmallVector<Value> vectors = *result;
@@ -1013,8 +1103,8 @@ void mlir::vector::populateVectorUnrollPatterns(
UnrollReductionPattern, UnrollMultiReductionPattern,
UnrollTransposePattern, UnrollGatherPattern, UnrollLoadPattern,
UnrollStorePattern, UnrollBroadcastPattern, UnrollFromElements,
- UnrollToElements, UnrollStepPattern>(patterns.getContext(),
- options, benefit);
+ UnrollToElements, UnrollStepPattern, ToElementsToTargetShape>(
+ patterns.getContext(), options, benefit);
}
void mlir::vector::populateVectorToElementsUnrollPatterns(
>From 228d0b142b14f95ef2dae0030fd39f48f14584b4 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Wed, 5 Nov 2025 15:14:09 -0500
Subject: [PATCH 10/11] [mlir] Test vector.to_elements to spirv conversion.
---
.../ConvertToSPIRV/vector-sizes.mlir | 67 +++++++++++++++++++
1 file changed, 67 insertions(+)
create mode 100644 mlir/test/Conversion/ConvertToSPIRV/vector-sizes.mlir
diff --git a/mlir/test/Conversion/ConvertToSPIRV/vector-sizes.mlir b/mlir/test/Conversion/ConvertToSPIRV/vector-sizes.mlir
new file mode 100644
index 0000000000000..402c539a77093
--- /dev/null
+++ b/mlir/test/Conversion/ConvertToSPIRV/vector-sizes.mlir
@@ -0,0 +1,67 @@
+// RUN: mlir-opt -test-convert-to-spirv="run-signature-conversion=false run-vector-unrolling=true" -split-input-file %s | FileCheck %s
+
+// COM: This file tests the current behaviour of the SignatureConversion
+// COM: and the unrolling of vector.to_elements to vectors of valid SPIR-V
+// COM: sizes.
+
+// COM: vector's of rank 1 and size 1 will be changed
+// COM: to scalars. Since vector.to_elements will also produce
+// COM: a scalar, we expect the vector.to_elements to be folded
+// COM: away. Please note that even if run-signature-conversion=false
+// COM: The pattern FuncOpConversion will still run and change parameters
+// COM: which fit this constraint.
+
+// CHECK-LABEL: spirv.func @vec_size_1
+// CHECK-SAME: (%[[ARG0:.+]]: f32)
+func.func @vec_size_1(%arg0: vector<1xf32>) -> (f32) {
+ // CHECK-NEXT: spirv.ReturnValue %[[ARG0]] : f32
+ %0:1 = vector.to_elements %arg0 : vector<1xf32>
+ return %0#0 : f32
+}
+
+// -----
+
+// COM: vector's of rank 2, 3, 4 are allowed by SPIR-V.
+// So they remain unchanged. FuncOpConversion will still
+// run, but the signature converter will not convert these vectors.
+
+// CHECK-LABEL: spirv.func @vec_size_2
+// CHECK-SAME: (%[[ARG0:.+]]: vector<2xf32>)
+func.func @vec_size_2(%arg0: vector<2xf32>) -> (f32) {
+ // COM: A single result type is enforced by the semantics
+
+ // CHECK-NEXT: %[[VAL:.+]] = spirv.CompositeExtract %[[ARG0]][0 : i32] : vector<2xf32>
+ %0:2 = vector.to_elements %arg0 : vector<2xf32>
+
+ // CHECK-NEXT: spirv.ReturnValue %[[VAL]]
+ return %0#0 : f32
+}
+
+// -----
+
+// COM: vector of rank 5 is the first one that doesn't fit
+// COM: into SPIR-V's vectors.
+
+// COM: run-signature-conversion=false means that
+// COM: this vector will not be unrolled.
+
+// CHECK-LABEL: func.func @vec_size_5
+// CHECK-SAME: (%[[ARG0:.+]]: vector<5xf32>)
+func.func @vec_size_5(%arg0: vector<5xf32>) -> (f32) {
+
+ // CHECK-NEXT: %[[VAL:.+]] = vector.extract_strided_slice %[[ARG0]] {offsets = [0], sizes = [1], strides = [1]} : vector<5xf32> to vector<1xf32>
+
+ // COM: We have the following comment in VectorConvertToElementOp
+ // COM:
+ // COM: // Input vectors of size 1 are converted to scalars by the type converter.
+ // COM: // We cannot use `spirv::CompositeExtractOp` directly in this case.
+ // COM: // For a scalar source, the result is just the scalar itself.
+ // COM:
+ // COM: Which in this case means an unrealized conversion cast.
+
+ // CHECK-NEXT: %[[RETVAL:.+]] = builtin.unrealized_conversion_cast %[[VAL]] : vector<1xf32> to f32
+ %0:5 = vector.to_elements %arg0 : vector<5xf32>
+
+ // CHECK-NEXT: spirv.ReturnValue %[[RETVAL]] : f32
+ return %0#0 : f32
+}
>From 8fe386a4edfe8148e6bb57e7cbd84f2a82e02b78 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Wed, 5 Nov 2025 17:03:39 -0500
Subject: [PATCH 11/11] [mlir] Update unrollToElements tests
---
.../Vector/Transforms/VectorUnroll.cpp | 5 ++--
.../ConvertToSPIRV/vector-unroll.mlir | 16 +++++++++++++
.../Vector/vector-to-elements-lowering.mlir | 23 +++++++++++++++++++
3 files changed, 42 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index c49718e0902a5..fd5a8f7c89d7d 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -1109,8 +1109,9 @@ void mlir::vector::populateVectorUnrollPatterns(
void mlir::vector::populateVectorToElementsUnrollPatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
- patterns.add<UnrollToElements>(patterns.getContext(), UnrollVectorOptions(),
- benefit);
+ auto options = UnrollVectorOptions().setNativeShape(SmallVector<int64_t>{4});
+ patterns.add<UnrollToElements, ToElementsToTargetShape>(patterns.getContext(),
+ options, benefit);
}
void mlir::vector::populateVectorFromElementsUnrollPatterns(
diff --git a/mlir/test/Conversion/ConvertToSPIRV/vector-unroll.mlir b/mlir/test/Conversion/ConvertToSPIRV/vector-unroll.mlir
index 0957f67690b97..dcc55a7868978 100644
--- a/mlir/test/Conversion/ConvertToSPIRV/vector-unroll.mlir
+++ b/mlir/test/Conversion/ConvertToSPIRV/vector-unroll.mlir
@@ -120,6 +120,22 @@ func.func @unroll_to_elements_2d() -> (f32, f32, f32, f32) {
// -----
+// CHECK-LABEL: @unroll_to_elements_8xf32
+func.func @unroll_to_elements_8xf32() -> (f32, f32) {
+
+ // CHECK: %[[VEC:.+]] = "test.op"
+ // CHECK: %[[V0:.+]] = vector.extract_strided_slice %[[VEC]] {offsets = [0]
+ // CHECK: %[[V1:.+]] = vector.extract_strided_slice %[[VEC]] {offsets = [4]
+ // CHECK: %[[ELEMS0:.+]]:4 = vector.to_elements %[[V0]]
+ // CHECK: %[[ELEMS1:.+]]:4 = vector.to_elements %[[V1]]
+ // CHECK: return %[[ELEMS0]]#3, %[[ELEMS1]]#0
+ %0 = "test.op"() : () -> (vector<8xf32>)
+ %1:8 = vector.to_elements %0 : vector<8xf32>
+ return %1#3, %1#4 : f32, f32
+}
+
+// -----
+
// In order to verify that the pattern is applied,
// we need to make sure that the the 2d vector is used
// by an operation and that extracts are not folded away.
diff --git a/mlir/test/Dialect/Vector/vector-to-elements-lowering.mlir b/mlir/test/Dialect/Vector/vector-to-elements-lowering.mlir
index c521bf0138f98..d448377143249 100644
--- a/mlir/test/Dialect/Vector/vector-to-elements-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-to-elements-lowering.mlir
@@ -29,3 +29,26 @@ func.func @unroll_to_elements_2d(%arg0: vector<2x2xf32>) -> (f32, f32, f32, f32)
%0:4 = vector.to_elements %arg0 : vector<2x2xf32>
return %0#0, %0#1, %0#2, %0#3 : f32, f32, f32, f32
}
+
+// -----
+
+// COM: Here we are testing the pattern ToElementsToTargetShape
+// COM: The pattern has a native shape of [4], which means
+// COM: that vectors multiples of 4 will be split. In this
+// COM: case, that will happen in the function's body, not the argument.
+
+// CHECK-LABEL: func.func @unroll_vector_8xf32
+// CHECK-SAME: (%[[ARG0:.+]]: vector<8xf32>)
+func.func @unroll_vector_8xf32(%arg0: vector<8xf32>) -> (f32, f32) {
+ %0:8 = vector.to_elements %arg0 : vector<8xf32>
+
+ // COM: We only return two elements, one from each of the
+ // COM: vectors.
+ return %0#3, %0#4: f32, f32
+
+ // CHECK: %[[V0:.+]] = vector.extract_strided_slice %[[ARG0]] {offsets = [0], sizes = [4], strides = [1]} : vector<8xf32> to vector<4xf32>
+ // CHECK-NEXT: %[[V1:.+]] = vector.extract_strided_slice %[[ARG0]] {offsets = [4], sizes = [4], strides = [1]} : vector<8xf32> to vector<4xf32>
+ // CHECK-NEXT: %[[ELEMS_0:.+]]:4 = vector.to_elements %[[V0]]
+ // CHECK-NEXT: %[[ELEMS_1:.+]]:4 = vector.to_elements %[[V1]]
+ // CHECK-NEXT: return %[[ELEMS_0]]#3, %[[ELEMS_1]]#0
+}
More information about the Mlir-commits
mailing list