[Mlir-commits] [mlir] [mlir][SPIRV] Add sub-element-byte lowering support for atomic_rmw ori/andi ops (PR #179831)

Han-Chung Wang llvmlistbot at llvm.org
Mon Feb 23 21:07:14 PST 2026


https://github.com/hanhanW updated https://github.com/llvm/llvm-project/pull/179831

>From 2c9c00d4ca1e0690af9f639998162a3a0676520b Mon Sep 17 00:00:00 2001
From: hanhanW <hanhan0912 at gmail.com>
Date: Wed, 4 Feb 2026 16:14:51 -0800
Subject: [PATCH 1/3] [mlir][SPIRV] Add sub-element-byte lowering support for
 atomic_rmw ori/andi ops

When the memref element type (e.g., i8) is narrower than the SPIR-V
storage type (e.g., i32 on Vulkan), ori and andi can be lowered with a
single wide atomic instruction because OR-with-0 and AND-with-1 are
identity operations. Below is the examples generated by LLM.

The revision follows `IntStoreOpPattern` to compute offsets/sizes via
`adjustAccessChainForBitwidth` method and `getOffsetForBitwidth` method.
Additionally, it handles the returned value (which is the old value by
definition), which is different from `IntStoreOpPattern`.

There are refactoring opportunities and it is not performed within the
revision because the current implementation is already complicated. The
refactoring can be happenned in a follow-up with its own patch, so
reviewing this revision is easier.

Example: atomic_rmw ori on byte 2 (index=2) with value 0xAB

  Memory (i32 word):
    byte 3    byte 2    byte 1    byte 0
  +----------+---------+---------+---------+
  | 0x12     | 0x34    | 0x56    | 0x78    |  <- original word
  +----------+---------+---------+---------+

  offset = (index % 4) * 8 = 16

  1) Mask and shift the operand into position:
       storeVal = (0xAB & 0xFF) << 16 = 0x00AB0000

  2) AtomicOr the whole word (zeros are no-ops for OR):
       0x12345678 | 0x00AB0000 = 0x12BF5678
       Only byte 2 is affected: 0x34 | 0xAB = 0xBF

  3) Extract old byte from the atomic's return value:
       (0x12345678 >> 16) & 0xFF = 0x34

Example: atomic_rmw andi on byte 1 (index=1) with value 0xF0

  offset = (index % 4) * 8 = 8

  1) Build the AND mask (operand at target, 1s elsewhere):
       shifted  = (0xF0 & 0xFF) << 8   = 0x0000F000
       inverted = ~(0xFF << 8)          = 0xFFFF00FF
       mask     = 0x0000F000 | 0xFFFF00FF = 0xFFFFF0FF

  2) AtomicAnd the whole word (ones are no-ops for AND):
       0x12345678 & 0xFFFFF0FF = 0x12345078
       Only byte 1 is affected: 0x56 & 0xF0 = 0x50

  3) Extract old byte from the atomic's return value:
       (0x12345678 >> 8) & 0xFF = 0x56

Signed-off-by: hanhanW <hanhan0912 at gmail.com>
---
 .../MemRefToSPIRV/MemRefToSPIRV.cpp           | 132 ++++++++++++++++--
 .../test/Conversion/MemRefToSPIRV/atomic.mlir |  72 ++++++++++
 2 files changed, 193 insertions(+), 11 deletions(-)

diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
index 42e082f69e475..3707ed7ca9cd2 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
@@ -427,6 +427,42 @@ AtomicRMWOpPattern::matchAndRewrite(memref::AtomicRMWOp atomicOp,
   if (!ptr)
     return failure();
 
+  // Determine the source and destination bitwidths. The source is the original
+  // memref element type and the destination is the SPIR-V storage type (e.g.,
+  // i32 for Vulkan).
+  int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
+  auto pointerType = typeConverter.convertType<spirv::PointerType>(memrefType);
+  if (!pointerType)
+    return rewriter.notifyMatchFailure(atomicOp,
+                                       "failed to convert memref type");
+
+  Type pointeeType = pointerType.getPointeeType();
+  IntegerType dstType;
+  if (typeConverter.allows(spirv::Capability::Kernel)) {
+    if (auto arrayType = dyn_cast<spirv::ArrayType>(pointeeType))
+      dstType = dyn_cast<IntegerType>(arrayType.getElementType());
+    else
+      dstType = dyn_cast<IntegerType>(pointeeType);
+  } else {
+    Type structElemType =
+        cast<spirv::StructType>(pointeeType).getElementType(0);
+    if (auto arrayType = dyn_cast<spirv::ArrayType>(structElemType))
+      dstType = dyn_cast<IntegerType>(arrayType.getElementType());
+    else
+      dstType = dyn_cast<IntegerType>(
+          cast<spirv::RuntimeArrayType>(structElemType).getElementType());
+  }
+
+  if (!dstType)
+    return rewriter.notifyMatchFailure(
+        atomicOp, "failed to determine destination element type");
+
+  int dstBits = static_cast<int>(dstType.getWidth());
+  assert(dstBits % srcBits == 0);
+
+  // When the source and destination bitwidths match, emit the atomic operation
+  // directly.
+  if (srcBits == dstBits) {
 #define ATOMIC_CASE(kind, spirvOp)                                             \
   case arith::AtomicRMWKind::kind:                                             \
     rewriter.replaceOpWithNewOp<spirv::spirvOp>(                               \
@@ -434,20 +470,94 @@ AtomicRMWOpPattern::matchAndRewrite(memref::AtomicRMWOp atomicOp,
         spirv::MemorySemantics::AcquireRelease, adaptor.getValue());           \
     break
 
-  switch (atomicOp.getKind()) {
-    ATOMIC_CASE(addi, AtomicIAddOp);
-    ATOMIC_CASE(maxs, AtomicSMaxOp);
-    ATOMIC_CASE(maxu, AtomicUMaxOp);
-    ATOMIC_CASE(mins, AtomicSMinOp);
-    ATOMIC_CASE(minu, AtomicUMinOp);
-    ATOMIC_CASE(ori, AtomicOrOp);
-    ATOMIC_CASE(andi, AtomicAndOp);
-  default:
-    return rewriter.notifyMatchFailure(atomicOp, "unimplemented atomic kind");
-  }
+    switch (atomicOp.getKind()) {
+      ATOMIC_CASE(addi, AtomicIAddOp);
+      ATOMIC_CASE(maxs, AtomicSMaxOp);
+      ATOMIC_CASE(maxu, AtomicUMaxOp);
+      ATOMIC_CASE(mins, AtomicSMinOp);
+      ATOMIC_CASE(minu, AtomicUMinOp);
+      ATOMIC_CASE(ori, AtomicOrOp);
+      ATOMIC_CASE(andi, AtomicAndOp);
+    default:
+      return rewriter.notifyMatchFailure(atomicOp, "unimplemented atomic kind");
+    }
 
 #undef ATOMIC_CASE
 
+    return success();
+  }
+
+  // Sub-element-width atomic: the element type (e.g., i8) is narrower than the
+  // storage type (e.g., i32). We need to adjust the index and shift/mask the
+  // value to operate on the correct bits within the wider storage element.
+  //
+  // Only ori and andi can be emulated because they operate bitwise and don't
+  // carry across byte boundaries. Other kinds (addi, max, min) would require
+  // CAS loops.
+  if (atomicOp.getKind() != arith::AtomicRMWKind::ori &&
+      atomicOp.getKind() != arith::AtomicRMWKind::andi) {
+    return rewriter.notifyMatchFailure(
+        atomicOp,
+        "atomic op on sub-element-width types is only supported for ori/andi");
+  }
+
+  // Bitcasting is currently unsupported for Kernel capability /
+  // spirv.PtrAccessChain.
+  if (typeConverter.allows(spirv::Capability::Kernel))
+    return rewriter.notifyMatchFailure(
+        atomicOp,
+        "sub-element-width atomic ops unsupported with Kernel capability");
+
+  auto accessChainOp = ptr.getDefiningOp<spirv::AccessChainOp>();
+  if (!accessChainOp)
+    return failure();
+
+  assert(accessChainOp.getIndices().size() == 2);
+  Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1);
+  Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter);
+  Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp,
+                                                   srcBits, dstBits, rewriter);
+  Value result;
+  if (atomicOp.getKind() == arith::AtomicRMWKind::ori) {
+    // OR only sets bits, so shifting the value to the target position and
+    // ORing with zeros in other positions preserves the unaffected bits.
+    Value elemMask = rewriter.createOrFold<spirv::ConstantOp>(
+        loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1));
+    Value storeVal =
+        shiftValue(loc, adaptor.getValue(), offset, elemMask, rewriter);
+    result = spirv::AtomicOrOp::create(
+        rewriter, loc, dstType, adjustedPtr, *scope,
+        spirv::MemorySemantics::AcquireRelease, storeVal);
+  } else {
+    assert(atomicOp.getKind() == arith::AtomicRMWKind::andi);
+    // AND: build a mask that preserves all bits outside the target element
+    // and applies the operand mask to the target element.
+    //   mask = (operand << offset) | ~(elemMask << offset)
+    Value elemMask = rewriter.createOrFold<spirv::ConstantOp>(
+        loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1));
+    Value storeVal =
+        shiftValue(loc, adaptor.getValue(), offset, elemMask, rewriter);
+    Value shiftedElemMask = rewriter.createOrFold<spirv::ShiftLeftLogicalOp>(
+        loc, dstType, elemMask, offset);
+    Value invertedElemMask =
+        rewriter.createOrFold<spirv::NotOp>(loc, dstType, shiftedElemMask);
+    Value mask = rewriter.createOrFold<spirv::BitwiseOrOp>(loc, storeVal,
+                                                           invertedElemMask);
+    result = spirv::AtomicAndOp::create(
+        rewriter, loc, dstType, adjustedPtr, *scope,
+        spirv::MemorySemantics::AcquireRelease, mask);
+  }
+
+  // The atomic op returns the old value of the full storage element (e.g.,
+  // i32). Extract the original sub-element value from the correct position.
+  result = rewriter.createOrFold<spirv::ShiftRightLogicalOp>(loc, dstType,
+                                                             result, offset);
+  Value mask = rewriter.createOrFold<spirv::ConstantOp>(
+      loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1));
+  result =
+      rewriter.createOrFold<spirv::BitwiseAndOp>(loc, dstType, result, mask);
+  rewriter.replaceOp(atomicOp, result);
+
   return success();
 }
 
diff --git a/mlir/test/Conversion/MemRefToSPIRV/atomic.mlir b/mlir/test/Conversion/MemRefToSPIRV/atomic.mlir
index 4729cccfb6228..92f98637a1939 100644
--- a/mlir/test/Conversion/MemRefToSPIRV/atomic.mlir
+++ b/mlir/test/Conversion/MemRefToSPIRV/atomic.mlir
@@ -74,3 +74,75 @@ func.func @atomic_andi_storage_buffer(%value: i32, %memref: memref<2x3x4xi32, #s
 
 }
 
+// -----
+
+// Check sub-element-width atomic ori on i8 memref (stored as i32 in SPIR-V).
+// The byte index must be divided by 4 to get the i32 index, and the value
+// must be shifted to the correct byte position within the i32.
+
+module attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.3, [Shader], []>, #spirv.resource_limits<>>} {
+
+// CHECK-LABEL: func.func @atomic_ori_i8_storage_buffer
+func.func @atomic_ori_i8_storage_buffer(%value: i8, %memref: memref<16xi8, #spirv.storage_class<StorageBuffer>>, %i0: index) -> i8 {
+  //      CHECK:     %[[IDX:.+]] = builtin.unrealized_conversion_cast %{{.*}} : index to i32
+  //      CHECK:     %[[MEM:.+]] = builtin.unrealized_conversion_cast %{{.*}} : memref{{.*}} to !spirv.ptr
+  //      CHECK:     %[[VAL:.+]] = builtin.unrealized_conversion_cast %{{.*}} : i8 to i32
+  // Compute bit offset: (idx % 4) * 8
+  //  CHECK-DAG:     %[[C4:.+]] = spirv.Constant 4 : i32
+  //  CHECK-DAG:     %[[C8:.+]] = spirv.Constant 8 : i32
+  //      CHECK:     %[[MOD:.+]] = spirv.UMod %[[IDX]], %[[C4]]
+  //      CHECK:     %[[OFFSET:.+]] = spirv.IMul %[[MOD]], %[[C8]]
+  // Adjust the access chain index: idx / 4
+  //      CHECK:     %[[DIV:.+]] = spirv.SDiv %[[IDX]], %{{.*}}
+  //      CHECK:     %[[AC:.+]] = spirv.AccessChain %[[MEM]][%{{.*}}, %[[DIV]]]
+  // Mask and shift the value
+  //      CHECK:     %[[C255:.+]] = spirv.Constant 255 : i32
+  //      CHECK:     %[[MASKED:.+]] = spirv.BitwiseAnd %[[VAL]], %[[C255]]
+  //      CHECK:     %[[SHIFTED:.+]] = spirv.ShiftLeftLogical %[[MASKED]], %[[OFFSET]]
+  // Atomic OR
+  //      CHECK:     %[[ATOMIC:.+]] = spirv.AtomicOr <Device> <AcquireRelease> %[[AC]], %[[SHIFTED]]
+  // Extract old value from result
+  //      CHECK:     spirv.ShiftRightLogical %[[ATOMIC]], %[[OFFSET]]
+  //      CHECK:     spirv.BitwiseAnd
+  //      CHECK:     return
+  %0 = memref.atomic_rmw "ori" %value, %memref[%i0] : (i8, memref<16xi8, #spirv.storage_class<StorageBuffer>>) -> i8
+  return %0: i8
+}
+
+}
+
+// -----
+
+// Check sub-element-width atomic andi on i8 memref.
+
+module attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.3, [Shader], []>, #spirv.resource_limits<>>} {
+
+// CHECK-LABEL: func.func @atomic_andi_i8_storage_buffer
+func.func @atomic_andi_i8_storage_buffer(%value: i8, %memref: memref<16xi8, #spirv.storage_class<StorageBuffer>>, %i0: index) -> i8 {
+  //      CHECK:     %[[IDX:.+]] = builtin.unrealized_conversion_cast %{{.*}} : index to i32
+  //      CHECK:     %[[MEM:.+]] = builtin.unrealized_conversion_cast %{{.*}} : memref{{.*}} to !spirv.ptr
+  //      CHECK:     %[[VAL:.+]] = builtin.unrealized_conversion_cast %{{.*}} : i8 to i32
+  //  CHECK-DAG:     %[[C4:.+]] = spirv.Constant 4 : i32
+  //  CHECK-DAG:     %[[C8:.+]] = spirv.Constant 8 : i32
+  //      CHECK:     %[[MOD:.+]] = spirv.UMod %[[IDX]], %[[C4]]
+  //      CHECK:     %[[OFFSET:.+]] = spirv.IMul %[[MOD]], %[[C8]]
+  //      CHECK:     %[[DIV:.+]] = spirv.SDiv %[[IDX]], %{{.*}}
+  //      CHECK:     %[[AC:.+]] = spirv.AccessChain %[[MEM]][%{{.*}}, %[[DIV]]]
+  // Build the AND mask: (val << offset) | ~(0xFF << offset)
+  //      CHECK:     %[[C255:.+]] = spirv.Constant 255 : i32
+  //      CHECK:     %[[MASKED:.+]] = spirv.BitwiseAnd %[[VAL]], %[[C255]]
+  //      CHECK:     %[[SHIFTED:.+]] = spirv.ShiftLeftLogical %[[MASKED]], %[[OFFSET]]
+  //      CHECK:     %[[ELEM_SHIFTED:.+]] = spirv.ShiftLeftLogical %[[C255]], %[[OFFSET]]
+  //      CHECK:     %[[NOT_ELEM:.+]] = spirv.Not %[[ELEM_SHIFTED]]
+  //      CHECK:     %[[MASK:.+]] = spirv.BitwiseOr %[[SHIFTED]], %[[NOT_ELEM]]
+  // Atomic AND
+  //      CHECK:     %[[ATOMIC:.+]] = spirv.AtomicAnd <Device> <AcquireRelease> %[[AC]], %[[MASK]]
+  // Extract old value
+  //      CHECK:     spirv.ShiftRightLogical %[[ATOMIC]], %[[OFFSET]]
+  //      CHECK:     spirv.BitwiseAnd
+  //      CHECK:     return
+  %0 = memref.atomic_rmw "andi" %value, %memref[%i0] : (i8, memref<16xi8, #spirv.storage_class<StorageBuffer>>) -> i8
+  return %0: i8
+}
+
+}

>From ec35efcfd53f4aad028b590f67bc5930ed2917c3 Mon Sep 17 00:00:00 2001
From: hanhanW <hanhan0912 at gmail.com>
Date: Mon, 23 Feb 2026 16:59:55 -0800
Subject: [PATCH 2/3] address comments

Signed-off-by: hanhanW <hanhan0912 at gmail.com>
---
 .../MemRefToSPIRV/MemRefToSPIRV.cpp           | 20 ++++++++++++-------
 1 file changed, 13 insertions(+), 7 deletions(-)

diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
index 3707ed7ca9cd2..cee6da5b4eabf 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
@@ -518,23 +518,25 @@ AtomicRMWOpPattern::matchAndRewrite(memref::AtomicRMWOp atomicOp,
   Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp,
                                                    srcBits, dstBits, rewriter);
   Value result;
-  if (atomicOp.getKind() == arith::AtomicRMWKind::ori) {
+  switch (atomicOp.getKind()) {
+  case arith::AtomicRMWKind::ori: {
     // OR only sets bits, so shifting the value to the target position and
     // ORing with zeros in other positions preserves the unaffected bits.
     Value elemMask = rewriter.createOrFold<spirv::ConstantOp>(
-        loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1));
+        loc, dstType, rewriter.getIntegerAttr(dstType, (1uLL << srcBits) - 1));
     Value storeVal =
         shiftValue(loc, adaptor.getValue(), offset, elemMask, rewriter);
     result = spirv::AtomicOrOp::create(
         rewriter, loc, dstType, adjustedPtr, *scope,
         spirv::MemorySemantics::AcquireRelease, storeVal);
-  } else {
-    assert(atomicOp.getKind() == arith::AtomicRMWKind::andi);
-    // AND: build a mask that preserves all bits outside the target element
+    break;
+  }
+  case arith::AtomicRMWKind::andi: {
+    // Build a mask that preserves all bits outside the target element
     // and applies the operand mask to the target element.
     //   mask = (operand << offset) | ~(elemMask << offset)
     Value elemMask = rewriter.createOrFold<spirv::ConstantOp>(
-        loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1));
+        loc, dstType, rewriter.getIntegerAttr(dstType, (1uLL << srcBits) - 1));
     Value storeVal =
         shiftValue(loc, adaptor.getValue(), offset, elemMask, rewriter);
     Value shiftedElemMask = rewriter.createOrFold<spirv::ShiftLeftLogicalOp>(
@@ -546,6 +548,10 @@ AtomicRMWOpPattern::matchAndRewrite(memref::AtomicRMWOp atomicOp,
     result = spirv::AtomicAndOp::create(
         rewriter, loc, dstType, adjustedPtr, *scope,
         spirv::MemorySemantics::AcquireRelease, mask);
+    break;
+  }
+  default:
+    return rewriter.notifyMatchFailure(atomicOp, "unimplemented atomic kind");
   }
 
   // The atomic op returns the old value of the full storage element (e.g.,
@@ -553,7 +559,7 @@ AtomicRMWOpPattern::matchAndRewrite(memref::AtomicRMWOp atomicOp,
   result = rewriter.createOrFold<spirv::ShiftRightLogicalOp>(loc, dstType,
                                                              result, offset);
   Value mask = rewriter.createOrFold<spirv::ConstantOp>(
-      loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1));
+      loc, dstType, rewriter.getIntegerAttr(dstType, (1uLL << srcBits) - 1));
   result =
       rewriter.createOrFold<spirv::BitwiseAndOp>(loc, dstType, result, mask);
   rewriter.replaceOp(atomicOp, result);

>From ce9276111869c2790e282353a698050b47778ed8 Mon Sep 17 00:00:00 2001
From: hanhanW <hanhan0912 at gmail.com>
Date: Mon, 23 Feb 2026 21:06:45 -0800
Subject: [PATCH 3/3] Add one more comment

Signed-off-by: hanhanW <hanhan0912 at gmail.com>
---
 mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp | 2 ++
 1 file changed, 2 insertions(+)

diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
index cee6da5b4eabf..81116cf6f13ad 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
@@ -512,6 +512,8 @@ AtomicRMWOpPattern::matchAndRewrite(memref::AtomicRMWOp atomicOp,
   if (!accessChainOp)
     return failure();
 
+  // Compute the bit offset within the storage element and adjust the pointer
+  // to address the containing storage element.
   assert(accessChainOp.getIndices().size() == 2);
   Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1);
   Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter);



More information about the Mlir-commits mailing list