[Mlir-commits] [mlir] [mlir][SPIRV] Add sub-element-byte lowering support for atomic_rmw ori/andi ops (PR #179831)

Han-Chung Wang llvmlistbot at llvm.org
Mon Feb 23 21:08:11 PST 2026


================
@@ -427,26 +427,142 @@ AtomicRMWOpPattern::matchAndRewrite(memref::AtomicRMWOp atomicOp,
   if (!ptr)
     return failure();
 
+  // Determine the source and destination bitwidths. The source is the original
+  // memref element type and the destination is the SPIR-V storage type (e.g.,
+  // i32 for Vulkan).
+  int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
+  auto pointerType = typeConverter.convertType<spirv::PointerType>(memrefType);
+  if (!pointerType)
+    return rewriter.notifyMatchFailure(atomicOp,
+                                       "failed to convert memref type");
+
+  Type pointeeType = pointerType.getPointeeType();
+  IntegerType dstType;
+  if (typeConverter.allows(spirv::Capability::Kernel)) {
+    if (auto arrayType = dyn_cast<spirv::ArrayType>(pointeeType))
+      dstType = dyn_cast<IntegerType>(arrayType.getElementType());
+    else
+      dstType = dyn_cast<IntegerType>(pointeeType);
+  } else {
+    Type structElemType =
+        cast<spirv::StructType>(pointeeType).getElementType(0);
+    if (auto arrayType = dyn_cast<spirv::ArrayType>(structElemType))
+      dstType = dyn_cast<IntegerType>(arrayType.getElementType());
+    else
+      dstType = dyn_cast<IntegerType>(
+          cast<spirv::RuntimeArrayType>(structElemType).getElementType());
+  }
+
+  if (!dstType)
+    return rewriter.notifyMatchFailure(
+        atomicOp, "failed to determine destination element type");
+
+  int dstBits = static_cast<int>(dstType.getWidth());
+  assert(dstBits % srcBits == 0);
+
+  // When the source and destination bitwidths match, emit the atomic operation
+  // directly.
+  if (srcBits == dstBits) {
 #define ATOMIC_CASE(kind, spirvOp)                                             \
   case arith::AtomicRMWKind::kind:                                             \
     rewriter.replaceOpWithNewOp<spirv::spirvOp>(                               \
         atomicOp, resultType, ptr, *scope,                                     \
         spirv::MemorySemantics::AcquireRelease, adaptor.getValue());           \
     break
 
+    switch (atomicOp.getKind()) {
+      ATOMIC_CASE(addi, AtomicIAddOp);
+      ATOMIC_CASE(maxs, AtomicSMaxOp);
+      ATOMIC_CASE(maxu, AtomicUMaxOp);
+      ATOMIC_CASE(mins, AtomicSMinOp);
+      ATOMIC_CASE(minu, AtomicUMinOp);
+      ATOMIC_CASE(ori, AtomicOrOp);
+      ATOMIC_CASE(andi, AtomicAndOp);
+    default:
+      return rewriter.notifyMatchFailure(atomicOp, "unimplemented atomic kind");
+    }
+
+#undef ATOMIC_CASE
+
+    return success();
+  }
+
+  // Sub-element-width atomic: the element type (e.g., i8) is narrower than the
+  // storage type (e.g., i32). We need to adjust the index and shift/mask the
+  // value to operate on the correct bits within the wider storage element.
+  //
+  // Only ori and andi can be emulated because they operate bitwise and don't
+  // carry across byte boundaries. Other kinds (addi, max, min) would require
+  // CAS loops.
+  if (atomicOp.getKind() != arith::AtomicRMWKind::ori &&
+      atomicOp.getKind() != arith::AtomicRMWKind::andi) {
+    return rewriter.notifyMatchFailure(
+        atomicOp,
+        "atomic op on sub-element-width types is only supported for ori/andi");
+  }
+
+  // Bitcasting is currently unsupported for Kernel capability /
+  // spirv.PtrAccessChain.
+  if (typeConverter.allows(spirv::Capability::Kernel))
+    return rewriter.notifyMatchFailure(
+        atomicOp,
+        "sub-element-width atomic ops unsupported with Kernel capability");
+
+  auto accessChainOp = ptr.getDefiningOp<spirv::AccessChainOp>();
+  if (!accessChainOp)
+    return failure();
+
+  assert(accessChainOp.getIndices().size() == 2);
+  Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1);
+  Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter);
+  Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp,
----------------
hanhanW wrote:

Let me add a bit about how I manage the comments:

- Lines 490-496 explain the sub-element-width algorithm clearly, and lines 523-524, 535-537 explain the ori/andi logic.
- The function name is self-documented, so we don't need much comments.
- I ended up with adding a small comment about what this snippet is doing.

https://github.com/llvm/llvm-project/pull/179831


More information about the Mlir-commits mailing list