[Mlir-commits] [mlir] [mlir][vector][spirv] Lower `vector.to_elements` to SPIR-V (PR #146618)

Eric Feng llvmlistbot at llvm.org
Wed Jul 2 11:40:27 PDT 2025


https://github.com/efric updated https://github.com/llvm/llvm-project/pull/146618

>From 2141816222b374c296f432e441334bd6ead2d5b1 Mon Sep 17 00:00:00 2001
From: Eric Feng <Eric.Feng at amd.com>
Date: Mon, 30 Jun 2025 23:28:50 -0700
Subject: [PATCH 1/8] add initial lowering pieces

---
 .../VectorToSPIRV/VectorToSPIRV.cpp           | 38 +++++++++++++++++++
 .../VectorToSPIRV/vector-to-spirv.mlir        | 30 +++++++++++++++
 2 files changed, 68 insertions(+)

diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index de2af69eba9ec..8e38b5280c527 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -1022,6 +1022,44 @@ struct VectorStepOpConvert final : OpConversionPattern<vector::StepOp> {
   }
 };
 
+struct VectorToElementOpConvert final
+    : public OpConversionPattern<vector::ToElementsOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::ToElementsOp toElementsOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+
+    Type srcType =
+        getTypeConverter()->convertType(toElementsOp.getSource().getType());
+    if (!srcType)
+      return failure();
+
+    // If the input vector was size 1, then it would have been converted to a
+    // scalar. Replace with it directly
+    if (isa<spirv::ScalarType>(adaptor.getSource().getType())) {
+      rewriter.replaceOp(toElementsOp, adaptor.getSource());
+      return success();
+    }
+
+    Location loc = toElementsOp.getLoc();
+    SmallVector<Value> results(toElementsOp->getNumResults());
+
+    for (auto [idx, element] : llvm::enumerate(toElementsOp.getElements())) {
+      if (element.use_empty())
+        continue;
+
+      Value result = rewriter.create<spirv::CompositeExtractOp>(
+          loc, toElementsOp->getResult(idx).getType(), adaptor.getSource(),
+          rewriter.getI32ArrayAttr({static_cast<int>(idx)}));
+      results[idx] = result;
+    }
+
+    rewriter.replaceOp(toElementsOp, results);
+    return success();
+  }
+};
+
 } // namespace
 #define CL_INT_MAX_MIN_OPS                                                     \
   spirv::CLUMaxOp, spirv::CLUMinOp, spirv::CLSMaxOp, spirv::CLSMinOp
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index 4701ac5d96009..c8d0c25f252cf 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -244,6 +244,36 @@ func.func @extract_dynamic_cst(%arg0 : vector<4xf32>) -> f32 {
   return %0: f32
 }
 
+// -----
+// we need a test for with dead, no dead, 1 element
+
+// CHECK-LABEL: @from_elements_0d
+//  CHECK-SAME: %[[ARG0:.+]]: f32
+//       CHECK:   %[[RETVAL:.+]] = builtin.unrealized_conversion_cast %[[ARG0]]
+//       CHECK:   return %[[RETVAL]]
+func.func @from_elements_0d(%arg0 : f32) -> vector<f32> {
+  %0 = vector.from_elements %arg0 : vector<f32>
+  return %0: vector<f32>
+}
+
+// CHECK-LABEL: @from_elements_1x
+//  CHECK-SAME: %[[ARG0:.+]]: f32
+//       CHECK:   %[[RETVAL:.+]] = builtin.unrealized_conversion_cast %[[ARG0]]
+//       CHECK:   return %[[RETVAL]]
+func.func @from_elements_1x(%arg0 : f32) -> vector<1xf32> {
+  %0 = vector.from_elements %arg0 : vector<1xf32>
+  return %0: vector<1xf32>
+}
+
+// CHECK-LABEL: @from_elements_3x
+//  CHECK-SAME: %[[ARG0:.+]]: f32, %[[ARG1:.+]]: f32, %[[ARG2:.+]]: f32
+//       CHECK:   %[[RETVAL:.+]] = spirv.CompositeConstruct %[[ARG0]], %[[ARG1]], %[[ARG2]] : (f32, f32, f32) -> vector<3xf32>
+//       CHECK:   return %[[RETVAL]]
+func.func @from_elements_3x(%arg0 : f32, %arg1 : f32, %arg2 : f32) -> vector<3xf32> {
+  %0 = vector.from_elements %arg0, %arg1, %arg2 : vector<3xf32>
+  return %0: vector<3xf32>
+}
+
 // -----
 
 // CHECK-LABEL: @from_elements_0d

>From 9fbc662254359835d5dccbbb264fcc03feb34088 Mon Sep 17 00:00:00 2001
From: Eric Feng <Eric.Feng at amd.com>
Date: Tue, 1 Jul 2025 14:12:19 -0700
Subject: [PATCH 2/8] test

---
 mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp | 8 +++++---
 1 file changed, 5 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 8e38b5280c527..b51553ad9d33f 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -13,6 +13,7 @@
 #include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h"
 
 #include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
@@ -1049,9 +1050,10 @@ struct VectorToElementOpConvert final
       if (element.use_empty())
         continue;
 
+      auto spirvType = getTypeConverter()->convertType(element.getType());
       Value result = rewriter.create<spirv::CompositeExtractOp>(
-          loc, toElementsOp->getResult(idx).getType(), adaptor.getSource(),
-          rewriter.getI32ArrayAttr({static_cast<int>(idx)}));
+          loc, spirvType, adaptor.getSource(),
+          rewriter.getI32ArrayAttr({static_cast<int32_t>(idx)}));
       results[idx] = result;
     }
 
@@ -1076,7 +1078,7 @@ void mlir::populateVectorToSPIRVPatterns(
       VectorBitcastConvert, VectorBroadcastConvert,
       VectorExtractElementOpConvert, VectorExtractOpConvert,
       VectorExtractStridedSliceOpConvert, VectorFmaOpConvert<spirv::GLFmaOp>,
-      VectorFmaOpConvert<spirv::CLFmaOp>, VectorFromElementsOpConvert,
+      VectorFmaOpConvert<spirv::CLFmaOp>, VectorFromElementsOpConvert, VectorToElementOpConvert,
       VectorInsertElementOpConvert, VectorInsertOpConvert,
       VectorReductionPattern<GL_INT_MAX_MIN_OPS>,
       VectorReductionPattern<CL_INT_MAX_MIN_OPS>,

>From 751cdd47c9fc15f9315de484e78ffda1a34d4624 Mon Sep 17 00:00:00 2001
From: Eric Feng <Eric.Feng at amd.com>
Date: Tue, 1 Jul 2025 15:54:03 -0700
Subject: [PATCH 3/8] nits

---
 .../VectorToSPIRV/VectorToSPIRV.cpp           | 21 +++++++++++--------
 1 file changed, 12 insertions(+), 9 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index b51553ad9d33f..9c7a413db8d35 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -13,7 +13,6 @@
 #include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h"
 
 #include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
@@ -1036,23 +1035,27 @@ struct VectorToElementOpConvert final
     if (!srcType)
       return failure();
 
-    // If the input vector was size 1, then it would have been converted to a
-    // scalar. Replace with it directly
+    SmallVector<Value> results(toElementsOp->getNumResults());
+    Location loc = toElementsOp.getLoc();
+
+    // Input vectors of size 1 are converted to scalars by the type converter.
+    // We cannot use `spirv::CompositeExtractOp` directly in this case.
+    // For a scalar source, the result is just the scalar itself.
     if (isa<spirv::ScalarType>(adaptor.getSource().getType())) {
-      rewriter.replaceOp(toElementsOp, adaptor.getSource());
+      results[0] = adaptor.getSource();
+      rewriter.replaceOp(toElementsOp, results);
       return success();
     }
 
-    Location loc = toElementsOp.getLoc();
-    SmallVector<Value> results(toElementsOp->getNumResults());
-
     for (auto [idx, element] : llvm::enumerate(toElementsOp.getElements())) {
       if (element.use_empty())
         continue;
 
-      auto spirvType = getTypeConverter()->convertType(element.getType());
+      auto elementType = getTypeConverter()->convertType(element.getType());
+      if (!elementType)
+        return failure();
       Value result = rewriter.create<spirv::CompositeExtractOp>(
-          loc, spirvType, adaptor.getSource(),
+          loc, elementType, adaptor.getSource(),
           rewriter.getI32ArrayAttr({static_cast<int32_t>(idx)}));
       results[idx] = result;
     }

>From 33b28dedfaa33d2e301c290ed351a7c79d444be3 Mon Sep 17 00:00:00 2001
From: Eric Feng <Eric.Feng at amd.com>
Date: Tue, 1 Jul 2025 18:26:08 -0700
Subject: [PATCH 4/8] nits and add test

---
 .../VectorToSPIRV/VectorToSPIRV.cpp           |  1 +
 .../VectorToSPIRV/vector-to-spirv.mlir        | 57 ++++++++++---------
 2 files changed, 32 insertions(+), 26 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 9c7a413db8d35..b350ef7a42b95 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -1048,6 +1048,7 @@ struct VectorToElementOpConvert final
     }
 
     for (auto [idx, element] : llvm::enumerate(toElementsOp.getElements())) {
+      // Create an CompositeExtract operation only for results that are not dead.
       if (element.use_empty())
         continue;
 
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index c8d0c25f252cf..99ab0e1dc4eef 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -245,33 +245,38 @@ func.func @extract_dynamic_cst(%arg0 : vector<4xf32>) -> f32 {
 }
 
 // -----
-// we need a test for with dead, no dead, 1 element
 
-// CHECK-LABEL: @from_elements_0d
-//  CHECK-SAME: %[[ARG0:.+]]: f32
-//       CHECK:   %[[RETVAL:.+]] = builtin.unrealized_conversion_cast %[[ARG0]]
-//       CHECK:   return %[[RETVAL]]
-func.func @from_elements_0d(%arg0 : f32) -> vector<f32> {
-  %0 = vector.from_elements %arg0 : vector<f32>
-  return %0: vector<f32>
-}
-
-// CHECK-LABEL: @from_elements_1x
-//  CHECK-SAME: %[[ARG0:.+]]: f32
-//       CHECK:   %[[RETVAL:.+]] = builtin.unrealized_conversion_cast %[[ARG0]]
-//       CHECK:   return %[[RETVAL]]
-func.func @from_elements_1x(%arg0 : f32) -> vector<1xf32> {
-  %0 = vector.from_elements %arg0 : vector<1xf32>
-  return %0: vector<1xf32>
-}
-
-// CHECK-LABEL: @from_elements_3x
-//  CHECK-SAME: %[[ARG0:.+]]: f32, %[[ARG1:.+]]: f32, %[[ARG2:.+]]: f32
-//       CHECK:   %[[RETVAL:.+]] = spirv.CompositeConstruct %[[ARG0]], %[[ARG1]], %[[ARG2]] : (f32, f32, f32) -> vector<3xf32>
-//       CHECK:   return %[[RETVAL]]
-func.func @from_elements_3x(%arg0 : f32, %arg1 : f32, %arg2 : f32) -> vector<3xf32> {
-  %0 = vector.from_elements %arg0, %arg1, %arg2 : vector<3xf32>
-  return %0: vector<3xf32>
+// CHECK-LABEL: func.func @to_elements_one_element 
+// CHECK-SAME:     %[[A:.*]]: vector<1xf32>)
+//      CHECK:   %[[ELEM0:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<1xf32> to f32
+//      CHECK:   return %[[ELEM0]] : f32
+func.func @to_elements_one_element(%a: vector<1xf32>) -> (f32) {
+  %0:1 = vector.to_elements %a : vector<1xf32>
+  return %0#0 : f32
+}
+
+// CHECK-LABEL: func.func @to_elements_no_dead_elements
+// CHECK-SAME:     %[[A:.*]]: vector<4xf32>)
+//      CHECK:   %[[ELEM0:.*]] = spirv.CompositeExtract %[[A]][0 : i32] : vector<4xf32>
+//      CHECK:   %[[ELEM1:.*]] = spirv.CompositeExtract %[[A]][1 : i32] : vector<4xf32>
+//      CHECK:   %[[ELEM2:.*]] = spirv.CompositeExtract %[[A]][2 : i32] : vector<4xf32>
+//      CHECK:   %[[ELEM3:.*]] = spirv.CompositeExtract %[[A]][3 : i32] : vector<4xf32>
+//      CHECK:   return %[[ELEM0]], %[[ELEM1]], %[[ELEM2]], %[[ELEM3]] : f32, f32, f32, f32
+func.func @to_elements_no_dead_elements(%a: vector<4xf32>) -> (f32, f32, f32, f32) {
+  %0:4 = vector.to_elements %a : vector<4xf32>
+  return %0#0, %0#1, %0#2, %0#3 : f32, f32, f32, f32
+}
+
+// CHECK-LABEL: func.func @to_elements_dead_elements
+// CHECK-SAME:     %[[A:.*]]: vector<4xf32>)
+//  CHECK-NOT:   spirv.CompositeExtract %[[A]][0 : i32]
+//      CHECK:   %[[ELEM1:.*]] = spirv.CompositeExtract %[[A]][1 : i32] : vector<4xf32>
+//  CHECK-NOT:   spirv.CompositeExtract %[[A]][2 : i32]
+//      CHECK:   %[[ELEM3:.*]] = spirv.CompositeExtract %[[A]][3 : i32] : vector<4xf32>
+//      CHECK:   return %[[ELEM1]], %[[ELEM3]] : f32, f32
+func.func @to_elements_dead_elements(%a: vector<4xf32>) -> (f32, f32) {
+  %0:4 = vector.to_elements %a : vector<4xf32>
+  return %0#1, %0#3 : f32, f32
 }
 
 // -----

>From b02c31ed0b164b7928362b71407f2c5967bd3e4d Mon Sep 17 00:00:00 2001
From: Eric Feng <Eric.Feng at amd.com>
Date: Tue, 1 Jul 2025 18:33:58 -0700
Subject: [PATCH 5/8] format

---
 mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index b350ef7a42b95..475fd76c667e6 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -1048,7 +1048,8 @@ struct VectorToElementOpConvert final
     }
 
     for (auto [idx, element] : llvm::enumerate(toElementsOp.getElements())) {
-      // Create an CompositeExtract operation only for results that are not dead.
+      // Create an CompositeExtract operation only for results that are not
+      // dead.
       if (element.use_empty())
         continue;
 

>From 98ab5f3100ce3ac436bdd217671e85683bced4b5 Mon Sep 17 00:00:00 2001
From: Eric Feng <Eric.Feng at amd.com>
Date: Wed, 2 Jul 2025 10:49:57 -0700
Subject: [PATCH 6/8] nits

---
 .../Conversion/VectorToSPIRV/VectorToSPIRV.cpp   | 16 +++++++---------
 1 file changed, 7 insertions(+), 9 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 475fd76c667e6..c79b368b333e4 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -1023,18 +1023,13 @@ struct VectorStepOpConvert final : OpConversionPattern<vector::StepOp> {
 };
 
 struct VectorToElementOpConvert final
-    : public OpConversionPattern<vector::ToElementsOp> {
+    : OpConversionPattern<vector::ToElementsOp> {
   using OpConversionPattern::OpConversionPattern;
 
   LogicalResult
   matchAndRewrite(vector::ToElementsOp toElementsOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
 
-    Type srcType =
-        getTypeConverter()->convertType(toElementsOp.getSource().getType());
-    if (!srcType)
-      return failure();
-
     SmallVector<Value> results(toElementsOp->getNumResults());
     Location loc = toElementsOp.getLoc();
 
@@ -1047,15 +1042,18 @@ struct VectorToElementOpConvert final
       return success();
     }
 
+    Type srcElementType = toElementsOp.getElements().getType().front();
+    Type elementType = getTypeConverter()->convertType(srcElementType);
+    if (!elementType)
+      return rewriter.notifyMatchFailure(
+          toElementsOp, "unsupported element type in source vector");
+
     for (auto [idx, element] : llvm::enumerate(toElementsOp.getElements())) {
       // Create an CompositeExtract operation only for results that are not
       // dead.
       if (element.use_empty())
         continue;
 
-      auto elementType = getTypeConverter()->convertType(element.getType());
-      if (!elementType)
-        return failure();
       Value result = rewriter.create<spirv::CompositeExtractOp>(
           loc, elementType, adaptor.getSource(),
           rewriter.getI32ArrayAttr({static_cast<int32_t>(idx)}));

>From 18c09ce355f509d0f9339dd9e3d53fd8948090be Mon Sep 17 00:00:00 2001
From: Eric Feng <Eric.Feng at amd.com>
Date: Wed, 2 Jul 2025 11:32:43 -0700
Subject: [PATCH 7/8] nit on error message

---
 mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp | 4 +++-
 1 file changed, 3 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index c79b368b333e4..abe0b55692d88 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -1046,7 +1046,9 @@ struct VectorToElementOpConvert final
     Type elementType = getTypeConverter()->convertType(srcElementType);
     if (!elementType)
       return rewriter.notifyMatchFailure(
-          toElementsOp, "unsupported element type in source vector");
+          toElementsOp,
+          llvm::formatv("failed to convert element type '{0}' to SPIR-V",
+                        srcElementType));
 
     for (auto [idx, element] : llvm::enumerate(toElementsOp.getElements())) {
       // Create an CompositeExtract operation only for results that are not

>From ba50736a874ad526d5527f13b95036c13b5e6e57 Mon Sep 17 00:00:00 2001
From: Eric Feng <Eric.Feng at amd.com>
Date: Wed, 2 Jul 2025 11:41:21 -0700
Subject: [PATCH 8/8] format

---
 mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index abe0b55692d88..21d8e1d9f1156 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -1083,9 +1083,9 @@ void mlir::populateVectorToSPIRVPatterns(
       VectorBitcastConvert, VectorBroadcastConvert,
       VectorExtractElementOpConvert, VectorExtractOpConvert,
       VectorExtractStridedSliceOpConvert, VectorFmaOpConvert<spirv::GLFmaOp>,
-      VectorFmaOpConvert<spirv::CLFmaOp>, VectorFromElementsOpConvert, VectorToElementOpConvert,
-      VectorInsertElementOpConvert, VectorInsertOpConvert,
-      VectorReductionPattern<GL_INT_MAX_MIN_OPS>,
+      VectorFmaOpConvert<spirv::CLFmaOp>, VectorFromElementsOpConvert,
+      VectorToElementOpConvert, VectorInsertElementOpConvert,
+      VectorInsertOpConvert, VectorReductionPattern<GL_INT_MAX_MIN_OPS>,
       VectorReductionPattern<CL_INT_MAX_MIN_OPS>,
       VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>,
       VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,



More information about the Mlir-commits mailing list