[Mlir-commits] [mlir] [mlir][AMDGPU] Add wrappers for in-memory barriers on gfx1250 (PR #180112)

Krzysztof Drewniak llvmlistbot at llvm.org
Thu Feb 5 20:01:24 PST 2026


https://github.com/krzysz00 created https://github.com/llvm/llvm-project/pull/180112

This commit introduces the `!amdgpu.ds_barrier_state` type and operations on that type, including extracting its components and (more importantly) provides wrappers around the upcoming barrier-management instructions that will be added in gfx1250.

This commit is loosely based on work done for Triton, but this commit provides slightly more lower level-primitives (namely a known-atomic load for getting the barrier state instead of providing a `wait` operation that includes an entire spin-loop, though if people want one we could consider adding it.) These operations will allow LDS barriers to be interacted with in a more type-safe manner.

The types and operations use the Ds naming scheme to match the underlying instructions and to avoid confusion with the "LDS barrier" already present in the AMDGPU dialect that was a workaround for LLVM's memory fencing support.

(To summarize a potential usage pattern, one can use a pair of these barriers to communicate between wave(s) in a workgroup that load data into memory and a separate wave(s) that compute with that data.)

>From 4a429ca44fb72ba0e7cadc47c13f3291838486cf Mon Sep 17 00:00:00 2001
From: Krzysztof Drewniak <Krzysztof.Drewniak at amd.com>
Date: Fri, 6 Feb 2026 03:44:35 +0000
Subject: [PATCH] [mlir][AMDGPU] Add wrappers for in-memory barriers on gfx1250

This commit introduces the `!amdgpu.ds_barrier_state` type and
operations on that type, including extracting its components and (more
importantly) provides wrappers around the upcoming barrier-management
instructions that will be added in gfx1250.

This commit is loosely based on work done for Triton, but this commit
provides slightly more lower level-primitives (namely a known-atomic
load for getting the barrier state instead of providing a `wait`
operation that includes an entire spin-loop, though if people want one
we could consider adding it.) These operations will allow LDS barriers
to be interacted with in a more type-safe manner.

The types and operations use the Ds naming scheme to match the
underlying instructions and to avoid confusion with the "LDS barrier"
already present in the AMDGPU dialect that was a workaround for LLVM's
memory fencing support.

(To summarize a potential usage pattern, one can use a pair of these
barriers to communicate between wave(s) in a workgroup that load data
into memory and a separate wave(s) that compute with that data.)

Co-authored-by: Claude Sonnet 4.5 <noreply at anthropic.com>
---
 .../mlir/Dialect/AMDGPU/IR/AMDGPUOps.td       | 216 ++++++++++++++
 .../mlir/Dialect/AMDGPU/IR/AMDGPUTypes.td     |  25 ++
 .../AMDGPUToROCDL/AMDGPUToROCDL.cpp           | 265 +++++++++++++++++-
 mlir/lib/Dialect/AMDGPU/IR/AMDGPUOps.cpp      |  29 ++
 .../Conversion/AMDGPUToROCDL/gfx1250.mlir     |  94 ++++++-
 mlir/test/Dialect/AMDGPU/invalid.mlir         |  24 ++
 mlir/test/Dialect/AMDGPU/ops.mlir             |  22 ++
 7 files changed, 670 insertions(+), 5 deletions(-)

diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUOps.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUOps.td
index 24e40f40c2031..6042a958a2b3b 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUOps.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUOps.td
@@ -1541,4 +1541,220 @@ def AMDGPU_TensorStoreFromLDSOp :
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// In-LDS Barrier Operations
+//
+// General note: any of these operations that impact memory have read and write
+// effects as a crude model of their atomic nature - we don't want "reads"
+// being hoisted out of loops.
+//===----------------------------------------------------------------------===//
+
+def AMDGPU_DsBarrierInitOp :
+    AMDGPU_Op<"ds_barrier_init">,
+    Arguments<(ins Arg<MemRefOf<[AMDGPU_DsBarrierStateType]>, "barrier(s)",
+                       [MemRead, MemWrite]>:$base,
+                   Variadic<Index>:$indices,
+                   I32:$participants)> {
+  let summary = "Initialize an in-LDS barrier.";
+  let description = [{
+    Given the location `!amdgpu.ds_barrier_state` in LDS (as specified by `base` and `indices`),
+    initialize the barrier structure so that the pending and init counts are equal to
+    `participants - 1`, which will have its high bits masked off, and its phase is equal to 0.
+
+    Note that we subtract 1 from `participants` when constructing the barrier state
+    to provide clearer high-level semantics.
+
+    The subtraction means that, when the `participant`th arrival occurs, the phase will change.
+    In practical terms, this means that you can use (for example) the number of subgroups or
+    waves per workgroup as `participants`, instead of manually needing to remove one.
+
+    While the write of the initial state will be performed atomically, no synchronization
+    between waves will be performed by this operation.
+
+    Example:
+    ```mlir
+    amdgpu.ds_barrier_init %barrier[], %c32 : memref<!amdgpu.ds_barrier_state, #gpu.address_space<workgroup>>, i32
+    ```
+
+    This operation is only available on gfx1250+.
+  }];
+
+  let assemblyFormat = [{
+    $base `[` $indices `]` `,` $participants attr-dict `:` type($base) `,` type($participants)
+  }];
+
+  let hasVerifier = 1;
+}
+
+def AMDGPU_DsBarrierPollStateOp :
+    AMDGPU_Op<"ds_barrier_poll_state">,
+    Arguments<(ins Arg<MemRefOf<[AMDGPU_DsBarrierStateType]>, "barrier(s)",
+                       [MemRead, MemWrite]>:$base,
+                 Variadic<Index>:$indices)>,
+    Results<(outs AMDGPU_DsBarrierStateType:$out)> {
+  let summary = "Atomically read the state of an in-LDS barrier.";
+  let description = [{
+    Atomically read and return the state of the barrier at `base[indices...]`.
+
+    This will ultimately act like a `memref.load`, but this operation will ensure
+    that appropriate atomic orderings and syncscopes are set.
+
+    Example:
+    ```mlir
+    %state = amdgpu.ds_barrier_poll_state %barrier[] : memref<!amdgpu.ds_barrier_state, #gpu.address_space<workgroup>> -> !amdgpu.ds_barrier_state
+    ```
+
+    This operation is only available on gfx1250+.
+  }];
+
+  let assemblyFormat = [{
+    $base `[` $indices `]` attr-dict `:` type($base) `->` type($out)
+  }];
+
+  let hasVerifier = 1;
+}
+
+def AMDGPU_DsAsyncBarrierArriveOp :
+    AMDGPU_Op<"ds_async_barrier_arrive">,
+    Arguments<(ins Arg<MemRefOf<[AMDGPU_DsBarrierStateType]>, "barrier(s)",
+                       [MemRead, MemWrite]>:$base,
+                 Variadic<Index>:$indices)> {
+  let summary = "Asynchronously arrive at an in-LDS barrier.";
+  let description = [{
+    Add a arrival to the LDS barrier at `base[indices]` to the sequence of pending
+    asynchronous memory operations.
+
+    This will add an "asynchronous memory operation" to the in-order list of pending
+    asynchronous loads from global memory to LDS. When the queue of such operations
+    issued before this operation is complete, the specified barrier will be arrived at,
+    decrementing the pending count by 1 **per lane that executes it** and rolling
+    over the phase if applicable.
+
+    This operation does not return the old barrier state.
+
+    Example:
+    ```mlir
+    amdgpu.ds_async_barrier_arrive %barrier[] : memref<!amdgpu.ds_barrier_state, #gpu.address_space<workgroup>>
+    ```
+
+    This operation is only available on gfx1250+.
+  }];
+
+  let assemblyFormat = [{
+    $base `[` $indices `]` attr-dict `:` type($base)
+  }];
+
+  let hasVerifier = 1;
+}
+
+def AMDGPU_DsBarrierArriveOp :
+    AMDGPU_Op<"ds_barrier_arrive">,
+    Arguments<(ins Arg<MemRefOf<[AMDGPU_DsBarrierStateType]>, "barrier(s)",
+                       [MemRead, MemWrite]>:$base,
+                 Variadic<Index>:$indices,
+                 I64:$count)>,
+    Results<(outs AMDGPU_DsBarrierStateType:$out)> {
+  let summary = "Arrive at an in-LDS barrier and return old state.";
+  let description = [{
+    Atomically arrive at the LDS barrier at `base[indices]` and decrement it by `count`,
+    rolling over the phase if needed and returning the old barrier state.
+
+    `count` is the number of participants that should be subtracted from the barrier's
+    pending count **per lane that executes the operation**.
+
+    Example:
+    ```mlir
+    %old_state = amdgpu.ds_barrier_arrive %barrier[], %c1 : memref<!amdgpu.ds_barrier_state, #gpu.address_space<workgroup>>, i64 -> !amdgpu.ds_barrier_state
+    ```
+
+    This operation is only available on gfx1250+.
+  }];
+
+  let assemblyFormat = [{
+    $base `[` $indices `]` `,` $count attr-dict `:` type($base) `,` type($count) `->` type($out)
+  }];
+
+  let hasVerifier = 1;
+}
+
+def AMDGPU_DsBarrierStatePhaseOp :
+    AMDGPU_Op<"ds_barrier_state_phase", [Pure]>,
+    Arguments<(ins AMDGPU_DsBarrierStateType:$state)>,
+    Results<(outs I32:$res)> {
+  let summary = "Extract the phase of a barrier state.";
+  let description = [{
+    Extract the phase of the `!amdgpu.ds_barrier_state` `state` as a 32-bit value.
+
+    Example:
+    ```mlir
+    %phase = amdgpu.ds_barrier_state_phase %state : !amdgpu.ds_barrier_state -> i32
+    ```
+  }];
+
+  let assemblyFormat = [{
+    $state attr-dict `:` type($state) `->` type($res)
+  }];
+}
+
+def AMDGPU_DsBarrierStatePendingCountOp :
+    AMDGPU_Op<"ds_barrier_state_pending_count", [Pure]>,
+    Arguments<(ins AMDGPU_DsBarrierStateType:$state)>,
+    Results<(outs I32:$res)> {
+  let summary = "Extract the pending count of a barrier state.";
+  let description = [{
+    Extract the pending count of the `!amdgpu.ds_barrier_state` `state` as a 32-bit value.
+
+    Example:
+    ```mlir
+    %pending = amdgpu.ds_barrier_state_pending_count %state : !amdgpu.ds_barrier_state -> i32
+    ```
+  }];
+
+  let assemblyFormat = [{
+    $state attr-dict `:` type($state) `->` type($res)
+  }];
+}
+
+def AMDGPU_DsBarrierStateInitCountOp :
+    AMDGPU_Op<"ds_barrier_state_init_count", [Pure]>,
+    Arguments<(ins AMDGPU_DsBarrierStateType:$state)>,
+    Results<(outs I32:$res)> {
+  let summary = "Extract the init count of a barrier state.";
+  let description = [{
+    Extract the init count of the `!amdgpu.ds_barrier_state` `state` as a 32-bit value.
+
+    Example:
+    ```mlir
+    %init = amdgpu.ds_barrier_state_init_count %state : !amdgpu.ds_barrier_state -> i32
+    ```
+  }];
+
+  let assemblyFormat = [{
+    $state attr-dict `:` type($state) `->` type($res)
+  }];
+}
+
+def AMDGPU_DsBarrierStatePhaseParity :
+    AMDGPU_Op<"ds_barrier_state_phase_parity", [Pure]>,
+    Arguments<(ins AMDGPU_DsBarrierStateType:$state)>,
+    Results<(outs I1:$res)> {
+  let summary = "Extract the phase parity of a barrier state.";
+  let description = [{
+    Return the parity of the phase of the `!amdgpu.ds_barrier_state` `state`.
+
+    This is intended to simplify the case where the barrier is being used to repeatedly
+    track completion of a task where the precise value of the phase won't mater, only that
+    it changed since (or as a result of) the arrival.
+
+    Example:
+    ```mlir
+    %parity = amdgpu.ds_barrier_state_phase_parity %state : !amdgpu.ds_barrier_state -> i1
+    ```
+  }];
+
+  let assemblyFormat = [{
+    $state attr-dict `:` type($state) `->` type($res)
+  }];
+}
+
 #endif // MLIR_DIALECT_AMDGPU_IR_AMDGPUOPS_TD
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUTypes.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUTypes.td
index 3ea1ba35815c3..aa76bff8ded80 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUTypes.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUTypes.td
@@ -11,6 +11,7 @@
 
 include "mlir/Dialect/AMDGPU/IR/AMDGPUBase.td"
 include "mlir/IR/AttrTypeBase.td"
+include "mlir/IR/BuiltinTypeInterfaces.td"
 
 //===----------------------------------------------------------------------===//
 // AMDGPU Type definitions
@@ -69,4 +70,28 @@ def AMDGPU_TDMDescriptorType : AMDGPU_Type<"TDMDescriptor", "tdm_descriptor"> {
   }];
 }
 
+def AMDGPU_DsBarrierStateType : AMDGPU_Type<"DsBarrierState", "ds_barrier_state",
+    [MemRefElementTypeInterface]> {
+  let summary = "State of an in-LDS barrier.";
+  let description = [{
+    Type that encodes the state of an in-LDS barrier as used by the atomic barrier
+    instructions introduced on gfx1250.
+
+    It consists of a 28-bit count of the number of pending arrivals at the barrier (the
+    *pending count*) in bits [27:0], a 4-bit *phase* in bits [31:28], and the 32-bit count
+    to re-initialize the pending count to on phase change (the *init count*) in bits [63:32].
+
+    When an instruction (either one of the explicit arrival primitives or tensor data
+    movement) *arrives* at such a barrier, the pending count is decremented. If this
+    decrement would cause the pending count to underflow, the count is instead reset
+    to the init count and the phase is decremented (wrapping back to 0). When the
+    phase is decremented, sleeping waves are woken up so they can check the barrier.
+
+    The barrier state resides in LDS, but an old barrier state can be returned from atomic
+    arrival instructions or though atomic loads.
+
+    This feature is not available prior to gfx1250.
+  }];
+}
+
 #endif // MLIR_DIALECT_AMDGPU_IR_AMDGPUTYPES_TD
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index e7dd8ea6f9149..1a2cbeeae7d0f 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -2602,6 +2602,257 @@ struct AMDGPUPermlaneLowering : public ConvertOpToLLVMPattern<PermlaneSwapOp> {
   }
 };
 
+//===----------------------------------------------------------------------===//
+// In-LDS Barrier Operations
+//===----------------------------------------------------------------------===//
+
+// Bit layout of ds_barrier_state (as i64):
+// [63:32] init count (32 bits)
+// [31:28] phase (4 bits)
+// [27:0] pending count (28 bits)
+constexpr int32_t kDsBarrierPendingCountBitWidth = 28;
+constexpr int32_t kDsBarrierPhasePos = kDsBarrierPendingCountBitWidth;
+constexpr int32_t kDsBarrierInitCountPos = 32;
+constexpr int32_t kDsBarrierPendingCountMask =
+    (1 << kDsBarrierPendingCountBitWidth) - 1;
+
+struct DsBarrierInitOpLowering
+    : public ConvertOpToLLVMPattern<DsBarrierInitOp> {
+  Chipset chipset;
+
+  DsBarrierInitOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
+      : ConvertOpToLLVMPattern<DsBarrierInitOp>(converter), chipset(chipset) {}
+
+  LogicalResult
+  matchAndRewrite(DsBarrierInitOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    if (chipset < kGfx1250)
+      return op->emitOpError("only supported on gfx1250+");
+
+    Location loc = op.getLoc();
+    Type i64 = rewriter.getI64Type();
+
+    MemRefType memrefType = cast<MemRefType>(op.getBase().getType());
+    Value ptr = getStridedElementPtr(rewriter, loc, memrefType,
+                                     adaptor.getBase(), adaptor.getIndices());
+
+    // Note: We give participants as the number of arrivals that have to occur
+    // before the phase changes. Hardware changes the phase when the count
+    // actually wraps around, so we subtract 1 to get the behavior we're looking
+    // for.
+    Value initCount =
+        LLVM::SubOp::create(rewriter, loc, adaptor.getParticipants(),
+                            createI32Constant(rewriter, loc, 1));
+
+    // Just a bit of paranoia, but this also allows for configurable width if
+    // that becomes a thing.
+    Value countMask =
+        createI32Constant(rewriter, loc, kDsBarrierPendingCountMask);
+    Value maskedCount32 =
+        LLVM::AndOp::create(rewriter, loc, initCount, countMask);
+    Value maskedCount = LLVM::ZExtOp::create(rewriter, loc, i64, maskedCount32);
+
+    Value initCountShifted = LLVM::ShlOp::create(
+        rewriter, loc, maskedCount,
+        createI64Constant(rewriter, loc, kDsBarrierInitCountPos));
+    Value barrierState =
+        LLVM::OrOp::create(rewriter, loc, initCountShifted, maskedCount);
+
+    LLVM::StoreOp::create(
+        rewriter, loc, barrierState, ptr, /*alignment=*/8, /*isVolatile=*/false,
+        /*isNonTemporal=*/false,
+        /*isInvariantGroup=*/false, LLVM::AtomicOrdering::release,
+        /*syncscope=*/"workgroup");
+
+    rewriter.eraseOp(op);
+    return success();
+  }
+};
+
+struct DsBarrierPollStateOpLowering
+    : public ConvertOpToLLVMPattern<DsBarrierPollStateOp> {
+  Chipset chipset;
+
+  DsBarrierPollStateOpLowering(const LLVMTypeConverter &converter,
+                               Chipset chipset)
+      : ConvertOpToLLVMPattern<DsBarrierPollStateOp>(converter),
+        chipset(chipset) {}
+
+  LogicalResult
+  matchAndRewrite(DsBarrierPollStateOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    if (chipset < kGfx1250)
+      return op->emitOpError("only supported on gfx1250+");
+
+    Location loc = op.getLoc();
+    Type i64 = rewriter.getI64Type();
+
+    MemRefType memrefType = cast<MemRefType>(op.getBase().getType());
+    Value ptr = getStridedElementPtr(rewriter, loc, memrefType,
+                                     adaptor.getBase(), adaptor.getIndices());
+
+    // Atomic load with workgroup scope and acquire ordering should be what
+    // we're looking for.
+    rewriter.replaceOpWithNewOp<LLVM::LoadOp>(
+        op, i64, ptr, /*alignment=*/8, /*volatile_=*/false,
+        /*nontemporal=*/false, /*invariant=*/false,
+        /*invariantGroup=*/false, LLVM::AtomicOrdering::acquire,
+        /*syncscope=*/"workgroup");
+    return success();
+  }
+};
+
+struct DsAsyncBarrierArriveOpLowering
+    : public ConvertOpToLLVMPattern<DsAsyncBarrierArriveOp> {
+  Chipset chipset;
+
+  DsAsyncBarrierArriveOpLowering(const LLVMTypeConverter &converter,
+                                 Chipset chipset)
+      : ConvertOpToLLVMPattern<DsAsyncBarrierArriveOp>(converter),
+        chipset(chipset) {}
+
+  LogicalResult
+  matchAndRewrite(DsAsyncBarrierArriveOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    if (chipset < kGfx1250)
+      return op->emitOpError("only supported on gfx1250+");
+
+    Location loc = op.getLoc();
+
+    MemRefType memrefType = cast<MemRefType>(op.getBase().getType());
+    Value ptr = getStridedElementPtr(rewriter, loc, memrefType,
+                                     adaptor.getBase(), adaptor.getIndices());
+
+    rewriter.replaceOpWithNewOp<ROCDL::DsAtomicAsyncBarrierArriveOp>(
+        op, ptr, /*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr,
+        /*tbaa=*/nullptr);
+    return success();
+  }
+};
+
+struct DsBarrierArriveOpLowering
+    : public ConvertOpToLLVMPattern<DsBarrierArriveOp> {
+  Chipset chipset;
+
+  DsBarrierArriveOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
+      : ConvertOpToLLVMPattern<DsBarrierArriveOp>(converter), chipset(chipset) {
+  }
+
+  LogicalResult
+  matchAndRewrite(DsBarrierArriveOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    if (chipset < kGfx1250)
+      return op->emitOpError("only supported on gfx1250+");
+
+    Location loc = op.getLoc();
+    Type i64 = rewriter.getI64Type();
+
+    MemRefType memrefType = cast<MemRefType>(op.getBase().getType());
+    Value ptr = getStridedElementPtr(rewriter, loc, memrefType,
+                                     adaptor.getBase(), adaptor.getIndices());
+
+    rewriter.replaceOpWithNewOp<ROCDL::DsAtomicBarrierArriveRtnOp>(
+        op, i64, ptr, adaptor.getCount(), /*alias_scopes=*/nullptr,
+        /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
+    return success();
+  }
+};
+
+struct DsBarrierStatePhaseOpLowering
+    : public ConvertOpToLLVMPattern<DsBarrierStatePhaseOp> {
+  using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(DsBarrierStatePhaseOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    Type i32 = rewriter.getI32Type();
+
+    Value state = adaptor.getState();
+
+    Value noInitCount = LLVM::TruncOp::create(rewriter, loc, i32, state);
+    Value phase = LLVM::LShrOp::create(
+        rewriter, loc, noInitCount,
+        createI32Constant(rewriter, loc, kDsBarrierPhasePos));
+
+    rewriter.replaceOp(op, phase);
+    return success();
+  }
+};
+
+struct DsBarrierStatePendingCountOpLowering
+    : public ConvertOpToLLVMPattern<DsBarrierStatePendingCountOp> {
+  using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(DsBarrierStatePendingCountOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    Type i32 = rewriter.getI32Type();
+
+    Value state = adaptor.getState();
+
+    Value noInitCount = LLVM::TruncOp::create(rewriter, loc, i32, state);
+    Value pendingCount = LLVM::AndOp::create(
+        rewriter, loc, noInitCount,
+        createI32Constant(rewriter, loc,
+                          static_cast<uint32_t>(kDsBarrierPendingCountMask)));
+
+    rewriter.replaceOp(op, pendingCount);
+    return success();
+  }
+};
+
+struct DsBarrierStateInitCountOpLowering
+    : public ConvertOpToLLVMPattern<DsBarrierStateInitCountOp> {
+  using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(DsBarrierStateInitCountOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    Type i32 = rewriter.getI32Type();
+
+    Value state = adaptor.getState();
+
+    Value initCountI64 = LLVM::LShrOp::create(
+        rewriter, loc, state,
+        createI64Constant(rewriter, loc, kDsBarrierInitCountPos));
+    Value initCount = LLVM::TruncOp::create(rewriter, loc, i32, initCountI64);
+
+    rewriter.replaceOp(op, initCount);
+    return success();
+  }
+};
+
+struct DsBarrierStatePhaseParityLowering
+    : public ConvertOpToLLVMPattern<DsBarrierStatePhaseParity> {
+  using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(DsBarrierStatePhaseParity op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    Type i1 = rewriter.getI1Type();
+
+    Value state = adaptor.getState();
+
+    Value noInitCount =
+        LLVM::TruncOp::create(rewriter, loc, rewriter.getI32Type(), state);
+    Value phase = LLVM::LShrOp::create(
+        rewriter, loc, noInitCount,
+        createI32Constant(rewriter, loc, kDsBarrierPhasePos));
+    Value parity = LLVM::TruncOp::create(rewriter, loc, i1, phase);
+
+    rewriter.replaceOp(op, parity);
+    return success();
+  }
+};
+
+//===----------------------------------------------------------------------===//
+// Tensor Data Mover (TDM)
+//===----------------------------------------------------------------------===//
+
 static Value setValueAtOffset(ConversionPatternRewriter &rewriter, Location loc,
                               Value accumulator, Value value, int64_t shift) {
   shift = shift % 32;
@@ -3505,6 +3756,9 @@ void mlir::populateAMDGPUTypeAndAttributeConversions(
         }
         return TypeConverter::AttributeConversionResult::abort();
       });
+  typeConverter.addConversion([&](DsBarrierStateType type) -> Type {
+    return IntegerType::get(type.getContext(), 64);
+  });
   typeConverter.addConversion([&](TDMBaseType type) -> Type {
     Type i32 = IntegerType::get(type.getContext(), 32);
     return typeConverter.convertType(VectorType::get(4, i32));
@@ -3574,7 +3828,12 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
            AMDGPUTensorLoadStoreOpLowering<TensorLoadToLDSOp,
                                            ROCDL::TensorLoadToLDSOp>,
            AMDGPUTensorLoadStoreOpLowering<TensorStoreFromLDSOp,
-                                           ROCDL::TensorStoreFromLDSOp>>(
-          converter, chipset);
-  patterns.add<AMDGPUSwizzleBitModeLowering>(converter);
+                                           ROCDL::TensorStoreFromLDSOp>,
+           DsBarrierInitOpLowering, DsBarrierPollStateOpLowering,
+           DsAsyncBarrierArriveOpLowering, DsBarrierArriveOpLowering>(converter,
+                                                                      chipset);
+  patterns.add<AMDGPUSwizzleBitModeLowering, DsBarrierStatePhaseOpLowering,
+               DsBarrierStatePendingCountOpLowering,
+               DsBarrierStateInitCountOpLowering,
+               DsBarrierStatePhaseParityLowering>(converter);
 }
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUOps.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUOps.cpp
index 87a813a31608d..4ed62cacd006f 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUOps.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUOps.cpp
@@ -1208,5 +1208,34 @@ void ScaledMFMAOp::getCanonicalizationPatterns(RewritePatternSet &results,
   results.add<PackScales>(context);
 }
 
+//===----------------------------------------------------------------------===//
+// In-LDS Barrier Operations (gfx1250+)
+//===----------------------------------------------------------------------===//
+
+template <typename T>
+static LogicalResult verifyDsBarrierOpCommon(T &op) {
+  MemRefType memrefType = llvm::cast<MemRefType>(op.getBase().getType());
+  if (!hasWorkgroupMemorySpace(memrefType.getMemorySpace()))
+    return op.emitOpError("barrier must be in workgroup (LDS) memory");
+
+  return success();
+}
+
+LogicalResult DsBarrierInitOp::verify() {
+  return verifyDsBarrierOpCommon(*this);
+}
+
+LogicalResult DsBarrierPollStateOp::verify() {
+  return verifyDsBarrierOpCommon(*this);
+}
+
+LogicalResult DsAsyncBarrierArriveOp::verify() {
+  return verifyDsBarrierOpCommon(*this);
+}
+
+LogicalResult DsBarrierArriveOp::verify() {
+  return verifyDsBarrierOpCommon(*this);
+}
+
 #define GET_OP_CLASSES
 #include "mlir/Dialect/AMDGPU/IR/AMDGPU.cpp.inc"
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/gfx1250.mlir b/mlir/test/Conversion/AMDGPUToROCDL/gfx1250.mlir
index 7eb7ddeab13a4..af12cfdd9a633 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/gfx1250.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/gfx1250.mlir
@@ -760,8 +760,6 @@ func.func @tensor_store_from_lds(%desc: !amdgpu.tdm_descriptor) {
   func.return
 }
 
-// -----
-
 // CHECK-LABEL: func @make_gather_dma_descriptor
 // CHECK-SAME: (%[[BASE:.+]]: !amdgpu.tdm_gather_base<i32, i16>, %[[INDICES:.+]]: vector<13xi16>)
 func.func @make_gather_dma_descriptor(%base: !amdgpu.tdm_gather_base<i32, i16>, %indices: vector<13xi16>) -> !amdgpu.tdm_descriptor {
@@ -868,3 +866,95 @@ func.func @make_gather_dma_descriptor(%base: !amdgpu.tdm_gather_base<i32, i16>,
   func.return %descriptor : !amdgpu.tdm_descriptor
 }
 
+/// LDS barriers
+
+// CHECK-LABEL: func @ds_barrier_init
+func.func @ds_barrier_init(%barrier: memref<!amdgpu.ds_barrier_state, #gpu.address_space<workgroup>>, %participants: i32) {
+  // CHECK: [[CAST:%.*]] = builtin.unrealized_conversion_cast %arg0
+  // CHECK: [[PTR:%.*]] = llvm.extractvalue [[CAST]][1]
+  // CHECK: [[C1:%.*]] = llvm.mlir.constant(1 : i32)
+  // CHECK: [[SUB:%.*]] = llvm.sub %arg1, [[C1]]
+  // CHECK: [[MASK:%.*]] = llvm.mlir.constant(268435455 : i32)
+  // CHECK: [[MASKED:%.*]] = llvm.and [[SUB]], [[MASK]]
+  // CHECK: [[ZEXT:%.*]] = llvm.zext [[MASKED]] : i32 to i64
+  // CHECK: [[C32:%.*]] = llvm.mlir.constant(32 : i64)
+  // CHECK: [[INIT_SHIFT:%.*]] = llvm.shl [[ZEXT]], [[C32]]
+  // CHECK: [[STATE:%.*]] = llvm.or [[INIT_SHIFT]], [[ZEXT]]
+  // CHECK: llvm.store [[STATE]], [[PTR]] atomic syncscope("workgroup") release
+  amdgpu.ds_barrier_init %barrier[], %participants : memref<!amdgpu.ds_barrier_state, #gpu.address_space<workgroup>>, i32
+  func.return
+}
+
+// CHECK-LABEL: func @ds_barrier_poll_state
+func.func @ds_barrier_poll_state(%barrier: memref<!amdgpu.ds_barrier_state, #gpu.address_space<workgroup>>) -> !amdgpu.ds_barrier_state {
+  // CHECK: [[CAST:%.*]] = builtin.unrealized_conversion_cast %arg0
+  // CHECK: [[PTR:%.*]] = llvm.extractvalue [[CAST]][1]
+  // CHECK: [[LOADED:%.*]] = llvm.load [[PTR]] atomic syncscope("workgroup") acquire
+  // CHECK: builtin.unrealized_conversion_cast [[LOADED]]
+  %state = amdgpu.ds_barrier_poll_state %barrier[] : memref<!amdgpu.ds_barrier_state, #gpu.address_space<workgroup>> -> !amdgpu.ds_barrier_state
+  func.return %state : !amdgpu.ds_barrier_state
+}
+
+// CHECK-LABEL: func @ds_async_barrier_arrive
+func.func @ds_async_barrier_arrive(%barrier: memref<!amdgpu.ds_barrier_state, #gpu.address_space<workgroup>>) {
+  // CHECK: [[CAST:%.*]] = builtin.unrealized_conversion_cast %arg0
+  // CHECK: [[PTR:%.*]] = llvm.extractvalue [[CAST]][1]
+  // CHECK: rocdl.ds.atomic.async.barrier.arrive.b64 [[PTR]] : !llvm.ptr<3>
+  amdgpu.ds_async_barrier_arrive %barrier[] : memref<!amdgpu.ds_barrier_state, #gpu.address_space<workgroup>>
+  func.return
+}
+
+// CHECK-LABEL: func @ds_barrier_arrive
+func.func @ds_barrier_arrive(%barrier: memref<!amdgpu.ds_barrier_state, #gpu.address_space<workgroup>>, %count: i64) -> !amdgpu.ds_barrier_state {
+  // CHECK: [[CAST:%.*]] = builtin.unrealized_conversion_cast %arg0
+  // CHECK: [[PTR:%.*]] = llvm.extractvalue [[CAST]][1]
+  // CHECK: [[OLD:%.*]] = rocdl.ds.atomic.barrier.arrive.rtn.b64 [[PTR]], %arg1 : !llvm.ptr<3>, i64 -> i64
+  // CHECK: builtin.unrealized_conversion_cast [[OLD]]
+  %old_state = amdgpu.ds_barrier_arrive %barrier[], %count : memref<!amdgpu.ds_barrier_state, #gpu.address_space<workgroup>>, i64 -> !amdgpu.ds_barrier_state
+  func.return %old_state : !amdgpu.ds_barrier_state
+}
+
+// CHECK-LABEL: func @ds_barrier_state_phase
+func.func @ds_barrier_state_phase(%state: !amdgpu.ds_barrier_state) -> i32 {
+  // CHECK: [[CAST:%.*]] = builtin.unrealized_conversion_cast %arg0
+  // CHECK: [[TRUNC:%.*]] = llvm.trunc [[CAST]] : i64 to i32
+  // CHECK: [[C28:%.*]] = llvm.mlir.constant(28 : i32)
+  // CHECK: [[PHASE:%.*]] = llvm.lshr [[TRUNC]], [[C28]]
+  // CHECK: return [[PHASE]]
+  %phase = amdgpu.ds_barrier_state_phase %state : !amdgpu.ds_barrier_state -> i32
+  func.return %phase : i32
+}
+
+// CHECK-LABEL: func @ds_barrier_state_pending_count
+func.func @ds_barrier_state_pending_count(%state: !amdgpu.ds_barrier_state) -> i32 {
+  // CHECK: [[CAST:%.*]] = builtin.unrealized_conversion_cast %arg0
+  // CHECK: [[TRUNC:%.*]] = llvm.trunc [[CAST]] : i64 to i32
+  // CHECK: [[MASK:%.*]] = llvm.mlir.constant(268435455 : i32)
+  // CHECK: [[COUNT:%.*]] = llvm.and [[TRUNC]], [[MASK]]
+  // CHECK: return [[COUNT]]
+  %pending = amdgpu.ds_barrier_state_pending_count %state : !amdgpu.ds_barrier_state -> i32
+  func.return %pending : i32
+}
+
+// CHECK-LABEL: func @ds_barrier_state_init_count
+func.func @ds_barrier_state_init_count(%state: !amdgpu.ds_barrier_state) -> i32 {
+  // CHECK: [[CAST:%.*]] = builtin.unrealized_conversion_cast %arg0
+  // CHECK: [[C32:%.*]] = llvm.mlir.constant(32 : i64)
+  // CHECK: [[SHIFTED:%.*]] = llvm.lshr [[CAST]], [[C32]]
+  // CHECK: [[COUNT:%.*]] = llvm.trunc [[SHIFTED]] : i64 to i32
+  // CHECK: return [[COUNT]]
+  %init = amdgpu.ds_barrier_state_init_count %state : !amdgpu.ds_barrier_state -> i32
+  func.return %init : i32
+}
+
+// CHECK-LABEL: func @ds_barrier_state_phase_parity
+func.func @ds_barrier_state_phase_parity(%state: !amdgpu.ds_barrier_state) -> i1 {
+  // CHECK: [[CAST:%.*]] = builtin.unrealized_conversion_cast %arg0
+  // CHECK: [[TRUNC:%.*]] = llvm.trunc [[CAST]] : i64 to i32
+  // CHECK: [[C28:%.*]] = llvm.mlir.constant(28 : i32)
+  // CHECK: [[SHIFTED:%.*]] = llvm.lshr [[TRUNC]], [[C28]]
+  // CHECK: [[PARITY:%.*]] = llvm.trunc [[SHIFTED]] : i32 to i1
+  // CHECK: return [[PARITY]]
+  %parity = amdgpu.ds_barrier_state_phase_parity %state : !amdgpu.ds_barrier_state -> i1
+  func.return %parity : i1
+}
diff --git a/mlir/test/Dialect/AMDGPU/invalid.mlir b/mlir/test/Dialect/AMDGPU/invalid.mlir
index 1299f3b14b14f..474fca1157118 100644
--- a/mlir/test/Dialect/AMDGPU/invalid.mlir
+++ b/mlir/test/Dialect/AMDGPU/invalid.mlir
@@ -516,3 +516,27 @@ func.func @sparse_mfma_wrong_dest_count(%a: vector<4xf16>, %b: vector<8xf16>, %c
   %d = amdgpu.sparse_mfma 16x16x32 %a * %b + %c sparse(%idx : vector<4xi8>) : vector<4xf16>, vector<8xf16>, vector<16xf32>
   func.return %d : vector<16xf32>
 }
+
+// -----
+
+func.func @ds_barrier_init_non_workgroup(%barrier: memref<!amdgpu.ds_barrier_state>, %participants: i32) {
+  // expected-error at +1 {{'amdgpu.ds_barrier_init' op barrier must be in workgroup (LDS) memory}}
+  amdgpu.ds_barrier_init %barrier[], %participants : memref<!amdgpu.ds_barrier_state>, i32
+  func.return
+}
+
+// -----
+
+func.func @ds_barrier_poll_state_non_workgroup(%barrier: memref<!amdgpu.ds_barrier_state, #gpu.address_space<global>>) -> !amdgpu.ds_barrier_state {
+  // expected-error at +1 {{'amdgpu.ds_barrier_poll_state' op barrier must be in workgroup (LDS) memory}}
+  %state = amdgpu.ds_barrier_poll_state %barrier[] : memref<!amdgpu.ds_barrier_state, #gpu.address_space<global>> -> !amdgpu.ds_barrier_state
+  func.return %state : !amdgpu.ds_barrier_state
+}
+
+// -----
+
+func.func @ds_barrier_arrive_non_workgroup(%barrier: memref<!amdgpu.ds_barrier_state, #amdgpu.address_space<fat_raw_buffer>>, %count: i64) -> !amdgpu.ds_barrier_state {
+  // expected-error at +1 {{'amdgpu.ds_barrier_arrive' op barrier must be in workgroup (LDS) memory}}
+  %old_state = amdgpu.ds_barrier_arrive %barrier[], %count : memref<!amdgpu.ds_barrier_state, #amdgpu.address_space<fat_raw_buffer>>, i64 -> !amdgpu.ds_barrier_state
+  func.return %old_state : !amdgpu.ds_barrier_state
+}
diff --git a/mlir/test/Dialect/AMDGPU/ops.mlir b/mlir/test/Dialect/AMDGPU/ops.mlir
index 2b3234ef8510d..c3e7c8b70f4ee 100644
--- a/mlir/test/Dialect/AMDGPU/ops.mlir
+++ b/mlir/test/Dialect/AMDGPU/ops.mlir
@@ -801,3 +801,25 @@ func.func @wmma_scale(%fp8_src: vector<64xf8E4M3FN>, %fp6_alt_src: vector<64xf6E
   %5 = amdgpu.scaled_wmma 32x16x128 (%scale_vec4_e4m3 * %fp4_src_a) * (%scale_vec4_e4m3 * %fp4_src_b) + %dst1 {a_first_scale_lane = 0 : i32, b_first_scale_lane = 0 : i32} : vector<4xf8E4M3FN>, vector<128xf4E2M1FN>, vector<4xf8E4M3FN>, vector<64xf4E2M1FN>, vector<16xf32>
   func.return
 }
+
+// CHECK-LABEL: func @ds_barrier_ops
+// CHECK-SAME: ([[BARRIER:%.*]]: memref<!amdgpu.ds_barrier_state, #gpu.address_space<workgroup>>, [[COUNT:%.*]]: i64, [[PARTICIPANTS:%.*]]: i32)
+func.func @ds_barrier_ops(%barrier: memref<!amdgpu.ds_barrier_state, #gpu.address_space<workgroup>>, %count: i64, %participants: i32) {
+  // CHECK: amdgpu.ds_barrier_init [[BARRIER]][], [[PARTICIPANTS]] : memref<!amdgpu.ds_barrier_state, #gpu.address_space<workgroup>>, i32
+  amdgpu.ds_barrier_init %barrier[], %participants : memref<!amdgpu.ds_barrier_state, #gpu.address_space<workgroup>>, i32
+  // CHECK: [[STATE:%.*]] = amdgpu.ds_barrier_poll_state [[BARRIER]][] : memref<!amdgpu.ds_barrier_state, #gpu.address_space<workgroup>> -> !amdgpu.ds_barrier_state
+  %state = amdgpu.ds_barrier_poll_state %barrier[] : memref<!amdgpu.ds_barrier_state, #gpu.address_space<workgroup>> -> !amdgpu.ds_barrier_state
+  // CHECK: amdgpu.ds_async_barrier_arrive [[BARRIER]][] : memref<!amdgpu.ds_barrier_state, #gpu.address_space<workgroup>>
+  amdgpu.ds_async_barrier_arrive %barrier[] : memref<!amdgpu.ds_barrier_state, #gpu.address_space<workgroup>>
+  // CHECK: [[OLD_STATE:%.*]] = amdgpu.ds_barrier_arrive [[BARRIER]][], [[COUNT]] : memref<!amdgpu.ds_barrier_state, #gpu.address_space<workgroup>>, i64 -> !amdgpu.ds_barrier_state
+  %old_state = amdgpu.ds_barrier_arrive %barrier[], %count : memref<!amdgpu.ds_barrier_state, #gpu.address_space<workgroup>>, i64 -> !amdgpu.ds_barrier_state
+  // CHECK: [[PHASE:%.*]] = amdgpu.ds_barrier_state_phase [[STATE]] : !amdgpu.ds_barrier_state -> i32
+  %phase = amdgpu.ds_barrier_state_phase %state : !amdgpu.ds_barrier_state -> i32
+  // CHECK: [[PENDING:%.*]] = amdgpu.ds_barrier_state_pending_count [[STATE]] : !amdgpu.ds_barrier_state -> i32
+  %pending = amdgpu.ds_barrier_state_pending_count %state : !amdgpu.ds_barrier_state -> i32
+  // CHECK: [[INIT:%.*]] = amdgpu.ds_barrier_state_init_count [[STATE]] : !amdgpu.ds_barrier_state -> i32
+  %init = amdgpu.ds_barrier_state_init_count %state : !amdgpu.ds_barrier_state -> i32
+  // CHECK: [[PARITY:%.*]] = amdgpu.ds_barrier_state_phase_parity [[STATE]] : !amdgpu.ds_barrier_state -> i1
+  %parity = amdgpu.ds_barrier_state_phase_parity %state : !amdgpu.ds_barrier_state -> i1
+  func.return
+}



More information about the Mlir-commits mailing list