[Mlir-commits] [mlir] a8198bd - [mlir][SPIRV][NFC] Refactor getElementTypeForStoragePointer for load/store lowering. (#183459)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Feb 26 07:04:41 PST 2026


Author: Han-Chung Wang
Date: 2026-02-26T07:04:36-08:00
New Revision: a8198bdd91f5daf751627c8b759680470979c26b

URL: https://github.com/llvm/llvm-project/commit/a8198bdd91f5daf751627c8b759680470979c26b
DIFF: https://github.com/llvm/llvm-project/commit/a8198bdd91f5daf751627c8b759680470979c26b.diff

LOG: [mlir][SPIRV][NFC] Refactor getElementTypeForStoragePointer for load/store lowering. (#183459)

The revision refactors the repeated logic for extracting the storage
element type from a SPIR-V pointer type into a shared helper function.
This pattern was duplicated in AtomicRMWOpPattern, LoadOpPattern, and
StoreOpPattern.

The helper handles both Kernel capability (direct array/scalar) and
Vulkan (struct-wrapped array/runtime array) cases.

It is a follow-up for
https://github.com/llvm/llvm-project/commit/f80205becd384e73e3dfc6ece97297ab1e8a35f9

Signed-off-by: hanhanW <hanhan0912 at gmail.com>

Added: 
    

Modified: 
    mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
index 81116cf6f13ad..565dee6f27589 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
@@ -163,6 +163,27 @@ static std::optional<spirv::Scope> getAtomicOpScope(MemRefType type) {
   return {};
 }
 
+/// Extracts the element type from a SPIR-V pointer type pointing to storage.
+///
+/// For Kernel capability, the pointer points directly to the element type
+/// (possibly wrapped in an array). For Vulkan, the pointer points to a struct
+/// containing an array or runtime array, and we need to unwrap to get the
+/// element type.
+static Type
+getElementTypeForStoragePointer(Type pointeeType,
+                                const SPIRVTypeConverter &typeConverter) {
+  if (typeConverter.allows(spirv::Capability::Kernel)) {
+    if (auto arrayType = dyn_cast<spirv::ArrayType>(pointeeType))
+      return arrayType.getElementType();
+    return pointeeType;
+  }
+  // For Vulkan we need to extract element from wrapping struct and array.
+  Type structElemType = cast<spirv::StructType>(pointeeType).getElementType(0);
+  if (auto arrayType = dyn_cast<spirv::ArrayType>(structElemType))
+    return arrayType.getElementType();
+  return cast<spirv::RuntimeArrayType>(structElemType).getElementType();
+}
+
 /// Casts the given `srcInt` into a boolean value.
 static Value castIntNToBool(Location loc, Value srcInt, OpBuilder &builder) {
   if (srcInt.getType().isInteger(1))
@@ -437,22 +458,8 @@ AtomicRMWOpPattern::matchAndRewrite(memref::AtomicRMWOp atomicOp,
                                        "failed to convert memref type");
 
   Type pointeeType = pointerType.getPointeeType();
-  IntegerType dstType;
-  if (typeConverter.allows(spirv::Capability::Kernel)) {
-    if (auto arrayType = dyn_cast<spirv::ArrayType>(pointeeType))
-      dstType = dyn_cast<IntegerType>(arrayType.getElementType());
-    else
-      dstType = dyn_cast<IntegerType>(pointeeType);
-  } else {
-    Type structElemType =
-        cast<spirv::StructType>(pointeeType).getElementType(0);
-    if (auto arrayType = dyn_cast<spirv::ArrayType>(structElemType))
-      dstType = dyn_cast<IntegerType>(arrayType.getElementType());
-    else
-      dstType = dyn_cast<IntegerType>(
-          cast<spirv::RuntimeArrayType>(structElemType).getElementType());
-  }
-
+  auto dstType = dyn_cast<IntegerType>(
+      getElementTypeForStoragePointer(pointeeType, typeConverter));
   if (!dstType)
     return rewriter.notifyMatchFailure(
         atomicOp, "failed to determine destination element type");
@@ -691,21 +698,7 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
     return rewriter.notifyMatchFailure(loadOp, "failed to convert memref type");
 
   Type pointeeType = pointerType.getPointeeType();
-  Type dstType;
-  if (typeConverter.allows(spirv::Capability::Kernel)) {
-    if (auto arrayType = dyn_cast<spirv::ArrayType>(pointeeType))
-      dstType = arrayType.getElementType();
-    else
-      dstType = pointeeType;
-  } else {
-    // For Vulkan we need to extract element from wrapping struct and array.
-    Type structElemType =
-        cast<spirv::StructType>(pointeeType).getElementType(0);
-    if (auto arrayType = dyn_cast<spirv::ArrayType>(structElemType))
-      dstType = arrayType.getElementType();
-    else
-      dstType = cast<spirv::RuntimeArrayType>(structElemType).getElementType();
-  }
+  Type dstType = getElementTypeForStoragePointer(pointeeType, typeConverter);
   int dstBits = dstType.getIntOrFloatBitWidth();
   assert(dstBits % srcBits == 0);
 
@@ -963,23 +956,8 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
                                        "failed to convert memref type");
 
   Type pointeeType = pointerType.getPointeeType();
-  IntegerType dstType;
-  if (typeConverter.allows(spirv::Capability::Kernel)) {
-    if (auto arrayType = dyn_cast<spirv::ArrayType>(pointeeType))
-      dstType = dyn_cast<IntegerType>(arrayType.getElementType());
-    else
-      dstType = dyn_cast<IntegerType>(pointeeType);
-  } else {
-    // For Vulkan we need to extract element from wrapping struct and array.
-    Type structElemType =
-        cast<spirv::StructType>(pointeeType).getElementType(0);
-    if (auto arrayType = dyn_cast<spirv::ArrayType>(structElemType))
-      dstType = dyn_cast<IntegerType>(arrayType.getElementType());
-    else
-      dstType = dyn_cast<IntegerType>(
-          cast<spirv::RuntimeArrayType>(structElemType).getElementType());
-  }
-
+  auto dstType = dyn_cast<IntegerType>(
+      getElementTypeForStoragePointer(pointeeType, typeConverter));
   if (!dstType)
     return rewriter.notifyMatchFailure(
         storeOp, "failed to determine destination element type");


        


More information about the Mlir-commits mailing list