[Mlir-commits] [mlir] [mlir][SPIRV] Add sub-element-byte lowering support for atomic_rmw ori/andi ops (PR #179831)
Jakub Kuderski
llvmlistbot at llvm.org
Mon Feb 23 08:12:14 PST 2026
================
@@ -427,27 +427,137 @@ AtomicRMWOpPattern::matchAndRewrite(memref::AtomicRMWOp atomicOp,
if (!ptr)
return failure();
+ // Determine the source and destination bitwidths. The source is the original
+ // memref element type and the destination is the SPIR-V storage type (e.g.,
+ // i32 for Vulkan).
+ int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
+ auto pointerType = typeConverter.convertType<spirv::PointerType>(memrefType);
+ if (!pointerType)
+ return rewriter.notifyMatchFailure(atomicOp,
+ "failed to convert memref type");
+
+ Type pointeeType = pointerType.getPointeeType();
+ IntegerType dstType;
+ if (typeConverter.allows(spirv::Capability::Kernel)) {
+ if (auto arrayType = dyn_cast<spirv::ArrayType>(pointeeType))
+ dstType = dyn_cast<IntegerType>(arrayType.getElementType());
+ else
+ dstType = dyn_cast<IntegerType>(pointeeType);
+ } else {
+ Type structElemType =
+ cast<spirv::StructType>(pointeeType).getElementType(0);
+ if (auto arrayType = dyn_cast<spirv::ArrayType>(structElemType))
+ dstType = dyn_cast<IntegerType>(arrayType.getElementType());
+ else
+ dstType = dyn_cast<IntegerType>(
+ cast<spirv::RuntimeArrayType>(structElemType).getElementType());
+ }
+
+ if (!dstType)
+ return rewriter.notifyMatchFailure(
+ atomicOp, "failed to determine destination element type");
+
+ int dstBits = static_cast<int>(dstType.getWidth());
+ assert(dstBits % srcBits == 0);
+
+ // When the source and destination bitwidths match, emit the atomic operation
+ // directly.
+ if (srcBits == dstBits) {
#define ATOMIC_CASE(kind, spirvOp) \
case arith::AtomicRMWKind::kind: \
rewriter.replaceOpWithNewOp<spirv::spirvOp>( \
atomicOp, resultType, ptr, *scope, \
spirv::MemorySemantics::AcquireRelease, adaptor.getValue()); \
break
- switch (atomicOp.getKind()) {
- ATOMIC_CASE(addi, AtomicIAddOp);
- ATOMIC_CASE(maxs, AtomicSMaxOp);
- ATOMIC_CASE(maxu, AtomicUMaxOp);
- ATOMIC_CASE(mins, AtomicSMinOp);
- ATOMIC_CASE(minu, AtomicUMinOp);
- ATOMIC_CASE(ori, AtomicOrOp);
- ATOMIC_CASE(andi, AtomicAndOp);
- default:
- return rewriter.notifyMatchFailure(atomicOp, "unimplemented atomic kind");
- }
+ switch (atomicOp.getKind()) {
+ ATOMIC_CASE(addi, AtomicIAddOp);
+ ATOMIC_CASE(maxs, AtomicSMaxOp);
+ ATOMIC_CASE(maxu, AtomicUMaxOp);
+ ATOMIC_CASE(mins, AtomicSMinOp);
+ ATOMIC_CASE(minu, AtomicUMinOp);
+ ATOMIC_CASE(ori, AtomicOrOp);
+ ATOMIC_CASE(andi, AtomicAndOp);
+ default:
+ return rewriter.notifyMatchFailure(atomicOp, "unimplemented atomic kind");
+ }
#undef ATOMIC_CASE
+ return success();
+ }
+
+ // Sub-element-width atomic: the element type (e.g., i8) is narrower than the
+ // storage type (e.g., i32). We need to adjust the index and shift/mask the
+ // value to operate on the correct bits within the wider storage element.
+ //
+ // Only ori and andi can be emulated because they operate bitwise and don't
+ // carry across byte boundaries. Other kinds (addi, max, min) would require
+ // CAS loops.
+ if (atomicOp.getKind() != arith::AtomicRMWKind::ori &&
+ atomicOp.getKind() != arith::AtomicRMWKind::andi) {
+ return rewriter.notifyMatchFailure(
+ atomicOp,
+ "atomic op on sub-element-width types is only supported for ori/andi");
+ }
+
+ // Bitcasting is currently unsupported for Kernel capability /
+ // spirv.PtrAccessChain.
+ if (typeConverter.allows(spirv::Capability::Kernel))
+ return rewriter.notifyMatchFailure(
+ atomicOp,
+ "sub-element-width atomic ops unsupported with Kernel capability");
+
+ auto accessChainOp = ptr.getDefiningOp<spirv::AccessChainOp>();
+ if (!accessChainOp)
+ return failure();
+
+ assert(accessChainOp.getIndices().size() == 2);
+ Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1);
+ Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter);
+ Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp,
+ srcBits, dstBits, rewriter);
+ Value result;
+ if (atomicOp.getKind() == arith::AtomicRMWKind::ori) {
+ // OR only sets bits, so shifting the value to the target position and
+ // ORing with zeros in other positions preserves the unaffected bits.
+ Value elemMask = rewriter.createOrFold<spirv::ConstantOp>(
+ loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1));
+ Value storeVal =
+ shiftValue(loc, adaptor.getValue(), offset, elemMask, rewriter);
+ result = spirv::AtomicOrOp::create(
+ rewriter, loc, dstType, adjustedPtr, *scope,
+ spirv::MemorySemantics::AcquireRelease, storeVal);
+ } else {
+ assert(atomicOp.getKind() == arith::AtomicRMWKind::andi);
----------------
kuhar wrote:
Use a switch instead and assert in the defalt branch?
https://github.com/llvm/llvm-project/pull/179831
More information about the Mlir-commits
mailing list