[Mlir-commits] [mlir] [mlir][vector][spirv] Lower `vector.to_elements` to SPIR-V (PR #146618)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Jul 1 18:53:48 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-spirv

Author: Eric Feng (efric)

<details>
<summary>Changes</summary>

Implement `vector.to_elements` lowering to SPIR-V.

Addresses [145929](https://github.com/llvm/llvm-project/issues/145929).

---
Full diff: https://github.com/llvm/llvm-project/pull/146618.diff


2 Files Affected:

- (modified) mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp (+46-1) 
- (modified) mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir (+35) 


``````````diff
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index de2af69eba9ec..475fd76c667e6 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -1022,6 +1022,51 @@ struct VectorStepOpConvert final : OpConversionPattern<vector::StepOp> {
   }
 };
 
+struct VectorToElementOpConvert final
+    : public OpConversionPattern<vector::ToElementsOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::ToElementsOp toElementsOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+
+    Type srcType =
+        getTypeConverter()->convertType(toElementsOp.getSource().getType());
+    if (!srcType)
+      return failure();
+
+    SmallVector<Value> results(toElementsOp->getNumResults());
+    Location loc = toElementsOp.getLoc();
+
+    // Input vectors of size 1 are converted to scalars by the type converter.
+    // We cannot use `spirv::CompositeExtractOp` directly in this case.
+    // For a scalar source, the result is just the scalar itself.
+    if (isa<spirv::ScalarType>(adaptor.getSource().getType())) {
+      results[0] = adaptor.getSource();
+      rewriter.replaceOp(toElementsOp, results);
+      return success();
+    }
+
+    for (auto [idx, element] : llvm::enumerate(toElementsOp.getElements())) {
+      // Create an CompositeExtract operation only for results that are not
+      // dead.
+      if (element.use_empty())
+        continue;
+
+      auto elementType = getTypeConverter()->convertType(element.getType());
+      if (!elementType)
+        return failure();
+      Value result = rewriter.create<spirv::CompositeExtractOp>(
+          loc, elementType, adaptor.getSource(),
+          rewriter.getI32ArrayAttr({static_cast<int32_t>(idx)}));
+      results[idx] = result;
+    }
+
+    rewriter.replaceOp(toElementsOp, results);
+    return success();
+  }
+};
+
 } // namespace
 #define CL_INT_MAX_MIN_OPS                                                     \
   spirv::CLUMaxOp, spirv::CLUMinOp, spirv::CLSMaxOp, spirv::CLSMinOp
@@ -1038,7 +1083,7 @@ void mlir::populateVectorToSPIRVPatterns(
       VectorBitcastConvert, VectorBroadcastConvert,
       VectorExtractElementOpConvert, VectorExtractOpConvert,
       VectorExtractStridedSliceOpConvert, VectorFmaOpConvert<spirv::GLFmaOp>,
-      VectorFmaOpConvert<spirv::CLFmaOp>, VectorFromElementsOpConvert,
+      VectorFmaOpConvert<spirv::CLFmaOp>, VectorFromElementsOpConvert, VectorToElementOpConvert,
       VectorInsertElementOpConvert, VectorInsertOpConvert,
       VectorReductionPattern<GL_INT_MAX_MIN_OPS>,
       VectorReductionPattern<CL_INT_MAX_MIN_OPS>,
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index 4701ac5d96009..99ab0e1dc4eef 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -246,6 +246,41 @@ func.func @extract_dynamic_cst(%arg0 : vector<4xf32>) -> f32 {
 
 // -----
 
+// CHECK-LABEL: func.func @to_elements_one_element 
+// CHECK-SAME:     %[[A:.*]]: vector<1xf32>)
+//      CHECK:   %[[ELEM0:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<1xf32> to f32
+//      CHECK:   return %[[ELEM0]] : f32
+func.func @to_elements_one_element(%a: vector<1xf32>) -> (f32) {
+  %0:1 = vector.to_elements %a : vector<1xf32>
+  return %0#0 : f32
+}
+
+// CHECK-LABEL: func.func @to_elements_no_dead_elements
+// CHECK-SAME:     %[[A:.*]]: vector<4xf32>)
+//      CHECK:   %[[ELEM0:.*]] = spirv.CompositeExtract %[[A]][0 : i32] : vector<4xf32>
+//      CHECK:   %[[ELEM1:.*]] = spirv.CompositeExtract %[[A]][1 : i32] : vector<4xf32>
+//      CHECK:   %[[ELEM2:.*]] = spirv.CompositeExtract %[[A]][2 : i32] : vector<4xf32>
+//      CHECK:   %[[ELEM3:.*]] = spirv.CompositeExtract %[[A]][3 : i32] : vector<4xf32>
+//      CHECK:   return %[[ELEM0]], %[[ELEM1]], %[[ELEM2]], %[[ELEM3]] : f32, f32, f32, f32
+func.func @to_elements_no_dead_elements(%a: vector<4xf32>) -> (f32, f32, f32, f32) {
+  %0:4 = vector.to_elements %a : vector<4xf32>
+  return %0#0, %0#1, %0#2, %0#3 : f32, f32, f32, f32
+}
+
+// CHECK-LABEL: func.func @to_elements_dead_elements
+// CHECK-SAME:     %[[A:.*]]: vector<4xf32>)
+//  CHECK-NOT:   spirv.CompositeExtract %[[A]][0 : i32]
+//      CHECK:   %[[ELEM1:.*]] = spirv.CompositeExtract %[[A]][1 : i32] : vector<4xf32>
+//  CHECK-NOT:   spirv.CompositeExtract %[[A]][2 : i32]
+//      CHECK:   %[[ELEM3:.*]] = spirv.CompositeExtract %[[A]][3 : i32] : vector<4xf32>
+//      CHECK:   return %[[ELEM1]], %[[ELEM3]] : f32, f32
+func.func @to_elements_dead_elements(%a: vector<4xf32>) -> (f32, f32) {
+  %0:4 = vector.to_elements %a : vector<4xf32>
+  return %0#1, %0#3 : f32, f32
+}
+
+// -----
+
 // CHECK-LABEL: @from_elements_0d
 //  CHECK-SAME: %[[ARG0:.+]]: f32
 //       CHECK:   %[[RETVAL:.+]] = builtin.unrealized_conversion_cast %[[ARG0]]

``````````

</details>


https://github.com/llvm/llvm-project/pull/146618


More information about the Mlir-commits mailing list