[Mlir-commits] [mlir] [mlir][Vector] Add `vector.to_elements` op (PR #141457)
Diego Caballero
llvmlistbot at llvm.org
Wed Jun 18 11:45:17 PDT 2025
https://github.com/dcaballe updated https://github.com/llvm/llvm-project/pull/141457
>From 0c522e4966d95860df559bf868379e9f8b69f2d1 Mon Sep 17 00:00:00 2001
From: Diego Caballero <dcaballero at nvidia.com>
Date: Mon, 26 May 2025 06:29:13 +0000
Subject: [PATCH 1/3] [mlir][Vector] Add `vector.to_elements` op
This PR introduces the `vector.to_elements` op, which decomposes a
vector into its scalar elements. This operation is symmetrical to
the existing `vector.from_elements`.
Examples:
```
// Decompose a 0-D vector.
%0 = vector.to_elements %v0 : vector<f32>
// %0 = %v0[0]
// Decompose a 1-D vector.
%0:2 = vector.to_elements %v1 : vector<2xf32>
// %0#0 = %v1[0]
// %0#1 = %v1[1]
// Decompose a 2-D.
%0:6 = vector.to_elements %v2 : vector<2x3xf32>
// %0#0 = %v2[0, 0]
// %0#1 = %v2[0, 1]
// %0#2 = %v2[0, 2]
// %0#3 = %v2[1, 0]
// %0#4 = %v2[1, 1]
// %0#5 = %v2[1, 2]
```
This op is aimed at reducing code size when modeling "structured" vector
extractions and simplifying canonicalizations of large sequences of
`vector.extract` and `vector.insert` ops into `vector.shuffle` and
other sophisticated ops that can re-arrange vector elements.
More related PRs to come!
---
.../mlir/Dialect/Vector/IR/VectorOps.td | 79 ++++++++++++++++---
mlir/test/Dialect/Vector/invalid.mlir | 22 +++++-
mlir/test/Dialect/Vector/ops.mlir | 19 +++++
3 files changed, 105 insertions(+), 15 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 8353314ed958b..e1fabb9389b5c 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -790,6 +790,57 @@ def Vector_FMAOp :
}];
}
+def Vector_ToElementsOp : Vector_Op<"to_elements", [
+ Pure,
+ TypesMatchWith<"operand element type matches result types",
+ "input", "elements", "SmallVector<Type>("
+ "::llvm::cast<VectorType>($_self).getNumElements(), "
+ "::llvm::cast<VectorType>($_self).getElementType())">]> {
+ let summary = "operation that decomposes a vector into all its scalar elements";
+ let description = [{
+ This operation decomposes all the scalar elements from a vector. The
+ decomposed scalar elements are returned in row-major order. The number of
+ scalar results must match the number of elements in the input vector type.
+ All the result elements have the same result type, which must match the
+ element type of the input vector. Scalable vectors are not supported.
+
+ Examples:
+
+ ```mlir
+ // Decompose a 0-D vector.
+ %0 = vector.to_elements %v0 : vector<f32>
+ // %0 = %v0[0]
+
+ // Decompose a 1-D vector.
+ %0:2 = vector.to_elements %v1 : vector<2xf32>
+ // %0#0 = %v1[0]
+ // %0#1 = %v1[1]
+
+ // Decompose a 2-D.
+ %0:6 = vector.to_elements %v2 : vector<2x3xf32>
+ // %0#0 = %v2[0, 0]
+ // %0#1 = %v2[0, 1]
+ // %0#2 = %v2[0, 2]
+ // %0#3 = %v2[1, 0]
+ // %0#4 = %v2[1, 1]
+ // %0#5 = %v2[1, 2]
+
+ // Decompose a 3-D vector.
+ %0:6 = vector.to_elements %v3 : vector<3x1x2xf32>
+ // %0#0 = %v3[0, 0, 0]
+ // %0#1 = %v3[0, 0, 1]
+ // %0#2 = %v3[1, 0, 0]
+ // %0#3 = %v3[1, 0, 1]
+ // %0#4 = %v3[2, 0, 0]
+ // %0#5 = %v3[2, 0, 1]
+ ```
+ }];
+
+ let arguments = (ins AnyVectorOfAnyRank:$input);
+ let results = (outs Variadic<AnyType>:$elements);
+ let assemblyFormat = "$input attr-dict `:` type($input)";
+}
+
def Vector_FromElementsOp : Vector_Op<"from_elements", [
Pure,
TypesMatchWith<"operand types match result element type",
@@ -799,26 +850,30 @@ def Vector_FromElementsOp : Vector_Op<"from_elements", [
let summary = "operation that defines a vector from scalar elements";
let description = [{
This operation defines a vector from one or multiple scalar elements. The
- number of elements must match the number of elements in the result type.
- All elements must have the same type, which must match the element type of
- the result vector type.
-
- `elements` are a flattened version of the result vector in row-major order.
+ scalar elements are arranged in row-major within the vector. The number of
+ elements must match the number of elements in the result type. All elements
+ must have the same type, which must match the element type of the result
+ vector type. Scalable vectors are not supported.
- Example:
+ Examples:
```mlir
- // %f1
+ // Define a 0-D vector.
%0 = vector.from_elements %f1 : vector<f32>
- // [%f1, %f2]
+ // [%f1]
+
+ // Define a 1-D vector.
%1 = vector.from_elements %f1, %f2 : vector<2xf32>
- // [[%f1, %f2, %f3], [%f4, %f5, %f6]]
+ // [%f1, %f2]
+
+ // Define a 2-D vector.
%2 = vector.from_elements %f1, %f2, %f3, %f4, %f5, %f6 : vector<2x3xf32>
- // [[[%f1, %f2]], [[%f3, %f4]], [[%f5, %f6]]]
+ // [[%f1, %f2, %f3], [%f4, %f5, %f6]]
+
+ // Define a 3-D vector.
%3 = vector.from_elements %f1, %f2, %f3, %f4, %f5, %f6 : vector<3x1x2xf32>
+ // [[[%f1, %f2]], [[%f3, %f4]], [[%f5, %f6]]]
```
-
- Note, scalable vectors are not supported.
}];
let arguments = (ins Variadic<AnyType>:$elements);
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 04810ed52584f..70a7274182442 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1896,7 +1896,24 @@ func.func @deinterleave_scalable_rank_fail(%vec : vector<2x[4]xf32>) {
// -----
-func.func @invalid_from_elements(%a: f32) {
+func.func @to_elements_wrong_num_results(%a: vector<1x1x2xf32>) {
+ // expected-error @+1 {{operation defines 2 results but was provided 4 to bind}}
+ %0:4 = vector.to_elements %a : vector<1x1x2xf32>
+ return
+}
+
+// -----
+
+func.func @to_elements_wrong_result_type(%a: vector<2xf32>) -> i32 {
+ // expected-error @+3 {{use of value '%0' expects different type than prior uses: 'i32'}}
+ // expected-note @+1 {{prior use here}}
+ %0:2 = vector.to_elements %a : vector<2xf32>
+ return %0#0 : i32
+}
+
+// -----
+
+func.func @from_elements_wrong_num_operands(%a: f32) {
// expected-error @+1 {{'vector.from_elements' number of operands and types do not match: got 1 operands and 2 types}}
vector.from_elements %a : vector<2xf32>
return
@@ -1905,12 +1922,11 @@ func.func @invalid_from_elements(%a: f32) {
// -----
// expected-note @+1 {{prior use here}}
-func.func @invalid_from_elements(%a: f32, %b: i32) {
+func.func @from_elements_wrong_operand_type(%a: f32, %b: i32) {
// expected-error @+1 {{use of value '%b' expects different type than prior uses: 'f32' vs 'i32'}}
vector.from_elements %a, %b : vector<2xf32>
return
}
-
// -----
func.func @invalid_from_elements_scalable(%a: f32, %b: i32) {
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index f3220aed4360c..7cfe4e89d6e2f 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -1175,6 +1175,25 @@ func.func @deinterleave_nd_scalable(%arg:vector<2x3x4x[6]xf32>) -> (vector<2x3x4
return %0, %1 : vector<2x3x4x[3]xf32>, vector<2x3x4x[3]xf32>
}
+// CHECK-LABEL: func @to_elements(
+// CHECK-SAME: %[[A_VEC:.*]]: vector<f32>, %[[B_VEC:.*]]: vector<4xf32>,
+// CHECK-SAME: %[[C_VEC:.*]]: vector<1xf32>, %[[D_VEC:.*]]: vector<2x2xf32>)
+func.func @to_elements(%a_vec : vector<f32>, %b_vec : vector<4xf32>, %c_vec : vector<1xf32>, %d_vec : vector<2x2xf32>)
+ -> (f32, f32, f32, f32, f32, f32, f32, f32, f32, f32) {
+ // CHECK: %[[A_ELEMS:.*]] = vector.to_elements %[[A_VEC]] : vector<f32>
+ %0 = vector.to_elements %a_vec : vector<f32>
+ // CHECK: %[[B_ELEMS:.*]]:4 = vector.to_elements %[[B_VEC]] : vector<4xf32>
+ %1:4 = vector.to_elements %b_vec : vector<4xf32>
+ // CHECK: %[[C_ELEMS:.*]] = vector.to_elements %[[C_VEC]] : vector<1xf32>
+ %2 = vector.to_elements %c_vec : vector<1xf32>
+ // CHECK: %[[D_ELEMS:.*]]:4 = vector.to_elements %[[D_VEC]] : vector<2x2xf32>
+ %3:4 = vector.to_elements %d_vec : vector<2x2xf32>
+ // CHECK: return %[[A_ELEMS]], %[[B_ELEMS]]#0, %[[B_ELEMS]]#1, %[[B_ELEMS]]#2,
+ // CHECK-SAME: %[[B_ELEMS]]#3, %[[C_ELEMS]], %[[D_ELEMS]]#0, %[[D_ELEMS]]#1,
+ // CHECK-SAME: %[[D_ELEMS]]#2, %[[D_ELEMS]]#3
+ return %0, %1#0, %1#1, %1#2, %1#3, %2, %3#0, %3#1, %3#2, %3#3 : f32, f32, f32, f32, f32, f32, f32, f32, f32, f32
+}
+
// CHECK-LABEL: func @from_elements(
// CHECK-SAME: %[[a:.*]]: f32, %[[b:.*]]: f32)
func.func @from_elements(%a: f32, %b: f32) -> (vector<f32>, vector<1xf32>, vector<1x2xf32>, vector<2x2xf32>) {
>From 1aa54aeef932f707c148afdc72dac0753d08efb9 Mon Sep 17 00:00:00 2001
From: Diego Caballero <dcaballero at nvidia.com>
Date: Wed, 11 Jun 2025 18:43:56 +0000
Subject: [PATCH 2/3] Feedback
---
mlir/include/mlir/Dialect/Vector/IR/VectorOps.td | 12 ++++++------
mlir/test/Dialect/Vector/invalid.mlir | 2 +-
2 files changed, 7 insertions(+), 7 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index e1fabb9389b5c..4e1ee145f44ef 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -793,7 +793,7 @@ def Vector_FMAOp :
def Vector_ToElementsOp : Vector_Op<"to_elements", [
Pure,
TypesMatchWith<"operand element type matches result types",
- "input", "elements", "SmallVector<Type>("
+ "source", "elements", "SmallVector<Type>("
"::llvm::cast<VectorType>($_self).getNumElements(), "
"::llvm::cast<VectorType>($_self).getElementType())">]> {
let summary = "operation that decomposes a vector into all its scalar elements";
@@ -836,15 +836,15 @@ def Vector_ToElementsOp : Vector_Op<"to_elements", [
```
}];
- let arguments = (ins AnyVectorOfAnyRank:$input);
+ let arguments = (ins AnyVectorOfAnyRank:$source);
let results = (outs Variadic<AnyType>:$elements);
- let assemblyFormat = "$input attr-dict `:` type($input)";
+ let assemblyFormat = "$source attr-dict `:` type($source)";
}
def Vector_FromElementsOp : Vector_Op<"from_elements", [
Pure,
TypesMatchWith<"operand types match result element type",
- "result", "elements", "SmallVector<Type>("
+ "dest", "elements", "SmallVector<Type>("
"::llvm::cast<VectorType>($_self).getNumElements(), "
"::llvm::cast<VectorType>($_self).getElementType())">]> {
let summary = "operation that defines a vector from scalar elements";
@@ -877,8 +877,8 @@ def Vector_FromElementsOp : Vector_Op<"from_elements", [
}];
let arguments = (ins Variadic<AnyType>:$elements);
- let results = (outs AnyFixedVectorOfAnyRank:$result);
- let assemblyFormat = "$elements attr-dict `:` type($result)";
+ let results = (outs AnyFixedVectorOfAnyRank:$dest);
+ let assemblyFormat = "$elements attr-dict `:` type($dest)";
let hasCanonicalizer = 1;
}
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 70a7274182442..ec7cee7b2c641 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1930,7 +1930,7 @@ func.func @from_elements_wrong_operand_type(%a: f32, %b: i32) {
// -----
func.func @invalid_from_elements_scalable(%a: f32, %b: i32) {
- // expected-error @+1 {{'result' must be fixed-length vector of any type values, but got 'vector<[2]xf32>'}}
+ // expected-error @+1 {{'dest' must be fixed-length vector of any type values, but got 'vector<[2]xf32>'}}
vector.from_elements %a, %b : vector<[2]xf32>
return
}
>From c71b42de4fa057c617a72c12380b5b22cf083314 Mon Sep 17 00:00:00 2001
From: Diego Caballero <dcaballero at nvidia.com>
Date: Wed, 18 Jun 2025 18:42:59 +0000
Subject: [PATCH 3/3] Add ShapedTypeMatchesElementCountAndTypes
---
.../mlir/Dialect/Vector/IR/VectorOps.td | 10 ++----
mlir/include/mlir/IR/OpBase.td | 19 ++++++++++++
mlir/lib/TableGen/Operator.cpp | 31 +++++++++++++++++++
mlir/tools/mlir-tblgen/OpFormatGen.cpp | 26 ++++++++++++++++
4 files changed, 78 insertions(+), 8 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 4e1ee145f44ef..125cd4645ccc2 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -792,10 +792,7 @@ def Vector_FMAOp :
def Vector_ToElementsOp : Vector_Op<"to_elements", [
Pure,
- TypesMatchWith<"operand element type matches result types",
- "source", "elements", "SmallVector<Type>("
- "::llvm::cast<VectorType>($_self).getNumElements(), "
- "::llvm::cast<VectorType>($_self).getElementType())">]> {
+ ShapedTypeMatchesElementCountAndTypes<"source", "elements">]> {
let summary = "operation that decomposes a vector into all its scalar elements";
let description = [{
This operation decomposes all the scalar elements from a vector. The
@@ -843,10 +840,7 @@ def Vector_ToElementsOp : Vector_Op<"to_elements", [
def Vector_FromElementsOp : Vector_Op<"from_elements", [
Pure,
- TypesMatchWith<"operand types match result element type",
- "dest", "elements", "SmallVector<Type>("
- "::llvm::cast<VectorType>($_self).getNumElements(), "
- "::llvm::cast<VectorType>($_self).getElementType())">]> {
+ ShapedTypeMatchesElementCountAndTypes<"dest", "elements">]> {
let summary = "operation that defines a vector from scalar elements";
let description = [{
This operation defines a vector from one or multiple scalar elements. The
diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index 51b60972203e7..b3fabe409806f 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -556,6 +556,25 @@ class AllShapesMatch<list<string> names> :
class AllTypesMatch<list<string> names> :
AllMatchSameOperatorTrait<names, "$_self.getType()", "type">;
+// A type constraint that verifies that a shaped type matches the size and
+// element type of a container with element types. More specifically, it denotes
+// shapedArg.getType().getNumElements() == elementsArg.size() &&
+// shapedArg.getType().getElementType() == elementsArg[i].getType(), for i in
+// [0, elementsArg.size()).
+class ShapedTypeMatchesElementCountAndTypes<string shapedArg,
+ string elementsArg> :
+ PredOpTrait<"shaped type '" # shapedArg # "' matches '" # elementsArg # "' "
+ "element count and types",
+ And<[CPred<ElementCount<shapedArg>.result # " == "
+ "$" # elementsArg # ".getTypes().size()">,
+ CPred<"::llvm::all_of($" # elementsArg # ".getTypes(), "
+ "[&](::mlir::Type t) { return t == "
+ # ElementType<shapedArg>.result # "; })">]>> {
+
+ string shaped = shapedArg;
+ string elements = elementsArg;
+}
+
// A type constraint that denotes `transform(lhs.getType()) == rhs.getType()`.
// An optional comparator function may be provided that changes the above form
// into: `comparator(transform(lhs.getType()), rhs.getType())`.
diff --git a/mlir/lib/TableGen/Operator.cpp b/mlir/lib/TableGen/Operator.cpp
index 2544f0a1b91b6..07520a2f94d77 100644
--- a/mlir/lib/TableGen/Operator.cpp
+++ b/mlir/lib/TableGen/Operator.cpp
@@ -468,6 +468,37 @@ void Operator::populateTypeInferenceInfo(
continue;
}
+ // The `ShapedTypeMatchesElementCountAndTypes` trait represents a 1 -> 1
+ // type inference edge where a shaped type matches element count and types
+ // of variadic elements.
+ if (def.isSubClassOf("ShapedTypeMatchesElementCountAndTypes")) {
+ StringRef shapedArg = def.getValueAsString("shaped");
+ StringRef elementsArg = def.getValueAsString("elements");
+
+ int shapedIndex = argumentsAndResultsIndex.lookup(shapedArg);
+ int elementsIndex = argumentsAndResultsIndex.lookup(elementsArg);
+
+ // Handle result type inference from shaped type to variadic elements.
+ if (InferredResultType::isResultIndex(elementsIndex) &&
+ InferredResultType::isArgIndex(shapedIndex)) {
+ int resultIndex = InferredResultType::unmapResultIndex(elementsIndex);
+ ResultTypeInference &infer = inference[resultIndex];
+ if (!infer.inferred) {
+ infer.sources.emplace_back(
+ shapedIndex,
+ "::llvm::SmallVector<::mlir::Type>(::llvm::cast<::mlir::"
+ "ShapedType>($_self).getNumElements(), "
+ "::llvm::cast<::mlir::ShapedType>($_self).getElementType())");
+ infer.inferred = true;
+ }
+ }
+
+ // Type inference in the opposite direction is not possible as the actual
+ // shaped type can't be inferred from the variadic elements.
+
+ continue;
+ }
+
if (!def.isSubClassOf("AllTypesMatch"))
continue;
diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
index 0a9d14d6603a8..ef3a18ba7df22 100644
--- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
@@ -2787,6 +2787,11 @@ class OpFormatParser : public FormatParser {
void handleTypesMatchConstraint(
StringMap<TypeResolutionInstance> &variableTyResolver, const Record &def);
+ /// Check for inferable type resolution based on
+ /// `ShapedTypeMatchesElementCountAndTypes` constraint.
+ void handleShapedTypeMatchesElementCountAndTypesConstraint(
+ StringMap<TypeResolutionInstance> &variableTyResolver, const Record &def);
+
/// Returns an argument or attribute with the given name that has been seen
/// within the format.
ConstArgument findSeenArg(StringRef name);
@@ -2850,6 +2855,9 @@ LogicalResult OpFormatParser::verify(SMLoc loc,
handleSameTypesConstraint(variableTyResolver, /*includeResults=*/true);
} else if (def.isSubClassOf("TypesMatchWith")) {
handleTypesMatchConstraint(variableTyResolver, def);
+ } else if (def.isSubClassOf("ShapedTypeMatchesElementCountAndTypes")) {
+ handleShapedTypeMatchesElementCountAndTypesConstraint(variableTyResolver,
+ def);
} else if (!op.allResultTypesKnown()) {
// This doesn't check the name directly to handle
// DeclareOpInterfaceMethods<InferTypeOpInterface>
@@ -3289,6 +3297,24 @@ void OpFormatParser::handleTypesMatchConstraint(
variableTyResolver[rhsName] = {arg, transformer};
}
+void OpFormatParser::handleShapedTypeMatchesElementCountAndTypesConstraint(
+ StringMap<TypeResolutionInstance> &variableTyResolver, const Record &def) {
+ StringRef shapedArg = def.getValueAsString("shaped");
+ StringRef elementsArg = def.getValueAsString("elements");
+
+ // Check if the 'shaped' argument is seen, then we can infer the 'elements'
+ // types.
+ if (ConstArgument arg = findSeenArg(shapedArg)) {
+ variableTyResolver[elementsArg] = {
+ arg, "::llvm::SmallVector<::mlir::Type>(::llvm::cast<::mlir::"
+ "ShapedType>($_self).getNumElements(), "
+ "::llvm::cast<::mlir::ShapedType>($_self).getElementType())"};
+ }
+
+ // Type inference in the opposite direction is not possible as the actual
+ // shaped type can't be inferred from the variadic elements.
+}
+
ConstArgument OpFormatParser::findSeenArg(StringRef name) {
if (const NamedTypeConstraint *arg = findArg(op.getOperands(), name))
return seenOperandTypes.test(arg - op.operand_begin()) ? arg : nullptr;
More information about the Mlir-commits
mailing list