[Mlir-commits] [mlir] a8198bd - [mlir][SPIRV][NFC] Refactor getElementTypeForStoragePointer for load/store lowering. (#183459)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Feb 26 07:04:41 PST 2026
Author: Han-Chung Wang
Date: 2026-02-26T07:04:36-08:00
New Revision: a8198bdd91f5daf751627c8b759680470979c26b
URL: https://github.com/llvm/llvm-project/commit/a8198bdd91f5daf751627c8b759680470979c26b
DIFF: https://github.com/llvm/llvm-project/commit/a8198bdd91f5daf751627c8b759680470979c26b.diff
LOG: [mlir][SPIRV][NFC] Refactor getElementTypeForStoragePointer for load/store lowering. (#183459)
The revision refactors the repeated logic for extracting the storage
element type from a SPIR-V pointer type into a shared helper function.
This pattern was duplicated in AtomicRMWOpPattern, LoadOpPattern, and
StoreOpPattern.
The helper handles both Kernel capability (direct array/scalar) and
Vulkan (struct-wrapped array/runtime array) cases.
It is a follow-up for
https://github.com/llvm/llvm-project/commit/f80205becd384e73e3dfc6ece97297ab1e8a35f9
Signed-off-by: hanhanW <hanhan0912 at gmail.com>
Added:
Modified:
mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
index 81116cf6f13ad..565dee6f27589 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
@@ -163,6 +163,27 @@ static std::optional<spirv::Scope> getAtomicOpScope(MemRefType type) {
return {};
}
+/// Extracts the element type from a SPIR-V pointer type pointing to storage.
+///
+/// For Kernel capability, the pointer points directly to the element type
+/// (possibly wrapped in an array). For Vulkan, the pointer points to a struct
+/// containing an array or runtime array, and we need to unwrap to get the
+/// element type.
+static Type
+getElementTypeForStoragePointer(Type pointeeType,
+ const SPIRVTypeConverter &typeConverter) {
+ if (typeConverter.allows(spirv::Capability::Kernel)) {
+ if (auto arrayType = dyn_cast<spirv::ArrayType>(pointeeType))
+ return arrayType.getElementType();
+ return pointeeType;
+ }
+ // For Vulkan we need to extract element from wrapping struct and array.
+ Type structElemType = cast<spirv::StructType>(pointeeType).getElementType(0);
+ if (auto arrayType = dyn_cast<spirv::ArrayType>(structElemType))
+ return arrayType.getElementType();
+ return cast<spirv::RuntimeArrayType>(structElemType).getElementType();
+}
+
/// Casts the given `srcInt` into a boolean value.
static Value castIntNToBool(Location loc, Value srcInt, OpBuilder &builder) {
if (srcInt.getType().isInteger(1))
@@ -437,22 +458,8 @@ AtomicRMWOpPattern::matchAndRewrite(memref::AtomicRMWOp atomicOp,
"failed to convert memref type");
Type pointeeType = pointerType.getPointeeType();
- IntegerType dstType;
- if (typeConverter.allows(spirv::Capability::Kernel)) {
- if (auto arrayType = dyn_cast<spirv::ArrayType>(pointeeType))
- dstType = dyn_cast<IntegerType>(arrayType.getElementType());
- else
- dstType = dyn_cast<IntegerType>(pointeeType);
- } else {
- Type structElemType =
- cast<spirv::StructType>(pointeeType).getElementType(0);
- if (auto arrayType = dyn_cast<spirv::ArrayType>(structElemType))
- dstType = dyn_cast<IntegerType>(arrayType.getElementType());
- else
- dstType = dyn_cast<IntegerType>(
- cast<spirv::RuntimeArrayType>(structElemType).getElementType());
- }
-
+ auto dstType = dyn_cast<IntegerType>(
+ getElementTypeForStoragePointer(pointeeType, typeConverter));
if (!dstType)
return rewriter.notifyMatchFailure(
atomicOp, "failed to determine destination element type");
@@ -691,21 +698,7 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
return rewriter.notifyMatchFailure(loadOp, "failed to convert memref type");
Type pointeeType = pointerType.getPointeeType();
- Type dstType;
- if (typeConverter.allows(spirv::Capability::Kernel)) {
- if (auto arrayType = dyn_cast<spirv::ArrayType>(pointeeType))
- dstType = arrayType.getElementType();
- else
- dstType = pointeeType;
- } else {
- // For Vulkan we need to extract element from wrapping struct and array.
- Type structElemType =
- cast<spirv::StructType>(pointeeType).getElementType(0);
- if (auto arrayType = dyn_cast<spirv::ArrayType>(structElemType))
- dstType = arrayType.getElementType();
- else
- dstType = cast<spirv::RuntimeArrayType>(structElemType).getElementType();
- }
+ Type dstType = getElementTypeForStoragePointer(pointeeType, typeConverter);
int dstBits = dstType.getIntOrFloatBitWidth();
assert(dstBits % srcBits == 0);
@@ -963,23 +956,8 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
"failed to convert memref type");
Type pointeeType = pointerType.getPointeeType();
- IntegerType dstType;
- if (typeConverter.allows(spirv::Capability::Kernel)) {
- if (auto arrayType = dyn_cast<spirv::ArrayType>(pointeeType))
- dstType = dyn_cast<IntegerType>(arrayType.getElementType());
- else
- dstType = dyn_cast<IntegerType>(pointeeType);
- } else {
- // For Vulkan we need to extract element from wrapping struct and array.
- Type structElemType =
- cast<spirv::StructType>(pointeeType).getElementType(0);
- if (auto arrayType = dyn_cast<spirv::ArrayType>(structElemType))
- dstType = dyn_cast<IntegerType>(arrayType.getElementType());
- else
- dstType = dyn_cast<IntegerType>(
- cast<spirv::RuntimeArrayType>(structElemType).getElementType());
- }
-
+ auto dstType = dyn_cast<IntegerType>(
+ getElementTypeForStoragePointer(pointeeType, typeConverter));
if (!dstType)
return rewriter.notifyMatchFailure(
storeOp, "failed to determine destination element type");
More information about the Mlir-commits
mailing list