[Mlir-commits] [mlir] [mlir][Vector] Add `vector.to_elements` op (PR #141457)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sun May 25 23:33:00 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-vector
Author: Diego Caballero (dcaballe)
<details>
<summary>Changes</summary>
This PR introduces the `vector.to_elements` op, which decomposes a vector into its scalar elements. This operation is symmetrical to the existing `vector.from_elements`.
Examples:
```
// Decompose a 0-D vector.
%0 = vector.to_elements %v0 : vector<f32>
// %0 = %v0[0]
// Decompose a 1-D vector.
%0:2 = vector.to_elements %v1 : vector<2xf32>
// %0#<!-- -->0 = %v1[0]
// %0#<!-- -->1 = %v1[1]
// Decompose a 2-D.
%0:6 = vector.to_elements %v2 : vector<2x3xf32>
// %0#<!-- -->0 = %v2[0, 0]
// %0#<!-- -->1 = %v2[0, 1]
// %0#<!-- -->2 = %v2[0, 2]
// %0#<!-- -->3 = %v2[1, 0]
// %0#<!-- -->4 = %v2[1, 1]
// %0#<!-- -->5 = %v2[1, 2]
```
This op is aimed at reducing code size when modeling "structured" vector extractions and simplifying canonicalizations of large sequences of `vector.extract` and `vector.insert` ops into `vector.shuffle` and other sophisticated ops that can re-arrange vector elements.
More related PRs to come!
---
Full diff: https://github.com/llvm/llvm-project/pull/141457.diff
3 Files Affected:
- (modified) mlir/include/mlir/Dialect/Vector/IR/VectorOps.td (+67-12)
- (modified) mlir/test/Dialect/Vector/invalid.mlir (+19-3)
- (modified) mlir/test/Dialect/Vector/ops.mlir (+19)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 5e8421ed67d66..3da47d8e612e2 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -789,6 +789,57 @@ def Vector_FMAOp :
}];
}
+def Vector_ToElementsOp : Vector_Op<"to_elements", [
+ Pure,
+ TypesMatchWith<"operand element type matches result types",
+ "input", "elements", "SmallVector<Type>("
+ "::llvm::cast<VectorType>($_self).getNumElements(), "
+ "::llvm::cast<VectorType>($_self).getElementType())">]> {
+ let summary = "operation that decomposes a vector into all its scalar elements";
+ let description = [{
+ This operation decomposes all the scalar elements from a vector. The
+ decomposed scalar elements are returned in row-major order. The number of
+ scalar results must match the number of elements in the input vector type.
+ All the result elements have the same result type, which must match the
+ element type of the input vector. Scalable vectors are not supported.
+
+ Examples:
+
+ ```mlir
+ // Decompose a 0-D vector.
+ %0 = vector.to_elements %v0 : vector<f32>
+ // %0 = %v0[0]
+
+ // Decompose a 1-D vector.
+ %0:2 = vector.to_elements %v1 : vector<2xf32>
+ // %0#0 = %v1[0]
+ // %0#1 = %v1[1]
+
+ // Decompose a 2-D.
+ %0:6 = vector.to_elements %v2 : vector<2x3xf32>
+ // %0#0 = %v2[0, 0]
+ // %0#1 = %v2[0, 1]
+ // %0#2 = %v2[0, 2]
+ // %0#3 = %v2[1, 0]
+ // %0#4 = %v2[1, 1]
+ // %0#5 = %v2[1, 2]
+
+ // Decompose a 3-D vector.
+ %0:6 = vector.to_elements %v3 : vector<3x1x2xf32>
+ // %0#0 = %v3[0, 0, 0]
+ // %0#1 = %v3[0, 0, 1]
+ // %0#2 = %v3[1, 0, 0]
+ // %0#3 = %v3[1, 0, 1]
+ // %0#4 = %v3[2, 0, 0]
+ // %0#5 = %v3[2, 0, 1]
+ ```
+ }];
+
+ let arguments = (ins AnyVectorOfAnyRank:$input);
+ let results = (outs Variadic<AnyType>:$elements);
+ let assemblyFormat = "$input attr-dict `:` type($input)";
+}
+
def Vector_FromElementsOp : Vector_Op<"from_elements", [
Pure,
TypesMatchWith<"operand types match result element type",
@@ -798,26 +849,30 @@ def Vector_FromElementsOp : Vector_Op<"from_elements", [
let summary = "operation that defines a vector from scalar elements";
let description = [{
This operation defines a vector from one or multiple scalar elements. The
- number of elements must match the number of elements in the result type.
- All elements must have the same type, which must match the element type of
- the result vector type.
-
- `elements` are a flattened version of the result vector in row-major order.
+ scalar elements are arranged in row-major within the vector. The number of
+ elements must match the number of elements in the result type. All elements
+ must have the same type, which must match the element type of the result
+ vector type. Scalable vectors are not supported.
- Example:
+ Examples:
```mlir
- // %f1
+ // Define a 0-D vector.
%0 = vector.from_elements %f1 : vector<f32>
- // [%f1, %f2]
+ // [%f1]
+
+ // Define a 1-D vector.
%1 = vector.from_elements %f1, %f2 : vector<2xf32>
- // [[%f1, %f2, %f3], [%f4, %f5, %f6]]
+ // [%f1, %f2]
+
+ // Define a 2-D vector.
%2 = vector.from_elements %f1, %f2, %f3, %f4, %f5, %f6 : vector<2x3xf32>
- // [[[%f1, %f2]], [[%f3, %f4]], [[%f5, %f6]]]
+ // [[%f1, %f2, %f3], [%f4, %f5, %f6]]
+
+ // Define a 3-D vector.
%3 = vector.from_elements %f1, %f2, %f3, %f4, %f5, %f6 : vector<3x1x2xf32>
+ // [[[%f1, %f2]], [[%f3, %f4]], [[%f5, %f6]]]
```
-
- Note, scalable vectors are not supported.
}];
let arguments = (ins Variadic<AnyType>:$elements);
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 04810ed52584f..70a7274182442 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1896,7 +1896,24 @@ func.func @deinterleave_scalable_rank_fail(%vec : vector<2x[4]xf32>) {
// -----
-func.func @invalid_from_elements(%a: f32) {
+func.func @to_elements_wrong_num_results(%a: vector<1x1x2xf32>) {
+ // expected-error @+1 {{operation defines 2 results but was provided 4 to bind}}
+ %0:4 = vector.to_elements %a : vector<1x1x2xf32>
+ return
+}
+
+// -----
+
+func.func @to_elements_wrong_result_type(%a: vector<2xf32>) -> i32 {
+ // expected-error @+3 {{use of value '%0' expects different type than prior uses: 'i32'}}
+ // expected-note @+1 {{prior use here}}
+ %0:2 = vector.to_elements %a : vector<2xf32>
+ return %0#0 : i32
+}
+
+// -----
+
+func.func @from_elements_wrong_num_operands(%a: f32) {
// expected-error @+1 {{'vector.from_elements' number of operands and types do not match: got 1 operands and 2 types}}
vector.from_elements %a : vector<2xf32>
return
@@ -1905,12 +1922,11 @@ func.func @invalid_from_elements(%a: f32) {
// -----
// expected-note @+1 {{prior use here}}
-func.func @invalid_from_elements(%a: f32, %b: i32) {
+func.func @from_elements_wrong_operand_type(%a: f32, %b: i32) {
// expected-error @+1 {{use of value '%b' expects different type than prior uses: 'f32' vs 'i32'}}
vector.from_elements %a, %b : vector<2xf32>
return
}
-
// -----
func.func @invalid_from_elements_scalable(%a: f32, %b: i32) {
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index f3220aed4360c..7cfe4e89d6e2f 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -1175,6 +1175,25 @@ func.func @deinterleave_nd_scalable(%arg:vector<2x3x4x[6]xf32>) -> (vector<2x3x4
return %0, %1 : vector<2x3x4x[3]xf32>, vector<2x3x4x[3]xf32>
}
+// CHECK-LABEL: func @to_elements(
+// CHECK-SAME: %[[A_VEC:.*]]: vector<f32>, %[[B_VEC:.*]]: vector<4xf32>,
+// CHECK-SAME: %[[C_VEC:.*]]: vector<1xf32>, %[[D_VEC:.*]]: vector<2x2xf32>)
+func.func @to_elements(%a_vec : vector<f32>, %b_vec : vector<4xf32>, %c_vec : vector<1xf32>, %d_vec : vector<2x2xf32>)
+ -> (f32, f32, f32, f32, f32, f32, f32, f32, f32, f32) {
+ // CHECK: %[[A_ELEMS:.*]] = vector.to_elements %[[A_VEC]] : vector<f32>
+ %0 = vector.to_elements %a_vec : vector<f32>
+ // CHECK: %[[B_ELEMS:.*]]:4 = vector.to_elements %[[B_VEC]] : vector<4xf32>
+ %1:4 = vector.to_elements %b_vec : vector<4xf32>
+ // CHECK: %[[C_ELEMS:.*]] = vector.to_elements %[[C_VEC]] : vector<1xf32>
+ %2 = vector.to_elements %c_vec : vector<1xf32>
+ // CHECK: %[[D_ELEMS:.*]]:4 = vector.to_elements %[[D_VEC]] : vector<2x2xf32>
+ %3:4 = vector.to_elements %d_vec : vector<2x2xf32>
+ // CHECK: return %[[A_ELEMS]], %[[B_ELEMS]]#0, %[[B_ELEMS]]#1, %[[B_ELEMS]]#2,
+ // CHECK-SAME: %[[B_ELEMS]]#3, %[[C_ELEMS]], %[[D_ELEMS]]#0, %[[D_ELEMS]]#1,
+ // CHECK-SAME: %[[D_ELEMS]]#2, %[[D_ELEMS]]#3
+ return %0, %1#0, %1#1, %1#2, %1#3, %2, %3#0, %3#1, %3#2, %3#3 : f32, f32, f32, f32, f32, f32, f32, f32, f32, f32
+}
+
// CHECK-LABEL: func @from_elements(
// CHECK-SAME: %[[a:.*]]: f32, %[[b:.*]]: f32)
func.func @from_elements(%a: f32, %b: f32) -> (vector<f32>, vector<1xf32>, vector<1x2xf32>, vector<2x2xf32>) {
``````````
</details>
https://github.com/llvm/llvm-project/pull/141457
More information about the Mlir-commits
mailing list