[Mlir-commits] [mlir] 274152c - [mlir][vector][spirv] Lower `vector.to_elements` to SPIR-V (#146618)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Jul 2 12:56:22 PDT 2025


Author: Eric Feng
Date: 2025-07-02T15:56:18-04:00
New Revision: 274152c5fa9f642d5ce6317ca24c0f2f27a53576

URL: https://github.com/llvm/llvm-project/commit/274152c5fa9f642d5ce6317ca24c0f2f27a53576
DIFF: https://github.com/llvm/llvm-project/commit/274152c5fa9f642d5ce6317ca24c0f2f27a53576.diff

LOG: [mlir][vector][spirv] Lower `vector.to_elements` to SPIR-V (#146618)

Implement `vector.to_elements` lowering to SPIR-V.

Fixes: https://github.com/llvm/llvm-project/issues/145929

Added: 
    

Modified: 
    mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
    mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index de2af69eba9ec..21d8e1d9f1156 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -1022,6 +1022,51 @@ struct VectorStepOpConvert final : OpConversionPattern<vector::StepOp> {
   }
 };
 
+struct VectorToElementOpConvert final
+    : OpConversionPattern<vector::ToElementsOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::ToElementsOp toElementsOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+
+    SmallVector<Value> results(toElementsOp->getNumResults());
+    Location loc = toElementsOp.getLoc();
+
+    // Input vectors of size 1 are converted to scalars by the type converter.
+    // We cannot use `spirv::CompositeExtractOp` directly in this case.
+    // For a scalar source, the result is just the scalar itself.
+    if (isa<spirv::ScalarType>(adaptor.getSource().getType())) {
+      results[0] = adaptor.getSource();
+      rewriter.replaceOp(toElementsOp, results);
+      return success();
+    }
+
+    Type srcElementType = toElementsOp.getElements().getType().front();
+    Type elementType = getTypeConverter()->convertType(srcElementType);
+    if (!elementType)
+      return rewriter.notifyMatchFailure(
+          toElementsOp,
+          llvm::formatv("failed to convert element type '{0}' to SPIR-V",
+                        srcElementType));
+
+    for (auto [idx, element] : llvm::enumerate(toElementsOp.getElements())) {
+      // Create an CompositeExtract operation only for results that are not
+      // dead.
+      if (element.use_empty())
+        continue;
+
+      Value result = rewriter.create<spirv::CompositeExtractOp>(
+          loc, elementType, adaptor.getSource(),
+          rewriter.getI32ArrayAttr({static_cast<int32_t>(idx)}));
+      results[idx] = result;
+    }
+
+    rewriter.replaceOp(toElementsOp, results);
+    return success();
+  }
+};
+
 } // namespace
 #define CL_INT_MAX_MIN_OPS                                                     \
   spirv::CLUMaxOp, spirv::CLUMinOp, spirv::CLSMaxOp, spirv::CLSMinOp
@@ -1039,8 +1084,8 @@ void mlir::populateVectorToSPIRVPatterns(
       VectorExtractElementOpConvert, VectorExtractOpConvert,
       VectorExtractStridedSliceOpConvert, VectorFmaOpConvert<spirv::GLFmaOp>,
       VectorFmaOpConvert<spirv::CLFmaOp>, VectorFromElementsOpConvert,
-      VectorInsertElementOpConvert, VectorInsertOpConvert,
-      VectorReductionPattern<GL_INT_MAX_MIN_OPS>,
+      VectorToElementOpConvert, VectorInsertElementOpConvert,
+      VectorInsertOpConvert, VectorReductionPattern<GL_INT_MAX_MIN_OPS>,
       VectorReductionPattern<CL_INT_MAX_MIN_OPS>,
       VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>,
       VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,

diff  --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index 4701ac5d96009..99ab0e1dc4eef 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -246,6 +246,41 @@ func.func @extract_dynamic_cst(%arg0 : vector<4xf32>) -> f32 {
 
 // -----
 
+// CHECK-LABEL: func.func @to_elements_one_element 
+// CHECK-SAME:     %[[A:.*]]: vector<1xf32>)
+//      CHECK:   %[[ELEM0:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<1xf32> to f32
+//      CHECK:   return %[[ELEM0]] : f32
+func.func @to_elements_one_element(%a: vector<1xf32>) -> (f32) {
+  %0:1 = vector.to_elements %a : vector<1xf32>
+  return %0#0 : f32
+}
+
+// CHECK-LABEL: func.func @to_elements_no_dead_elements
+// CHECK-SAME:     %[[A:.*]]: vector<4xf32>)
+//      CHECK:   %[[ELEM0:.*]] = spirv.CompositeExtract %[[A]][0 : i32] : vector<4xf32>
+//      CHECK:   %[[ELEM1:.*]] = spirv.CompositeExtract %[[A]][1 : i32] : vector<4xf32>
+//      CHECK:   %[[ELEM2:.*]] = spirv.CompositeExtract %[[A]][2 : i32] : vector<4xf32>
+//      CHECK:   %[[ELEM3:.*]] = spirv.CompositeExtract %[[A]][3 : i32] : vector<4xf32>
+//      CHECK:   return %[[ELEM0]], %[[ELEM1]], %[[ELEM2]], %[[ELEM3]] : f32, f32, f32, f32
+func.func @to_elements_no_dead_elements(%a: vector<4xf32>) -> (f32, f32, f32, f32) {
+  %0:4 = vector.to_elements %a : vector<4xf32>
+  return %0#0, %0#1, %0#2, %0#3 : f32, f32, f32, f32
+}
+
+// CHECK-LABEL: func.func @to_elements_dead_elements
+// CHECK-SAME:     %[[A:.*]]: vector<4xf32>)
+//  CHECK-NOT:   spirv.CompositeExtract %[[A]][0 : i32]
+//      CHECK:   %[[ELEM1:.*]] = spirv.CompositeExtract %[[A]][1 : i32] : vector<4xf32>
+//  CHECK-NOT:   spirv.CompositeExtract %[[A]][2 : i32]
+//      CHECK:   %[[ELEM3:.*]] = spirv.CompositeExtract %[[A]][3 : i32] : vector<4xf32>
+//      CHECK:   return %[[ELEM1]], %[[ELEM3]] : f32, f32
+func.func @to_elements_dead_elements(%a: vector<4xf32>) -> (f32, f32) {
+  %0:4 = vector.to_elements %a : vector<4xf32>
+  return %0#1, %0#3 : f32, f32
+}
+
+// -----
+
 // CHECK-LABEL: @from_elements_0d
 //  CHECK-SAME: %[[ARG0:.+]]: f32
 //       CHECK:   %[[RETVAL:.+]] = builtin.unrealized_conversion_cast %[[ARG0]]


        


More information about the Mlir-commits mailing list