[Mlir-commits] [mlir] [mlir][nvgpu] separate ops, types, attributes definitions in NVGPU dialect. (PR #129846)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Mar 5 00:01:00 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-nvgpu

@llvm/pr-subscribers-mlir

Author: lonely eagle (linuxlonelyeagle)

<details>
<summary>Changes</summary>

It is hoped that the Ops, Types, and Attributes of the NVGPU dialect can be defined in separate files.If downstream projects extend NVGPU and define other Ops, the types and attributes will be used.This PR was raised to avoid including the definition of NVGPU Ops.

---

Patch is 67.90 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/129846.diff


7 Files Affected:

- (modified) mlir/include/mlir/Dialect/NVGPU/IR/CMakeLists.txt (+9-4) 
- (modified) mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td (+1-720) 
- (modified) mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h (+2-2) 
- (added) mlir/include/mlir/Dialect/NVGPU/IR/NVGPUOps.td (+633) 
- (added) mlir/include/mlir/Dialect/NVGPU/IR/NVGPUTypes.td (+112) 
- (modified) mlir/lib/Dialect/NVGPU/IR/CMakeLists.txt (+1-1) 
- (modified) mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp (+4-4) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/NVGPU/IR/CMakeLists.txt
index 13d754ca06316..ecdaae7f24d93 100644
--- a/mlir/include/mlir/Dialect/NVGPU/IR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/NVGPU/IR/CMakeLists.txt
@@ -1,5 +1,10 @@
 add_mlir_dialect(NVGPU nvgpu)
-add_mlir_doc(NVGPU NVGPU Dialects/ -gen-dialect-doc)
+add_mlir_doc(NVGPUOps NVGPU Dialects/ -gen-dialect-doc)
+
+set(LLVM_TARGET_DEFINITIONS NVGPUOps.td)
+mlir_tablegen(NVGPUOps.h.inc -gen-op-decls)
+mlir_tablegen(NVGPUOps.cpp.inc -gen-op-defs)
+add_public_tablegen_target(MLIRNVGPUOpsIncGen)
 
 set(LLVM_TARGET_DEFINITIONS NVGPU.td)
 mlir_tablegen(NVGPUEnums.h.inc -gen-enum-decls)
@@ -11,7 +16,7 @@ mlir_tablegen(NVGPUAttrDefs.h.inc -gen-attrdef-decls)
 mlir_tablegen(NVGPUAttrDefs.cpp.inc -gen-attrdef-defs)
 add_public_tablegen_target(MLIRNVGPUAttributesIncGen)
 
-set(LLVM_TARGET_DEFINITIONS NVGPU.td)
-mlir_tablegen(NVGPUAttrTypes.h.inc -gen-typedef-decls)
-mlir_tablegen(NVGPUAttrTypes.cpp.inc -gen-typedef-decls)
+set(LLVM_TARGET_DEFINITIONS NVGPUTypes.td)
+mlir_tablegen(NVGPUTypeDefs.h.inc -gen-typedef-decls)
+mlir_tablegen(NVGPUTypeDefs.cpp.inc -gen-typedef-defs)
 add_public_tablegen_target(MLIRNVGPUTypesIncGen)
diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
index f48fa9976da12..6b5470310e4a1 100644
--- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
+++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
@@ -1,21 +1,10 @@
-//===-- NVGPU.td - NVGPU dialect operation definitions *- tablegen -*------===//
+//===-- NVGPU.td - Attribute defs for NVGPU dialect *- tablegen -*---------===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 //
 //===----------------------------------------------------------------------===//
-//
-// This file defines the basic operations for the NVGPU dialect.
-//
-// This NVGPU provides a bridge between the target agnostic GPU and Vector
-// dialects and lower level NVVM dialect. This allow representing PTX specific
-// operations while using MLIR high level concepts like memref and 2-D vector.
-//
-// Ops semantic are going to be based on vendor specific PTX defintion:
-// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html
-//
-//===----------------------------------------------------------------------===//
 
 #ifndef NVGPU
 #define NVGPU
@@ -127,712 +116,4 @@ def TensorMapOOBAttr : EnumAttr<NVGPU_Dialect, TensorMapOOBKind, "oob">;
 def TensorMapInterleaveAttr : EnumAttr<NVGPU_Dialect, TensorMapInterleaveKind, "interleave">;
 def RcpRoundingModeAttr : EnumAttr<NVGPU_Dialect, RcpRoundingMode, "rcp_rounding_mode">;
 
-//===----------------------------------------------------------------------===//
-// NVGPU Type Definitions
-//===----------------------------------------------------------------------===//
-
-class NVGPU_Type<string name, string typeMnemonic,
-        list<Trait> traits = []> : TypeDef<NVGPU_Dialect, name, traits> {
-  let mnemonic = typeMnemonic;
-}
-
-def NVGPU_DeviceAsyncToken : NVGPU_Type<"DeviceAsyncToken",
-                                        "device.async.token", []> {
-  let summary = "device async token type";
-  let description = [{
-    `nvgpu.device.async.token` is a type returned by an asynchronous operation
-    that runs on the GPU (device). It is used to establish an SSA-based link
-    between the async operation (e.g. DeviceAsyncCopy) and operations that
-    group or synchronize the async operations (e.g. DeviceAsyncCreateGroupOp,
-    DeviceAsyncWaitOp).
-  }];
-}
-
-def NVGPU_MBarrierGroup : NVGPU_Type<"MBarrierGroup", "mbarrier.group", []> {
-  let summary = "mbarrier barrier type";
-  let description = [{
-    This is the type for one or more mbarrier object in shared memory that is 
-    used to synchronize a variable number of threads.
-
-    If `num_barriers` is not set, the number of mbarrier objects is 1.
-
-    A mbarrier object is 64 bit with 8 byte alignment. The mbarrier object 
-    can be initiated and invalidated.
-
-    [See for more details in PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#size-and-alignment-of-mbarrier-object)
-  }];    
-  let parameters = (ins "Attribute":$memorySpace, DefaultValuedParameter<"unsigned", "1">:$num_barriers);
-  let assemblyFormat = "`<` struct(params) `>`";
-  let builders = [
-    TypeBuilder<(ins "Attribute":$memorySpace), [{
-      return $_get($_ctxt, memorySpace, 1);
-    }]>
-  ];
-}
-
-def NVGPU_MBarrierToken : NVGPU_Type<"MBarrierToken", "mbarrier.token", []> { }
-
-// https://docs.nvidia.com/cuda/parallel-thread-execution/#tensor-map
-def NVGPU_TensorMapDescriptor : NVGPU_Type<"TensorMapDescriptor", "tensormap.descriptor", []> {
-  let summary = "TensorMap descriptor";
-  let parameters = (ins "MemRefType":$tensor,
-                        EnumParameter<TensorMapSwizzleKind>:$swizzle,
-                        EnumParameter<TensorMapL2PromoKind>:$l2promo,
-                        EnumParameter<TensorMapOOBKind>:$oob,
-                        EnumParameter<TensorMapInterleaveKind>:$interleave);
-  let description = [{
-    `nvgpu.tma.descriptor` is a type that represents a TMA descriptor. It is 
-    128-byte object either in constant space or kernel paramater.    
-  }];
-  let assemblyFormat = "`<` struct(params) `>`";
-}
-
-def NVGPU_WarpgroupMatrixDescriptor : NVGPU_Type<"WarpgroupMatrixDescriptor", "warpgroup.descriptor", []> {
-  let summary = "Warpgroup matrix descriptor type";
-  let description = [{
-  The descriptor specifies the properties of the matrix in shared memory that 
-  is a multiplicand in the matrix multiply and accumulate operation. 
-  
-  The descriptor is a 64-bit value contained in a register with the following:
-  ```
-  +---------+-----+-----------+-----+-----------+-----+-----+-----------+-----+
-  |   0-13  |14-15|   16-29   |30-31|   32-45   |46-48|49-51|   52-61   |62-63|
-  +---------+-----+-----------+-----+-----------+-----+-----+-----------+-----+
-  |  14bits |2bits|   14bits  |2bits|   14bits  |2bits|3bits|   10bits  |2bits|
-  +---------+-----+-----------+-----+-----------+-----+-----+-----------+-----+
-  | BaseAddr|  0  | LeadingDim|  0  |   Stride  |  0  |Offst|     0     |Swzle|
-  +---------+-----+-----------+-----+-----------+-----+-----+-----------+-----+
-  ```
-   
-  [See for more details in PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-shared-memory-layout-matrix-descriptor) 
-  
-  }];  
-  let parameters = (ins "MemRefType":$tensor);
-  let assemblyFormat = "`<` struct(params) `>`";
-}
-
-def NVGPU_WarpgroupAccumulator : NVGPU_Type<"WarpgroupAccumulator", "warpgroup.accumulator", []> {
-  let parameters = (ins "VectorType":$fragmented);
-  let assemblyFormat = "`<` struct(params) `>`";
-  let description = [{
-    This type represents the result matrix obtained from `nvgpu.warpgroup.mma`. 
-    The `$fragmented` type signifies the distributed or fragmented result 
-    vector that is collectively owned by all the threads in the warp-group 
-    that executed `nvgpu.warpgroup.mma`.
-    [See the details of register fragment layout for accumulator matrix D]
-    (https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#wgmma-64n16-d) 
-  }];
-}
-
-//===----------------------------------------------------------------------===//
-// NVGPU Op Definitions
-//===----------------------------------------------------------------------===//
-
-class NVGPU_Op<string mnemonic, list<Trait> traits = []> :
-  Op<NVGPU_Dialect, mnemonic, traits> {}
-
-def NVGPU_LdMatrixOp : NVGPU_Op<"ldmatrix", [
-                                MemoryEffects<[MemRead]>,
-                                PredOpTrait<"srcMemref and res have same element type",
-                                            TCresVTEtIsSameAsOp<0, 0>>]> {
-  let description = [{
-    The `nvgpu.ldmatrix` op represents loading a matrix fragment from
-    memory to registers. The source and result type must be compatible
-    with lowering to the `nvvm.ldmatrix` instruction. This op represents
-    the distributed version of a `vector.transfer_read` as an intermediate
-    step between lowering from `vector.transfer_read` to `nvvm.ldmatrix`.
-
-    This operation is meant to follow the semantic of described here:
-    https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-ldmatrix
-
-    Example:
-    ```mlir
-    %0 = nvgpu.ldmatrix %sm[%c0, %c0] {numTiles = 4 : i32, transpose = false} :
-      memref<?x?xf16, 3> -> vector<4x2xf16>
-    ```
-  }];
-
-  let arguments = (ins Arg<AnyMemRef, "", [MemReadAt<0, FullEffect>]>:$srcMemref,
-                           Variadic<Index>:$indices, BoolAttr:$transpose,
-                           I32Attr:$numTiles);
-  let results = (outs AnyVectorOfNonZeroRank:$res);
-  let assemblyFormat = [{
-    $srcMemref`[` $indices `]` attr-dict `:` type($srcMemref) `->` type($res)
-  }];
-
-  let hasVerifier = 1;
-}
-
-class NVGPU_MmaSyncOp<string mnemonic> :
-        NVGPU_Op<mnemonic,  [Pure,
-                             PredOpTrait<"matrixA and matrixB have same element type",
-                                         TCopVTEtIsSameAs<0, 1>>]> {
-  code extraBaseClassDeclaration = [{
-    std::array<int64_t, 3> getMmaShapeAsArray() {
-      ArrayAttr mmaShape = this->getMmaShape();
-      assert(mmaShape.size() == 3 && "mmaShape should be three integers");
-      return {::llvm::cast<IntegerAttr>(mmaShape[0]).getInt(),
-              ::llvm::cast<IntegerAttr>(mmaShape[1]).getInt(),
-              ::llvm::cast<IntegerAttr>(mmaShape[2]).getInt()};
-    }
-  }];
-
-  let hasVerifier = 1;
-}
-
-def NVGPU_MmaSyncOp : NVGPU_MmaSyncOp<"mma.sync"> {
-  let description = [{
-    The `nvgpu.mma.sync` op represents the warp-level matrix-multiply-and-
-    accumulate (mma) operation that is compatible with `nvvm.mma.sync`.
-    The operands and results vector sizes are thread-level onwership to
-    the warp-level mma operation shape. `mmaShape` attribute holds the
-    warp-level matrix-multiply shape.
-
-    The `nvgpu.mma.sync` op serves as an intermediate point between lowering from
-    `vector.contract` to `nvvm.mma.sync`.
-
-    This operation is meant to follow the semantic of described here:
-      https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-mma
-
-    Example:
-
-    ```mlir
-    %res = nvgpu.mma.sync (%matrixA, %matrixB, %matrixC) {mmaShape = [16, 8, 16]} :
-        (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf32>) -> vector<2x2xf32>
-    ```
-  }];
-  let arguments = (ins AnyVectorOfNonZeroRank:$matrixA,
-                       AnyVectorOfNonZeroRank:$matrixB,
-                       AnyVectorOfNonZeroRank:$matrixC,
-                       I64ArrayAttr:$mmaShape,
-                       OptionalAttr<UnitAttr>:$tf32Enabled);
-
-  let results = (outs AnyVectorOfNonZeroRank:$res);
-
-  let builders = [
-    OpBuilder<(ins "Value":$matrixA,
-                   "Value":$matrixB,
-                   "Value":$matrixC,
-                   "ArrayAttr":$mmaShape)>,
-    OpBuilder<(ins "Value":$matrixA,
-                   "Value":$matrixB,
-                   "Value":$matrixC,
-                   "ArrayRef<int64_t>":$mmaShape,
-                   CArg<"bool", "false">:$tf32Enabled)>
-  ];
-
-  let assemblyFormat = [{
-    `(` $matrixA`,` $matrixB`,` $matrixC `)` attr-dict
-    `:` `(` type($matrixA) `,` type($matrixB) `,` type($matrixC) `)` `->` type($res)
-  }];
-
-  let extraClassDeclaration = extraBaseClassDeclaration;
-}
-
-def NVGPU_MmaSparseSyncMetadataType : FixedVectorOfLengthAndType<[2], [I16]>,
-                        BuildableType<"::mlir::VectorType::get("
-                          "{2},$_builder.getI16Type())">;
-
-def NVGPU_MmaSparseSyncOp : NVGPU_MmaSyncOp<"mma.sp.sync"> {
-  let description = [{
-  The `nvgu.mma.sp.sync` operation performs a warp-distributed MMA operation
-  where operand A is "structured sparse". In this case, the `matrixA` operand
-  represents the (warp-distributed) non-zero values of operand A, and the
-  `sparse_metadata` operand provides the indices.
-
-  The full description of the sparsity storage format and distribution scheme is
-  described in the PTX docs. This operation is meant to follow the semantic
-  described in the PTX documentation here:
-  https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-for-sparse-mma
-
-  The way the indices are distributed among the threads in a warp is controlled
-  by the optional `sparsity_selector` operand, which is `0` by default. For
-  more information, please consult the PTX documentation linked above.
-
-  Example (targetingthe f16 16x8x32 `mma.sp` PTX instruction):
-
-  ```mlir
-  nvgpu.mma.sp.sync (%a, %b, %c) metadata (%meta) {mmaShape = [16, 8, 32]} :
-    (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
-  ```
-  }];
-
-  let arguments = (ins AnyVectorOfNonZeroRank:$matrixA,
-                       AnyVectorOfNonZeroRank:$matrixB,
-                       AnyVectorOfNonZeroRank:$matrixC,
-                       NVGPU_MmaSparseSyncMetadataType:$sparseMetadata,
-                       I64ArrayAttr:$mmaShape,
-                       DefaultValuedAttr<I32Attr, "0">:$sparsitySelector,
-                       OptionalAttr<UnitAttr>:$tf32Enabled
-                       );
-
-  let results = (outs AnyVectorOfNonZeroRank:$res);
-
-  let builders = [
-    OpBuilder<(ins "Value":$matrixA,
-                   "Value":$matrixB,
-                   "Value":$matrixC,
-                   "Value":$sparseMetadata,
-                   "ArrayRef<int64_t>":$mmaShape)>
-  ];
-
-  let assemblyFormat = [{
-    `(` $matrixA`,` $matrixB`,` $matrixC `)` `metadata` `(` $sparseMetadata `)` attr-dict
-    `:` `(` type($matrixA) `,` type($matrixB) `,` type($matrixC) `)` `->` type($res)
-  }];
-
-  let extraClassDeclaration = extraBaseClassDeclaration;
-}
-
-def NVGPU_DeviceAsyncCopyOp : NVGPU_Op<"device_async_copy", [
-                                       AttrSizedOperandSegments]> {
-  let summary = "device-side asynchronous copy";
-  let description = [{
-    The `nvgpu.device_async_copy` op initiates an asynchronous copy operation of
-    elements from source (global memory) to the destination (shared memory)
-    without blocking the thread. The async copy is added to a group.
-
-    This op is meant to be used with `nvgpu.device_async_create_group` and
-    `nvgpu.device_async_wait` to synchronize copies as explained in those ops
-    descriptions.
-
-    `bypassL1` attribute is hint to the hardware to bypass the L1 cache during
-    async copy, this hint may be ignored by the hardware.
-
-    `dstElements` attribute is the total number of elements written to
-    destination (shared memory).
-
-    `srcElements` argument is the total number of elements read from
-    source (global memory).
-
-    `srcElements` is an optional argument and when present the op only reads
-    `srcElements` number of elements from the source (global memory) and zero fills
-    the rest of the elements in the destination (shared memory).
-
-    In order to do a copy and wait for the result we need the following
-    combination:
-    ```
-    // copy 1.
-    %cp1 = nvgpu.device_async_copy %A[%c0], %B[%c0], 4 :memref<16xf32> to memref<16xf32, 3>
-    // copy 2.
-    %cp2 = nvgpu.device_async_copy %C[%c0], %D[%c0], 4 : memref<16xf32> to memref<16xf32, 3>
-    // group 1 contains copy 1 and copy 2.
-    %token1 = nvgpu.device_async_create_group %cp1, %cp2
-    // copy 3.
-    %cp3 = nvgpu.device_async_copy %E[%c0], %F[%c0], 4 : memref<16xf32> to memref<16xf32, 3>
-    // group 2 contains copy 3.
-    %token2 = nvgpu.device_async_create_group %cp3
-    // after the wait copy 1 and copy 2 are complete.
-    nvgpu.device_async_wait %token1
-    // after the wait copy 3 is complete.
-    nvgpu.device_async_wait %token2
-    ```
-
-    Example:
-
-    ```mlir
-    %0 = nvgpu.device_async_copy %src[%c0, %c0], %dst[%c0, %c0, %c0], 4 :
-      memref<4x5xf32> to memref<2x7x5xf32, 3>
-    ```
-  }];
-  let results = (outs NVGPU_DeviceAsyncToken:$asyncToken);
-  let arguments = (ins Arg<AnyMemRef, "", [MemWriteAt<0, FullEffect>]>:$dst,
-                       Variadic<Index>:$dstIndices,
-                       Arg<AnyMemRef, "", [MemReadAt<0, FullEffect>]>:$src,
-                       Variadic<Index>:$srcIndices,
-                       IndexAttr:$dstElements,
-                       Optional<Index>:$srcElements,
-                       OptionalAttr<UnitAttr>:$bypassL1);
-  let assemblyFormat = [{
-    $src `[` $srcIndices `]` `,` $dst `[` $dstIndices `]` `,` $dstElements (`,` $srcElements^)?
-      attr-dict `:` type($src) `to` type($dst)
-  }];
-  let hasVerifier = 1;
-}
-
-def NVGPU_DeviceAsyncCreateGroupOp : NVGPU_Op<"device_async_create_group", []> {
-  let summary = "device side asynchronous create group operation";
-  let description = [{
-    The `nvgpu.device_async_create_group` op creates a group of memory accesses
-    containing all the pending `device_async_copy` operations associated with
-    argument tokens. Each token can only be part of one group.
-
-    It returns a token that can be use to wait until the group fully completes.
-
-    This is meant to be used with `nvgpu.device_async_wait` to synchronize copies
-    as explained in those ops descriptions.
-
-    Groups are executed in the order they are created.
-
-    Example:
-
-    ```mlir
-    %0 = nvgpu.device_async_create_group
-  ```
-  }];
-  let results = (outs NVGPU_DeviceAsyncToken:$asyncToken);
-  let arguments = (ins Variadic<NVGPU_DeviceAsyncToken>:$inputTokens);
-  let assemblyFormat = [{
-    $inputTokens attr-dict
-  }];
-}
-
-def NVGPU_DeviceAsyncWaitOp : NVGPU_Op<"device_async_wait", []> {
-  let summary = "Wait for async gpu ops to complete.";
-  let description = [{
-    The `nvgpu.device_async_wait` op will block the execution thread until the group
-    associated with the source token is fully completed.
-
-    The optional `$numGroups` attribute gives an upper bound of the number of
-    groups uncompleted when the wait can unblock the thread. For example,  if
-    16 async groups are pushe and `$numGroups` is set to 12, then the thread
-    will unblock when 12 groups or fewer are in flight (4 groups have
-    completed).
-
-    Example:
-
-    ```mlir
-    nvgpu.device_async_wait %0
-    ```
-  }];
-  let arguments = (ins NVGPU_DeviceAsyncToken:$asyncDependencies,
-                       OptionalAttr<I32Attr>:$numGroups);
-  let assemblyFormat = [{
-    $asyncDependencies attr-dict
-  }];
-}
-
-def NVGPU_MBarrierCreateOp : NVGPU_Op<"mbarrier.create", []> {
-  let summary = "Creates a `nvgpu.mbarrier` object.";
-  let description = [{
-    The Op generates one or more `mbarrier` object, which is a barrier created in 
-    shared memory and supports various synchronization behaviors for threads.
-
-    The `mbarrier` object has the following type and alignment requirements:
-      Type: .b64, Alignment: 8, Memory space: .shared
-    
-    Example:
-    ```mlir
-      %barrier = nvgpu.mbarrier.create -> !nvgpu.mbarrier.barrier<memorySpace = #gpu.address_space<workgroup>>
-    ```
-    }];
-  let arguments = (ins);
-  let results = (outs NVGPU_MBarrierGroup:$barriers);
-  let assemblyFormat = [{
-     attr-dict `->` type($barriers)
-  }];
-}
-
-def NVGPU_MBarrierInitOp : NVGPU_Op<"mbarrier.init", []> {
-  let summary = "Initialize the `nvgpu.mbarrier`.";
-  let description = [{
-    The Op initializes the `mbarrier` object with the given number of threads.
-
-    Example:
-    ```mlir
-      %num_threads = gpu.block_dim x
-      %barrier = nvgpu.mbarrier.create -> !nvgpu.mbarrier.barrier<memorySpace = #gpu.address_space<workgroup>>
-      nvgpu.mbarrier.init %barrier, %num_threads : !nvgpu.mbarrier.barrier<memorySpace = #gpu.address_space<workgroup>>
-    ```
-  }];
-  let arguments = (ins NVGPU_MBarrierGroup:$barriers, Index:$count, Index:$mbarId, Optional<I1>:$predicate);
-  let assemblyFormat = "$barriers `[` $mbarId `]` `,` $count (`,` `predicate` `=` $predicate^)? attr-dict `:` type($barriers)";
-}
-
-def NVGPU_MBarrierTestWaitOp : NV...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/129846


More information about the Mlir-commits mailing list