[Mlir-commits] [mlir] [mlir][SPIRV] Add sub-element-byte lowering support for atomic_rmw ori/andi ops (PR #179831)

Jakub Kuderski llvmlistbot at llvm.org
Mon Feb 23 08:12:14 PST 2026


================
@@ -427,27 +427,137 @@ AtomicRMWOpPattern::matchAndRewrite(memref::AtomicRMWOp atomicOp,
   if (!ptr)
     return failure();
 
+  // Determine the source and destination bitwidths. The source is the original
+  // memref element type and the destination is the SPIR-V storage type (e.g.,
+  // i32 for Vulkan).
+  int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
+  auto pointerType = typeConverter.convertType<spirv::PointerType>(memrefType);
+  if (!pointerType)
+    return rewriter.notifyMatchFailure(atomicOp,
+                                       "failed to convert memref type");
+
+  Type pointeeType = pointerType.getPointeeType();
+  IntegerType dstType;
+  if (typeConverter.allows(spirv::Capability::Kernel)) {
+    if (auto arrayType = dyn_cast<spirv::ArrayType>(pointeeType))
+      dstType = dyn_cast<IntegerType>(arrayType.getElementType());
+    else
+      dstType = dyn_cast<IntegerType>(pointeeType);
+  } else {
+    Type structElemType =
+        cast<spirv::StructType>(pointeeType).getElementType(0);
+    if (auto arrayType = dyn_cast<spirv::ArrayType>(structElemType))
+      dstType = dyn_cast<IntegerType>(arrayType.getElementType());
+    else
+      dstType = dyn_cast<IntegerType>(
+          cast<spirv::RuntimeArrayType>(structElemType).getElementType());
+  }
+
+  if (!dstType)
+    return rewriter.notifyMatchFailure(
+        atomicOp, "failed to determine destination element type");
+
+  int dstBits = static_cast<int>(dstType.getWidth());
+  assert(dstBits % srcBits == 0);
+
+  // When the source and destination bitwidths match, emit the atomic operation
+  // directly.
+  if (srcBits == dstBits) {
 #define ATOMIC_CASE(kind, spirvOp)                                             \
   case arith::AtomicRMWKind::kind:                                             \
     rewriter.replaceOpWithNewOp<spirv::spirvOp>(                               \
         atomicOp, resultType, ptr, *scope,                                     \
         spirv::MemorySemantics::AcquireRelease, adaptor.getValue());           \
     break
 
-  switch (atomicOp.getKind()) {
-    ATOMIC_CASE(addi, AtomicIAddOp);
-    ATOMIC_CASE(maxs, AtomicSMaxOp);
-    ATOMIC_CASE(maxu, AtomicUMaxOp);
-    ATOMIC_CASE(mins, AtomicSMinOp);
-    ATOMIC_CASE(minu, AtomicUMinOp);
-    ATOMIC_CASE(ori, AtomicOrOp);
-    ATOMIC_CASE(andi, AtomicAndOp);
-  default:
-    return rewriter.notifyMatchFailure(atomicOp, "unimplemented atomic kind");
-  }
+    switch (atomicOp.getKind()) {
+      ATOMIC_CASE(addi, AtomicIAddOp);
+      ATOMIC_CASE(maxs, AtomicSMaxOp);
+      ATOMIC_CASE(maxu, AtomicUMaxOp);
+      ATOMIC_CASE(mins, AtomicSMinOp);
+      ATOMIC_CASE(minu, AtomicUMinOp);
+      ATOMIC_CASE(ori, AtomicOrOp);
+      ATOMIC_CASE(andi, AtomicAndOp);
+    default:
+      return rewriter.notifyMatchFailure(atomicOp, "unimplemented atomic kind");
+    }
 
 #undef ATOMIC_CASE
 
+    return success();
+  }
+
+  // Sub-element-width atomic: the element type (e.g., i8) is narrower than the
+  // storage type (e.g., i32). We need to adjust the index and shift/mask the
+  // value to operate on the correct bits within the wider storage element.
+  //
+  // Only ori and andi can be emulated because they operate bitwise and don't
+  // carry across byte boundaries. Other kinds (addi, max, min) would require
+  // CAS loops.
+  if (atomicOp.getKind() != arith::AtomicRMWKind::ori &&
+      atomicOp.getKind() != arith::AtomicRMWKind::andi) {
+    return rewriter.notifyMatchFailure(
+        atomicOp,
+        "atomic op on sub-element-width types is only supported for ori/andi");
+  }
+
+  // Bitcasting is currently unsupported for Kernel capability /
+  // spirv.PtrAccessChain.
+  if (typeConverter.allows(spirv::Capability::Kernel))
+    return rewriter.notifyMatchFailure(
+        atomicOp,
+        "sub-element-width atomic ops unsupported with Kernel capability");
+
+  auto accessChainOp = ptr.getDefiningOp<spirv::AccessChainOp>();
+  if (!accessChainOp)
+    return failure();
+
+  assert(accessChainOp.getIndices().size() == 2);
+  Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1);
+  Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter);
+  Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp,
+                                                   srcBits, dstBits, rewriter);
+  Value result;
+  if (atomicOp.getKind() == arith::AtomicRMWKind::ori) {
+    // OR only sets bits, so shifting the value to the target position and
+    // ORing with zeros in other positions preserves the unaffected bits.
+    Value elemMask = rewriter.createOrFold<spirv::ConstantOp>(
+        loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1));
+    Value storeVal =
+        shiftValue(loc, adaptor.getValue(), offset, elemMask, rewriter);
+    result = spirv::AtomicOrOp::create(
+        rewriter, loc, dstType, adjustedPtr, *scope,
+        spirv::MemorySemantics::AcquireRelease, storeVal);
+  } else {
+    assert(atomicOp.getKind() == arith::AtomicRMWKind::andi);
----------------
kuhar wrote:

Use a switch instead and assert in the defalt branch?

https://github.com/llvm/llvm-project/pull/179831


More information about the Mlir-commits mailing list