[Mlir-commits] [mlir] [mlir][SPIRV][NFC] Refactor getElementTypeForStoragePointer for load/store lowering. (PR #183459)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Feb 25 22:48:57 PST 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Han-Chung Wang (hanhanW)

<details>
<summary>Changes</summary>

The revision refactors the repeated logic for extracting the storage element type from a SPIR-V pointer type into a shared helper function. This pattern was duplicated in AtomicRMWOpPattern, LoadOpPattern, and StoreOpPattern.

The helper handles both Kernel capability (direct array/scalar) and Vulkan (struct-wrapped array/runtime array) cases.

It is a follow-up for https://github.com/llvm/llvm-project/commit/f80205becd384e73e3dfc6ece97297ab1e8a35f9

---
Full diff: https://github.com/llvm/llvm-project/pull/183459.diff


1 Files Affected:

- (modified) mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp (+26-48) 


``````````diff
diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
index 81116cf6f13ad..565dee6f27589 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
@@ -163,6 +163,27 @@ static std::optional<spirv::Scope> getAtomicOpScope(MemRefType type) {
   return {};
 }
 
+/// Extracts the element type from a SPIR-V pointer type pointing to storage.
+///
+/// For Kernel capability, the pointer points directly to the element type
+/// (possibly wrapped in an array). For Vulkan, the pointer points to a struct
+/// containing an array or runtime array, and we need to unwrap to get the
+/// element type.
+static Type
+getElementTypeForStoragePointer(Type pointeeType,
+                                const SPIRVTypeConverter &typeConverter) {
+  if (typeConverter.allows(spirv::Capability::Kernel)) {
+    if (auto arrayType = dyn_cast<spirv::ArrayType>(pointeeType))
+      return arrayType.getElementType();
+    return pointeeType;
+  }
+  // For Vulkan we need to extract element from wrapping struct and array.
+  Type structElemType = cast<spirv::StructType>(pointeeType).getElementType(0);
+  if (auto arrayType = dyn_cast<spirv::ArrayType>(structElemType))
+    return arrayType.getElementType();
+  return cast<spirv::RuntimeArrayType>(structElemType).getElementType();
+}
+
 /// Casts the given `srcInt` into a boolean value.
 static Value castIntNToBool(Location loc, Value srcInt, OpBuilder &builder) {
   if (srcInt.getType().isInteger(1))
@@ -437,22 +458,8 @@ AtomicRMWOpPattern::matchAndRewrite(memref::AtomicRMWOp atomicOp,
                                        "failed to convert memref type");
 
   Type pointeeType = pointerType.getPointeeType();
-  IntegerType dstType;
-  if (typeConverter.allows(spirv::Capability::Kernel)) {
-    if (auto arrayType = dyn_cast<spirv::ArrayType>(pointeeType))
-      dstType = dyn_cast<IntegerType>(arrayType.getElementType());
-    else
-      dstType = dyn_cast<IntegerType>(pointeeType);
-  } else {
-    Type structElemType =
-        cast<spirv::StructType>(pointeeType).getElementType(0);
-    if (auto arrayType = dyn_cast<spirv::ArrayType>(structElemType))
-      dstType = dyn_cast<IntegerType>(arrayType.getElementType());
-    else
-      dstType = dyn_cast<IntegerType>(
-          cast<spirv::RuntimeArrayType>(structElemType).getElementType());
-  }
-
+  auto dstType = dyn_cast<IntegerType>(
+      getElementTypeForStoragePointer(pointeeType, typeConverter));
   if (!dstType)
     return rewriter.notifyMatchFailure(
         atomicOp, "failed to determine destination element type");
@@ -691,21 +698,7 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
     return rewriter.notifyMatchFailure(loadOp, "failed to convert memref type");
 
   Type pointeeType = pointerType.getPointeeType();
-  Type dstType;
-  if (typeConverter.allows(spirv::Capability::Kernel)) {
-    if (auto arrayType = dyn_cast<spirv::ArrayType>(pointeeType))
-      dstType = arrayType.getElementType();
-    else
-      dstType = pointeeType;
-  } else {
-    // For Vulkan we need to extract element from wrapping struct and array.
-    Type structElemType =
-        cast<spirv::StructType>(pointeeType).getElementType(0);
-    if (auto arrayType = dyn_cast<spirv::ArrayType>(structElemType))
-      dstType = arrayType.getElementType();
-    else
-      dstType = cast<spirv::RuntimeArrayType>(structElemType).getElementType();
-  }
+  Type dstType = getElementTypeForStoragePointer(pointeeType, typeConverter);
   int dstBits = dstType.getIntOrFloatBitWidth();
   assert(dstBits % srcBits == 0);
 
@@ -963,23 +956,8 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
                                        "failed to convert memref type");
 
   Type pointeeType = pointerType.getPointeeType();
-  IntegerType dstType;
-  if (typeConverter.allows(spirv::Capability::Kernel)) {
-    if (auto arrayType = dyn_cast<spirv::ArrayType>(pointeeType))
-      dstType = dyn_cast<IntegerType>(arrayType.getElementType());
-    else
-      dstType = dyn_cast<IntegerType>(pointeeType);
-  } else {
-    // For Vulkan we need to extract element from wrapping struct and array.
-    Type structElemType =
-        cast<spirv::StructType>(pointeeType).getElementType(0);
-    if (auto arrayType = dyn_cast<spirv::ArrayType>(structElemType))
-      dstType = dyn_cast<IntegerType>(arrayType.getElementType());
-    else
-      dstType = dyn_cast<IntegerType>(
-          cast<spirv::RuntimeArrayType>(structElemType).getElementType());
-  }
-
+  auto dstType = dyn_cast<IntegerType>(
+      getElementTypeForStoragePointer(pointeeType, typeConverter));
   if (!dstType)
     return rewriter.notifyMatchFailure(
         storeOp, "failed to determine destination element type");

``````````

</details>


https://github.com/llvm/llvm-project/pull/183459


More information about the Mlir-commits mailing list