[Mlir-commits] [mlir] [mlir][Vector] Add `vector.to_elements` op (PR #141457)

Diego Caballero llvmlistbot at llvm.org
Wed Jun 11 11:48:18 PDT 2025


https://github.com/dcaballe updated https://github.com/llvm/llvm-project/pull/141457

>From 64b06ba7903752495446627bd79bbed68792e2c3 Mon Sep 17 00:00:00 2001
From: Diego Caballero <dcaballero at nvidia.com>
Date: Mon, 26 May 2025 06:29:13 +0000
Subject: [PATCH 1/2] [mlir][Vector] Add `vector.to_elements` op

This PR introduces the `vector.to_elements` op, which decomposes a
vector into its scalar elements. This operation is symmetrical to
the existing `vector.from_elements`.

Examples:

```
// Decompose a 0-D vector.
    %0 = vector.to_elements %v0 : vector<f32>
    // %0 = %v0[0]

    // Decompose a 1-D vector.
    %0:2 = vector.to_elements %v1 : vector<2xf32>
    // %0#0 = %v1[0]
    // %0#1 = %v1[1]

    // Decompose a 2-D.
    %0:6 = vector.to_elements %v2 : vector<2x3xf32>
    // %0#0 = %v2[0, 0]
    // %0#1 = %v2[0, 1]
    // %0#2 = %v2[0, 2]
    // %0#3 = %v2[1, 0]
    // %0#4 = %v2[1, 1]
    // %0#5 = %v2[1, 2]
```

This op is aimed at reducing code size when modeling "structured" vector
extractions and simplifying canonicalizations of large sequences of
`vector.extract` and `vector.insert` ops into `vector.shuffle` and
other sophisticated ops that can re-arrange vector elements.

More related PRs to come!
---
 .../mlir/Dialect/Vector/IR/VectorOps.td       | 79 ++++++++++++++++---
 mlir/test/Dialect/Vector/invalid.mlir         | 22 +++++-
 mlir/test/Dialect/Vector/ops.mlir             | 19 +++++
 3 files changed, 105 insertions(+), 15 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 8353314ed958b..e1fabb9389b5c 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -790,6 +790,57 @@ def Vector_FMAOp :
   }];
 }
 
+def Vector_ToElementsOp : Vector_Op<"to_elements", [
+    Pure,
+    TypesMatchWith<"operand element type matches result types",
+                   "input", "elements", "SmallVector<Type>("
+                   "::llvm::cast<VectorType>($_self).getNumElements(), "
+                   "::llvm::cast<VectorType>($_self).getElementType())">]> {
+  let summary = "operation that decomposes a vector into all its scalar elements";
+  let description = [{
+    This operation decomposes all the scalar elements from a vector. The
+    decomposed scalar elements are returned in row-major order. The number of
+    scalar results must match the number of elements in the input vector type.
+    All the result elements have the same result type, which must match the
+    element type of the input vector. Scalable vectors are not supported.
+
+    Examples:
+
+    ```mlir
+    // Decompose a 0-D vector.
+    %0 = vector.to_elements %v0 : vector<f32>
+    // %0 = %v0[0]
+
+    // Decompose a 1-D vector.
+    %0:2 = vector.to_elements %v1 : vector<2xf32>
+    // %0#0 = %v1[0]
+    // %0#1 = %v1[1]
+
+    // Decompose a 2-D.
+    %0:6 = vector.to_elements %v2 : vector<2x3xf32>
+    // %0#0 = %v2[0, 0]
+    // %0#1 = %v2[0, 1]
+    // %0#2 = %v2[0, 2]
+    // %0#3 = %v2[1, 0]
+    // %0#4 = %v2[1, 1]
+    // %0#5 = %v2[1, 2]
+
+    // Decompose a 3-D vector.
+    %0:6 = vector.to_elements %v3 : vector<3x1x2xf32>
+    // %0#0 = %v3[0, 0, 0]
+    // %0#1 = %v3[0, 0, 1]
+    // %0#2 = %v3[1, 0, 0]
+    // %0#3 = %v3[1, 0, 1]
+    // %0#4 = %v3[2, 0, 0]
+    // %0#5 = %v3[2, 0, 1]
+    ```
+  }];
+
+  let arguments = (ins AnyVectorOfAnyRank:$input);
+  let results = (outs Variadic<AnyType>:$elements);
+  let assemblyFormat = "$input attr-dict `:` type($input)";
+}
+
 def Vector_FromElementsOp : Vector_Op<"from_elements", [
     Pure,
     TypesMatchWith<"operand types match result element type",
@@ -799,26 +850,30 @@ def Vector_FromElementsOp : Vector_Op<"from_elements", [
   let summary = "operation that defines a vector from scalar elements";
   let description = [{
     This operation defines a vector from one or multiple scalar elements. The
-    number of elements must match the number of elements in the result type.
-    All elements must have the same type, which must match the element type of
-    the result vector type.
-
-    `elements` are a flattened version of the result vector in row-major order.
+    scalar elements are arranged in row-major within the vector. The number of
+    elements must match the number of elements in the result type. All elements
+    must have the same type, which must match the element type of the result
+    vector type. Scalable vectors are not supported.
 
-    Example:
+    Examples:
 
     ```mlir
-    // %f1
+    // Define a 0-D vector.
     %0 = vector.from_elements %f1 : vector<f32>
-    // [%f1, %f2]
+    // [%f1]
+
+    // Define a 1-D vector.
     %1 = vector.from_elements %f1, %f2 : vector<2xf32>
-    // [[%f1, %f2, %f3], [%f4, %f5, %f6]]
+    // [%f1, %f2]
+
+    // Define a 2-D vector.
     %2 = vector.from_elements %f1, %f2, %f3, %f4, %f5, %f6 : vector<2x3xf32>
-    // [[[%f1, %f2]], [[%f3, %f4]], [[%f5, %f6]]]
+    // [[%f1, %f2, %f3], [%f4, %f5, %f6]]
+
+    // Define a 3-D vector.
     %3 = vector.from_elements %f1, %f2, %f3, %f4, %f5, %f6 : vector<3x1x2xf32>
+    // [[[%f1, %f2]], [[%f3, %f4]], [[%f5, %f6]]]
     ```
-
-    Note, scalable vectors are not supported.
   }];
 
   let arguments = (ins Variadic<AnyType>:$elements);
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 04810ed52584f..70a7274182442 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1896,7 +1896,24 @@ func.func @deinterleave_scalable_rank_fail(%vec : vector<2x[4]xf32>) {
 
 // -----
 
-func.func @invalid_from_elements(%a: f32) {
+func.func @to_elements_wrong_num_results(%a: vector<1x1x2xf32>) {
+  // expected-error @+1 {{operation defines 2 results but was provided 4 to bind}}
+  %0:4 = vector.to_elements %a : vector<1x1x2xf32>
+  return
+}
+
+// -----
+
+func.func @to_elements_wrong_result_type(%a: vector<2xf32>) -> i32 {
+  // expected-error @+3 {{use of value '%0' expects different type than prior uses: 'i32'}}
+  // expected-note @+1 {{prior use here}}
+  %0:2 = vector.to_elements %a : vector<2xf32>
+  return %0#0 : i32
+}
+
+// -----
+
+func.func @from_elements_wrong_num_operands(%a: f32) {
   // expected-error @+1 {{'vector.from_elements' number of operands and types do not match: got 1 operands and 2 types}}
   vector.from_elements %a : vector<2xf32>
   return
@@ -1905,12 +1922,11 @@ func.func @invalid_from_elements(%a: f32) {
 // -----
 
 // expected-note @+1 {{prior use here}}
-func.func @invalid_from_elements(%a: f32, %b: i32) {
+func.func @from_elements_wrong_operand_type(%a: f32, %b: i32) {
   // expected-error @+1 {{use of value '%b' expects different type than prior uses: 'f32' vs 'i32'}}
   vector.from_elements %a, %b : vector<2xf32>
   return
 }
-
 // -----
 
 func.func @invalid_from_elements_scalable(%a: f32, %b: i32) {
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index f3220aed4360c..7cfe4e89d6e2f 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -1175,6 +1175,25 @@ func.func @deinterleave_nd_scalable(%arg:vector<2x3x4x[6]xf32>) -> (vector<2x3x4
   return %0, %1 : vector<2x3x4x[3]xf32>, vector<2x3x4x[3]xf32>
 }
 
+// CHECK-LABEL: func @to_elements(
+//  CHECK-SAME:     %[[A_VEC:.*]]: vector<f32>, %[[B_VEC:.*]]: vector<4xf32>,
+//  CHECK-SAME:     %[[C_VEC:.*]]: vector<1xf32>, %[[D_VEC:.*]]: vector<2x2xf32>)
+func.func @to_elements(%a_vec : vector<f32>, %b_vec : vector<4xf32>, %c_vec : vector<1xf32>,  %d_vec : vector<2x2xf32>)
+                   -> (f32, f32, f32, f32, f32, f32, f32, f32, f32, f32) {
+  // CHECK: %[[A_ELEMS:.*]] = vector.to_elements %[[A_VEC]] : vector<f32>
+  %0 = vector.to_elements %a_vec : vector<f32>
+  // CHECK: %[[B_ELEMS:.*]]:4 = vector.to_elements %[[B_VEC]] : vector<4xf32>
+  %1:4 = vector.to_elements %b_vec : vector<4xf32>
+  // CHECK: %[[C_ELEMS:.*]] = vector.to_elements %[[C_VEC]] : vector<1xf32>
+  %2 = vector.to_elements %c_vec : vector<1xf32>
+  // CHECK: %[[D_ELEMS:.*]]:4 = vector.to_elements %[[D_VEC]] : vector<2x2xf32>
+  %3:4 = vector.to_elements %d_vec : vector<2x2xf32>
+  //      CHECK: return %[[A_ELEMS]], %[[B_ELEMS]]#0, %[[B_ELEMS]]#1, %[[B_ELEMS]]#2,
+  // CHECK-SAME:   %[[B_ELEMS]]#3, %[[C_ELEMS]], %[[D_ELEMS]]#0, %[[D_ELEMS]]#1,
+  // CHECK-SAME:   %[[D_ELEMS]]#2, %[[D_ELEMS]]#3
+  return %0, %1#0, %1#1, %1#2, %1#3, %2, %3#0, %3#1, %3#2, %3#3 : f32, f32, f32, f32, f32, f32, f32, f32, f32, f32
+}
+
 // CHECK-LABEL: func @from_elements(
 //  CHECK-SAME:     %[[a:.*]]: f32, %[[b:.*]]: f32)
 func.func @from_elements(%a: f32, %b: f32) -> (vector<f32>, vector<1xf32>, vector<1x2xf32>, vector<2x2xf32>) {

>From ea17e128f05c9f072371ec850d02891ab0d9d5e5 Mon Sep 17 00:00:00 2001
From: Diego Caballero <dcaballero at nvidia.com>
Date: Wed, 11 Jun 2025 18:43:56 +0000
Subject: [PATCH 2/2] Feedback

---
 mlir/include/mlir/Dialect/Vector/IR/VectorOps.td | 12 ++++++------
 mlir/test/Dialect/Vector/invalid.mlir            |  2 +-
 2 files changed, 7 insertions(+), 7 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index e1fabb9389b5c..4e1ee145f44ef 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -793,7 +793,7 @@ def Vector_FMAOp :
 def Vector_ToElementsOp : Vector_Op<"to_elements", [
     Pure,
     TypesMatchWith<"operand element type matches result types",
-                   "input", "elements", "SmallVector<Type>("
+                   "source", "elements", "SmallVector<Type>("
                    "::llvm::cast<VectorType>($_self).getNumElements(), "
                    "::llvm::cast<VectorType>($_self).getElementType())">]> {
   let summary = "operation that decomposes a vector into all its scalar elements";
@@ -836,15 +836,15 @@ def Vector_ToElementsOp : Vector_Op<"to_elements", [
     ```
   }];
 
-  let arguments = (ins AnyVectorOfAnyRank:$input);
+  let arguments = (ins AnyVectorOfAnyRank:$source);
   let results = (outs Variadic<AnyType>:$elements);
-  let assemblyFormat = "$input attr-dict `:` type($input)";
+  let assemblyFormat = "$source attr-dict `:` type($source)";
 }
 
 def Vector_FromElementsOp : Vector_Op<"from_elements", [
     Pure,
     TypesMatchWith<"operand types match result element type",
-                   "result", "elements", "SmallVector<Type>("
+                   "dest", "elements", "SmallVector<Type>("
                    "::llvm::cast<VectorType>($_self).getNumElements(), "
                    "::llvm::cast<VectorType>($_self).getElementType())">]> {
   let summary = "operation that defines a vector from scalar elements";
@@ -877,8 +877,8 @@ def Vector_FromElementsOp : Vector_Op<"from_elements", [
   }];
 
   let arguments = (ins Variadic<AnyType>:$elements);
-  let results = (outs AnyFixedVectorOfAnyRank:$result);
-  let assemblyFormat = "$elements attr-dict `:` type($result)";
+  let results = (outs AnyFixedVectorOfAnyRank:$dest);
+  let assemblyFormat = "$elements attr-dict `:` type($dest)";
   let hasCanonicalizer = 1;
 }
 
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 70a7274182442..ec7cee7b2c641 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1930,7 +1930,7 @@ func.func @from_elements_wrong_operand_type(%a: f32, %b: i32) {
 // -----
 
 func.func @invalid_from_elements_scalable(%a: f32, %b: i32) {
-  // expected-error @+1 {{'result' must be fixed-length vector of any type values, but got 'vector<[2]xf32>'}}
+  // expected-error @+1 {{'dest' must be fixed-length vector of any type values, but got 'vector<[2]xf32>'}}
   vector.from_elements %a, %b : vector<[2]xf32>
   return
 }



More information about the Mlir-commits mailing list