[Mlir-commits] [mlir] 03bb10d - [MLIR][XeGPU] Add dpas, atomic, and named barrier ops (#88973)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Apr 24 12:29:15 PDT 2024


Author: Chao Chen
Date: 2024-04-24T14:29:11-05:00
New Revision: 03bb10dfb3725ec2c31fb66deede96d066f2b49a

URL: https://github.com/llvm/llvm-project/commit/03bb10dfb3725ec2c31fb66deede96d066f2b49a
DIFF: https://github.com/llvm/llvm-project/commit/03bb10dfb3725ec2c31fb66deede96d066f2b49a.diff

LOG: [MLIR][XeGPU] Add dpas, atomic, and named barrier ops (#88973)

---------

Co-authored-by: Mehdi Amini <joker.eph at gmail.com>
Co-authored-by: Adam Siemieniuk <adam.siemieniuk at intel.com>

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/XeGPU/IR/CMakeLists.txt
    mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h
    mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
    mlir/include/mlir/Dialect/XeGPU/IR/XeGPUDialect.td
    mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
    mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
    mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
    mlir/test/Dialect/XeGPU/XeGPUOps.mlir
    mlir/test/Dialect/XeGPU/invalid.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/XeGPU/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/XeGPU/IR/CMakeLists.txt
index f1740e9ed929a6..3f8cac4dc07c3c 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/CMakeLists.txt
@@ -2,12 +2,12 @@ add_mlir_dialect(XeGPU xegpu)
 add_mlir_doc(XeGPU XeGPU Dialects/ -gen-dialect-doc -dialect=xegpu)
 
 set(LLVM_TARGET_DEFINITIONS XeGPU.td)
-mlir_tablegen(XeGPUAttrs.h.inc -gen-attrdef-decls)
-mlir_tablegen(XeGPUAttrs.cpp.inc -gen-attrdef-defs)
+mlir_tablegen(XeGPUAttrs.h.inc -gen-attrdef-decls -attrdefs-dialect=xegpu)
+mlir_tablegen(XeGPUAttrs.cpp.inc -gen-attrdef-defs -attrdefs-dialect=xegpu)
 add_public_tablegen_target(MLIRXeGPUAttrsIncGen)
 add_dependencies(mlir-headers MLIRXeGPUAttrsIncGen)
 
-set(LLVM_TARGET_DEFINITIONS XeGPU.td)
+set(LLVM_TARGET_DEFINITIONS XeGPUAttrs.td)
 mlir_tablegen(XeGPUEnums.h.inc -gen-enum-decls)
 mlir_tablegen(XeGPUEnums.cpp.inc -gen-enum-defs)
 add_public_tablegen_target(MLIRXeGPUEnumsIncGen)

diff  --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h
index eca9255ff3974b..7ac0cf77fe59bb 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h
@@ -10,6 +10,7 @@
 #define MLIR_DIALECT_XEGPU_IR_XEGPU_H
 
 #include "mlir/Bytecode/BytecodeOpInterface.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/TypeUtilities.h"
@@ -19,7 +20,7 @@
 
 namespace mlir {
 namespace xegpu {
-// placeholder
+class TensorDescType;
 } // namespace xegpu
 } // namespace mlir
 

diff  --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index 6579d07ec26215..f3ca09a6a68ea8 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -10,6 +10,7 @@
 #define MLIR_DIALECT_XEGPU_IR_XEGPUATTRS_TD
 
 include "mlir/Dialect/XeGPU/IR/XeGPUDialect.td"
+include "mlir/IR/AttrTypeBase.td"
 include "mlir/IR/EnumAttr.td"
 
 class XeGPUAttr<string name, string attrMnemonic, list<Trait> traits = [],
@@ -98,4 +99,21 @@ def XeGPU_CacheHintAttr
     let assemblyFormat = "`<` $value `>`";
 }
 
+def XeGPU_FenceScopeWorkgroup: I32EnumAttrCase<"Workgroup", 0, "workgroup">;
+def XeGPU_FenceScopeGPU: I32EnumAttrCase<"GPU", 1, "gpu">;
+def XeGPU_FenceScope: I32EnumAttr<"FenceScope",
+      "The enumeration for the scope of fence operation.",
+      [XeGPU_FenceScopeWorkgroup, XeGPU_FenceScopeGPU]> {
+  let genSpecializedAttr = 0;
+  let cppNamespace = "::mlir::xegpu";
+}
+
+def XeGPU_FenceScopeAttr:
+  EnumAttr<XeGPU_Dialect, XeGPU_FenceScope, "fence_scope"> {
+    let summary = [{Describes the scope of fence.
+                    "workgroup" means that the scope is within each work group.
+                    "gpu" means the scope is across work groups within the gpu.}];
+    let assemblyFormat = "$value";
+}
+
 #endif // MLIR_DIALECT_XEGPU_IR_XEGPUATTRS_TD
\ No newline at end of file

diff  --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUDialect.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUDialect.td
index c2f09319c790e0..765f218f95d269 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUDialect.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUDialect.td
@@ -17,12 +17,14 @@ def XeGPU_Dialect : Dialect {
     let summary = "The XeGPU dialect that models Intel GPU's ISA";
     let description = [{
       The XeGPU dialect models Intel Xe ISA semantics but works at vector and
-      TensorDesc data type. It provides 1:1 mappings to match Xe instructions 
+      TensorDesc data type. It provides 1:1 mappings to match Xe instructions
       like DPAS and 2D block load. The matrix size being processed at this level
       exactly matches the hardware instructions or the intrinsic supported by
       the lower-level GPU compiler.
     }];
 
+    let dependentDialects = ["arith::ArithDialect"];
+
     let useDefaultTypePrinterParser = true;
     let useDefaultAttributePrinterParser = true;
 }

diff  --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index c6f7f83441b96c..88f2e1acfeeb58 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -9,7 +9,7 @@
 #ifndef MLIR_DIALECT_XEGPU_IR_XEGPUOPS_TD
 #define MLIR_DIALECT_XEGPU_IR_XEGPUOPS_TD
 
-include "mlir/IR/AttrTypeBase.td"
+include "mlir/Dialect/Arith/IR/ArithBase.td"
 include "mlir/Dialect/XeGPU/IR/XeGPUAttrs.td"
 include "mlir/Dialect/XeGPU/IR/XeGPUDialect.td"
 include "mlir/Dialect/XeGPU/IR/XeGPUTypes.td"
@@ -36,7 +36,7 @@ class XeGPU_Op<string mnemonic, list<Trait> traits = []>:
 
     static ::mlir::ParseResult parseProperties(::mlir::OpAsmParser &parser,
                                      ::mlir::OperationState &result) {
-      if (mlir::succeeded(parser.parseLess())) {
+      if (mlir::succeeded(parser.parseOptionalLess())) {
         if (parser.parseAttribute(result.propertiesAttr) || parser.parseGreater())
           return failure();
       }
@@ -254,7 +254,7 @@ def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [AllElementTypesMatch<["value", "Tensor
     a block of data from memory to register. It takes a set of optional cache
     hints for each level of cache, L1, L2 and L3. If hardware does not have a
     correspoding cache, Corresponding cache hint attribute will be masked.
-    vnni transform is an hardware feature for Intel GPU, which is used to
+    VNNI transformation is an hardware feature for Intel GPU, which is used to
     do data packing during the load for B operand of matrix operation, if
     the bit width of the data type is less then 32 bits, e.g., fp16. And
     transpose is another Intel hardware feature, which will do transpose
@@ -425,10 +425,6 @@ def XeGPU_CreateDescOp: XeGPU_Op<"create_tdesc", [Pure, ViewLikeOpInterface]> {
     %0 = memref.alloc() : memref<1024xf32>
     %1 = xegpu.create_tdesc %0[0, 4, 8, 12] {chunk_size = 8}: memref<1024xf32> -> TensorDesc<4x8xf32>
     ```
-
-
-
-
   }];
 
   let arguments = (ins XeGPU_BaseAddrType: $source,
@@ -663,4 +659,153 @@ def XeGPU_UpdateOffsetOp: XeGPU_Op<"update_offset",
   }];
 }
 
+def XeGPU_DpasOp : XeGPU_Op<"dpas", [Pure, AllElementTypesMatch<["lhs", "rhs"]>]> {
+  let summary = "It performs mma computation";
+
+  let description = [{DPAS performs matrix multiplication on matrix A of `mxk`
+    size, B of `kxn` size, and accumulate on matrix C of `mxn` to the same size
+    matrix , `m=8`, `n=16` and `k=8 * 32/bit_width_of_elem_type`. So for fp16
+    data type, the matrices are `A: vector<8x16xf16>`, `B: vector<16x16xf16>`,
+    and `C/D: vector<8x16xf32>`. Besides the matrix size requirements, DPAS
+    also requires A and B to be loaded with the required data layout. Specially,
+    VNNI layout is required for B operand. It is achieved via setting `vnni_axis = 0`
+    of the corresponding `load_nd` operator. To keep both operands as 3D vector,
+    operand A is loaded via setting `vnni_axis = 1` without impacting the
+    physical layouts change in register. Due to the VNNI transformation, A and B operands
+    are represented as 3D vector, with the last dimension representing the VNNI factor,
+    which is computed as `32/bit_width_of_elem_type`. Therefore, `A: vector<8x16xf16>`
+    is represented as `A: vector<8x8x2xf16>`, and `B: vector<16x16xf16>` is
+    represented as `B: vector<8x16x2xf16>`.
+
+    Note: on PVC, the hardware can perform load with VNNI transformation when data
+          element type is 16-bit or lower precision, taking 2 or 4 elements from
+          the first dimension and inserted into the newly added innermost dimension.
+  }];
+
+  let arguments = (ins
+    XeGPU_DpasOpType : $lhs,
+    XeGPU_DpasOpType : $rhs,
+    Optional<XeGPU_Vector2DType>: $acc);
+  let results = (outs XeGPU_Vector2DType: $result);
+
+  let extraClassDeclaration = [{
+    VectorType getLhsType() {
+      return getLhs().getType();
+    }
+
+    VectorType getRhsType() {
+      return getRhs().getType();
+    }
+
+    VectorType getAccType() {
+      if (getAcc())
+        return getAcc().getType();
+      return {};
+    }
+
+    VectorType getResultType() {
+      return getResult().getType();
+    }
+  }];
+
+  let assemblyFormat = [{
+    $lhs `,` $rhs (`,` $acc^)? attr-dict `:` type($lhs)`,` type($rhs) (`,` type($acc)^)?  `->` type($result)
+  }];
+
+  let hasVerifier = 1;
+}
+
+def XeGPU_AtomicRMWOp: XeGPU_Op<"atomic_rmw", [Pure,
+      AllElementTypesMatch<["tensorDesc", "value", "result"]>,
+      AllShapesMatch<["tensorDesc", "mask", "value", "result"]>]> {
+  let summary = "Atomic ready-modify-write operation on the TensorDesc. ";
+
+  let description = [{
+    The `xegpu.atomic_rmw` operation provides a way to perform a read-modify-write
+    operation on the region described by the `TensorDesc` free from data races. The
+    `kind` enumeration specifies the modification to be performed, The `mask` operand
+    has the same shape with `TensorDesc`, and is used to enable or disable specific
+    data points of the `TensorDesc`. The `value` operand represents the new value to
+    be applied during the modification.
+  }];
+
+  let arguments = (ins
+    AtomicRMWKindAttr:$kind,
+    XeGPU_TensorDesc:$tensorDesc,
+    XeGPU_MaskType:$mask,
+    XeGPU_ValueType:$value);
+
+  let results = (outs XeGPU_ValueType:$result);
+
+  let assemblyFormat = [{
+    $kind $tensorDesc `,` $mask `,` $value attr-dict `:`
+    type($tensorDesc) `,` type($mask) `,` type($value) `->` type($result)
+  }];
+}
+
+def XeGPU_AllocNbarrierOp: XeGPU_Op<"alloc_nbarrier", []> {
+  let summary = "It allocates a set of named barriers.";
+  let description = [{AllocNbarrier is to create a set of named barriers as
+  specified by `nbarrier_num`. Named barriers are workgroup level resources,
+    and are shared by all threads in the workgroup. For example, there are
+    up to 32 barriers (range 0-31) for each XeCore on PVC. A typical use case
+    is that a workgroup is partitioned into N subgroups of threads (N <= 32),
+    and each subgroup coordinating their work with a separate barrier with id
+    range from 0 to N respectively.}];
+  let arguments = (ins I64Attr: $nbarrier_num);
+  let assemblyFormat = "$nbarrier_num attr-dict";
+}
+
+def XeGPU_InitNbarrierOp: XeGPU_Op<"init_nbarrier", []> {
+  let summary = "It assigns a named barrier to the current thread.";
+  let description = [{InitNbarrierOp assigns the named barrier with the specified
+      barrier ID (0~31) to the current thread. Multiple threads may bind to the
+      same named barrier, and the `participant_thread_num` specifies the total
+      number of threads associated with the nbarrier. It returns an object of
+      NbarrierType representing the barrier}];
+
+  let arguments = (ins I8: $nbarrier_id,
+                       I8: $participant_thread_num);
+  let results = (outs XeGPU_Nbarrier: $result);
+  let assemblyFormat = [{
+    $nbarrier_id `,` $participant_thread_num attr-dict `:`
+    type($nbarrier_id) `,` type($participant_thread_num) `->` qualified(type($result))
+  }];
+}
+
+def XeGPU_NbarrierArriveOp: XeGPU_Op<"nbarrier_arrive", []> {
+  let summary = "It signals the arrival at the named barrier.";
+  let description = [{NbarrierArriveOp signals the hardware (or other threads)
+    that the current thread has produced its data for the consumer threads. When
+    the hardware signalled by `participant_thread_num` threads for the named barrier,
+    it will notify the threads waiting for the named barrier to continue their work.}];
+
+  let arguments = (ins XeGPU_Nbarrier: $nbarrier);
+  let assemblyFormat = [{ $nbarrier attr-dict `:` qualified(type($nbarrier))}];
+}
+
+def XeGPU_NbarrierWaitOp: XeGPU_Op<"nbarrier_wait", []> {
+  let summary = "It waits for a named barrier.";
+  let description = [{NbarrierWaitOp signals the hardware which named barrier
+    the current thread is waiting for, such that it can get notified when the
+    named barrier is completed.}];
+  let arguments = (ins XeGPU_Nbarrier: $nbarrier);
+  let assemblyFormat = [{ $nbarrier attr-dict `:` qualified(type($nbarrier)) }];
+}
+
+def XeGPU_FenceOp: XeGPU_Op<"fence", []> {
+  let summary = "It synchronizes memory accesses.";
+  let description = [{It synchronizes the memory access between
+    write and following read or write.
+    1. `Memory_kind` describes the memory kind. "global" means the global memory,
+        "slm" means the share local memory.
+    2. `Fence_scope` describes the scope of fence. "Workgroup" means that the scope would be
+        within each workgroup. "GPU" means the scope would be across workgroups within the GPU.
+  }];
+  let arguments = (ins XeGPU_MemoryScopeAttr: $memory_kind,
+                       XeGPU_FenceScopeAttr: $fence_scope);
+  let assemblyFormat = [{`memory_kind` `=` `` $memory_kind `,` `fence_scope` `=` `` $fence_scope attr-dict}];
+  let extraClassDeclaration = extraBaseClassDeclaration;
+}
+
 #endif // MLIR_DIALECT_XEGPU_IR_XEGPUOPS_TD

diff  --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
index 4cd4e5411653c1..bab0e4afb1e5ed 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
@@ -151,4 +151,15 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
 
 }
 
+
+def XeGPU_Nbarrier: XeGPUTypeDef<"Nbarrier", "nbarrier", [], "mlir::Type"> {
+  let summary = "!xegpu.nbarrier a custom XeGPU type representing a barrier.";
+
+  let extraClassDeclaration = [{
+    static NbarrierType get(mlir::MLIRContext *context) {
+      return Base::get(context);
+    };
+  }];
+}
+
 #endif // MLIR_DIALECT_XEGPU_IR_XEGPUTYPES_TD

diff  --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 23c5749c2309de..22959224d56c2f 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -406,6 +406,28 @@ LogicalResult StoreScatterOp::verify() {
 
   return success();
 }
+//===----------------------------------------------------------------------===//
+// XeGPU_DpasOp
+//===----------------------------------------------------------------------===//
+LogicalResult DpasOp::verify() {
+  int64_t lhsRank = getLhsType().getRank();
+  int64_t rhsRank = getRhsType().getRank();
+
+  if (lhsRank != rhsRank || lhsRank != 3)
+    return emitOpError(
+        "lhs and rhs rank does not match for dpas op, or their rank is not 3.");
+
+  if (getAcc() && getAccType() != getResultType())
+    return emitOpError("Accumulator and Result for dpas op should have the "
+                       "same type (both shape and element type).");
+
+  auto lhsShape = getLhsType().getShape();
+  auto rhsShape = getRhsType().getShape();
+  if (lhsShape[1] != rhsShape[0] || lhsShape[2] != rhsShape[2])
+    return emitOpError("K-dimension or vnni-factor mismatch.");
+
+  return success();
+}
 
 } // namespace xegpu
 } // namespace mlir

diff  --git a/mlir/test/Dialect/XeGPU/XeGPUOps.mlir b/mlir/test/Dialect/XeGPU/XeGPUOps.mlir
index f0945c79a94ac3..00d32d2a2ee943 100644
--- a/mlir/test/Dialect/XeGPU/XeGPUOps.mlir
+++ b/mlir/test/Dialect/XeGPU/XeGPUOps.mlir
@@ -80,7 +80,7 @@ gpu.func @test_prefetch_vc(%src: ui64) {
   //CHECK: %[[R0:.*]] = xegpu.create_tdesc %arg0 [0, 8, 16, 24] {chunk_size = 2 : i64} : ui64 -> !xegpu.tensor_desc<4x2xf32, #xegpu.tdesc_attr<scattered = true>>
   %1 = xegpu.create_tdesc %src[0, 8, 16, 24] {chunk_size = 2} : ui64  -> !xegpu.tensor_desc<4x2xf32, #xegpu.tdesc_attr<scattered = true>>
   // CHECK: xegpu.prefetch %[[R0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<4x2xf32, #xegpu.tdesc_attr<scattered = true>>
-  xegpu.prefetch %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>: !xegpu.tensor_desc<4x2xf32, #xegpu.tdesc_attr<scattered = true>> 
+  xegpu.prefetch %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>: !xegpu.tensor_desc<4x2xf32, #xegpu.tdesc_attr<scattered = true>>
   gpu.return
 }
 
@@ -121,4 +121,59 @@ gpu.func @test_create_update_tdesc_vc(%src: ui64) {
   gpu.return
 }
 
-}
\ No newline at end of file
+// CHECK: gpu.func @test_dpas_vc(%[[arg0:.*]]: vector<8x8x2xf16>, %[[arg1:.*]]: vector<8x16x2xf16>)
+gpu.func @test_dpas_vc(%a : vector<8x8x2xf16>, %b: vector<8x16x2xf16>) {
+  // CHECK: %0 = xegpu.dpas %[[arg0]], %[[arg1]] : vector<8x8x2xf16>, vector<8x16x2xf16> -> vector<8x16xf32>
+  %1 = xegpu.dpas %a, %b: vector<8x8x2xf16>, vector<8x16x2xf16> -> vector<8x16xf32>
+  gpu.return
+}
+
+// CHECK: gpu.func @test_atomic_rmw(%[[arg0:.*]]: ui64, %[[arg1:.*]]: vector<16xf32>, %[[arg2:.*]]: vector<16xi1>)
+gpu.func @test_atomic_rmw(%src: ui64, %value : vector<16xf32>, %mask : vector<16xi1>) {
+  //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : ui64 -> !xegpu.tensor_desc<16xf32, #xegpu.tdesc_attr<scattered = true>>
+  %1 = xegpu.create_tdesc %src[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]: ui64 -> !xegpu.tensor_desc<16xf32, #xegpu.tdesc_attr<scattered = true>>
+  //CHECK: %[[R1:.*]] = xegpu.atomic_rmw addf %[[R0]], %[[arg2]], %[[arg1]] : <16xf32, #xegpu.tdesc_attr<scattered = true>>, vector<16xi1>, vector<16xf32> -> vector<16xf32>
+  xegpu.atomic_rmw addf %1, %mask, %value: !xegpu.tensor_desc<16xf32, #xegpu.tdesc_attr<scattered = true>>, vector<16xi1>, vector<16xf32> -> vector<16xf32>
+  gpu.return
+}
+
+// CHECK: gpu.func @alloc_nbarrier({{.*}}) {
+gpu.func @alloc_nbarrier() {
+  // CHECK: xegpu.alloc_nbarrier
+  xegpu.alloc_nbarrier 8
+  gpu.return
+}
+
+// CHECK: gpu.func @init_nbarrier({{.*}}) {
+gpu.func @init_nbarrier() {
+  //CHECK: %[[c1:.*]] = arith.constant 1 : i8
+  //CHECK: %[[c16:.*]] = arith.constant 16 : i8
+  %nbarrier_id = arith.constant 1 : i8
+  %threads_count = arith.constant 16 : i8
+  //CHECK: xegpu.init_nbarrier %[[c1]], %[[c16]] : i8, i8 -> !xegpu.nbarrier
+  %nbarrier = xegpu.init_nbarrier %nbarrier_id, %threads_count : i8, i8 -> !xegpu.nbarrier
+  gpu.return
+}
+
+// CHECK: gpu.func @nbarrier_arrive(%[[arg0:.*]]: !xegpu.nbarrier) {
+gpu.func @nbarrier_arrive(%nbarrier : !xegpu.nbarrier) {
+  //CHECK: xegpu.nbarrier_arrive %[[arg0]] : !xegpu.nbarrier
+  xegpu.nbarrier_arrive %nbarrier : !xegpu.nbarrier
+  gpu.return
+}
+
+// CHECK: gpu.func @nbarrier_wait(%[[arg0:.*]]: !xegpu.nbarrier) {
+gpu.func @nbarrier_wait(%nbarrier : !xegpu.nbarrier) {
+  //CHECK: xegpu.nbarrier_wait %[[arg0]] : !xegpu.nbarrier
+  xegpu.nbarrier_wait %nbarrier : !xegpu.nbarrier
+  gpu.return
+}
+
+// CHECK-LABEL: gpu.func @fence({{.*}}) {
+gpu.func @fence() {
+  //CHECK: xegpu.fence memory_kind = global, fence_scope = workgroup
+  xegpu.fence memory_kind = global, fence_scope = workgroup
+  gpu.return
+}
+
+}

diff  --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index 5e29361ec69087..7819ad60b97d92 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -156,4 +156,32 @@ func.func @test_store_scatter_vc_2(%src: ui64) {
   xegpu.store %1, %2, %0 <{l1_hint = #xegpu.cache_hint<streaming>}> : vector<4x2xf32>,
           !xegpu.tensor_desc<4x2xf32, #xegpu.tdesc_attr<scattered = true>>, vector<4xi1>
   return
+}
+
+// -----
+func.func @test_dpas_vc_1(%a : vector<8x4x2xf16>, %b: vector<8x16x2xf16>) {
+  // expected-error at +1 {{K-dimension or vnni-factor mismatch}}
+  %1 = xegpu.dpas %a, %b : vector<8x4x2xf16>, vector<8x16x2xf16> -> vector<8x16xf32>
+  return
+}
+
+// -----
+func.func @test_dpas_vc_2(%a : vector<8x16xf16>, %b: vector<8x16x2xf16>) {
+  // expected-error at +1 {{lhs and rhs rank does not match for dpas op, or their rank is not 3}}
+  %1 = xegpu.dpas %a, %b : vector<8x16xf16>, vector<8x16x2xf16> -> vector<8x16xf32>
+  return
+}
+
+// -----
+func.func @test_dpas_vc_3(%a : vector<8x16xf16>, %b: vector<16x16xf16>) {
+  // expected-error at +1 {{lhs and rhs rank does not match for dpas op, or their rank is not 3}}
+  %1 = xegpu.dpas %a, %b : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+  return
+}
+
+// -----
+func.func @test_dpas_vc_4(%a : vector<8x8x2xf16>, %b: vector<8x16x2xf16>, %c : vector<8x16xf16>) {
+  // expected-error at +1 {{Accumulator and Result for dpas op should have the same type}}
+  %1 = xegpu.dpas %a, %b, %c : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf16> -> vector<8x16xf32>
+  return
 }
\ No newline at end of file


        


More information about the Mlir-commits mailing list