[Mlir-commits] [mlir] [mlir][SPIRV] Add sub-element-byte lowering support for atomic_rmw ori/andi ops (PR #179831)
Han-Chung Wang
llvmlistbot at llvm.org
Mon Feb 23 21:07:14 PST 2026
https://github.com/hanhanW updated https://github.com/llvm/llvm-project/pull/179831
>From 2c9c00d4ca1e0690af9f639998162a3a0676520b Mon Sep 17 00:00:00 2001
From: hanhanW <hanhan0912 at gmail.com>
Date: Wed, 4 Feb 2026 16:14:51 -0800
Subject: [PATCH 1/3] [mlir][SPIRV] Add sub-element-byte lowering support for
atomic_rmw ori/andi ops
When the memref element type (e.g., i8) is narrower than the SPIR-V
storage type (e.g., i32 on Vulkan), ori and andi can be lowered with a
single wide atomic instruction because OR-with-0 and AND-with-1 are
identity operations. Below is the examples generated by LLM.
The revision follows `IntStoreOpPattern` to compute offsets/sizes via
`adjustAccessChainForBitwidth` method and `getOffsetForBitwidth` method.
Additionally, it handles the returned value (which is the old value by
definition), which is different from `IntStoreOpPattern`.
There are refactoring opportunities and it is not performed within the
revision because the current implementation is already complicated. The
refactoring can be happenned in a follow-up with its own patch, so
reviewing this revision is easier.
Example: atomic_rmw ori on byte 2 (index=2) with value 0xAB
Memory (i32 word):
byte 3 byte 2 byte 1 byte 0
+----------+---------+---------+---------+
| 0x12 | 0x34 | 0x56 | 0x78 | <- original word
+----------+---------+---------+---------+
offset = (index % 4) * 8 = 16
1) Mask and shift the operand into position:
storeVal = (0xAB & 0xFF) << 16 = 0x00AB0000
2) AtomicOr the whole word (zeros are no-ops for OR):
0x12345678 | 0x00AB0000 = 0x12BF5678
Only byte 2 is affected: 0x34 | 0xAB = 0xBF
3) Extract old byte from the atomic's return value:
(0x12345678 >> 16) & 0xFF = 0x34
Example: atomic_rmw andi on byte 1 (index=1) with value 0xF0
offset = (index % 4) * 8 = 8
1) Build the AND mask (operand at target, 1s elsewhere):
shifted = (0xF0 & 0xFF) << 8 = 0x0000F000
inverted = ~(0xFF << 8) = 0xFFFF00FF
mask = 0x0000F000 | 0xFFFF00FF = 0xFFFFF0FF
2) AtomicAnd the whole word (ones are no-ops for AND):
0x12345678 & 0xFFFFF0FF = 0x12345078
Only byte 1 is affected: 0x56 & 0xF0 = 0x50
3) Extract old byte from the atomic's return value:
(0x12345678 >> 8) & 0xFF = 0x56
Signed-off-by: hanhanW <hanhan0912 at gmail.com>
---
.../MemRefToSPIRV/MemRefToSPIRV.cpp | 132 ++++++++++++++++--
.../test/Conversion/MemRefToSPIRV/atomic.mlir | 72 ++++++++++
2 files changed, 193 insertions(+), 11 deletions(-)
diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
index 42e082f69e475..3707ed7ca9cd2 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
@@ -427,6 +427,42 @@ AtomicRMWOpPattern::matchAndRewrite(memref::AtomicRMWOp atomicOp,
if (!ptr)
return failure();
+ // Determine the source and destination bitwidths. The source is the original
+ // memref element type and the destination is the SPIR-V storage type (e.g.,
+ // i32 for Vulkan).
+ int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
+ auto pointerType = typeConverter.convertType<spirv::PointerType>(memrefType);
+ if (!pointerType)
+ return rewriter.notifyMatchFailure(atomicOp,
+ "failed to convert memref type");
+
+ Type pointeeType = pointerType.getPointeeType();
+ IntegerType dstType;
+ if (typeConverter.allows(spirv::Capability::Kernel)) {
+ if (auto arrayType = dyn_cast<spirv::ArrayType>(pointeeType))
+ dstType = dyn_cast<IntegerType>(arrayType.getElementType());
+ else
+ dstType = dyn_cast<IntegerType>(pointeeType);
+ } else {
+ Type structElemType =
+ cast<spirv::StructType>(pointeeType).getElementType(0);
+ if (auto arrayType = dyn_cast<spirv::ArrayType>(structElemType))
+ dstType = dyn_cast<IntegerType>(arrayType.getElementType());
+ else
+ dstType = dyn_cast<IntegerType>(
+ cast<spirv::RuntimeArrayType>(structElemType).getElementType());
+ }
+
+ if (!dstType)
+ return rewriter.notifyMatchFailure(
+ atomicOp, "failed to determine destination element type");
+
+ int dstBits = static_cast<int>(dstType.getWidth());
+ assert(dstBits % srcBits == 0);
+
+ // When the source and destination bitwidths match, emit the atomic operation
+ // directly.
+ if (srcBits == dstBits) {
#define ATOMIC_CASE(kind, spirvOp) \
case arith::AtomicRMWKind::kind: \
rewriter.replaceOpWithNewOp<spirv::spirvOp>( \
@@ -434,20 +470,94 @@ AtomicRMWOpPattern::matchAndRewrite(memref::AtomicRMWOp atomicOp,
spirv::MemorySemantics::AcquireRelease, adaptor.getValue()); \
break
- switch (atomicOp.getKind()) {
- ATOMIC_CASE(addi, AtomicIAddOp);
- ATOMIC_CASE(maxs, AtomicSMaxOp);
- ATOMIC_CASE(maxu, AtomicUMaxOp);
- ATOMIC_CASE(mins, AtomicSMinOp);
- ATOMIC_CASE(minu, AtomicUMinOp);
- ATOMIC_CASE(ori, AtomicOrOp);
- ATOMIC_CASE(andi, AtomicAndOp);
- default:
- return rewriter.notifyMatchFailure(atomicOp, "unimplemented atomic kind");
- }
+ switch (atomicOp.getKind()) {
+ ATOMIC_CASE(addi, AtomicIAddOp);
+ ATOMIC_CASE(maxs, AtomicSMaxOp);
+ ATOMIC_CASE(maxu, AtomicUMaxOp);
+ ATOMIC_CASE(mins, AtomicSMinOp);
+ ATOMIC_CASE(minu, AtomicUMinOp);
+ ATOMIC_CASE(ori, AtomicOrOp);
+ ATOMIC_CASE(andi, AtomicAndOp);
+ default:
+ return rewriter.notifyMatchFailure(atomicOp, "unimplemented atomic kind");
+ }
#undef ATOMIC_CASE
+ return success();
+ }
+
+ // Sub-element-width atomic: the element type (e.g., i8) is narrower than the
+ // storage type (e.g., i32). We need to adjust the index and shift/mask the
+ // value to operate on the correct bits within the wider storage element.
+ //
+ // Only ori and andi can be emulated because they operate bitwise and don't
+ // carry across byte boundaries. Other kinds (addi, max, min) would require
+ // CAS loops.
+ if (atomicOp.getKind() != arith::AtomicRMWKind::ori &&
+ atomicOp.getKind() != arith::AtomicRMWKind::andi) {
+ return rewriter.notifyMatchFailure(
+ atomicOp,
+ "atomic op on sub-element-width types is only supported for ori/andi");
+ }
+
+ // Bitcasting is currently unsupported for Kernel capability /
+ // spirv.PtrAccessChain.
+ if (typeConverter.allows(spirv::Capability::Kernel))
+ return rewriter.notifyMatchFailure(
+ atomicOp,
+ "sub-element-width atomic ops unsupported with Kernel capability");
+
+ auto accessChainOp = ptr.getDefiningOp<spirv::AccessChainOp>();
+ if (!accessChainOp)
+ return failure();
+
+ assert(accessChainOp.getIndices().size() == 2);
+ Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1);
+ Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter);
+ Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp,
+ srcBits, dstBits, rewriter);
+ Value result;
+ if (atomicOp.getKind() == arith::AtomicRMWKind::ori) {
+ // OR only sets bits, so shifting the value to the target position and
+ // ORing with zeros in other positions preserves the unaffected bits.
+ Value elemMask = rewriter.createOrFold<spirv::ConstantOp>(
+ loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1));
+ Value storeVal =
+ shiftValue(loc, adaptor.getValue(), offset, elemMask, rewriter);
+ result = spirv::AtomicOrOp::create(
+ rewriter, loc, dstType, adjustedPtr, *scope,
+ spirv::MemorySemantics::AcquireRelease, storeVal);
+ } else {
+ assert(atomicOp.getKind() == arith::AtomicRMWKind::andi);
+ // AND: build a mask that preserves all bits outside the target element
+ // and applies the operand mask to the target element.
+ // mask = (operand << offset) | ~(elemMask << offset)
+ Value elemMask = rewriter.createOrFold<spirv::ConstantOp>(
+ loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1));
+ Value storeVal =
+ shiftValue(loc, adaptor.getValue(), offset, elemMask, rewriter);
+ Value shiftedElemMask = rewriter.createOrFold<spirv::ShiftLeftLogicalOp>(
+ loc, dstType, elemMask, offset);
+ Value invertedElemMask =
+ rewriter.createOrFold<spirv::NotOp>(loc, dstType, shiftedElemMask);
+ Value mask = rewriter.createOrFold<spirv::BitwiseOrOp>(loc, storeVal,
+ invertedElemMask);
+ result = spirv::AtomicAndOp::create(
+ rewriter, loc, dstType, adjustedPtr, *scope,
+ spirv::MemorySemantics::AcquireRelease, mask);
+ }
+
+ // The atomic op returns the old value of the full storage element (e.g.,
+ // i32). Extract the original sub-element value from the correct position.
+ result = rewriter.createOrFold<spirv::ShiftRightLogicalOp>(loc, dstType,
+ result, offset);
+ Value mask = rewriter.createOrFold<spirv::ConstantOp>(
+ loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1));
+ result =
+ rewriter.createOrFold<spirv::BitwiseAndOp>(loc, dstType, result, mask);
+ rewriter.replaceOp(atomicOp, result);
+
return success();
}
diff --git a/mlir/test/Conversion/MemRefToSPIRV/atomic.mlir b/mlir/test/Conversion/MemRefToSPIRV/atomic.mlir
index 4729cccfb6228..92f98637a1939 100644
--- a/mlir/test/Conversion/MemRefToSPIRV/atomic.mlir
+++ b/mlir/test/Conversion/MemRefToSPIRV/atomic.mlir
@@ -74,3 +74,75 @@ func.func @atomic_andi_storage_buffer(%value: i32, %memref: memref<2x3x4xi32, #s
}
+// -----
+
+// Check sub-element-width atomic ori on i8 memref (stored as i32 in SPIR-V).
+// The byte index must be divided by 4 to get the i32 index, and the value
+// must be shifted to the correct byte position within the i32.
+
+module attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.3, [Shader], []>, #spirv.resource_limits<>>} {
+
+// CHECK-LABEL: func.func @atomic_ori_i8_storage_buffer
+func.func @atomic_ori_i8_storage_buffer(%value: i8, %memref: memref<16xi8, #spirv.storage_class<StorageBuffer>>, %i0: index) -> i8 {
+ // CHECK: %[[IDX:.+]] = builtin.unrealized_conversion_cast %{{.*}} : index to i32
+ // CHECK: %[[MEM:.+]] = builtin.unrealized_conversion_cast %{{.*}} : memref{{.*}} to !spirv.ptr
+ // CHECK: %[[VAL:.+]] = builtin.unrealized_conversion_cast %{{.*}} : i8 to i32
+ // Compute bit offset: (idx % 4) * 8
+ // CHECK-DAG: %[[C4:.+]] = spirv.Constant 4 : i32
+ // CHECK-DAG: %[[C8:.+]] = spirv.Constant 8 : i32
+ // CHECK: %[[MOD:.+]] = spirv.UMod %[[IDX]], %[[C4]]
+ // CHECK: %[[OFFSET:.+]] = spirv.IMul %[[MOD]], %[[C8]]
+ // Adjust the access chain index: idx / 4
+ // CHECK: %[[DIV:.+]] = spirv.SDiv %[[IDX]], %{{.*}}
+ // CHECK: %[[AC:.+]] = spirv.AccessChain %[[MEM]][%{{.*}}, %[[DIV]]]
+ // Mask and shift the value
+ // CHECK: %[[C255:.+]] = spirv.Constant 255 : i32
+ // CHECK: %[[MASKED:.+]] = spirv.BitwiseAnd %[[VAL]], %[[C255]]
+ // CHECK: %[[SHIFTED:.+]] = spirv.ShiftLeftLogical %[[MASKED]], %[[OFFSET]]
+ // Atomic OR
+ // CHECK: %[[ATOMIC:.+]] = spirv.AtomicOr <Device> <AcquireRelease> %[[AC]], %[[SHIFTED]]
+ // Extract old value from result
+ // CHECK: spirv.ShiftRightLogical %[[ATOMIC]], %[[OFFSET]]
+ // CHECK: spirv.BitwiseAnd
+ // CHECK: return
+ %0 = memref.atomic_rmw "ori" %value, %memref[%i0] : (i8, memref<16xi8, #spirv.storage_class<StorageBuffer>>) -> i8
+ return %0: i8
+}
+
+}
+
+// -----
+
+// Check sub-element-width atomic andi on i8 memref.
+
+module attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.3, [Shader], []>, #spirv.resource_limits<>>} {
+
+// CHECK-LABEL: func.func @atomic_andi_i8_storage_buffer
+func.func @atomic_andi_i8_storage_buffer(%value: i8, %memref: memref<16xi8, #spirv.storage_class<StorageBuffer>>, %i0: index) -> i8 {
+ // CHECK: %[[IDX:.+]] = builtin.unrealized_conversion_cast %{{.*}} : index to i32
+ // CHECK: %[[MEM:.+]] = builtin.unrealized_conversion_cast %{{.*}} : memref{{.*}} to !spirv.ptr
+ // CHECK: %[[VAL:.+]] = builtin.unrealized_conversion_cast %{{.*}} : i8 to i32
+ // CHECK-DAG: %[[C4:.+]] = spirv.Constant 4 : i32
+ // CHECK-DAG: %[[C8:.+]] = spirv.Constant 8 : i32
+ // CHECK: %[[MOD:.+]] = spirv.UMod %[[IDX]], %[[C4]]
+ // CHECK: %[[OFFSET:.+]] = spirv.IMul %[[MOD]], %[[C8]]
+ // CHECK: %[[DIV:.+]] = spirv.SDiv %[[IDX]], %{{.*}}
+ // CHECK: %[[AC:.+]] = spirv.AccessChain %[[MEM]][%{{.*}}, %[[DIV]]]
+ // Build the AND mask: (val << offset) | ~(0xFF << offset)
+ // CHECK: %[[C255:.+]] = spirv.Constant 255 : i32
+ // CHECK: %[[MASKED:.+]] = spirv.BitwiseAnd %[[VAL]], %[[C255]]
+ // CHECK: %[[SHIFTED:.+]] = spirv.ShiftLeftLogical %[[MASKED]], %[[OFFSET]]
+ // CHECK: %[[ELEM_SHIFTED:.+]] = spirv.ShiftLeftLogical %[[C255]], %[[OFFSET]]
+ // CHECK: %[[NOT_ELEM:.+]] = spirv.Not %[[ELEM_SHIFTED]]
+ // CHECK: %[[MASK:.+]] = spirv.BitwiseOr %[[SHIFTED]], %[[NOT_ELEM]]
+ // Atomic AND
+ // CHECK: %[[ATOMIC:.+]] = spirv.AtomicAnd <Device> <AcquireRelease> %[[AC]], %[[MASK]]
+ // Extract old value
+ // CHECK: spirv.ShiftRightLogical %[[ATOMIC]], %[[OFFSET]]
+ // CHECK: spirv.BitwiseAnd
+ // CHECK: return
+ %0 = memref.atomic_rmw "andi" %value, %memref[%i0] : (i8, memref<16xi8, #spirv.storage_class<StorageBuffer>>) -> i8
+ return %0: i8
+}
+
+}
>From ec35efcfd53f4aad028b590f67bc5930ed2917c3 Mon Sep 17 00:00:00 2001
From: hanhanW <hanhan0912 at gmail.com>
Date: Mon, 23 Feb 2026 16:59:55 -0800
Subject: [PATCH 2/3] address comments
Signed-off-by: hanhanW <hanhan0912 at gmail.com>
---
.../MemRefToSPIRV/MemRefToSPIRV.cpp | 20 ++++++++++++-------
1 file changed, 13 insertions(+), 7 deletions(-)
diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
index 3707ed7ca9cd2..cee6da5b4eabf 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
@@ -518,23 +518,25 @@ AtomicRMWOpPattern::matchAndRewrite(memref::AtomicRMWOp atomicOp,
Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp,
srcBits, dstBits, rewriter);
Value result;
- if (atomicOp.getKind() == arith::AtomicRMWKind::ori) {
+ switch (atomicOp.getKind()) {
+ case arith::AtomicRMWKind::ori: {
// OR only sets bits, so shifting the value to the target position and
// ORing with zeros in other positions preserves the unaffected bits.
Value elemMask = rewriter.createOrFold<spirv::ConstantOp>(
- loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1));
+ loc, dstType, rewriter.getIntegerAttr(dstType, (1uLL << srcBits) - 1));
Value storeVal =
shiftValue(loc, adaptor.getValue(), offset, elemMask, rewriter);
result = spirv::AtomicOrOp::create(
rewriter, loc, dstType, adjustedPtr, *scope,
spirv::MemorySemantics::AcquireRelease, storeVal);
- } else {
- assert(atomicOp.getKind() == arith::AtomicRMWKind::andi);
- // AND: build a mask that preserves all bits outside the target element
+ break;
+ }
+ case arith::AtomicRMWKind::andi: {
+ // Build a mask that preserves all bits outside the target element
// and applies the operand mask to the target element.
// mask = (operand << offset) | ~(elemMask << offset)
Value elemMask = rewriter.createOrFold<spirv::ConstantOp>(
- loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1));
+ loc, dstType, rewriter.getIntegerAttr(dstType, (1uLL << srcBits) - 1));
Value storeVal =
shiftValue(loc, adaptor.getValue(), offset, elemMask, rewriter);
Value shiftedElemMask = rewriter.createOrFold<spirv::ShiftLeftLogicalOp>(
@@ -546,6 +548,10 @@ AtomicRMWOpPattern::matchAndRewrite(memref::AtomicRMWOp atomicOp,
result = spirv::AtomicAndOp::create(
rewriter, loc, dstType, adjustedPtr, *scope,
spirv::MemorySemantics::AcquireRelease, mask);
+ break;
+ }
+ default:
+ return rewriter.notifyMatchFailure(atomicOp, "unimplemented atomic kind");
}
// The atomic op returns the old value of the full storage element (e.g.,
@@ -553,7 +559,7 @@ AtomicRMWOpPattern::matchAndRewrite(memref::AtomicRMWOp atomicOp,
result = rewriter.createOrFold<spirv::ShiftRightLogicalOp>(loc, dstType,
result, offset);
Value mask = rewriter.createOrFold<spirv::ConstantOp>(
- loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1));
+ loc, dstType, rewriter.getIntegerAttr(dstType, (1uLL << srcBits) - 1));
result =
rewriter.createOrFold<spirv::BitwiseAndOp>(loc, dstType, result, mask);
rewriter.replaceOp(atomicOp, result);
>From ce9276111869c2790e282353a698050b47778ed8 Mon Sep 17 00:00:00 2001
From: hanhanW <hanhan0912 at gmail.com>
Date: Mon, 23 Feb 2026 21:06:45 -0800
Subject: [PATCH 3/3] Add one more comment
Signed-off-by: hanhanW <hanhan0912 at gmail.com>
---
mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp | 2 ++
1 file changed, 2 insertions(+)
diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
index cee6da5b4eabf..81116cf6f13ad 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
@@ -512,6 +512,8 @@ AtomicRMWOpPattern::matchAndRewrite(memref::AtomicRMWOp atomicOp,
if (!accessChainOp)
return failure();
+ // Compute the bit offset within the storage element and adjust the pointer
+ // to address the containing storage element.
assert(accessChainOp.getIndices().size() == 2);
Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1);
Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter);
More information about the Mlir-commits
mailing list