[Mlir-commits] [mlir] [MLIR] Add XeGPU dialect for Intel GPU (PR #78483)
Chao Chen
llvmlistbot at llvm.org
Thu Jan 18 08:15:43 PST 2024
https://github.com/chencha3 updated https://github.com/llvm/llvm-project/pull/78483
>From 35440b9b0751dec934049aed9257ae2bbcfabe13 Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Wed, 17 Jan 2024 17:45:24 +0000
Subject: [PATCH 1/2] add XeGPU dialect definition
---
mlir/include/mlir/Dialect/CMakeLists.txt | 1 +
.../include/mlir/Dialect/XeGPU/CMakeLists.txt | 1 +
.../mlir/Dialect/XeGPU/IR/CMakeLists.txt | 14 +
mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h | 52 +
mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.td | 14 +
.../mlir/Dialect/XeGPU/IR/XeGPUAttrs.td | 150 ++
.../mlir/Dialect/XeGPU/IR/XeGPUDialect.td | 46 +
.../include/mlir/Dialect/XeGPU/IR/XeGPUOps.td | 505 +++++
.../mlir/Dialect/XeGPU/IR/XeGPUTypes.td | 170 ++
mlir/include/mlir/InitAllDialects.h | 4 +-
mlir/lib/Dialect/CMakeLists.txt | 1 +
mlir/lib/Dialect/XeGPU/CMakeLists.txt | 1 +
mlir/lib/Dialect/XeGPU/IR/CMakeLists.txt | 15 +
mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp | 385 ++++
mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 1929 +++++++++++++++++
mlir/test/Dialect/XeGPU/IR/XeGPUOps.mlir | 110 +
mlir/test/Dialect/XeGPU/IR/atomic_rmw.mlir | 43 +
mlir/test/Dialect/XeGPU/IR/atomic_rmw_vc.mlir | 38 +
mlir/test/Dialect/XeGPU/IR/barrier_ops.mlir | 54 +
.../Dialect/XeGPU/IR/create_nd_tdesc.mlir | 111 +
.../Dialect/XeGPU/IR/create_nd_tdesc_vc.mlir | 115 +
mlir/test/Dialect/XeGPU/IR/create_tdesc.mlir | 11 +
.../Dialect/XeGPU/IR/create_tdesc_vc.mlir | 51 +
mlir/test/Dialect/XeGPU/IR/invalid_vc.mlir | 70 +
.../test/Dialect/XeGPU/IR/load_gather_vc.mlir | 50 +
mlir/test/Dialect/XeGPU/IR/load_nd.mlir | 164 ++
mlir/test/Dialect/XeGPU/IR/load_nd_vc.mlir | 69 +
.../test/Dialect/XeGPU/IR/prefetch_nd_vc.mlir | 62 +
mlir/test/Dialect/XeGPU/IR/simple_gemm.mlir | 71 +
.../test/Dialect/XeGPU/IR/simple_gemm_vc.mlir | 65 +
mlir/test/Dialect/XeGPU/IR/store_nd_vc.mlir | 83 +
mlir/test/Dialect/XeGPU/IR/store_scatter.mlir | 29 +
.../Dialect/XeGPU/IR/store_scatter_vc.mlir | 29 +
.../Dialect/XeGPU/IR/update_nd_offset.mlir | 27 +
.../Dialect/XeGPU/IR/update_offset_vc.mlir | 29 +
35 files changed, 4568 insertions(+), 1 deletion(-)
create mode 100644 mlir/include/mlir/Dialect/XeGPU/CMakeLists.txt
create mode 100644 mlir/include/mlir/Dialect/XeGPU/IR/CMakeLists.txt
create mode 100644 mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h
create mode 100644 mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.td
create mode 100644 mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
create mode 100644 mlir/include/mlir/Dialect/XeGPU/IR/XeGPUDialect.td
create mode 100644 mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
create mode 100644 mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
create mode 100644 mlir/lib/Dialect/XeGPU/CMakeLists.txt
create mode 100644 mlir/lib/Dialect/XeGPU/IR/CMakeLists.txt
create mode 100644 mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
create mode 100644 mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
create mode 100644 mlir/test/Dialect/XeGPU/IR/XeGPUOps.mlir
create mode 100644 mlir/test/Dialect/XeGPU/IR/atomic_rmw.mlir
create mode 100644 mlir/test/Dialect/XeGPU/IR/atomic_rmw_vc.mlir
create mode 100644 mlir/test/Dialect/XeGPU/IR/barrier_ops.mlir
create mode 100644 mlir/test/Dialect/XeGPU/IR/create_nd_tdesc.mlir
create mode 100644 mlir/test/Dialect/XeGPU/IR/create_nd_tdesc_vc.mlir
create mode 100644 mlir/test/Dialect/XeGPU/IR/create_tdesc.mlir
create mode 100644 mlir/test/Dialect/XeGPU/IR/create_tdesc_vc.mlir
create mode 100644 mlir/test/Dialect/XeGPU/IR/invalid_vc.mlir
create mode 100644 mlir/test/Dialect/XeGPU/IR/load_gather_vc.mlir
create mode 100644 mlir/test/Dialect/XeGPU/IR/load_nd.mlir
create mode 100644 mlir/test/Dialect/XeGPU/IR/load_nd_vc.mlir
create mode 100644 mlir/test/Dialect/XeGPU/IR/prefetch_nd_vc.mlir
create mode 100644 mlir/test/Dialect/XeGPU/IR/simple_gemm.mlir
create mode 100644 mlir/test/Dialect/XeGPU/IR/simple_gemm_vc.mlir
create mode 100644 mlir/test/Dialect/XeGPU/IR/store_nd_vc.mlir
create mode 100644 mlir/test/Dialect/XeGPU/IR/store_scatter.mlir
create mode 100644 mlir/test/Dialect/XeGPU/IR/store_scatter_vc.mlir
create mode 100644 mlir/test/Dialect/XeGPU/IR/update_nd_offset.mlir
create mode 100644 mlir/test/Dialect/XeGPU/IR/update_offset_vc.mlir
diff --git a/mlir/include/mlir/Dialect/CMakeLists.txt b/mlir/include/mlir/Dialect/CMakeLists.txt
index 1c4569ecfa5848..e0eb421291ded7 100644
--- a/mlir/include/mlir/Dialect/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/CMakeLists.txt
@@ -39,3 +39,4 @@ add_subdirectory(UB)
add_subdirectory(Utils)
add_subdirectory(Vector)
add_subdirectory(X86Vector)
+add_subdirectory(XeGPU)
diff --git a/mlir/include/mlir/Dialect/XeGPU/CMakeLists.txt b/mlir/include/mlir/Dialect/XeGPU/CMakeLists.txt
new file mode 100644
index 00000000000000..f33061b2d87cff
--- /dev/null
+++ b/mlir/include/mlir/Dialect/XeGPU/CMakeLists.txt
@@ -0,0 +1 @@
+add_subdirectory(IR)
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/XeGPU/IR/CMakeLists.txt
new file mode 100644
index 00000000000000..f1740e9ed929a6
--- /dev/null
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/CMakeLists.txt
@@ -0,0 +1,14 @@
+add_mlir_dialect(XeGPU xegpu)
+add_mlir_doc(XeGPU XeGPU Dialects/ -gen-dialect-doc -dialect=xegpu)
+
+set(LLVM_TARGET_DEFINITIONS XeGPU.td)
+mlir_tablegen(XeGPUAttrs.h.inc -gen-attrdef-decls)
+mlir_tablegen(XeGPUAttrs.cpp.inc -gen-attrdef-defs)
+add_public_tablegen_target(MLIRXeGPUAttrsIncGen)
+add_dependencies(mlir-headers MLIRXeGPUAttrsIncGen)
+
+set(LLVM_TARGET_DEFINITIONS XeGPU.td)
+mlir_tablegen(XeGPUEnums.h.inc -gen-enum-decls)
+mlir_tablegen(XeGPUEnums.cpp.inc -gen-enum-defs)
+add_public_tablegen_target(MLIRXeGPUEnumsIncGen)
+add_dependencies(mlir-headers MLIRXeGPUEnumsIncGen)
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h
new file mode 100644
index 00000000000000..a05e046a0e0c0b
--- /dev/null
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h
@@ -0,0 +1,52 @@
+//===- XeGPU.h - MLIR dialect for XeGPU -------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_XEGPU_IR_XEGPU_H
+#define MLIR_DIALECT_XEGPU_IR_XEGPU_H
+
+#include <mlir/IR/BuiltinTypes.h>
+#include <mlir/IR/Dialect.h>
+#include <mlir/IR/OpDefinition.h>
+#include <mlir/IR/Region.h>
+#include <mlir/IR/Types.h>
+#include <mlir/Interfaces/CastInterfaces.h>
+#include <mlir/Interfaces/ControlFlowInterfaces.h>
+#include <mlir/Interfaces/CopyOpInterface.h>
+#include <mlir/Interfaces/InferTypeOpInterface.h>
+#include <mlir/Interfaces/ShapedOpInterfaces.h>
+#include <mlir/Interfaces/SideEffectInterfaces.h>
+#include <mlir/Interfaces/ViewLikeInterface.h>
+
+namespace mlir {
+
+/// Return the list of Range (i.e. offset, size, stride). Each Range
+/// entry contains either the dynamic value or a ConstantIndexOp constructed
+/// with `b` at location `loc`.
+SmallVector<Range, 8> getOrCreateRanges(OffsetSizeAndStrideOpInterface op,
+ OpBuilder &b, Location loc);
+
+} // namespace mlir
+
+namespace mlir {
+namespace xegpu {
+
+class TensorDescType;
+
+} // namespace xegpu
+} // namespace mlir
+
+#include <mlir/Dialect/XeGPU/IR/XeGPUDialect.h.inc>
+#include <mlir/Dialect/XeGPU/IR/XeGPUEnums.h.inc>
+#define GET_ATTRDEF_CLASSES
+#include <mlir/Dialect/XeGPU/IR/XeGPUAttrs.h.inc>
+#define GET_TYPEDEF_CLASSES
+#include <mlir/Dialect/XeGPU/IR/XeGPUTypes.h.inc>
+#define GET_OP_CLASSES
+#include <mlir/Dialect/XeGPU/IR/XeGPU.h.inc>
+
+#endif // MLIR_DIALECT_XEGPU_IR_XEGPU_H
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.td
new file mode 100644
index 00000000000000..232e962870716c
--- /dev/null
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.td
@@ -0,0 +1,14 @@
+//===- XeGPU.td - XeGPU dialect definition ------------------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_XEGPU_IR_XEGPU_TD
+#define MLIR_DIALECT_XEGPU_IR_XEGPU_TD
+
+include "mlir/Dialect/XeGPU/IR/XeGPUOps.td"
+
+#endif // MLIR_DIALECT_XEGPU_IR_XEGPU_TD
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
new file mode 100644
index 00000000000000..ed3d9bbc772567
--- /dev/null
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -0,0 +1,150 @@
+//===- XeGPUAttrs.td - XeGPU dialect attributes definition --*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_XEGPU_IR_XEGPUATTRS_TD
+#define MLIR_DIALECT_XEGPU_IR_XEGPUATTRS_TD
+
+include "mlir/Dialect/XeGPU/IR/XeGPUDialect.td"
+include "mlir/IR/EnumAttr.td"
+
+class XeGPUAttr<string name, string attrMnemonic, list<Trait> traits = [],
+ string baseCppClass = "::mlir::Attribute">
+ : AttrDef<XeGPU_Dialect, name, traits, baseCppClass> {
+ let mnemonic = attrMnemonic;
+}
+
+def XeGPU_ScatteredAttr : XeGPUAttr<"Scattered", "scattered"> {
+ let summary = "Scattered attribute for scattered read and write operation.";
+ let description = [{An attribute represent scattered read and write operation.
+ It does not (need to) have meaningful input values. The existence of itself
+ implies scattered read/write.}];
+
+ let assemblyFormat = "";
+}
+
+def XeGPU_SgMapAttr: XeGPUAttr<"SubGroupMap", "sg_map"> {
+ let parameters = (ins
+ "mlir::DenseI32ArrayAttr":$wi_layout,
+ "mlir::DenseI32ArrayAttr":$wi_data
+ );
+
+ // In format of #xegpu.sg_map<{mma_block_size = [2, 4], wi_layout = [2, 4], wi_data = [2, 4]}>
+ let assemblyFormat = "`<` struct(params) `>`";
+
+ let genVerifyDecl = true;
+
+ let builders = [
+ AttrBuilder<(ins
+ "llvm::ArrayRef<int32_t>":$wiLayout,
+ "llvm::ArrayRef<int32_t>":$wiData
+ )>
+ ];
+}
+
+def XeGPU_TensorDescAttr: XeGPUAttr<"TensorDesc", "tdesc_attr"> {
+ let parameters = (ins
+ DefaultValuedParameter<"xegpu::MemoryScopeKind", "xegpu::MemoryScopeKind::GLOBAL">: $memory_scope,
+ DefaultValuedParameter<"int", "1">: $array_length,
+ DefaultValuedParameter<"bool", "true">: $boundary_check,
+ OptionalParameter<"xegpu::ScatteredAttr">: $scattered,
+ OptionalParameter<"xegpu::SubGroupMapAttr"> : $map
+ );
+
+ let builders = [
+ AttrBuilder<(ins
+ CArg<"xegpu::MemoryScopeKind", "xegpu::MemoryScopeKind::GLOBAL">:$memory_scope,
+ CArg<"int", "1">:$array_length,
+ CArg<"xegpu::ScatteredAttr", "{}">:$scattered,
+ CArg<"xegpu::SubGroupMapAttr", "{}">:$map
+ )>
+ ];
+
+ let extraClassDeclaration = [{
+ bool hasNonDefaultAttrs();
+ }];
+
+ let hasCustomAssemblyFormat = true;
+}
+
+def ARG_TYPE_VECTOR : I32EnumAttrCase<"VECTOR", 0, "vector">;
+def ARG_TYPE_SCALAR : I32EnumAttrCase<"SCALAR", 1, "scalar">;
+def XeGPU_ArgTypeKind : I32EnumAttr<"ArgTypeKind",
+ "Argument type for Invoke_SIMD op",
+ [ARG_TYPE_VECTOR, ARG_TYPE_SCALAR]> {
+ let genSpecializedAttr = 0;
+ let cppNamespace = "::mlir::xegpu";
+}
+
+def MODE_SIMT : I32EnumAttrCase<"SIMT", 0, "simt">;
+def MODE_VC : I32EnumAttrCase<"VC", 1, "vc">;
+def XeGPU_ModeKind : I32EnumAttr<"ModeKind",
+ "The Mode an operator runs on",
+ [MODE_SIMT, MODE_VC]> {
+ let genSpecializedAttr = 0;
+ let cppNamespace = "::mlir::xegpu";
+}
+
+def MEMORY_SCOPE_GLOBAL: I32EnumAttrCase<"GLOBAL", 0, "global">;
+def MEMORY_SCOPE_SHARED: I32EnumAttrCase<"SLM", 1, "slm">;
+def XeGPU_MemoryScopeKind: I32EnumAttr<"MemoryScopeKind",
+ "The scope of the memory the tensor descritor is created for",
+ [MEMORY_SCOPE_GLOBAL, MEMORY_SCOPE_SHARED]> {
+ let genSpecializedAttr = 0;
+ let cppNamespace = "::mlir::xegpu";
+}
+
+def CACHE_KIND_CACHED: I32EnumAttrCase<"CACHED", 0, "cached">; // valid for read and write
+def CACHE_KIND_UNCACHED: I32EnumAttrCase<"UNCACHED", 1, "uncached">; // valid for read and write
+def CACHE_KIND_STREAMING: I32EnumAttrCase<"STREAMING", 2, "streaming">; // valid for read only
+def CACHE_KIND_INVALIDATE: I32EnumAttrCase<"READ_INVALIDATE", 3, "read_invalidate">; // valid for read only
+def CACHE_KIND_WRITE_BACK: I32EnumAttrCase<"WRITE_BACK", 4, "write_back">; // valid for write only
+def CACHE_KIND_WRITE_THROUGH: I32EnumAttrCase<"WRITE_THROUGH", 5, "write_through">; // valid for write only
+
+
+
+def XeGPU_CacheKind : I32EnumAttr<"CacheKind", "Cache kind",
+ [CACHE_KIND_CACHED, CACHE_KIND_UNCACHED,
+ CACHE_KIND_STREAMING, CACHE_KIND_INVALIDATE,
+ CACHE_KIND_WRITE_BACK, CACHE_KIND_WRITE_THROUGH]> {
+ let genSpecializedAttr = 0;
+ let cppNamespace = "::mlir::xegpu";
+}
+
+def XeGPU_ArgTypeAttr : EnumAttr<XeGPU_Dialect, XeGPU_ArgTypeKind, "arg_type_kind">;
+def XeGPU_ModeAttr : EnumAttr<XeGPU_Dialect, XeGPU_ModeKind, "mode_kind">;
+def XeGPU_MemoryScopeAttr : EnumAttr<XeGPU_Dialect, XeGPU_MemoryScopeKind, "memory_scope_kind">;
+def XeGPU_CacheAttr : EnumAttr<XeGPU_Dialect, XeGPU_CacheKind, "cache_kind">;
+
+// RMW kind attribute
+def ATOMIC_RMW_KIND_ADDF : I32EnumAttrCase<"addf", 0>;
+def ATOMIC_RMW_KIND_ADDI : I32EnumAttrCase<"addi", 1>;
+def ATOMIC_RMW_KIND_ASSIGN : I32EnumAttrCase<"assign", 2>;
+def ATOMIC_RMW_KIND_MAXF : I32EnumAttrCase<"maxf", 3>;
+def ATOMIC_RMW_KIND_MAXS : I32EnumAttrCase<"maxs", 4>;
+def ATOMIC_RMW_KIND_MAXU : I32EnumAttrCase<"maxu", 5>;
+def ATOMIC_RMW_KIND_MINF : I32EnumAttrCase<"minf", 6>;
+def ATOMIC_RMW_KIND_MINS : I32EnumAttrCase<"mins", 7>;
+def ATOMIC_RMW_KIND_MINU : I32EnumAttrCase<"minu", 8>;
+def ATOMIC_RMW_KIND_MULF : I32EnumAttrCase<"mulf", 9>;
+def ATOMIC_RMW_KIND_MULI : I32EnumAttrCase<"muli", 10>;
+def ATOMIC_RMW_KIND_ORI : I32EnumAttrCase<"ori", 11>;
+def ATOMIC_RMW_KIND_ANDI : I32EnumAttrCase<"andi", 12>;
+
+def XeGPU_AtomicRMWKind : I32EnumAttr<"AtomicRMWKind",
+ "Operation type for AtomicRMW",
+ [ATOMIC_RMW_KIND_ADDF, ATOMIC_RMW_KIND_ADDI, ATOMIC_RMW_KIND_ASSIGN,
+ ATOMIC_RMW_KIND_MAXF, ATOMIC_RMW_KIND_MAXS, ATOMIC_RMW_KIND_MAXU,
+ ATOMIC_RMW_KIND_MINF, ATOMIC_RMW_KIND_MINS, ATOMIC_RMW_KIND_MINU,
+ ATOMIC_RMW_KIND_MULF, ATOMIC_RMW_KIND_MULI, ATOMIC_RMW_KIND_ORI,
+ ATOMIC_RMW_KIND_ANDI]> {
+ let genSpecializedAttr = 0;
+ let cppNamespace = "::mlir::xegpu";
+}
+def XeGPU_AtomicRMWKindAttr : EnumAttr<XeGPU_Dialect, XeGPU_AtomicRMWKind, "atomic_rmw_kind">;
+
+#endif // MLIR_DIALECT_XEGPU_IR_XEGPUATTRS_TD
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUDialect.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUDialect.td
new file mode 100644
index 00000000000000..f85ccb32cc43b0
--- /dev/null
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUDialect.td
@@ -0,0 +1,46 @@
+//===- XeGPUDialect.td - XeGPU dialect definition -----------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_XEGPU_IR_XEGPUDIALECT_TD
+#define MLIR_DIALECT_XEGPU_IR_XEGPUDIALECT_TD
+
+include "mlir/IR/OpBase.td"
+include "mlir/IR/OpAsmInterface.td"
+include "mlir/IR/AttrTypeBase.td"
+include "mlir/IR/BuiltinTypes.td"
+include "mlir/IR/BuiltinTypeInterfaces.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/Interfaces/ViewLikeInterface.td"
+include "mlir/Interfaces/CastInterfaces.td"
+include "mlir/Interfaces/ControlFlowInterfaces.td"
+include "mlir/Interfaces/CopyOpInterface.td"
+include "mlir/Interfaces/InferTypeOpInterface.td"
+include "mlir/Interfaces/ShapedOpInterfaces.td"
+
+def XeGPU_Dialect : Dialect {
+ let name = "xegpu";
+ let cppNamespace = "::mlir::xegpu";
+ let summary = "The XeGPU dialect that models Intel GPU's ISA";
+ let description = [{
+ The XeGPU dialect models Intel Xe ISA semantics but works at vector and
+ TensorDesc data type. It provides 1:1 mappings to match Xe instructions
+ like DPAS and 2D block load. The matrix size being processed at this level
+ exactly matches the hardware instructions or the intrinsic supported by
+ the lower-level GPU compiler.
+ }];
+
+ let dependentDialects = [
+ "arith::ArithDialect",
+ "memref::MemRefDialect"
+ ];
+
+ let useDefaultTypePrinterParser = true;
+ let useDefaultAttributePrinterParser = true;
+}
+
+#endif // MLIR_DIALECT_XEGPU_IR_XEGPUDIALECT_TD
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
new file mode 100644
index 00000000000000..766590f6a3f878
--- /dev/null
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -0,0 +1,505 @@
+//===- XeGPUOps.td - XeGPU dialect operations definition ----*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_XEGPU_IR_XEGPUOPS_TD
+#define MLIR_DIALECT_XEGPU_IR_XEGPUOPS_TD
+
+include "mlir/Dialect/XeGPU/IR/XeGPUAttrs.td"
+include "mlir/Dialect/XeGPU/IR/XeGPUDialect.td"
+include "mlir/Dialect/XeGPU/IR/XeGPUTypes.td"
+
+
+// Base class for dialect operations. This operation inherits from the base
+// `Op` class in OpBase.td, and provides:
+// * The parent dialect of the operation.
+// * The mnemonic for the operation, or the name without the dialect prefix.
+// * A list of traits for the operation.
+class XeGPU_Op<string mnemonic, list<Trait> traits = []>:
+ Op<XeGPU_Dialect, mnemonic, traits>;
+
+def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, AttrSizedOperandSegments]> {
+
+ let summary = "create nd tensor descriptor operation";
+ let description = [{
+ The "create_nd_tdesc" operation creates a TensorDescType which represents
+ a sub-view of a 2D memory region (It can be extended to support N-D memory
+ region if needed in future). Elements in the subview continuous in each
+ dimention. It encodes the following important information for supporting
+ Intel hardware features:
+
+ * source: an object representing (starting address/pointer of) a 2D memory reagion.
+ It can be either a 2D memref object, or simply a pointer represented by uint64_t type.
+ * offsets: two index values represents offsets from the "source" at the each dimension
+ at which the subview of the target memory will be created. It is encoded via two
+ variables, including "dynamic_offsets" and "static_offsets", such that it can
+ accept various forms, such as, operands (e.g., [%c0, %c]) and attributes (e.g., [2, 4])).
+ * shape: the shape information of the memory region pointed by the "source". It is
+ typically encoded via the MemRefType of the source, e.g., memref<4096x4096xf16>.
+ But if "source" is simply a pointer represented as uint64_t type, or a memref
+ type without shape information e.g., memref<?x?xf16>, the shape information has
+ to be explicitly passed via the "dynamic_shape" argument. Currently "dynamic_shape"
+ only accepts operands(e.g., [%c4096, %c4096]), not attributes(e.g., [4096, 4096]).
+ * strides: the strides of the memory region pointed by the "source". Similar to shape,
+ it is typically encoded via the MemRefType of the source too. But if "source" is
+ simply a pointer represented as uint64_t type, or a memref type without shape
+ information e.g., memref<?x?xf16>, the strides information has to be explicitly
+ passed via the "dynamic_strides" argument. And it currently only accepts operands two.
+
+ Example 1 (suppose the tensor shape inferred by the compiler is 8x16):
+ %0 = memref.alloc() : memref<32x24xf32>
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %1 = xegpu.create_nd_tdesc %0[%c0, %c1]: memref<32x24xf32> -> TensorDesc<8x16xf32>
+
+ Example 2 (suppose the tensor shape inferred by the compiler is 8x16):
+ %0 = memref.alloc(%h, %w) : memref<?x?xf32>
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %1 = xegpu.create_nd_tdesc %0[%c0, %c1], [%h, %w], [%w, %c1]: memref<?x?xf32> -> TensorDesc<8x16xf32>
+
+ Example 3 (suppose the tensor shape inferred by the compiler is 8x16):
+ %0 = ... : ui64
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %1 = xegpu.create_nd_tdesc %0[%c0, %c1], [%h, %w], [%w, %c1]: ui64 -> TensorDesc<8x16xf32>
+ }];
+
+ let arguments = (ins XeGPU_BaseAddrType: $source,
+ Variadic<Index>: $dynamic_offsets,
+ Variadic<Index>: $dynamic_shape,
+ Variadic<Index>: $dynamic_strides,
+ DenseI64ArrayAttr: $static_offsets,
+ DefaultValuedAttr<XeGPU_ModeAttr, "xegpu::ModeKind::SIMT">: $mode);
+ let results = (outs XeGPU_TensorDesc:$TensorDesc);
+
+ let hasCustomAssemblyFormat = 1;
+ let skipDefaultBuilders = 1;
+ let hasVerifier = 1;
+
+ let builders = [
+ OpBuilder<(ins "Type": $TensorDesc, "Value": $source, "ValueRange": $offsets,
+ "ValueRange": $shape, "ValueRange": $strides,
+ "llvm::ArrayRef<int64_t>": $static_offsets,
+ CArg<"xegpu::ModeKind", "xegpu::ModeKind::SIMT">: $mode)>,
+
+ OpBuilder<(ins "Type": $tdesc, "Value": $source,
+ "llvm::ArrayRef<OpFoldResult>": $offsets,
+ CArg<"xegpu::ModeKind", "xegpu::ModeKind::SIMT">: $mode)>,
+
+ OpBuilder<(ins "Type": $tdesc, "Value": $source,
+ "llvm::ArrayRef<OpFoldResult>": $offsets,
+ "ValueRange": $shape, "ValueRange": $stride,
+ CArg<"xegpu::ModeKind", "xegpu::ModeKind::SIMT">: $mode)>
+ ];
+
+ let extraClassDeclaration = [{
+ /// Returns the type of the source memref operand.
+ Type getSourceType() {
+ return getSource().getType();
+ }
+
+ /// Returns the type of the result TensorDesc.
+ xegpu::TensorDescType getTensorDescType();
+
+ /// Returns the offsets info to the source. It consolidates
+ /// information from both dynamic_offsets and static_offsets
+ /// parameters. static_offsets parameter always has the expected
+ /// ranks with some dim could have ShapeType::kDynamic value
+ /// indicating the corresponding value should be from dynamic_offsets.
+ llvm::SmallVector<OpFoldResult> getOffsets();
+
+ /// returns the shape info of the source. It is either from the
+ /// memref type, if source is a memref with static shape
+ /// information or from the dynamic_shape parameter. If both
+ /// exists, the dynamic_shape parameter will be used and the
+ /// shape information from memref type will be ignored.
+ llvm::SmallVector<OpFoldResult> getShape();
+
+ /// returns the strides info of the source. It is either from the
+ /// memref type, if source is a memref with static shape
+ /// information or from the dynamic_stride parameter. If both
+ /// exists, the dynamic_strides parameter will be used and the
+ /// strides information from memref type will be ignored.
+ llvm::SmallVector<OpFoldResult> getStrides();
+
+ /// return the shape embeded in the memref type of the source.
+ /// If source is not memref type. array of kDynamic will be returned.
+ llvm::ArrayRef<int64_t> getStaticShape();
+
+ /// return the strides embeded in the memref type of the source.
+ /// If source is not memref type. array of kDynamic will be returned.
+ llvm::ArrayRef<int64_t> getStaticStrides();
+
+ /// Return the element type of the TensorDesc
+ Type getElementType();
+
+ /// Return the shape of the TensorDesc
+ llvm::ArrayRef<int64_t> getTensorDescShape();
+ }];
+
+}
+
+def XeGPU_LoadNDOp : XeGPU_Op<"load_nd"> {
+ let summary = "loads a n-D block from memory (represented by TensorDesc)"
+ "to registers (represented by vector)";
+ let description = [{
+ LoadNDOp essentially mimics the hardware block read instruction to read
+ a block of data from memory to register. It takes a set of cache hints
+ for each level of cache, L1, L2 and L3. If hardware does not have a
+ correspoding cache, Corresponding cache hint attribute will be masked.
+ If both transpose and vnni_axis present at the same time. It assume to
+ perform transpose first and then vnni transform.
+ }];
+
+ let arguments = (ins XeGPU_TensorDesc: $TensorDesc,
+ OptionalAttr<I64Attr>: $vnni_axis,
+ OptionalAttr<XeGPU_CacheAttr>: $l1_hint,
+ OptionalAttr<XeGPU_CacheAttr>: $l2_hint,
+ OptionalAttr<XeGPU_CacheAttr>: $l3_hint,
+ OptionalAttr<DenseI64ArrayAttr>: $transpose,
+ DefaultValuedAttr<XeGPU_ModeAttr, "xegpu::ModeKind::SIMT">: $mode);
+ let results = (outs XeGPU_ValueType: $value);
+
+ let extraClassDeclaration = [{
+ VectorType getValueType() {
+ return llvm::dyn_cast<VectorType>(getValue().getType());
+ }
+
+ xegpu::TensorDescType getTensorDescType() {
+ return getTensorDesc().getType();
+ }
+ }];
+
+ // Format: xegpu.load_nd %1 {transpose = [1, 0], l1_hint = cached, l2_hint = uncached, l3_hint=streaming}
+ // : !xegpu.tensor_desc<8x16xf32> -> vector<16x8xf32>
+ let hasCustomAssemblyFormat = 1;
+ let hasVerifier = 1;
+}
+
+def XeGPU_StoreNDOp : XeGPU_Op<"store_nd", []> {
+ let summary = "stores a n-D block register region back to memory, currently only supports 2D";
+ let arguments = (ins XeGPU_ValueType: $value,
+ XeGPU_TensorDesc: $TensorDesc,
+ OptionalAttr<XeGPU_CacheAttr>: $l1_hint,
+ OptionalAttr<XeGPU_CacheAttr>: $l2_hint,
+ OptionalAttr<XeGPU_CacheAttr>: $l3_hint,
+ DefaultValuedAttr<XeGPU_ModeAttr, "xegpu::ModeKind::SIMT">: $mode);
+
+ // Format: xegpu.store_nd %3, %2 {l1_hint = write_back, l2_hint = uncached}
+ // : vector<8x16xf16>, !xegpu.tensor_desc<8x16xf16>
+ let hasCustomAssemblyFormat = 1;
+ let hasVerifier = 1;
+}
+
+def XeGPU_PrefetchNDOp : XeGPU_Op<"prefetch_nd", []> {
+ let summary = "prefetches a nD block to cache";
+ let arguments = (ins XeGPU_TensorDesc: $TensorDesc,
+ OptionalAttr<XeGPU_CacheAttr>: $l1_hint,
+ OptionalAttr<XeGPU_CacheAttr>: $l2_hint,
+ OptionalAttr<XeGPU_CacheAttr>: $l3_hint,
+ DefaultValuedAttr<XeGPU_ModeAttr, "xegpu::ModeKind::SIMT">: $mode);
+
+ // Format: xegpu.prefetch_nd %tdesc {l1_hint = cached, l2_hint = uncached}:
+ // !xegpu.tensor_desc<8x16xf16>
+ let hasCustomAssemblyFormat = 1;
+}
+
+def XeGPU_UpdateNDOffsetOp : XeGPU_Op<"update_nd_offset", []> {
+ let summary = "update the offsets for the given tensor descriptor";
+
+ let arguments = (ins
+ XeGPU_TensorDesc: $TensorDesc,
+ Variadic<Index>: $offsets,
+ DefaultValuedAttr<XeGPU_ModeAttr, "xegpu::ModeKind::SIMT">: $mode);
+
+ let results = (outs XeGPU_TensorDesc: $result);
+
+ let hasCustomAssemblyFormat = 1;
+ let hasVerifier = 1;
+}
+
+def XeGPU_CreateDescOp: XeGPU_Op<"create_tdesc", [Pure]> {
+ let summary = "create scattered tensor descritors (TensorDesc).";
+ let description = [{
+ "create_tdesc" is similar to "create_nd_tdesc" in terms that it creates
+ a Tensor Descriptor (TensorDescType) for a memory region. While "create_nd_tdesc"
+ is for creating continious subviews, "create_tdesc" is for creating non-continious
+ (scattered) subviews. It is designed only works with VectorCompute (VC) mode and
+ accepts the following parameters:
+
+ * source: a 1D memref or pointer (uint64_t) represents the memory object.
+ * offsets: It is a 1D vector containing offsets of each access point, the supportted
+ group size, e.g., vector<16xindex>. And each element in the vector corresponds
+ to a work item (SIMT lane) in the subgroup.
+ * chunk_size_per_lane: [optional attribute] indicates number of continious elements
+ accessed for each offset, default is 1.
+
+ Example 1. It assumes subgroup size is 4, and accesses a[0], a[16], a[32], a[64]
+ %a = memref.alloc() : memref<1024xf32>
+ %c0 = arith.constant dense<0, 16, 32, 64> : vector<4xindex>
+ %1 = xegpu.create_tdesc %a, %c0: memref<1024xf32> -> TensorDesc<4xf32>
+
+ Example 2. It assumes subgroup size is 4, and each workitem access 8 elements.
+ It will access totally 32 data elements: a[0:7], a[16:23], a[32:39], a[64:71]
+ %0 = memref.alloc() : memref<1024xf32>
+ %c0 = arith.constant dense<0, 16, 32, 64> : vector<4xindex>
+ %1 = xegpu.create_tdesc %0, %c0 {chunk_size_per_lane = 8}: memref<1024xf32> -> TensorDesc<4x8xf32>
+ }];
+
+ let arguments = (ins XeGPU_BaseAddrType: $source,
+ XeGPU_OffsetType: $offsets,
+ DefaultValuedAttr<I64Attr, "1">: $chunk_size_per_lane,
+ DefaultValuedAttr<XeGPU_ModeAttr, "xegpu::ModeKind::SIMT">: $mode);
+ let results = (outs XeGPU_TensorDesc:$TensorDesc);
+
+ let builders = [
+ OpBuilder<(ins "xegpu::TensorDescType": $TensorDesc, "Value": $source,
+ "Value": $offsets, CArg<"uint32_t", "1"> : $chunk_size_per_lane)>,
+ OpBuilder<(ins "xegpu::TensorDescType": $TensorDesc, "Value": $source,
+ "Value": $offsets, "IntegerAttr": $chunk_size_per_lane)>
+ ];
+ let skipDefaultBuilders = 1;
+
+ // Format: xegpu.create_tdesc %src, %offsets {mode=simt, chunk_size_per_lane=1}
+ // : ui64, vector<16 x index> -> !xegpu.tensor_desc<16xf32, #xegpu.scattered>
+ let hasCustomAssemblyFormat = 1;
+ let hasVerifier = 1;
+}
+
+def XeGPU_LoadGatherOp : XeGPU_Op<"load"> {
+ let summary = "load a scalar at source[offset].";
+
+ let arguments = (ins XeGPU_TensorDesc: $TensorDesc,
+ XeGPU_MaskType: $mask,
+ OptionalAttr<I64Attr>: $vnni_axis,
+ OptionalAttr<DenseI64ArrayAttr>: $transpose,
+ OptionalAttr<XeGPU_CacheAttr>: $l1_hint,
+ OptionalAttr<XeGPU_CacheAttr>: $l2_hint,
+ OptionalAttr<XeGPU_CacheAttr>: $l3_hint,
+ DefaultValuedAttr<XeGPU_ModeAttr, "xegpu::ModeKind::SIMT">: $mode);
+ let results = (outs XeGPU_ValueType: $value);
+
+ let builders = [
+ OpBuilder<(ins "mlir::Type": $value, "mlir::Value": $TensorDesc,
+ "mlir::Value": $mask, "mlir::IntegerAttr": $vnni_axis,
+ CArg<"mlir::DenseI64ArrayAttr", "mlir::DenseI64ArrayAttr()">: $transpose,
+ CArg<"xegpu::CacheKindAttr", "xegpu::CacheKindAttr()">: $l1_hint,
+ CArg<"xegpu::CacheKindAttr", "xegpu::CacheKindAttr()">: $l2_hint,
+ CArg<"xegpu::CacheKindAttr", "xegpu::CacheKindAttr()">: $l3_hint)>,
+
+ OpBuilder<(ins "mlir::Type": $value, "mlir::Value": $TensorDesc,
+ "mlir::Value": $mask, "mlir::IntegerAttr": $vnni_axis,
+ CArg<"DenseI64ArrayAttr", "DenseI64ArrayAttr()">: $transpose,
+ CArg<"xegpu::CacheKind", "xegpu::CacheKind::CACHED">: $l1_hint,
+ CArg<"xegpu::CacheKind", "xegpu::CacheKind::CACHED">: $l2_hint,
+ CArg<"xegpu::CacheKind", "xegpu::CacheKind::CACHED">: $l3_hint)>
+ ];
+ let skipDefaultBuilders = 1;
+
+ // Format: %2 = xegpu.load %1, %0 {transpose = [1, 0], l1_hint = cached, l2_hint = uncached}
+ // : !xegpu.tensor_desc<16x8xf32, #xegpu.scattered>, vector<16x8xi1> -> vector<8x16xf32>
+ let hasCustomAssemblyFormat = 1;
+ let hasVerifier = 1;
+}
+
+def XeGPU_StoreScatterOp : XeGPU_Op<"store", []> {
+ let summary = "store a scalar to source[offset].";
+
+ let arguments = (ins
+ XeGPU_ValueType: $value,
+ XeGPU_TensorDesc: $TensorDesc,
+ XeGPU_MaskType: $mask,
+ OptionalAttr<XeGPU_CacheAttr>: $l1_hint,
+ OptionalAttr<XeGPU_CacheAttr>: $l2_hint,
+ OptionalAttr<XeGPU_CacheAttr>: $l3_hint,
+ DefaultValuedAttr<XeGPU_ModeAttr, "xegpu::ModeKind::SIMT">: $mode
+ );
+
+ let builders = [
+ OpBuilder<(ins "Value": $value, "Value": $TensorDesc, "Value": $mask,
+ CArg<"xegpu::CacheKindAttr", "xegpu::CacheKindAttr()">: $l1_hint,
+ CArg<"xegpu::CacheKindAttr", "xegpu::CacheKindAttr()">: $l2_hint,
+ CArg<"xegpu::CacheKindAttr", "xegpu::CacheKindAttr()">: $l3_hint)>,
+ OpBuilder<(ins "Value": $value, "Value": $TensorDesc, "Value": $mask,
+ CArg<"xegpu::CacheKind", "xegpu::CacheKind::WRITE_BACK">: $l1_hint,
+ CArg<"xegpu::CacheKind", "xegpu::CacheKind::WRITE_BACK">: $l2_hint,
+ CArg<"xegpu::CacheKind", "xegpu::CacheKind::WRITE_BACK">: $l3_hint)>
+ ];
+ let skipDefaultBuilders = 1;
+
+ // Format: %3 = xegpu.load %1, %0 {l1_hint = cached, l2_hint = uncached}
+ // : !xegpu.tensor_desc<16xf32, #xegpu.scattered>, vector<16xi1> -> vector<16xf32>
+ let hasCustomAssemblyFormat = 1;
+ let hasVerifier = 1;
+}
+
+def XeGPU_PrefetchOp : XeGPU_Op<"prefetch", []> {
+ let summary = "prefetches a nD block to cache";
+ let arguments = (ins XeGPU_TensorDesc: $TensorDesc,
+ OptionalAttr<XeGPU_CacheAttr>: $l1_hint,
+ OptionalAttr<XeGPU_CacheAttr>: $l2_hint,
+ OptionalAttr<XeGPU_CacheAttr>: $l3_hint,
+ DefaultValuedAttr<XeGPU_ModeAttr, "xegpu::ModeKind::SIMT">: $mode);
+
+ let builders = [
+ OpBuilder<(ins "Value": $TensorDesc,
+ CArg<"xegpu::CacheKindAttr", "xegpu::CacheKindAttr()">: $l1_hint,
+ CArg<"xegpu::CacheKindAttr", "xegpu::CacheKindAttr()">: $l2_hint,
+ CArg<"xegpu::CacheKindAttr", "xegpu::CacheKindAttr()">: $l3_hint)>,
+ OpBuilder<(ins "Value": $TensorDesc,
+ CArg<"xegpu::CacheKind", "xegpu::CacheKind::CACHED">: $l1_hint,
+ CArg<"xegpu::CacheKind", "xegpu::CacheKind::CACHED">: $l2_hint,
+ CArg<"xegpu::CacheKind", "xegpu::CacheKind::CACHED">: $l3_hint)>
+ ];
+
+ let skipDefaultBuilders = 1;
+ let hasVerifier = 1;
+
+ // Format: xegpu.prefetch %tdesc {l1_hint = cached, l2_hint = uncached}:
+ // !xegpu.tensor_desc<8x16xf16>
+ let hasCustomAssemblyFormat = 1;
+}
+
+def XeGPU_UpdateOffsetOp: XeGPU_Op<"update_offset", []> {
+ let summary = "update the offsets for the given tensor descriptor";
+ let arguments = (ins XeGPU_TensorDesc: $TensorDesc,
+ XeGPU_OffsetType: $offsets,
+ DefaultValuedAttr<XeGPU_ModeAttr, "xegpu::ModeKind::SIMT">: $mode);
+ let results = (outs XeGPU_TensorDesc: $result);
+
+ let builders = [
+ OpBuilder<(ins "Type": $result, "Value": $TensorDesc, "Value": $offsets)>
+ ];
+
+ let skipDefaultBuilders = 1;
+ let hasCustomAssemblyFormat = 1;
+ let hasVerifier = 1;
+}
+
+def XeGPU_DpasOp : XeGPU_Op<"dpas"> {
+ let summary = "performs dpas computation";
+ let arguments = (ins
+ XeGPU_DpasOpType : $lhs,
+ XeGPU_DpasOpType : $rhs,
+ Optional<XeGPU_Vector2DType>: $acc,
+ DefaultValuedAttr<XeGPU_ModeAttr, "xegpu::ModeKind::SIMT">: $mode
+ );
+ let results = (outs XeGPU_Vector2DType: $result);
+ let hasCustomAssemblyFormat = 1;
+
+ let extraClassDeclaration = [{
+ VectorType getLhsType() {
+ return ::llvm::cast<VectorType>(getLhs().getType());
+ }
+
+ VectorType getRhsType() {
+ return ::llvm::cast<VectorType>(getRhs().getType());
+ }
+
+ VectorType getAccType() {
+ return ::llvm::cast<VectorType>(getAcc().getType());
+ }
+
+ VectorType getResultType() {
+ return getResult().getType();
+ }
+ }];
+
+ let hasVerifier = 1;
+}
+
+def XeGPU_InvokeSIMDOp : XeGPU_Op<"invoke_SIMD", []> {
+ let summary = "Invoke_SIMD operation";
+ let description = [{
+ The `xegpu.invoke_SIMD` operation works similar to a direct call to a function.
+ But it is special to Intel GPU.
+ }];
+
+ let arguments = (ins FlatSymbolRefAttr:$callee,
+ Variadic<AnyType>:$operands,
+ XeGPU_ArgTypeAttr: $argType);
+ let results = (outs Variadic<AnyType>);
+
+ let builders = [
+ OpBuilder<(ins "SymbolRefAttr":$callee, "TypeRange":$results,
+ "xegpu::ArgTypeKindAttr":$argType, CArg<"ValueRange", "{}">:$operands)>,
+ OpBuilder<(ins "StringAttr":$callee, "TypeRange":$results,
+ "xegpu::ArgTypeKindAttr":$argType, CArg<"ValueRange", "{}">:$operands)>,
+ OpBuilder<(ins "llvm::StringRef":$callee, "TypeRange":$results,
+ "xegpu::ArgTypeKindAttr":$argType, CArg<"ValueRange", "{}">:$operands)>
+ ];
+}
+
+def XeGPU_AtomicRMWOp: XeGPU_Op<"atomic_rmw", []> {
+ let summary = "perform ready-modify-write operation that is free from data races.";
+ let arguments = (ins
+ XeGPU_AtomicRMWKindAttr:$kind,
+ XeGPU_TensorDesc:$tensorDesc,
+ XeGPU_MaskType:$mask,
+ Optional<XeGPU_ValueType>:$value,
+ DefaultValuedAttr<XeGPU_ModeAttr, "xegpu::ModeKind::SIMT">: $mode
+ );
+
+ let results = (outs XeGPU_ValueType:$result);
+ let hasCustomAssemblyFormat = 1;
+
+ let builders = [
+ OpBuilder<(ins "Type": $result, "xegpu::AtomicRMWKindAttr": $kind,
+ "Value": $tensorDesc, "Value": $mask, "Value": $value)>,
+ OpBuilder<(ins "Type": $result, "xegpu::AtomicRMWKind": $kind,
+ "Value": $tensorDesc, "Value": $mask, "Value": $value)>
+ ];
+
+ let skipDefaultBuilders = 1;
+ let hasVerifier = 1;
+}
+
+def XeGPU_AllocNbarrierOp: XeGPU_Op<"alloc_nbarrier", []> {
+ let summary = "allocate a specific number of named barriers.";
+ let arguments = (ins I64Attr: $nbarrierCount);
+ let assemblyFormat = "$nbarrierCount attr-dict";
+}
+
+
+def XeGPU_CreateNbarrierOp: XeGPU_Op<"create_nbarrier", []> {
+ let summary = "create a named barrier.";
+ let arguments = (ins I8: $nbarrier_id,
+ I8: $nbarrier_role,
+ I8Attr: $num_producers,
+ I8Attr: $num_consumers,
+ DefaultValuedAttr<XeGPU_ModeAttr, "xegpu::ModeKind::SIMT">: $mode);
+ let results = (outs XeGPU_Nbarrier: $result);
+ let hasCustomAssemblyFormat = 1;
+}
+
+def XeGPU_NbarrierArriveOp: XeGPU_Op<"nbarrier_arrive", []> {
+ let summary = "arrive at a named barrier.";
+ let arguments = (ins XeGPU_Nbarrier: $payload);
+ let assemblyFormat = [{ $payload attr-dict `:` qualified(type($payload))}];
+}
+
+def XeGPU_NbarrierWaitOp: XeGPU_Op<"nbarrier_wait", []> {
+ let summary = "wait for a named barrier.";
+ let arguments = (ins XeGPU_Nbarrier: $payload);
+ let assemblyFormat = [{ $payload attr-dict `:` qualified(type($payload)) }];
+}
+
+def XeGPU_CompileHintOp: XeGPU_Op<"compile_hint", []> {
+ let summary = "prevents the compiler from scheduling.";
+ let assemblyFormat = [{ attr-dict }];
+}
+
+def XeGPU_MfenceOp: XeGPU_Op<"mfence", []> {
+ let summary = "lsc fence.";
+ let arguments = (ins StrAttr: $memory_kind,
+ StrAttr: $fence_op,
+ StrAttr: $fence_scope);
+ let assemblyFormat = [{ attr-dict }];
+}
+
+#endif // MLIR_DIALECT_XEGPU_IR_XEGPUOPS_TD
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
new file mode 100644
index 00000000000000..b3dceff9587ada
--- /dev/null
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
@@ -0,0 +1,170 @@
+//===- XeGPUTypes.td - XeGPU dialect types definition -------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_XEGPU_IR_XEGPUTYPES_TD
+#define MLIR_DIALECT_XEGPU_IR_XEGPUTYPES_TD
+
+include "mlir/IR/BuiltinTypes.td"
+
+include "mlir/Dialect/XeGPU/IR/XeGPUAttrs.td"
+include "mlir/Dialect/XeGPU/IR/XeGPUDialect.td"
+
+// An Integer array attribute with fixed 2 elements.
+def XeGPU_IntType: AnyTypeOf<[I1, I8, I16, I32, I64, SI1, SI8, SI16, SI32, SI64, UI1, UI8, UI16, UI32, UI64]>;
+def XeGPU_FloatType: AnyTypeOf<[F16, F32, F64, BF16, TF32]>;
+def XeGPU_ScalarType: AnyTypeOf<[XeGPU_IntType, XeGPU_FloatType]>;
+def XeGPU_BaseAddrType: AnyTypeOf<[MemRefRankOf<[XeGPU_ScalarType], [1, 2]>, UI64, UI32, I64, I32]>;
+def XeGPU_DpasOpType: VectorOfRankAndType<[2, 3], [XeGPU_ScalarType]>;
+// def XeGPU_OffsetType: AnyTypeOf<[VectorOfRankAndType<[1], [Index]>, Index]>;
+def XeGPU_OffsetType: VectorOfRankAndType<[1], [Index]>;
+def XeGPU_MaskType: AnyTypeOf<[VectorOfRankAndType<[1,2], [I1]>, I1]>;
+def XeGPU_ValueType: AnyTypeOf<[VectorOfRankAndType<[1,2,3,4], [XeGPU_ScalarType]>, XeGPU_ScalarType]>;
+
+def XeGPU_Vector2DType: VectorOfRankAndType<[2], [XeGPU_ScalarType]>;
+
+// common base class for types in XeGPU dialect
+class XeGPUTypeDef<string name, string typeMnemonic, list<Trait> traits = [],
+ string baseCppClass = "::mlir::Type">
+ : TypeDef<XeGPU_Dialect, name, traits, baseCppClass> {
+ let mnemonic = typeMnemonic;
+}
+
+// TensorDesc contains dim and element type info
+def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
+ [ShapedTypeInterface], "::mlir::TensorType"> {
+ let summary = "TensorDesc describing all kinds of memory and tensors, including scatter tensor, 1d tensor, 2d tensor, … 5d tensor";
+ let description = [{
+ TensorDesc is a type designed to describe all kinds of memory, scatter tensor, 1d tensor, 2d tensor, … 5d tensor.
+ Different with the builtin tensor type in MLIR, it essentially only contains the meta data that describes a region
+ of the intereted data as well as some features that are unique to intel hardware features. It does not hold the data
+ directly by itself. It is designed to mainly support 2d block load/store and DPAS (matrix multiplication instruction)
+ on Intel GPU. It majorly encodes the following information:
+
+ * shape: the sizes/shape of the intereted data block, e.g., 8x16 means 8 rows
+ and each row contains 16 continious data element. The rows could be
+ either continuous or not, depends on whether the encoding attribute
+ is set or not.
+ * element_type: the data type of the data element, e.g., f16, f32.
+
+ Similar to the builtin tensor, it also provides an optinal attribute to encoding the following information via the TensorDescAttr object:
+ * memory_scope (xegpu::MemoryScope): [optional] where the data is located, global memory or shared memory. It is default to Global.
+ * array_length (int): [optional] The number of continuous blocks with size as `shape`,
+ that will be loaded by block load at a time. It is default to 1.
+ * boundary_check (bool): [optional] indicates whether the operation detects the boundary and pads with zero for out-of-boundary access (default)
+ * scattered (xegpu::ScatteredAttr): [optional] It is a unit attribute. It can be only set as empty or ScatteredAttr, indicating
+ whether the TensorDesc is blocked (empty, default) or scattered (ScatteredAttr). If it is
+ blocked, rows are continuous in the correspoding dimention, otherwise, rows may be not continous.
+ * mapping (xegpu::SubGroupMapAttr): [optional] Used to guide compiler to distribute the workload into different threads. It is default to none.
+
+ For convinience, its attribute field can also take either "ScatteredAttr" or "SubGroupMapAttr" directly if and only
+ if others are taking default values.
+
+ Syntax:
+
+ ```
+ TensorDesc-type ::= `tensor_desc` `<` dim-list element-type (attr-list)? `>`
+ element-type ::= float-type | integer-type | index-type
+ dim-list := (static-dim-list `x`)?
+ static-dim-list ::= decimal-literal `x` decimal-literal
+ attr-list = (, memory_scope = value)? (, arr_len = value)? (, ScatteredAttr)? (, mapping)?
+ ```
+
+ Examples:
+
+ ```mlir
+ // A block TensorDesc with 3x42 i32 elements
+ xegpu.tensor_desc<3x42xi32>
+
+ // A block TensorDesc with 4x5 f32 elements
+ xegpu.tensor_desc<4x5xf32>
+
+ // A Scattered TensorDesc with 16x4 f32 elements
+ xegpu.tensor_desc<16x4xf32, #!xegpu.scattered>
+
+ // A TensorDesc with 8x16 f16 elements.
+ // It will be distributed accross 16 hardware threads, organized as [2, 8],
+ // and each access 2 continious elements in dim 1.
+ xegpu.tensor_desc<8x16xf16, #xegpu.sg_map<wi_layout = [2, 8], wi_data = [1, 2]>>
+
+ // A TensorDesc with 8x16 f32 elements for a memory region in shared memory space.
+ xegpu.tensor_desc<8x16xf32, #xegpu.tdesc_attr<memory_scope = slm>>
+ ```
+ }];
+
+ let parameters = (ins ArrayRefParameter<"int64_t">: $shape,
+ "mlir::Type": $elementType,
+ OptionalParameter<"mlir::Attribute">: $encoding);
+
+ let builders = [
+ TypeBuilderWithInferredContext<(ins
+ "llvm::ArrayRef<int64_t>":$shape, "mlir::Type":$elementType,
+ CArg<"mlir::Attribute", "{}"> : $encoding
+ )>,
+ TypeBuilder<(ins
+ "llvm::ArrayRef<int64_t>": $shape, "mlir::Type": $elementType,
+ "mlir::xegpu::MemoryScopeKind": $memory_scope, "int": $array_length,
+ "bool": $boundary_check, "mlir::xegpu::ScatteredAttr": $scattered,
+ "mlir::xegpu::SubGroupMapAttr": $mapping
+ )>,
+ TypeBuilderWithInferredContext<(ins
+ "llvm::ArrayRef<int64_t>": $shape, "mlir::Type": $elementType,
+ "mlir::xegpu::MemoryScopeKind": $memory_scope, "int": $array_length,
+ "bool": $boundary_check, "mlir::xegpu::ScatteredAttr": $scattered,
+ "mlir::xegpu::SubGroupMapAttr": $mapping
+ )>
+ ];
+
+ let extraClassDeclaration = [{
+ using TensorType::clone;
+ using mlir::ShapedType::Trait<TensorDescType>::getElementTypeBitWidth;
+ using mlir::ShapedType::Trait<TensorDescType>::getRank;
+ using mlir::ShapedType::Trait<TensorDescType>::getNumElements;
+ using mlir::ShapedType::Trait<TensorDescType>::isDynamicDim;
+ using mlir::ShapedType::Trait<TensorDescType>::hasStaticShape;
+ using mlir::ShapedType::Trait<TensorDescType>::getNumDynamicDims;
+ using mlir::ShapedType::Trait<TensorDescType>::getDimSize;
+ using mlir::ShapedType::Trait<TensorDescType>::getDynamicDimIndex;
+
+ TensorDescType clone(::mlir::Type elementType) {
+ return llvm::cast<TensorDescType>(cloneWith(getShape(), elementType));
+ }
+
+ TensorDescAttr getEncodingAsTensorDescAttr() const {
+ return llvm::dyn_cast_if_present<TensorDescAttr>(getEncoding());
+ }
+
+ SubGroupMapAttr getEncodingAsMapAttr() const {
+ return llvm::dyn_cast_if_present<SubGroupMapAttr>(getEncoding());
+ }
+
+ ScatteredAttr getEncodingAsScatteredAttr() const {
+ return llvm::dyn_cast_if_present<ScatteredAttr>(getEncoding());
+ }
+
+ xegpu::MemoryScopeKind getMemoryScope();
+ int getArrayLength();
+ bool getBoundaryCheck();
+ xegpu::ScatteredAttr getScattered();
+ xegpu::SubGroupMapAttr getMapping();
+ }];
+
+ let hasCustomAssemblyFormat = true;
+}
+
+
+def XeGPU_Nbarrier: XeGPUTypeDef<"Nbarrier", "nbarrier", [], "mlir::Type"> {
+ let summary = "!xegpu.nbarrier a custom XeGPU type representing a barrier.";
+
+ let extraClassDeclaration = [{
+ static NbarrierType get(mlir::MLIRContext *context) {
+ return Base::get(context);
+ };
+ }];
+}
+
+#endif // MLIR_DIALECT_XEGPU_IR_XEGPUTYPES_TD
diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index 19a62cadaa2e04..838b7b87b09b64 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -87,6 +87,7 @@
#include "mlir/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Vector/Transforms/SubsetOpInterfaceImpl.h"
#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
+#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
#include "mlir/IR/Dialect.h"
#include "mlir/Interfaces/CastInterfaces.h"
#include "mlir/Target/LLVM/NVVM/Target.h"
@@ -138,7 +139,8 @@ inline void registerAllDialects(DialectRegistry ®istry) {
transform::TransformDialect,
ub::UBDialect,
vector::VectorDialect,
- x86vector::X86VectorDialect>();
+ x86vector::X86VectorDialect,
+ xegpu::XeGPUDialect>();
// clang-format on
// Register all external models.
diff --git a/mlir/lib/Dialect/CMakeLists.txt b/mlir/lib/Dialect/CMakeLists.txt
index 68776a695cac4d..f5eeaaed5af97d 100644
--- a/mlir/lib/Dialect/CMakeLists.txt
+++ b/mlir/lib/Dialect/CMakeLists.txt
@@ -39,6 +39,7 @@ add_subdirectory(UB)
add_subdirectory(Utils)
add_subdirectory(Vector)
add_subdirectory(X86Vector)
+add_subdirectory(XeGPU)
set(LLVM_OPTIONAL_SOURCES
Traits.cpp
diff --git a/mlir/lib/Dialect/XeGPU/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/CMakeLists.txt
new file mode 100644
index 00000000000000..f33061b2d87cff
--- /dev/null
+++ b/mlir/lib/Dialect/XeGPU/CMakeLists.txt
@@ -0,0 +1 @@
+add_subdirectory(IR)
diff --git a/mlir/lib/Dialect/XeGPU/IR/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/IR/CMakeLists.txt
new file mode 100644
index 00000000000000..2e99f39ed86d2e
--- /dev/null
+++ b/mlir/lib/Dialect/XeGPU/IR/CMakeLists.txt
@@ -0,0 +1,15 @@
+add_mlir_dialect_library(MLIRXeGPUDialect
+ XeGPUDialect.cpp
+ XeGPUOps.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${PROJECT_SOURCE_DIR}/include/mlir/Dialect/XeGPU
+
+ DEPENDS
+ MLIRXeGPUIncGen
+ MLIRXeGPUAttrsIncGen
+ MLIRXeGPUEnumsIncGen
+
+ LINK_LIBS PUBLIC
+ MLIRIR
+)
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
new file mode 100644
index 00000000000000..60ab50227c2247
--- /dev/null
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -0,0 +1,385 @@
+//===- XeGPUDialect.cpp - MLIR XeGPU dialect implementation -----*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include <llvm/ADT/TypeSwitch.h>
+#include <llvm/Support/Debug.h>
+#include <mlir/Dialect/XeGPU/IR/XeGPU.h>
+
+#include <mlir/Dialect/Arith/IR/Arith.h>
+#include <mlir/Dialect/Linalg/IR/Linalg.h>
+#include <mlir/Dialect/MemRef/IR/MemRef.h>
+#include <mlir/Dialect/Tensor/IR/Tensor.h>
+#include <mlir/Dialect/Utils/StaticValueUtils.h>
+#include <mlir/IR/Builders.h>
+#include <mlir/IR/DialectImplementation.h>
+#include <mlir/IR/TypeUtilities.h>
+
+#include <numeric>
+
+namespace mlir {
+namespace xegpu {
+
+void XeGPUDialect::initialize() {
+ addTypes<
+#define GET_TYPEDEF_LIST
+#include <mlir/Dialect/XeGPU/IR/XeGPUTypes.cpp.inc>
+ >();
+ addOperations<
+#define GET_OP_LIST
+#include <mlir/Dialect/XeGPU/IR/XeGPU.cpp.inc>
+ >();
+ addAttributes<
+#define GET_ATTRDEF_LIST
+#include <mlir/Dialect/XeGPU/IR/XeGPUAttrs.cpp.inc>
+ >();
+}
+
+bool printDefaultValues() {
+ auto *env = getenv("MLIR_XEGPU_PRINT_DEFAULTS");
+ if (env && std::string(env) == "true")
+ return true;
+ return false;
+}
+
+SubGroupMapAttr SubGroupMapAttr::get(mlir::MLIRContext *context,
+ llvm::ArrayRef<int32_t> wiLayout,
+ llvm::ArrayRef<int32_t> wiData) {
+ assert(wiLayout.size() == 2 && wiData.size() == 2 &&
+ "wiLayout and wiData should be 2D arrays.\n");
+ return Base::get(context, mlir::DenseI32ArrayAttr::get(context, wiLayout),
+ mlir::DenseI32ArrayAttr::get(context, wiData));
+}
+
+mlir::LogicalResult SubGroupMapAttr::verify(
+ llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
+ mlir::DenseI32ArrayAttr layout, mlir::DenseI32ArrayAttr data) {
+
+ if (layout.size() != 2) {
+ emitError() << "Failed to parse SubGroupMapAttr: missing wi_layout which "
+ "is to be an integer array of size 2.\n";
+ return mlir::failure();
+ }
+
+ if (data.size() != 2) {
+ emitError() << "Failed to parse SubGroupMapAttr: missing wi_data which is "
+ "to be an integer array of size 2.\n";
+ return mlir::failure();
+ }
+
+ return mlir::success();
+}
+
+mlir::Attribute TensorDescAttr::parse(mlir::AsmParser &parser,
+ mlir::Type type) {
+ mlir::FailureOr<xegpu::MemoryScopeKind> memory_scope;
+ mlir::FailureOr<int> array_length;
+ mlir::FailureOr<bool> boundary_check;
+ mlir::FailureOr<xegpu::ScatteredAttr> scattered;
+ mlir::FailureOr<xegpu::SubGroupMapAttr> map;
+
+ bool seen_memory_scope = false;
+ bool seen_array_length = false;
+ bool seen_boundary_check = false;
+ bool seen_scattered = false;
+ bool seen_map = false;
+
+ // Parse literal '<'
+ if (parser.parseLess())
+ return {};
+
+ // Parse elements
+ auto parseElt = [&]() -> mlir::ParseResult {
+ llvm::StringRef paramKey;
+
+ if (!parser.parseOptionalKeyword(¶mKey)) {
+ if (parser.parseEqual())
+ return mlir::failure();
+
+ if (!seen_memory_scope && paramKey == "memory_scope") {
+ seen_memory_scope = true;
+ // Parse variable 'memory_scope'
+ memory_scope =
+ mlir::FieldParser<mlir::xegpu::MemoryScopeKind>::parse(parser);
+ if (mlir::failed(memory_scope))
+ return parser.emitError(
+ parser.getCurrentLocation(),
+ "Failed to parse the 'memory_scope' of TensorDescAttr, which is "
+ "to be a `xegpu::MemoryScope`");
+ } else if (!seen_array_length && paramKey == "array_length") {
+ seen_array_length = true;
+ // Parse variable 'array_length'
+ array_length = ::mlir::FieldParser<int>::parse(parser);
+ if (mlir::failed(array_length))
+ return parser.emitError(parser.getCurrentLocation(),
+ "Failed to parse the 'array_length' of "
+ "TensorDescAttr, which is to be a `int`");
+ } else if (!seen_boundary_check && paramKey == "boundary_check") {
+ seen_boundary_check = true;
+ // Parse variable 'boundary_check'
+ boundary_check = ::mlir::FieldParser<bool>::parse(parser);
+ if (::mlir::failed(boundary_check))
+ return parser.emitError(parser.getCurrentLocation(),
+ "Failed to parse the 'boundary_check' of "
+ "TensorDescAttr, which is to be a `bool`");
+ } else if (!seen_map && paramKey == "map") {
+ seen_map = true;
+ // Parse variable 'map'
+ map = ::mlir::FieldParser<xegpu::SubGroupMapAttr>::parse(parser);
+ if (::mlir::failed(map))
+ return parser.emitError(
+ parser.getCurrentLocation(),
+ "Failed to parse the 'map' of TensorDescAttr, which is to be a "
+ "`xegpu::SubGroupMapAttr`");
+ }
+ } else if (!seen_scattered) {
+ // parse scattered
+ scattered = mlir::FieldParser<xegpu::ScatteredAttr>::parse(parser);
+ if (mlir::failed(scattered))
+ return parser.emitError(
+ parser.getCurrentLocation(),
+ "Failed to parse 'scattered' attr of TensorDescAttr, which is to "
+ "be a `xegpu::ScatteredAttr`");
+ seen_scattered = true;
+ }
+ return mlir::success();
+ };
+
+ if (parser.parseCommaSeparatedList(parseElt))
+ return {};
+
+ // Parse literal '>'
+ if (parser.parseGreater())
+ return {};
+ return TensorDescAttr::get(
+ parser.getContext(),
+ memory_scope.value_or(xegpu::MemoryScopeKind::GLOBAL),
+ array_length.value_or(1), boundary_check.value_or(true),
+ scattered.value_or(xegpu::ScatteredAttr()),
+ map.value_or(xegpu::SubGroupMapAttr()));
+}
+
+void TensorDescAttr::print(::mlir::AsmPrinter &printer) const {
+ bool printSep = false;
+ bool printDefaults = printDefaultValues();
+
+ printer << "<";
+
+ if (printDefaults || getMemoryScope() != xegpu::MemoryScopeKind::GLOBAL) {
+ if (printSep)
+ printer << ", ";
+ printSep = true;
+ printer << "memory_scope = ";
+ printer.printStrippedAttrOrType(getMemoryScope());
+ }
+ if (printDefaults || getArrayLength() != 1) {
+ if (printSep)
+ printer << ", ";
+ printSep = true;
+ printer << "array_length = ";
+ printer.printStrippedAttrOrType(getArrayLength());
+ }
+ if (printDefaults || getBoundaryCheck() != true) {
+ if (printSep)
+ printer << ", ";
+ printSep = true;
+ printer << "boundary_check = ";
+ printer.printStrippedAttrOrType(getBoundaryCheck());
+ }
+ if (getScattered()) {
+ if (printSep)
+ printer << ", ";
+ printSep = true;
+ printer.printStrippedAttrOrType(getScattered());
+ }
+ if (getMap()) {
+ if (printSep)
+ printer << ", ";
+ printSep = true;
+ printer << "map = ";
+ printer.printStrippedAttrOrType(getMap());
+ }
+ printer << ">";
+}
+
+bool TensorDescAttr::hasNonDefaultAttrs() {
+ int count = 0;
+ if (getMemoryScope() != MemoryScopeKind::GLOBAL)
+ count++;
+ if (getBoundaryCheck() != true)
+ count++;
+ if (getArrayLength() != 1)
+ count++;
+ if (getScattered())
+ count++;
+ if (getMap())
+ count++;
+ return count;
+}
+
+TensorDescAttr TensorDescAttr::get(mlir::MLIRContext *context,
+ xegpu::MemoryScopeKind memory_scope,
+ int array_length,
+ xegpu::ScatteredAttr scattered,
+ xegpu::SubGroupMapAttr map) {
+ return Base::get(context, std::move(memory_scope), std::move(array_length),
+ true, std::move(scattered), std::move(map));
+}
+
+mlir::Type TensorDescType::parse(::mlir::AsmParser &parser) {
+ llvm::SmallVector<int64_t> shape;
+ mlir::Type elementType;
+ mlir::FailureOr<mlir::Attribute> encoding;
+
+ // Parse literal '<'
+ if (parser.parseLess())
+ return {};
+
+ auto shapeLoc = parser.getCurrentLocation();
+ if (mlir::failed(parser.parseDimensionList(shape))) {
+ parser.emitError(shapeLoc, "failed to parse parameter 'shape'");
+ return {};
+ }
+
+ auto elemTypeLoc = parser.getCurrentLocation();
+ if (mlir::failed(parser.parseType(elementType))) {
+ parser.emitError(elemTypeLoc, "failed to parse parameter 'elementType'");
+ return {};
+ }
+
+ // parse optional attributes
+ if (mlir::succeeded(parser.parseOptionalComma())) {
+ encoding = mlir::FieldParser<mlir::Attribute>::parse(parser);
+ if (mlir::failed(encoding)) {
+ parser.emitError(
+ parser.getCurrentLocation(),
+ "Failed to parse the attribute field for TensorDescType.\n");
+ return {};
+ }
+ }
+
+ // Parse literal '>'
+ if (parser.parseGreater())
+ return {};
+
+ return TensorDescType::get(parser.getContext(), shape, elementType,
+ encoding.value_or(mlir::Attribute()));
+}
+
+void TensorDescType::print(::mlir::AsmPrinter &printer) const {
+ printer << "<";
+
+ auto shape = getShape();
+ for (int64_t dim : shape) {
+ if (mlir::ShapedType::isDynamic(dim))
+ printer << '?';
+ else
+ printer << dim;
+ printer << 'x';
+ }
+ printer << getElementType();
+
+ if (printDefaultValues()) {
+ auto encoding = getEncoding();
+ if (auto attr = getEncodingAsMapAttr()) {
+ encoding = TensorDescAttr::get(getContext(), MemoryScopeKind::GLOBAL, 1,
+ {}, attr);
+ }
+ if (auto attr = getEncodingAsScatteredAttr()) {
+ encoding = TensorDescAttr::get(getContext(), MemoryScopeKind::GLOBAL, 1,
+ attr, {});
+ }
+ printer << ", " << encoding;
+ } else if (auto encoding = getEncodingAsTensorDescAttr()) {
+ if (encoding.hasNonDefaultAttrs())
+ printer << ", " << encoding;
+ } else if (auto encoding = getEncoding()) {
+ printer << ", " << encoding;
+ }
+ printer << ">";
+}
+
+TensorDescType TensorDescType::get(llvm::ArrayRef<int64_t> shape,
+ mlir::Type elementType,
+ mlir::Attribute encoding) {
+ return Base::get(elementType.getContext(), shape, elementType, encoding);
+}
+
+TensorDescType TensorDescType::get(mlir::MLIRContext *context,
+ llvm::ArrayRef<int64_t> shape,
+ mlir::Type elementType,
+ mlir::xegpu::MemoryScopeKind memory_scope,
+ int array_length, bool boundary_check,
+ mlir::xegpu::ScatteredAttr scattered,
+ mlir::xegpu::SubGroupMapAttr mapping) {
+ auto attr = TensorDescAttr::get(context, memory_scope, array_length,
+ boundary_check, scattered, mapping);
+ return Base::get(context, shape, elementType, attr);
+}
+
+TensorDescType TensorDescType::get(llvm::ArrayRef<int64_t> shape,
+ mlir::Type elementType,
+ mlir::xegpu::MemoryScopeKind memory_scope,
+ int array_length, bool boundary_check,
+ mlir::xegpu::ScatteredAttr scattered,
+ mlir::xegpu::SubGroupMapAttr mapping) {
+ auto attr =
+ TensorDescAttr::get(elementType.getContext(), memory_scope, array_length,
+ boundary_check, scattered, mapping);
+ return Base::get(elementType.getContext(), shape, elementType, attr);
+}
+
+xegpu::MemoryScopeKind TensorDescType::getMemoryScope() {
+ auto attr = getEncodingAsTensorDescAttr();
+ if (attr)
+ return attr.getMemoryScope();
+ // return default value
+ return MemoryScopeKind::GLOBAL;
+}
+
+int TensorDescType::getArrayLength() {
+ auto attr = getEncodingAsTensorDescAttr();
+ if (attr)
+ return attr.getArrayLength();
+ // return default value
+ return 1;
+}
+
+bool TensorDescType::getBoundaryCheck() {
+ auto attr = getEncodingAsTensorDescAttr();
+ if (attr)
+ return attr.getBoundaryCheck();
+ // return default value
+ return true;
+}
+
+xegpu::ScatteredAttr TensorDescType::getScattered() {
+ if (auto attr = getEncodingAsTensorDescAttr())
+ return attr.getScattered();
+ if (auto attr = getEncodingAsScatteredAttr())
+ return attr;
+ // return default value
+ return {};
+}
+
+xegpu::SubGroupMapAttr TensorDescType::getMapping() {
+ if (auto attr = getEncodingAsTensorDescAttr())
+ return attr.getMap();
+ if (auto attr = getEncodingAsMapAttr())
+ return attr;
+ // return default value
+ return xegpu::SubGroupMapAttr();
+}
+
+} // namespace xegpu
+} // namespace mlir
+
+#include <mlir/Dialect/XeGPU/IR/XeGPUDialect.cpp.inc>
+#define GET_ATTRDEF_CLASSES
+#include <mlir/Dialect/XeGPU/IR/XeGPUAttrs.cpp.inc>
+#define GET_TYPEDEF_CLASSES
+#include <mlir/Dialect/XeGPU/IR/XeGPUTypes.cpp.inc>
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
new file mode 100644
index 00000000000000..627680e84ec949
--- /dev/null
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -0,0 +1,1929 @@
+//===- XeGPUOps.cpp - MLIR XeGPU ops implementation -------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include <llvm/ADT/TypeSwitch.h>
+#include <llvm/Support/Debug.h>
+#include <mlir/Dialect/Arith/IR/Arith.h>
+#include <mlir/Dialect/Linalg/IR/Linalg.h>
+#include <mlir/Dialect/MemRef/IR/MemRef.h>
+#include <mlir/Dialect/Tensor/IR/Tensor.h>
+#include <mlir/Dialect/Utils/StaticValueUtils.h>
+#include <mlir/Dialect/XeGPU/IR/XeGPU.h>
+#include <mlir/IR/Builders.h>
+#include <mlir/IR/DialectImplementation.h>
+#include <mlir/IR/TypeUtilities.h>
+#include <numeric>
+#include <type_traits>
+
+#define DEBUG_TYPE "xegpu"
+
+namespace mlir {
+class Token;
+
+namespace xegpu {
+
+extern bool printDefaultValues();
+
+template <typename T>
+static std::string makeString(T array, bool breakline = false) {
+ std::string buf;
+ buf.clear();
+ llvm::raw_string_ostream os(buf);
+ os << "[";
+ for (size_t i = 1; i < array.size(); i++) {
+ os << array[i - 1] << ", ";
+ if (breakline)
+ os << "\n\t\t";
+ }
+ os << array.back() << "]";
+ os.flush();
+ return buf;
+}
+
+static size_t getRankOf(Value value) {
+ if (value.getType().isIntOrIndexOrFloat())
+ return 0;
+ if (auto ty = llvm::dyn_cast_if_present<MemRefType>(value.getType()))
+ return ty.getRank();
+ if (auto ty = llvm::dyn_cast_if_present<VectorType>(value.getType()))
+ return ty.getRank();
+ llvm_unreachable("Unsupported value for getRankOf");
+}
+
+static void transpose(llvm::ArrayRef<int64_t> trans,
+ std::vector<int64_t> &shape) {
+ std::vector<int64_t> old = shape;
+ for (size_t i = 0; i < trans.size(); i++)
+ shape[i] = old[trans[i]];
+}
+
+static bool verifyAndInferShape(std::vector<int64_t> &shape,
+ SubGroupMapAttr sgMap) {
+ if (sgMap) {
+ auto wiLayout = sgMap.getWiLayout();
+ auto wiData = sgMap.getWiData();
+
+ if ((int64_t)shape.size() != wiData.size() ||
+ (int64_t)shape.size() != wiLayout.size()) {
+ return false;
+ }
+
+ for (size_t i = 0; i < shape.size(); i++) {
+
+ if ((shape[i] % (wiLayout[i] * wiData[i]) != 0 &&
+ (wiLayout[i] * wiData[i]) % shape[i] != 0) ||
+ shape[i] % wiLayout[i] != 0 || shape[i] % wiData[i] != 0) {
+ return false;
+ }
+ shape[i] /= wiLayout[i];
+ }
+ }
+
+ return true;
+}
+
+static ParseResult
+parseOptionalAttrDictWithCustomAttrs(OpAsmParser &parser,
+ OperationState &result) {
+ // no optional attributes, return success
+ if (failed(parser.parseOptionalLBrace()))
+ return success();
+
+ llvm::SmallDenseSet<StringRef, 8> seenKeys;
+ auto parseElt = [&]() -> ParseResult {
+ // The name of an attribute can either be a keyword, or a string.
+ // as compared to mlir::parseOptionalAttrList, the cases of using
+ // TOken::bare_identifier and Token::inttype as key maybe not handlered
+ std::string nameId;
+ auto loc = parser.getCurrentLocation();
+ if (parser.parseOptionalKeywordOrString(&nameId))
+ return parser.emitError(loc, "invalid attribute name: ")
+ << nameId << ".\n";
+
+ if (nameId.empty())
+ return parser.emitError(loc, "expected valid attribute name");
+
+ if (!seenKeys.insert(nameId).second)
+ return parser.emitError(loc, "duplicate key '")
+ << nameId << "' in dictionary attribute.";
+
+ // Lazy load a dialect in the context if there is a possible namespace.
+ auto splitName = StringRef(nameId).split('.');
+ if (!splitName.second.empty())
+ parser.getContext()->getOrLoadDialect(splitName.first);
+
+ // Try to parse the '=' for the attribute value.
+ if (parser.parseEqual()) {
+ // If there is no '=', it is treated as a unit attribute.
+ result.addAttribute(nameId, parser.getBuilder().getUnitAttr());
+ return success();
+ }
+
+ // for xegpu specific attributes
+ if (nameId == "mode") {
+ ModeKindAttr attr;
+ return parser.parseCustomAttributeWithFallback(attr, Type{}, nameId,
+ result.attributes);
+ } else if (nameId == "l1_hint" || nameId == "l2_hint" ||
+ nameId == "l3_hint") {
+ CacheKindAttr attr;
+ return parser.parseCustomAttributeWithFallback(attr, Type{}, nameId,
+ result.attributes);
+ } else if (nameId == "transpose") {
+ // in form of [4, 5], acctually it is a copy of DenseI63ArrayAttr::parse()
+ if (succeeded(parser.parseOptionalLSquare())) {
+ Attribute attr;
+ // handle empty list case
+ if (succeeded(parser.parseOptionalRSquare())) {
+ attr = DenseI64ArrayAttr::get(parser.getContext(), {});
+ } else {
+ attr = DenseI64ArrayAttr::parseWithoutBraces(parser, Type{});
+ if (failed(parser.parseRSquare()))
+ return failure();
+ }
+ if (!attr)
+ return failure();
+ result.addAttribute(nameId, attr);
+ return success();
+ } else {
+ // in form of array<i64: 4, 5>
+ DenseI64ArrayAttr attr;
+ return parser.parseAttribute(attr, nameId, result.attributes);
+ }
+ } else {
+ Attribute attr;
+ return parser.parseAttribute(attr, nameId, result.attributes);
+ }
+ };
+
+ if (parser.parseCommaSeparatedList(parseElt))
+ return failure();
+
+ return parser.parseRBrace();
+}
+
+//===----------------------------------------------------------------------===//
+// XeGPU_CreateNdDescOp
+//===----------------------------------------------------------------------===//
+void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
+ Type TensorDesc, Value source, ValueRange offsets,
+ ValueRange shape, ValueRange strides,
+ llvm::ArrayRef<int64_t> static_offsets,
+ ModeKind mode) {
+ auto offsetRank = static_offsets.size();
+ auto shapeRank = shape.size() ? shape.size() : getRankOf(source);
+
+ size_t dynOffsetRank =
+ std::count_if(static_offsets.begin(), static_offsets.end(),
+ [](int64_t d) { return ShapedType::isDynamic(d); });
+
+ // shape and strides should exists at the same time
+ // and the final rank for shape and offset (dynamic + static)
+ // should be the same
+ assert(shape.size() == strides.size() && shapeRank == offsetRank &&
+ offsets.size() == dynOffsetRank);
+
+ state.addOperands(source);
+ state.addOperands(offsets);
+ state.addOperands(shape);
+ state.addOperands(strides);
+ state.addAttribute(
+ getOperandSegmentSizesAttrName(state.name),
+ builder.getDenseI32ArrayAttr({1, static_cast<int32_t>(offsets.size()),
+ static_cast<int32_t>(shape.size()),
+ static_cast<int32_t>(strides.size())}));
+ state.addAttribute(getStaticOffsetsAttrName(state.name),
+ builder.getDenseI64ArrayAttr(static_offsets));
+ state.addAttribute(getModeAttrName(state.name),
+ xegpu::ModeKindAttr::get(builder.getContext(), mode));
+ state.addTypes(TensorDesc);
+}
+
+void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
+ Type tdesc, Value source,
+ llvm::ArrayRef<OpFoldResult> offsets,
+ ModeKind mode) {
+ auto ty = llvm::dyn_cast_if_present<MemRefType>(source.getType());
+ assert(ty && ty.hasStaticShape() && offsets.size() == getRankOf(source));
+
+ llvm::SmallVector<int64_t> staticOffsets;
+ llvm::SmallVector<Value> dynamicOffsets;
+ dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
+
+ build(builder, state, tdesc, source, dynamicOffsets /* dynamic offsets */,
+ ValueRange({}) /* empty dynamic shape */,
+ ValueRange({}) /* empty dynamic strides */,
+ staticOffsets /* static offsets */, mode);
+}
+
+void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
+ Type tdesc, Value source,
+ llvm::ArrayRef<OpFoldResult> offsets,
+ ValueRange shape, ValueRange stride, ModeKind mode) {
+ assert(shape.size() && offsets.size() && stride.size() &&
+ shape.size() == stride.size() && shape.size() == offsets.size());
+
+ llvm::SmallVector<int64_t> staticOffsets;
+ llvm::SmallVector<Value> dynamicOffsets;
+
+ dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
+
+ build(builder, state, tdesc, source, dynamicOffsets /* dynamic offsets */,
+ shape /* dynamic shape */, stride /* dynamic strides */,
+ staticOffsets /* static offsets */, mode);
+}
+
+ParseResult CreateNdDescOp::parse(OpAsmParser &parser, OperationState &result) {
+ // parse the source operand
+ llvm::SmallVector<OpAsmParser::UnresolvedOperand> sourceOperands(1);
+ llvm::SMLoc sourceOperandsLoc = parser.getCurrentLocation();
+ if (parser.parseOperand(sourceOperands[0]))
+ return failure();
+
+ // parse the offset operand, in format of [x, y]
+ llvm::SmallVector<OpAsmParser::UnresolvedOperand, 4> offsetsOperands;
+ DenseI64ArrayAttr static_offsetsAttr;
+ llvm::SMLoc offsetsOperandsLoc = parser.getCurrentLocation();
+ if (parseDynamicIndexList(parser, offsetsOperands, static_offsetsAttr))
+ return failure();
+ result.addAttribute("static_offsets", static_offsetsAttr);
+
+ llvm::SmallVector<OpAsmParser::UnresolvedOperand, 4> shapeOperands;
+ llvm::SMLoc shapeOperandsLoc;
+
+ llvm::SmallVector<OpAsmParser::UnresolvedOperand, 4> stridesOperands;
+ llvm::SMLoc stridesOperandsLoc;
+ // parse optional shape and strides, shape and strides should always come
+ // together
+ if (succeeded(parser.parseOptionalComma())) {
+ // parse shape part, in form of [x, y]
+ if (parser.parseLSquare())
+ return failure();
+ shapeOperandsLoc = parser.getCurrentLocation();
+ if (parser.parseOperandList(shapeOperands))
+ return failure();
+ if (parser.parseRSquare())
+ return failure();
+
+ if (parser.parseComma())
+ return failure();
+
+ // parse stride part, in form of [x, y]
+ if (parser.parseLSquare())
+ return failure();
+ stridesOperandsLoc = parser.getCurrentLocation();
+ if (parser.parseOperandList(stridesOperands))
+ return failure();
+ if (parser.parseRSquare())
+ return failure();
+ }
+
+ auto loc = parser.getCurrentLocation();
+ if (parseOptionalAttrDictWithCustomAttrs(parser, result))
+ return failure();
+
+ if (failed(verifyInherentAttrs(result.name, result.attributes, [&]() {
+ return parser.emitError(loc)
+ << "'" << result.name.getStringRef() << "' op ";
+ })))
+ return failure();
+
+ if (parser.parseColon())
+ return failure();
+
+ llvm::SmallVector<Type> sourceTypes(1);
+ if (parser.parseType(sourceTypes[0]))
+ return failure();
+
+ if (parser.parseArrow())
+ return failure();
+
+ llvm::SmallVector<Type> TensorDescTypes(1);
+ if (parser.parseType(TensorDescTypes[0]))
+ return failure();
+ result.addAttribute("operandSegmentSizes",
+ parser.getBuilder().getDenseI32ArrayAttr(
+ {1, static_cast<int32_t>(offsetsOperands.size()),
+ static_cast<int32_t>(shapeOperands.size()),
+ static_cast<int32_t>(stridesOperands.size())}));
+
+ result.addTypes(TensorDescTypes);
+ if (parser.resolveOperands(sourceOperands, sourceTypes, sourceOperandsLoc,
+ result.operands))
+ return failure();
+
+ Type indexType = parser.getBuilder().getIndexType();
+ if (parser.resolveOperands(offsetsOperands, indexType, offsetsOperandsLoc,
+ result.operands))
+ return failure();
+ if (parser.resolveOperands(shapeOperands, indexType, shapeOperandsLoc,
+ result.operands))
+ return failure();
+ if (parser.resolveOperands(stridesOperands, indexType, stridesOperandsLoc,
+ result.operands))
+ return failure();
+ return success();
+}
+
+void CreateNdDescOp::print(OpAsmPrinter &printer) {
+ auto mode = getMode();
+ auto printDefaults = printDefaultValues();
+
+ printer << ' ';
+ printer << getSource();
+ printDynamicIndexList(printer, *this, getDynamicOffsets(),
+ getStaticOffsetsAttr());
+ if (!getDynamicShape().empty()) {
+ printer << ",";
+ printer << ' ' << "[";
+ printer << getDynamicShape();
+ printer << "]";
+ }
+
+ if (!getDynamicStrides().empty()) {
+ printer << ",";
+ printer << ' ' << "[";
+ printer << getDynamicStrides();
+ printer << "]";
+ }
+
+ llvm::SmallVector<llvm::StringRef> elidedAttrs;
+ elidedAttrs.push_back("static_offsets");
+ elidedAttrs.push_back("operandSegmentSizes");
+ if (!printDefaults && mode == xegpu::ModeKind::SIMT)
+ elidedAttrs.push_back("mode");
+
+ printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
+
+ printer << ' ' << ":";
+ printer << ' ';
+ printer << getSourceType();
+ printer << ' ' << "->";
+ printer << ' ';
+ printer << getTensorDescType();
+}
+
+LogicalResult CreateNdDescOp::verify() {
+ auto mode = getMode();
+ auto isScattered = getTensorDescType().getScattered();
+ auto mapping = getTensorDescType().getMapping();
+
+ if (isScattered) {
+ return emitOpError("Encoding Attribute of TensorDesc is not expected for "
+ "non-scattered operators.\n");
+ }
+
+ if (mode == ModeKind::VC && mapping) {
+ return emitOpError("Mapping attribute of TensorDesc is not expected "
+ "for VC mode operations.\n");
+ }
+
+ if (mode == ModeKind::SIMT && !mapping) {
+ return emitOpError("Expecting SgMap attribute for SIMT mode operators.\n");
+ }
+
+ auto offsetRank = getOffsets().size();
+ auto shapeRank = getShape().size();
+ auto stridesRank = getStrides().size();
+ auto baseRank = getRankOf(getSource()) ? getRankOf(getSource()) : 2;
+
+ if (offsetRank != shapeRank || shapeRank != stridesRank ||
+ shapeRank != baseRank)
+ return emitOpError(
+ "Expecting the rank of shape, strides, offsets and memref type "
+ "should match with each other (they currently should be 2D).");
+
+ return success();
+}
+
+xegpu::TensorDescType CreateNdDescOp::getTensorDescType() {
+ return getTensorDesc().getType();
+}
+
+llvm::SmallVector<OpFoldResult> CreateNdDescOp::getOffsets() {
+ llvm::SmallVector<OpFoldResult> offsets;
+ auto dynamicOffsets = getDynamicOffsets(); // given by dynamic_offsets
+ // variable
+ auto staticOffsets = getStaticOffsets(); // given by static_offsets attribute
+
+ // in case static_offsets is missing
+ if (staticOffsets.size() == 0) {
+ offsets.assign(dynamicOffsets.begin(), dynamicOffsets.end());
+ return offsets;
+ }
+
+ for (size_t i = 0, j = 0; i < staticOffsets.size(); i++) {
+ if (ShapedType::isDynamic(staticOffsets[i])) {
+ assert(j < dynamicOffsets.size());
+ offsets.push_back(dynamicOffsets[j++]);
+ } else {
+ auto ty = IndexType::get(getContext());
+ auto attr = IntegerAttr::get(ty, staticOffsets[i]);
+ offsets.push_back(attr);
+ }
+ }
+ return offsets;
+}
+
+llvm::ArrayRef<int64_t> CreateNdDescOp::getStaticShape() {
+ auto rank = getTensorDescType().getRank();
+ static llvm::SmallVector<int64_t> dyn(rank, ShapedType::kDynamic);
+ auto srcTy = llvm::dyn_cast_if_present<MemRefType>(getSourceType());
+ if (srcTy)
+ return srcTy.getShape();
+
+ return dyn;
+}
+
+llvm::SmallVector<OpFoldResult> CreateNdDescOp::getShape() {
+ llvm::SmallVector<OpFoldResult> shape;
+ auto dynShape = getDynamicShape();
+ if (dynShape.size()) {
+ shape.append(dynShape.begin(), dynShape.end());
+ return shape;
+ }
+
+ auto ty = llvm::dyn_cast_if_present<MemRefType>(getSourceType());
+ if (ty && ty.hasStaticShape()) {
+ for (auto dim : ty.getShape()) {
+ auto attr = IntegerAttr::get(IndexType::get(getContext()), dim);
+ shape.push_back(attr);
+ }
+ return shape;
+ }
+
+ llvm_unreachable("Unexpected error in CreateNdDescOp. "
+ "The shape information is missing.\n");
+}
+
+llvm::ArrayRef<int64_t> CreateNdDescOp::getStaticStrides() {
+ auto rank = getTensorDescType().getRank();
+ static llvm::SmallVector<int64_t> dyn(rank, ShapedType::kDynamic);
+ auto srcTy = llvm::dyn_cast_if_present<MemRefType>(getSourceType());
+ if (srcTy) {
+ auto [strides, offset] = getStridesAndOffset(srcTy);
+ return strides;
+ }
+ return dyn;
+}
+
+llvm::SmallVector<OpFoldResult> CreateNdDescOp::getStrides() {
+ llvm::SmallVector<OpFoldResult> strides;
+
+ auto dynStrides = getDynamicStrides();
+ if (dynStrides.size()) {
+ strides.append(dynStrides.begin(), dynStrides.end());
+ return strides;
+ }
+
+ auto ty = llvm::dyn_cast_if_present<MemRefType>(getSourceType());
+ if (ty && ty.hasStaticShape()) {
+ auto [staticStrides, offset] = getStridesAndOffset(ty);
+ for (auto dim : staticStrides) {
+ auto attr = IntegerAttr::get(IndexType::get(getContext()), dim);
+ strides.push_back(attr);
+ }
+ return strides;
+ }
+ llvm_unreachable("Unexpected error in CreateNdDescOp. The strides "
+ "information is missing.\n");
+}
+
+/// Return the element type of the TensorDesc
+Type CreateNdDescOp::getElementType() {
+ return getTensorDescType().getElementType();
+}
+
+/// Return the shape of the TensorDesc
+llvm::ArrayRef<int64_t> CreateNdDescOp::getTensorDescShape() {
+ return getTensorDescType().getShape();
+}
+
+//===----------------------------------------------------------------------===//
+// XeGPU_LoadNDOp
+//===----------------------------------------------------------------------===//
+
+ParseResult LoadNDOp::parse(OpAsmParser &parser, OperationState &result) {
+ llvm::SmallVector<OpAsmParser::UnresolvedOperand> Operands(1);
+ llvm::SMLoc OperandsLoc = parser.getCurrentLocation();
+ if (parser.parseOperand(Operands[0]))
+ return failure();
+
+ auto loc = parser.getCurrentLocation();
+ if (parseOptionalAttrDictWithCustomAttrs(parser, result))
+ return failure();
+
+ if (failed(verifyInherentAttrs(result.name, result.attributes, [&]() {
+ return parser.emitError(loc)
+ << "'" << result.name.getStringRef() << "' op ";
+ })))
+ return failure();
+
+ if (parser.parseColon())
+ return failure();
+
+ llvm::SmallVector<Type> Types(1);
+ if (parser.parseType(Types[0]))
+ return failure();
+
+ if (parser.parseArrow())
+ return failure();
+
+ llvm::SmallVector<Type> valueTypes(1);
+ if (parser.parseType(valueTypes[0]))
+ return failure();
+
+ result.addTypes(valueTypes);
+ if (parser.resolveOperands(Operands, Types, OperandsLoc, result.operands))
+ return failure();
+
+ return success();
+}
+
+void LoadNDOp::print(OpAsmPrinter &printer) {
+ auto mode = getMode();
+ auto printDefaults = printDefaultValues();
+
+ printer << ' ';
+ printer << getTensorDesc();
+
+ llvm::SmallVector<llvm::StringRef> elidedAttrs;
+ if (!printDefaults && mode == xegpu::ModeKind::SIMT)
+ elidedAttrs.push_back("mode");
+
+ printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
+
+ printer << ' ' << ":";
+ printer << ' ';
+ printer << getTensorDesc().getType();
+ printer << ' ' << "->";
+ printer << ' ';
+ printer << getValue().getType();
+}
+
+LogicalResult LoadNDOp::verify() {
+ auto tdescTy = getTensorDescType();
+ auto valueTy = getValueType();
+
+ if (tdescTy.getRank() != 2)
+ return emitOpError(
+ "The TensorDesc for LoadNDOp should be a 2D TensorDesc.");
+
+ if (!valueTy)
+ return emitOpError("Invalid result, it should be a VectorType.\n");
+
+ auto tdescElemTy = tdescTy.getElementType();
+ auto valueElemTy = valueTy.getElementType();
+
+ if (tdescElemTy != valueElemTy)
+ return emitOpError(
+ "Value should have the same element type as TensorDesc.");
+
+ auto mode = getMode();
+ auto tdescShape = tdescTy.getShape().vec();
+ auto valueShape = valueTy.getShape().vec();
+ auto array_len = tdescTy.getArrayLength();
+
+ if (mode == ModeKind::SIMT) {
+ auto sgMap = tdescTy.getMapping();
+ if (!sgMap) {
+ return emitOpError(
+ "Expecting SgMap attribute for SIMT mode operators.\n");
+ }
+
+ if (!verifyAndInferShape(tdescShape, sgMap)) {
+ return emitOpError("Failed to infer the shape.")
+ << "The new shape[i] should meet the following condistions "
+ "for SubGroupMapAttr: "
+ << "\n\ttdescShape[i] % mma_block_size[i] == 0 (if it has) && "
+ << "\n\ttdescShape[i] % wi_layout[i] == 0 && "
+ << "\n\ttdescShape[i] % wi_data[i] == 0 && "
+ << "\n\t(tdescShape[i] % (wi_layout[i] * wi_data[i]) == 0 || "
+ << "\n\t (wi_layout[i] * wi_data[i]) % tdescShape[i] == 0).\n";
+ }
+ }
+
+ if (getTranspose()) {
+ auto trans = getTranspose().value();
+ if (tdescShape.size() >= trans.size())
+ transpose(trans, tdescShape);
+ else
+ emitWarning("Invalid transpose attr. It is ignored.");
+ }
+
+ if (getVnniAxis()) {
+ auto axis = getVnniAxis().value();
+ auto vnni_factor = valueShape.back();
+ tdescShape[axis] /= vnni_factor;
+ tdescShape.push_back(vnni_factor);
+ }
+
+ if (array_len > 1) {
+ auto it = tdescShape.begin();
+ tdescShape.insert(it, array_len);
+ }
+
+ if (tdescShape != valueShape)
+ return emitOpError("Result shape doesn't match TensorDesc shape.")
+ << "\nThe expected shape is " << makeString(tdescShape) << "."
+ << "\nBut the given shape is " << makeString(valueShape) << "."
+ << "\nIn VC mode, when VNNI is not enabled, the result should have "
+ << "the same shape (or transposed shape if transpose is enabled) "
+ << "as TensorDesc; \nwhen VNNI is enabled, the result should have "
+ << "one more dimention than the TensorDesc, with last dimention "
+ << "having vnni factor, \nbut having same number of total data "
+ << "elements. The vnni factor are typically calculated as "
+ << "simd_lane_width / elementTypeBitWidth. \nFor element type "
+ << "having more than 32 bits, vnni shouldn't be used. \nIn SIMT "
+ << "mode, the shape is derived from the mapping attributes.\n";
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// XeGPU_StoreNDOp
+//===----------------------------------------------------------------------===//
+ParseResult StoreNDOp::parse(OpAsmParser &parser, OperationState &result) {
+ llvm::SmallVector<OpAsmParser::UnresolvedOperand> Operands(2);
+ llvm::SMLoc OperandsLoc = parser.getCurrentLocation();
+ // parse value
+ if (parser.parseOperand(Operands[0]))
+ return failure();
+
+ if (parser.parseComma())
+ return failure();
+
+ // parse TensorDesc
+ if (parser.parseOperand(Operands[1]))
+ return failure();
+
+ // parse optional attributes
+ auto loc = parser.getCurrentLocation();
+ if (parseOptionalAttrDictWithCustomAttrs(parser, result))
+ return failure();
+
+ if (failed(verifyInherentAttrs(result.name, result.attributes, [&]() {
+ return parser.emitError(loc)
+ << "'" << result.name.getStringRef() << "' op ";
+ })))
+ return failure();
+
+ if (parser.parseColon())
+ return failure();
+
+ llvm::SmallVector<Type> Types;
+ if (parser.parseTypeList(Types))
+ return failure();
+
+ if (parser.resolveOperands(Operands, Types, OperandsLoc, result.operands))
+ return failure();
+
+ return success();
+}
+
+void StoreNDOp::print(OpAsmPrinter &printer) {
+ auto mode = getMode();
+ auto printDefaults = printDefaultValues();
+
+ printer << ' ';
+ printer << getValue();
+ printer << ",";
+ printer << ' ';
+ printer << getTensorDesc();
+
+ llvm::SmallVector<llvm::StringRef> elidedAttrs;
+ if (!printDefaults && mode == xegpu::ModeKind::SIMT)
+ elidedAttrs.push_back("mode");
+ printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
+
+ printer << ' ' << ":";
+ printer << ' ';
+ printer << getValue().getType();
+ printer << ",";
+ printer << ' ';
+ printer << getTensorDesc().getType();
+}
+
+LogicalResult StoreNDOp::verify() {
+ auto dstTy = getTensorDesc().getType(); // Tile
+ auto valTy = llvm::dyn_cast<VectorType>(getValue().getType()); // Vector
+
+ if (dstTy.getRank() != 2)
+ return emitOpError(
+ "The TensorDesc for StoreNdOp should be a 2D TensorDesc.");
+
+ if (!valTy)
+ return emitOpError("Invalid value operand, it should be a VectorType.\n");
+
+ auto dstElemTy = dstTy.getElementType();
+ auto valElemTy = valTy.getElementType();
+
+ if (dstElemTy != valElemTy) {
+ return emitOpError("The elem type of value (vector) shape doesn't match "
+ "the elem type of memory (dst) shape.\n");
+ }
+
+ auto mode = getMode();
+
+ if (mode == ModeKind::VC) { // for VC mode, no attr attached
+ if (dstTy.getShape() != valTy.getShape())
+ return emitOpError("In VC mode, the value (vector) shape doesn't match "
+ "the memory (dst) shape.\n");
+ } else {
+ auto mapping = dstTy.getMapping();
+ if (!mapping) {
+ return emitOpError(
+ "Expecting SgMap attribute for SIMT mode operators.\n");
+ }
+
+ SubGroupMapAttr sgMap;
+ std::vector<int64_t> shape = dstTy.getShape().vec();
+
+ sgMap = llvm::dyn_cast<SubGroupMapAttr>(mapping);
+
+ if (!verifyAndInferShape(shape, sgMap)) {
+ return emitOpError("Failed to infer the shape.")
+ << "The new shape[i] should meet the following condistions "
+ "for SubGroupMapAttr: "
+ << "\n\ttdescShape[i] % mma_block_size[i] == 0 (if it has) && "
+ << "\n\ttdescShape[i] % wi_layout[i] == 0 && "
+ << "\n\ttdescShape[i] % wi_data[i] == 0 && "
+ << "\n\t(tdescShape[i] % (wi_layout[i] * wi_data[i]) == 0 || "
+ << "\n\t (wi_layout[i] * wi_data[i]) % tdescShape[i] == 0).\n";
+ }
+
+ if (shape != valTy.getShape().vec())
+ return emitOpError(
+ "In SIMT mode, the value (vector) shape doesn't match the memory"
+ "(dst) shape as derived according to the mapping rule.\n");
+ }
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// XeGPU_PrefetchNDOp
+//===----------------------------------------------------------------------===//
+ParseResult PrefetchNDOp::parse(OpAsmParser &parser, OperationState &result) {
+ llvm::SmallVector<OpAsmParser::UnresolvedOperand> TensorDescOperands(1);
+ llvm::SmallVector<Type> TensorDescTypes(1);
+ llvm::SMLoc TensorDescOperandsLoc;
+
+ TensorDescOperandsLoc = parser.getCurrentLocation();
+ if (parser.parseOperand(TensorDescOperands[0]))
+ return failure();
+
+ auto loc = parser.getCurrentLocation();
+ if (parseOptionalAttrDictWithCustomAttrs(parser, result))
+ return failure();
+
+ if (failed(verifyInherentAttrs(result.name, result.attributes, [&]() {
+ return parser.emitError(loc)
+ << "'" << result.name.getStringRef() << "' op ";
+ })))
+ return failure();
+
+ if (parser.parseColon())
+ return failure();
+
+ if (parser.parseType(TensorDescTypes[0]))
+ return failure();
+ if (parser.resolveOperands(TensorDescOperands, TensorDescTypes,
+ TensorDescOperandsLoc, result.operands))
+ return failure();
+ return success();
+}
+
+void PrefetchNDOp::print(OpAsmPrinter &printer) {
+ auto mode = getMode();
+ auto printDefaults = printDefaultValues();
+
+ printer << ' ';
+ printer << getTensorDesc();
+
+ llvm::SmallVector<llvm::StringRef> elidedAttrs;
+ if (!printDefaults && mode == xegpu::ModeKind::SIMT)
+ elidedAttrs.push_back("mode");
+ printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
+
+ printer << ' ' << ":";
+ printer << ' ';
+ printer << getTensorDesc().getType();
+}
+
+//===----------------------------------------------------------------------===//
+// XeGPU_UpdateNDOffsetOp
+//===----------------------------------------------------------------------===//
+ParseResult UpdateNDOffsetOp::parse(OpAsmParser &parser,
+ OperationState &result) {
+ llvm::SmallVector<OpAsmParser::UnresolvedOperand> TensorDescOperands(1);
+ llvm::SmallVector<OpAsmParser::UnresolvedOperand, 4> offsetsOperands;
+ llvm::SmallVector<Type> TensorDescTypes(1);
+ llvm::SmallVector<Type> resultTypes(1);
+ llvm::SMLoc TensorDescOperandsLoc;
+ llvm::SMLoc offsetsOperandsLoc;
+
+ TensorDescOperandsLoc = parser.getCurrentLocation();
+ if (parser.parseOperand(TensorDescOperands[0]))
+ return failure();
+ if (parser.parseComma())
+ return failure();
+
+ // parse offsets, e.g., [x, y]
+ if (succeeded(parser.parseOptionalLSquare())) {
+ offsetsOperandsLoc = parser.getCurrentLocation();
+ if (parser.parseOperandList(offsetsOperands))
+ return failure();
+ if (parser.parseRSquare())
+ return failure();
+ }
+
+ if (parseOptionalAttrDictWithCustomAttrs(parser, result))
+ return failure();
+
+ auto loc = parser.getCurrentLocation();
+ if (failed(verifyInherentAttrs(result.name, result.attributes, [&]() {
+ return parser.emitError(loc)
+ << "'" << result.name.getStringRef() << "' op ";
+ })))
+ return failure();
+
+ if (parser.parseColon())
+ return failure();
+
+ if (parser.parseType(TensorDescTypes[0]))
+ return failure();
+ if (parser.parseArrow())
+ return failure();
+
+ if (parser.parseType(resultTypes[0]))
+ return failure();
+ result.addTypes(resultTypes);
+ if (parser.resolveOperands(TensorDescOperands, TensorDescTypes,
+ TensorDescOperandsLoc, result.operands))
+ return failure();
+
+ Type indexType = parser.getBuilder().getIndexType();
+ if (parser.resolveOperands(offsetsOperands, indexType, offsetsOperandsLoc,
+ result.operands))
+ return failure();
+ return success();
+}
+
+void UpdateNDOffsetOp::print(OpAsmPrinter &printer) {
+ auto mode = getMode();
+ auto printDefaults = printDefaultValues();
+
+ printer << ' ';
+ printer << getTensorDesc();
+ printer << ",";
+ if (!getOffsets().empty()) {
+ printer << ' ' << "[";
+ printer << getOffsets();
+ printer << "]";
+ }
+
+ llvm::SmallVector<llvm::StringRef> elidedAttrs;
+ if (!printDefaults && mode == xegpu::ModeKind::SIMT)
+ elidedAttrs.push_back("mode");
+ printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
+
+ printer << ' ' << ":";
+ printer << ' ';
+ printer << getTensorDesc().getType();
+ printer << ' ' << "->";
+ printer << ' ';
+ printer << getResult().getType();
+}
+
+LogicalResult UpdateNDOffsetOp::verify() {
+ // number of offsets specified must match the rank of the tensor descriptor
+ if (getTensorDesc().getType().getRank() != (int64_t)getOffsets().size()) {
+ return emitOpError("Invalid number of offsets.");
+ }
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// XeGPU_CreateDescOp
+//===----------------------------------------------------------------------===//
+void CreateDescOp::build(OpBuilder &builder, OperationState &state,
+ TensorDescType TensorDesc, Value source, Value offsets,
+ uint32_t chunk_size_per_lane) {
+ state.addOperands(source);
+ state.addOperands(offsets);
+ state.getOrAddProperties<Properties>().chunk_size_per_lane =
+ builder.getIntegerAttr(builder.getIntegerType(32), chunk_size_per_lane);
+ state.getOrAddProperties<Properties>().mode =
+ ModeKindAttr::get(builder.getContext(), ModeKind::VC);
+ state.addTypes(TensorDesc);
+}
+
+void CreateDescOp::build(OpBuilder &builder, OperationState &state,
+ TensorDescType TensorDesc, Value source, Value offsets,
+ IntegerAttr chunk_size_per_lane) {
+ state.addOperands(source);
+ state.addOperands(offsets);
+ if (chunk_size_per_lane)
+ state.getOrAddProperties<Properties>().chunk_size_per_lane =
+ chunk_size_per_lane;
+ state.getOrAddProperties<Properties>().mode =
+ ModeKindAttr::get(builder.getContext(), ModeKind::VC);
+ state.addTypes(TensorDesc);
+}
+
+ParseResult CreateDescOp::parse(OpAsmParser &parser, OperationState &result) {
+ llvm::SmallVector<OpAsmParser::UnresolvedOperand> Operands(2);
+ llvm::SmallVector<Type> Types(2);
+ llvm::SMLoc operandsLoc = parser.getCurrentLocation();
+ // parse the source operand
+ if (parser.parseOperand(Operands[0]))
+ return failure();
+
+ if (parser.parseComma())
+ return failure();
+
+ // parse the offset operand
+ if (parser.parseOperand(Operands[1]))
+ return failure();
+
+ // parse the optional attributes
+ auto loc = parser.getCurrentLocation();
+ if (parseOptionalAttrDictWithCustomAttrs(parser, result))
+ return failure();
+
+ if (failed(verifyInherentAttrs(result.name, result.attributes, [&]() {
+ return parser.emitError(loc)
+ << "'" << result.name.getStringRef() << "' op ";
+ })))
+ return failure();
+
+ if (parser.parseColon())
+ return failure();
+
+ if (parser.parseType(Types[0]))
+ return failure();
+ if (parser.parseComma())
+ return failure();
+
+ if (parser.parseType(Types[1]))
+ return failure();
+ if (parser.parseArrow())
+ return failure();
+
+ llvm::SmallVector<Type> TensorDescTypes(1);
+ if (parser.parseType(TensorDescTypes[0]))
+ return failure();
+
+ result.addTypes(TensorDescTypes);
+ if (parser.resolveOperands(Operands, Types, operandsLoc, result.operands))
+ return failure();
+ return success();
+}
+
+void CreateDescOp::print(OpAsmPrinter &printer) {
+ auto mode = getMode();
+ auto chunk = getChunkSizePerLane();
+ auto printDefaults = printDefaultValues();
+
+ printer << ' ';
+ printer << getSource();
+ printer << ",";
+ printer << ' ';
+ printer << getOffsets();
+
+ llvm::SmallVector<llvm::StringRef> elidedAttrs;
+ if (!printDefaults) {
+ if (mode == xegpu::ModeKind::SIMT)
+ elidedAttrs.push_back("mode");
+ if (chunk == 1)
+ elidedAttrs.push_back("chunk_size_per_lane");
+ }
+ printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
+
+ printer << ' ' << ":";
+ printer << ' ';
+ printer << getSource().getType();
+ printer << ",";
+ printer << ' ';
+ printer << getOffsets().getType();
+ printer << ' ' << "->";
+ printer << ' ';
+ printer << getTensorDesc().getType();
+}
+
+LogicalResult CreateDescOp::verify() {
+ auto mode = getMode();
+ auto mapping = getTensorDesc().getType().getMapping();
+ auto offsetTy = getOffsets().getType();
+ auto tdescTy = getTensorDesc().getType();
+ auto chunkSize = getChunkSizePerLane();
+
+ if (mode == ModeKind::SIMT || mapping) {
+ return emitOpError("CreateDescOp only support VC mode and mapping "
+ "attribute of TensorDesc is not expected.\n");
+ }
+
+ if (getRankOf(getSource()) > 2)
+ return emitOpError(
+ "Expecting the source is a 1D/2D memref or pointer (uint64_t).");
+
+ if (!tdescTy.getScattered())
+ return emitOpError(
+ "Expecting the presence of ScatteredAttr for tensor descriptor.");
+
+ // Infer the TensorDesc shape
+ std::vector<int64_t> shape;
+ if (llvm::isa<VectorType>(offsetTy)) {
+ shape = llvm::dyn_cast<VectorType>(offsetTy).getShape().vec();
+ if (shape.size() != 1)
+ return emitOpError("Expecting the offset is a 1D vector.");
+ }
+
+ if (chunkSize != 1) {
+ shape.push_back(chunkSize);
+ }
+
+ auto tdescShape = tdescTy.getShape();
+ if (shape != tdescShape.vec()) {
+ return emitOpError("Expecting dimensions of offsets is the same as the "
+ "tensor descriptor, or one less than.");
+ }
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// XeGPU_LoadGatherOp
+//===----------------------------------------------------------------------===//
+void LoadGatherOp::build(OpBuilder &builder, OperationState &state, Type value,
+ Value TensorDesc, Value mask, IntegerAttr vnni_axis,
+ DenseI64ArrayAttr transpose, CacheKindAttr l1_hint,
+ CacheKindAttr l2_hint, CacheKindAttr l3_hint) {
+ state.addOperands(TensorDesc);
+ state.addOperands(mask);
+ if (vnni_axis)
+ state.getOrAddProperties<Properties>().vnni_axis = vnni_axis;
+
+ if (transpose)
+ state.getOrAddProperties<Properties>().transpose = transpose;
+
+ if (l1_hint)
+ state.getOrAddProperties<Properties>().l1_hint = l1_hint;
+
+ if (l2_hint)
+ state.getOrAddProperties<Properties>().l2_hint = l2_hint;
+
+ if (l3_hint)
+ state.getOrAddProperties<Properties>().l3_hint = l3_hint;
+
+ state.getOrAddProperties<Properties>().mode =
+ ModeKindAttr::get(builder.getContext(), ModeKind::VC);
+ state.addTypes(value);
+}
+
+void LoadGatherOp::build(OpBuilder &builder, OperationState &state, Type value,
+ Value TensorDesc, Value mask, IntegerAttr vnni_axis,
+ DenseI64ArrayAttr transpose, CacheKind l1_hint,
+ CacheKind l2_hint, CacheKind l3_hint) {
+ state.addOperands(TensorDesc);
+ state.addOperands(mask);
+ if (vnni_axis)
+ state.getOrAddProperties<Properties>().vnni_axis = vnni_axis;
+
+ if (transpose)
+ state.getOrAddProperties<Properties>().transpose = transpose;
+
+ state.getOrAddProperties<Properties>().l1_hint =
+ CacheKindAttr::get(builder.getContext(), l1_hint);
+ state.getOrAddProperties<Properties>().l2_hint =
+ CacheKindAttr::get(builder.getContext(), l2_hint);
+ state.getOrAddProperties<Properties>().l3_hint =
+ CacheKindAttr::get(builder.getContext(), l3_hint);
+ state.getOrAddProperties<Properties>().mode =
+ ModeKindAttr::get(builder.getContext(), ModeKind::VC);
+ state.addTypes(value);
+}
+
+ParseResult LoadGatherOp::parse(OpAsmParser &parser, OperationState &result) {
+ llvm::SmallVector<OpAsmParser::UnresolvedOperand> Operands(2);
+ llvm::SmallVector<Type> Types(2);
+ llvm::SmallVector<Type> valueTypes(1);
+ llvm::SMLoc OperandsLoc;
+
+ OperandsLoc = parser.getCurrentLocation();
+ if (parser.parseOperand(Operands[0]))
+ return failure();
+
+ if (parser.parseComma())
+ return failure();
+
+ if (parser.parseOperand(Operands[1]))
+ return failure();
+
+ auto loc = parser.getCurrentLocation();
+ if (parseOptionalAttrDictWithCustomAttrs(parser, result))
+ return failure();
+ if (failed(verifyInherentAttrs(result.name, result.attributes, [&]() {
+ return parser.emitError(loc)
+ << "'" << result.name.getStringRef() << "' op ";
+ })))
+ return failure();
+
+ if (parser.parseColon())
+ return failure();
+
+ if (parser.parseType(Types[0]))
+ return failure();
+
+ if (parser.parseComma())
+ return failure();
+
+ if (parser.parseType(Types[1]))
+ return failure();
+
+ if (parser.parseArrow())
+ return failure();
+
+ if (parser.parseType(valueTypes[0]))
+ return failure();
+
+ result.addTypes(valueTypes);
+
+ if (parser.resolveOperands(Operands, Types, OperandsLoc, result.operands))
+ return failure();
+
+ return success();
+}
+
+void LoadGatherOp::print(OpAsmPrinter &printer) {
+ auto mode = getMode();
+ auto printDefaults = printDefaultValues();
+
+ printer << ' ';
+ printer << getTensorDesc();
+ printer << ",";
+ printer << ' ';
+ printer << getMask();
+
+ llvm::SmallVector<llvm::StringRef> elidedAttrs;
+ if (!printDefaults && mode == xegpu::ModeKind::SIMT)
+ elidedAttrs.push_back("mode");
+ printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
+
+ printer << ' ' << ":";
+ printer << ' ';
+ printer << getTensorDesc().getType();
+ printer << ",";
+ printer << ' ';
+ printer << getMask().getType();
+ printer << ' ' << "->";
+ printer << ' ';
+ printer << getValue().getType();
+}
+
+LogicalResult LoadGatherOp::verify() {
+ auto tdescTy = getTensorDesc().getType();
+ auto maskTy = getMask().getType();
+ auto valueTy = getValue().getType();
+
+ if (!tdescTy.getScattered())
+ return emitOpError(
+ "LoadGatherOp only works on TensorDesc with ScatteredAttr.");
+
+ auto getElementType = [&](Type type) -> Type {
+ if (type.isIntOrIndexOrFloat())
+ return type;
+ else if (llvm::isa<VectorType>(type))
+ return llvm::dyn_cast<VectorType>(type).getElementType();
+ else if (llvm::isa<TensorDescType>(type))
+ return llvm::dyn_cast<TensorDescType>(type).getElementType();
+ llvm_unreachable("Unsupported type.");
+ return type;
+ };
+
+ auto tdescElemTy = getElementType(tdescTy);
+ auto valueElemTy = getElementType(valueTy);
+ if (tdescElemTy != valueElemTy)
+ return emitOpError(
+ "Value should have the same element type as TensorDesc.");
+
+ auto getShape = [&](Type type) -> std::vector<int64_t> {
+ std::vector<int64_t> shape;
+ if (type.isIntOrIndexOrFloat())
+ shape.push_back(1);
+ else if (llvm::isa<VectorType>(type))
+ shape = llvm::dyn_cast<VectorType>(type).getShape().vec();
+ else
+ llvm_unreachable("Unsupported type.");
+ return shape;
+ };
+
+ std::vector<int64_t> maskShape = getShape(maskTy);
+ std::vector<int64_t> valueShape = getShape(valueTy);
+ std::vector<int64_t> tdescShape = tdescTy.getShape().vec();
+
+ if (tdescShape != maskShape)
+ return emitOpError("Mask should have the same shape as TensorDesc.");
+
+ auto mode = getMode();
+ auto mapping = tdescTy.getMapping();
+ if (mode == ModeKind::SIMT || mapping) {
+ return emitOpError("LoadGatherOp only supports VC mode and mapping "
+ "attribute of TensorDesc is not expected.\n");
+ }
+
+ if (getTransposeAttr()) {
+ auto trans = getTranspose().value();
+ if (tdescShape.size() < trans.size())
+ return emitWarning("Invalid transpose attr. It is ignored.");
+ transpose(trans, tdescShape);
+ }
+
+ if (getVnniAxis()) {
+ auto axis = getVnniAxis().value();
+ auto vnni_factor = valueShape.back();
+ tdescShape[axis] /= vnni_factor;
+ tdescShape.push_back(vnni_factor);
+ }
+
+ if (valueShape != tdescShape)
+ return emitOpError(
+ "Result shape doesn't match TensorDesc shape. when VNNI is not enabled,"
+ "the result should have the same shape (or transposed shape if "
+ "transpose is also enabled) as TensorDesc. When VNNI is enabled, "
+ "the result should have one more dimention than the TensorDesc, "
+ "with last dimention having vnni factor, but having same number of"
+ "total data elements. The vnni factor are typically calculated as "
+ "simd_lane_width/elementTypeBitWidth. For element type having "
+ "more than 32 bits, vnni shouldn't be used.\n");
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// XeGPU_StoreScatterOp
+//===----------------------------------------------------------------------===//
+void StoreScatterOp::build(OpBuilder &builder, OperationState &state,
+ Value value, Value TensorDesc, Value mask,
+ CacheKindAttr l1_hint, CacheKindAttr l2_hint,
+ CacheKindAttr l3_hint) {
+ state.addOperands(value);
+ state.addOperands(TensorDesc);
+ state.addOperands(mask);
+ if (l1_hint)
+ state.getOrAddProperties<Properties>().l1_hint = l1_hint;
+ if (l2_hint)
+ state.getOrAddProperties<Properties>().l2_hint = l2_hint;
+ if (l3_hint)
+ state.getOrAddProperties<Properties>().l3_hint = l3_hint;
+ state.getOrAddProperties<Properties>().mode =
+ ModeKindAttr::get(builder.getContext(), ModeKind::VC);
+}
+
+void StoreScatterOp::build(OpBuilder &builder, OperationState &state,
+ Value value, Value TensorDesc, Value mask,
+ CacheKind l1_hint, CacheKind l2_hint,
+ CacheKind l3_hint) {
+ state.addOperands(value);
+ state.addOperands(TensorDesc);
+ state.addOperands(mask);
+ state.getOrAddProperties<Properties>().l1_hint =
+ CacheKindAttr::get(builder.getContext(), l1_hint);
+ state.getOrAddProperties<Properties>().l2_hint =
+ CacheKindAttr::get(builder.getContext(), l2_hint);
+ ;
+ state.getOrAddProperties<Properties>().l3_hint =
+ CacheKindAttr::get(builder.getContext(), l3_hint);
+ ;
+ state.getOrAddProperties<Properties>().mode =
+ ModeKindAttr::get(builder.getContext(), ModeKind::VC);
+}
+
+ParseResult StoreScatterOp::parse(OpAsmParser &parser, OperationState &result) {
+ llvm::SmallVector<OpAsmParser::UnresolvedOperand> Operands;
+ llvm::SmallVector<Type> Types;
+ llvm::SMLoc OperandsLoc;
+
+ OperandsLoc = parser.getCurrentLocation();
+ if (parser.parseOperandList(Operands))
+ return failure();
+
+ auto loc = parser.getCurrentLocation();
+ if (parseOptionalAttrDictWithCustomAttrs(parser, result))
+ return failure();
+ if (failed(verifyInherentAttrs(result.name, result.attributes, [&]() {
+ return parser.emitError(loc)
+ << "'" << result.name.getStringRef() << "' op ";
+ })))
+ return failure();
+
+ if (parser.parseColon())
+ return failure();
+
+ if (parser.parseTypeList(Types))
+ return failure();
+
+ if (parser.resolveOperands(Operands, Types, OperandsLoc, result.operands))
+ return failure();
+
+ return success();
+}
+
+void StoreScatterOp::print(OpAsmPrinter &printer) {
+ auto mode = getMode();
+ auto printDefaults = printDefaultValues();
+
+ printer << ' ';
+ printer << getValue();
+ printer << ",";
+ printer << ' ';
+ printer << getTensorDesc();
+ printer << ",";
+ printer << ' ';
+ printer << getMask();
+
+ llvm::SmallVector<llvm::StringRef> elidedAttrs;
+ if (!printDefaults && mode == xegpu::ModeKind::SIMT)
+ elidedAttrs.push_back("mode");
+ printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
+
+ printer << ' ' << ":";
+ printer << ' ';
+ printer << getValue().getType();
+ printer << ",";
+ printer << ' ';
+ printer << getTensorDesc().getType();
+ printer << ",";
+ printer << ' ';
+ printer << getMask().getType();
+}
+
+LogicalResult StoreScatterOp::verify() {
+ auto tdescTy = getTensorDesc().getType();
+ auto valueTy = getValue().getType();
+ auto maskTy = getMask().getType();
+ auto mode = getMode();
+ auto mapping = tdescTy.getMapping();
+
+ if (mode != ModeKind::VC || mapping)
+ return emitOpError("StoreScatterOp only supports VC mode and mapping "
+ "attribute of TensorDesc is not expected.\n");
+
+ if (!tdescTy.getScattered())
+ return emitOpError("Invalid TensorDesc. StoreScatterOp only works on "
+ "TensorDescs with ScatteredAttr.");
+
+ auto getShape = [&](Type type) -> std::vector<int64_t> {
+ std::vector<int64_t> shape;
+ if (type.isIntOrIndexOrFloat())
+ shape.push_back(1);
+ else if (llvm::isa<VectorType>(type))
+ shape = llvm::dyn_cast<VectorType>(type).getShape().vec();
+ else
+ llvm_unreachable("Unsupported type.");
+ return shape;
+ };
+
+ std::vector<int64_t> maskShape = getShape(maskTy);
+ std::vector<int64_t> valueShape = getShape(valueTy);
+ std::vector<int64_t> tdescShape = tdescTy.getShape().vec();
+
+ if (valueShape != maskShape) {
+ return emitOpError("Mask and value should have the same shape/size");
+ }
+
+ if (tdescShape != valueShape) {
+ return emitOpError("TensorDesc shape and value shape doesn't match. ")
+ << "The expected/derived value shape is: " << makeString(tdescShape)
+ << ".\nMask and value should have the same shape/size as "
+ "TensorDesc.\n";
+ }
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// XeGPU_PrefetchOp
+//===----------------------------------------------------------------------===//
+void PrefetchOp::build(OpBuilder &builder, OperationState &state,
+ Value TensorDesc, CacheKindAttr l1_hint,
+ CacheKindAttr l2_hint, CacheKindAttr l3_hint) {
+ state.addOperands(TensorDesc);
+ if (l1_hint)
+ state.getOrAddProperties<Properties>().l1_hint = l1_hint;
+
+ if (l2_hint)
+ state.getOrAddProperties<Properties>().l2_hint = l2_hint;
+
+ if (l3_hint)
+ state.getOrAddProperties<Properties>().l3_hint = l3_hint;
+
+ state.getOrAddProperties<Properties>().mode =
+ ModeKindAttr::get(builder.getContext(), ModeKind::VC);
+}
+
+void PrefetchOp::build(OpBuilder &builder, OperationState &state,
+ Value TensorDesc, CacheKind l1_hint, CacheKind l2_hint,
+ CacheKind l3_hint) {
+ state.addOperands(TensorDesc);
+ state.getOrAddProperties<Properties>().l1_hint =
+ CacheKindAttr::get(builder.getContext(), l1_hint);
+ state.getOrAddProperties<Properties>().l2_hint =
+ CacheKindAttr::get(builder.getContext(), l2_hint);
+ state.getOrAddProperties<Properties>().l3_hint =
+ CacheKindAttr::get(builder.getContext(), l3_hint);
+ state.getOrAddProperties<Properties>().mode =
+ ModeKindAttr::get(builder.getContext(), ModeKind::VC);
+}
+
+ParseResult PrefetchOp::parse(OpAsmParser &parser, OperationState &result) {
+ llvm::SmallVector<OpAsmParser::UnresolvedOperand> TensorDescOperands(1);
+ llvm::SmallVector<Type> TensorDescTypes(1);
+ llvm::SMLoc TensorDescOperandsLoc;
+
+ TensorDescOperandsLoc = parser.getCurrentLocation();
+ if (parser.parseOperand(TensorDescOperands[0]))
+ return failure();
+
+ auto loc = parser.getCurrentLocation();
+ if (parseOptionalAttrDictWithCustomAttrs(parser, result))
+ return failure();
+ if (failed(verifyInherentAttrs(result.name, result.attributes, [&]() {
+ return parser.emitError(loc)
+ << "'" << result.name.getStringRef() << "' op ";
+ })))
+ return failure();
+
+ if (parser.parseColon())
+ return failure();
+
+ if (parser.parseType(TensorDescTypes[0]))
+ return failure();
+
+ if (parser.resolveOperands(TensorDescOperands, TensorDescTypes,
+ TensorDescOperandsLoc, result.operands))
+ return failure();
+ return success();
+}
+
+void PrefetchOp::print(OpAsmPrinter &printer) {
+ auto mode = getMode();
+ auto printDefaults = printDefaultValues();
+
+ printer << ' ';
+ printer << getTensorDesc();
+
+ llvm::SmallVector<llvm::StringRef> elidedAttrs;
+ if (!printDefaults && mode == xegpu::ModeKind::SIMT)
+ elidedAttrs.push_back("mode");
+ printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
+
+ printer << ' ' << ":";
+ printer << ' ';
+ printer << getTensorDesc().getType();
+}
+
+LogicalResult PrefetchOp::verify() {
+ auto mode = getMode();
+ auto tdescTy = getTensorDesc().getType();
+ auto mapping = tdescTy.getMapping();
+
+ auto isValidHint = [&](CacheKindAttr attr) -> bool {
+ if (!attr)
+ return true;
+ auto kind = attr.getValue();
+ return kind == CacheKind::CACHED || kind == CacheKind::UNCACHED ||
+ kind == CacheKind::STREAMING || kind == CacheKind::READ_INVALIDATE;
+ };
+
+ if (!isValidHint(getL1HintAttr()))
+ return emitOpError("invlid l1_hint: ") << getL1HintAttr();
+
+ if (!isValidHint(getL2HintAttr()))
+ return emitOpError("invlid l2_hint: ") << getL2HintAttr();
+
+ if (!isValidHint(getL3HintAttr()))
+ return emitOpError("invlid l3_hint: ") << getL3HintAttr();
+
+ if (!tdescTy.getScattered())
+ return emitOpError("Invalid TensorDesc. PrefetchOp only works on "
+ "TensorDescs with ScatteredAttr.");
+
+ if (mode != ModeKind::VC || mapping) {
+ return emitOpError("PrefetchOp only supports VC mode, and mapping "
+ "attribute of TensorDesc is not expected.\n");
+ }
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// XeGPU_UpdateOffsetOp
+//===----------------------------------------------------------------------===//
+void UpdateOffsetOp::build(OpBuilder &builder, OperationState &state,
+ Type result, Value TensorDesc, Value offsets) {
+ state.addOperands(TensorDesc);
+ state.addOperands(offsets);
+ state.getOrAddProperties<Properties>().mode =
+ xegpu::ModeKindAttr::get(builder.getContext(), xegpu::ModeKind::VC);
+ state.addTypes(result);
+}
+
+ParseResult UpdateOffsetOp::parse(OpAsmParser &parser, OperationState &result) {
+ llvm::SmallVector<OpAsmParser::UnresolvedOperand> Operands;
+ llvm::SmallVector<Type> Types;
+
+ auto OperandsLoc = parser.getCurrentLocation();
+ if (parser.parseOperandList(Operands))
+ return failure();
+
+ auto loc = parser.getCurrentLocation();
+ if (parseOptionalAttrDictWithCustomAttrs(parser, result))
+ return failure();
+ if (failed(verifyInherentAttrs(result.name, result.attributes, [&]() {
+ return parser.emitError(loc)
+ << "'" << result.name.getStringRef() << "' op ";
+ })))
+ return failure();
+
+ if (parser.parseColon())
+ return failure();
+
+ if (parser.parseTypeList(Types))
+ return failure();
+
+ if (parser.parseArrow())
+ return failure();
+
+ llvm::SmallVector<Type> resultTypes(1);
+ if (parser.parseType(resultTypes[0]))
+ return failure();
+ result.addTypes(resultTypes);
+
+ if (parser.resolveOperands(Operands, Types, OperandsLoc, result.operands))
+ return failure();
+ return success();
+}
+
+void UpdateOffsetOp::print(OpAsmPrinter &printer) {
+ auto mode = getMode();
+ auto printDefaults = printDefaultValues();
+
+ printer << ' ';
+ printer << getTensorDesc();
+ printer << ",";
+ printer << ' ';
+ printer << getOffsets();
+
+ llvm::SmallVector<llvm::StringRef> elidedAttrs;
+ if (!printDefaults && mode == xegpu::ModeKind::SIMT)
+ elidedAttrs.push_back("mode");
+ printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
+ printer << ' ' << ":";
+ printer << ' ';
+ printer << getTensorDesc().getType();
+ printer << ",";
+ printer << ' ';
+ printer << getOffsets().getType();
+ printer << ' ' << "->";
+ printer << ' ';
+ printer << getResult().getType();
+}
+
+LogicalResult UpdateOffsetOp::verify() {
+ auto mode = getMode();
+ if (mode != ModeKind::VC)
+ return emitOpError("UpdateOffsetOp only work on VC mode.\n");
+
+ auto srcTy = getTensorDesc().getType();
+ auto resTy = getResult().getType();
+ if (srcTy != resTy)
+ return emitOpError("The result should have the same type (shape and "
+ "encoding attribute) as the input TensorDesc.");
+
+ if (!srcTy.getScattered()) {
+ return emitOpError("Invalid TensorDesc. UpdateOffsetOp only works on "
+ "TensorDescs with ScatteredAttr.");
+ }
+
+ auto offTy = llvm::dyn_cast<VectorType>(getOffsets().getType());
+ if (!offTy || offTy.getRank() != 1)
+ return emitOpError("The offset should be an 1D vector.\n");
+
+ auto shape = srcTy.getShape();
+ if (shape[0] != offTy.getShape()[0])
+ return emitOpError(
+ "The offset should have same length as the dim-0 of TensorDesc.");
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// XeGPU_DpasOp
+//===----------------------------------------------------------------------===//
+ParseResult DpasOp::parse(OpAsmParser &parser, OperationState &result) {
+ llvm::SmallVector<OpAsmParser::UnresolvedOperand> Operands;
+ llvm::SmallVector<Type> Types;
+
+ llvm::SMLoc OperandsLoc = parser.getCurrentLocation();
+ if (parser.parseOperandList(Operands))
+ return failure();
+
+ auto loc = parser.getCurrentLocation();
+ if (parseOptionalAttrDictWithCustomAttrs(parser, result))
+ return failure();
+ if (failed(verifyInherentAttrs(result.name, result.attributes, [&]() {
+ return parser.emitError(loc)
+ << "'" << result.name.getStringRef() << "' op ";
+ })))
+ return failure();
+
+ if (parser.parseColon())
+ return failure();
+
+ if (parser.parseTypeList(Types))
+ return failure();
+
+ if (parser.parseArrow())
+ return failure();
+
+ llvm::SmallVector<Type> resultTypes(1);
+ if (parser.parseType(resultTypes[0]))
+ return failure();
+ result.addTypes(resultTypes);
+
+ if (parser.resolveOperands(Operands, Types, OperandsLoc, result.operands))
+ return failure();
+
+ return success();
+}
+
+void DpasOp::print(OpAsmPrinter &printer) {
+ auto mode = getMode();
+ auto printDefaults = printDefaultValues();
+
+ printer << ' ';
+ printer << getLhs();
+ printer << ",";
+ printer << ' ';
+ printer << getRhs();
+ if (Value value = getAcc())
+ printer << ", " << value;
+
+ llvm::SmallVector<llvm::StringRef, 2> elidedAttrs;
+ if (!printDefaults && mode == xegpu::ModeKind::SIMT)
+ elidedAttrs.push_back("mode");
+
+ printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
+ printer << ' ' << ":";
+ printer << ' ';
+ printer << getLhs().getType();
+ printer << ",";
+ printer << ' ';
+ printer << getRhs().getType();
+ if (getAcc()) {
+ printer << ",";
+ printer << ' ';
+ printer << llvm::ArrayRef<Type>(getAcc().getType());
+ }
+ printer << ' ' << "->";
+ printer << ' ';
+ printer << getResult().getType();
+}
+
+LogicalResult DpasOp::verify() {
+ int64_t lhsRank = getLhsType().getRank();
+ int64_t rhsRank = getRhsType().getRank();
+ Type lhsElemType = getLhsType().getElementType();
+ Type rhsElemType = getRhsType().getElementType();
+
+ if (lhsElemType != rhsElemType)
+ return emitOpError("lhs and rhs element type does not match for dpas op");
+
+ if (getAcc() && getAccType() != getResultType())
+ return emitOpError("Accumulator and Result for dpas op should have the "
+ "same type (both shape and element type).");
+
+ if (lhsRank != rhsRank || lhsRank != 3)
+ return emitOpError(
+ "lhs and rhs rank does not match for dpas op, or their rank is not 3.");
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// XeGPU_InvokeSIMDOp
+//===----------------------------------------------------------------------===//
+void InvokeSIMDOp::build(OpBuilder &builder, OperationState &state,
+ SymbolRefAttr callee, TypeRange results,
+ ArgTypeKindAttr argType, ValueRange operands) {
+ state.addOperands(operands);
+ state.addAttribute("argType", argType);
+ state.addAttribute("callee", callee);
+ state.addTypes(results);
+}
+
+void InvokeSIMDOp::build(OpBuilder &builder, OperationState &state,
+ StringAttr callee, TypeRange results,
+ ArgTypeKindAttr argType, ValueRange operands) {
+ build(builder, state, SymbolRefAttr::get(callee), results, argType, operands);
+}
+
+void InvokeSIMDOp::build(OpBuilder &builder, OperationState &state,
+ llvm::StringRef callee, TypeRange results,
+ ArgTypeKindAttr argType, ValueRange operands) {
+ build(builder, state, StringAttr::get(builder.getContext(), callee), results,
+ argType, operands);
+}
+
+//===----------------------------------------------------------------------===//
+// XeGPU_AtomicRMWOp
+//===----------------------------------------------------------------------===//
+void AtomicRMWOp::build(OpBuilder &builder, OperationState &state, Type result,
+ AtomicRMWKindAttr kind, Value tensorDesc, Value mask,
+ Value value) {
+ state.addOperands(tensorDesc);
+ state.addOperands(mask);
+ if (value)
+ state.addOperands(value);
+ state.getOrAddProperties<Properties>().kind = kind;
+ state.getOrAddProperties<Properties>().mode =
+ ModeKindAttr::get(builder.getContext(), ModeKind::VC);
+ state.addTypes(result);
+}
+
+void AtomicRMWOp::build(OpBuilder &builder, OperationState &state, Type result,
+ AtomicRMWKind kind, Value tensorDesc, Value mask,
+ Value value) {
+ state.addOperands(tensorDesc);
+ state.addOperands(mask);
+ if (value)
+ state.addOperands(value);
+ state.getOrAddProperties<Properties>().kind =
+ AtomicRMWKindAttr::get(builder.getContext(), kind);
+ state.getOrAddProperties<Properties>().mode =
+ ModeKindAttr::get(builder.getContext(), ModeKind::VC);
+ state.addTypes(result);
+}
+
+ParseResult AtomicRMWOp::parse(OpAsmParser &parser, OperationState &result) {
+ llvm::SmallVector<OpAsmParser::UnresolvedOperand> Operands;
+ llvm::SmallVector<Type, 1> Types;
+ llvm::SMLoc OperandsLoc;
+
+ llvm::SmallVector<Type> resultTypes(1);
+
+ xegpu::AtomicRMWKindAttr kindAttr;
+ if (parser.parseCustomAttributeWithFallback(kindAttr, Type{}))
+ return failure();
+ if (kindAttr)
+ result.getOrAddProperties<AtomicRMWOp::Properties>().kind = kindAttr;
+
+ OperandsLoc = parser.getCurrentLocation();
+ if (parser.parseOperandList(Operands))
+ return failure();
+
+ auto loc = parser.getCurrentLocation();
+ if (parseOptionalAttrDictWithCustomAttrs(parser, result))
+ return failure();
+ if (failed(verifyInherentAttrs(result.name, result.attributes, [&]() {
+ return parser.emitError(loc)
+ << "'" << result.name.getStringRef() << "' op ";
+ })))
+ return failure();
+
+ if (parser.parseColon())
+ return failure();
+
+ if (parser.parseTypeList(Types))
+ return failure();
+
+ if (parser.parseArrow())
+ return failure();
+
+ if (parser.parseCustomTypeWithFallback(resultTypes[0]))
+ return failure();
+ result.addTypes(resultTypes);
+
+ if (parser.resolveOperands(Operands, Types, OperandsLoc, result.operands))
+ return failure();
+ return success();
+}
+
+void AtomicRMWOp::print(OpAsmPrinter &printer) {
+ auto mode = getMode();
+ auto printDefaults = printDefaultValues();
+
+ printer.printStrippedAttrOrType(getKindAttr());
+ printer << ' ';
+ printer << getTensorDesc();
+ printer << ",";
+ printer << ' ';
+ printer << getMask();
+ if (Value value = getValue())
+ printer << ", " << value;
+
+ llvm::SmallVector<llvm::StringRef, 2> elidedAttrs;
+ elidedAttrs.push_back("kind");
+ if (!printDefaults && mode == xegpu::ModeKind::SIMT)
+ elidedAttrs.push_back("mode");
+
+ printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
+ printer << ' ' << ":";
+ printer << ' ';
+ printer << getOperation()->getOperandTypes();
+ printer << ' ' << "->";
+ printer << ' ';
+ printer << getResult().getType();
+}
+
+LogicalResult AtomicRMWOp::verify() {
+ auto mode = getMode();
+ if (mode != ModeKind::VC)
+ return emitOpError("AtomicRMWOp only work on VC mode.\n");
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// XeGPU_CreateNbarrierOp
+//===----------------------------------------------------------------------===//
+ParseResult CreateNbarrierOp::parse(OpAsmParser &parser,
+ OperationState &result) {
+ llvm::SmallVector<OpAsmParser::UnresolvedOperand, 2> Operands;
+ llvm::SmallVector<Type> Types;
+ llvm::SMLoc OperandsLoc;
+
+ OperandsLoc = parser.getCurrentLocation();
+ if (parser.parseOperandList(Operands))
+ return failure();
+
+ auto loc = parser.getCurrentLocation();
+ if (parseOptionalAttrDictWithCustomAttrs(parser, result))
+ return failure();
+
+ if (failed(verifyInherentAttrs(result.name, result.attributes, [&]() {
+ return parser.emitError(loc)
+ << "'" << result.name.getStringRef() << "' op ";
+ })))
+ return failure();
+
+ if (parser.parseColon())
+ return failure();
+
+ if (parser.parseLParen())
+ return failure();
+
+ if (parser.parseTypeList(Types))
+ return failure();
+
+ if (parser.parseRParen())
+ return failure();
+
+ if (parser.parseArrow())
+ return failure();
+
+ llvm::SmallVector<Type> resultTypes(1);
+ if (parser.parseType(resultTypes[0]))
+ return failure();
+
+ result.addTypes(resultTypes);
+ if (parser.resolveOperands(Operands, Types, OperandsLoc, result.operands))
+ return failure();
+ return success();
+}
+
+void CreateNbarrierOp::print(OpAsmPrinter &printer) {
+ auto mode = getMode();
+ auto printDefaults = printDefaultValues();
+ llvm::SmallVector<llvm::StringRef, 2> elidedAttrs;
+ if (!printDefaults && mode == xegpu::ModeKind::SIMT)
+ elidedAttrs.push_back("mode");
+
+ printer << ' ';
+ printer << getNbarrierId();
+ printer << ",";
+ printer << ' ';
+ printer << getNbarrierRole();
+ printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
+ printer << ' ' << ":";
+ printer << ' ' << "(";
+ printer << getNbarrierId().getType();
+ printer << ",";
+ printer << ' ';
+ printer << getNbarrierRole().getType();
+ printer << ")";
+ printer << ' ' << "->";
+ printer << ' ';
+ printer << getResult().getType();
+}
+
+} // namespace xegpu
+} // namespace mlir
+
+#include <mlir/Dialect/XeGPU/IR/XeGPUEnums.cpp.inc>
+#define GET_OP_CLASSES
+#include <mlir/Dialect/XeGPU/IR/XeGPU.cpp.inc>
diff --git a/mlir/test/Dialect/XeGPU/IR/XeGPUOps.mlir b/mlir/test/Dialect/XeGPU/IR/XeGPUOps.mlir
new file mode 100644
index 00000000000000..64a6f547fbd29d
--- /dev/null
+++ b/mlir/test/Dialect/XeGPU/IR/XeGPUOps.mlir
@@ -0,0 +1,110 @@
+// RUN: mlir-opt %s | FileCheck %s
+// Verify the printed output can be parsed.
+// RUN: mlir-opt %s | mlir-opt | FileCheck %s
+// Verify the generic form can be parsed.
+// RUN: mlir-opt -mlir-print-op-generic %s | mlir-opt | FileCheck %s
+
+// CHECK-LABEL: func @test_create_nd_tdesc_vc({{.*}}) {
+func.func @test_create_nd_tdesc_vc(%src: memref<24x32xf32>) {
+ %c0 = arith.constant 2 : index
+ %c1 = arith.constant 4 : index
+
+ // CHECK: xegpu.create_nd_tdesc {{.*}} {mode = #xegpu<mode_kind vc>}
+ // CHECK-SAME: memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
+ %1 = xegpu.create_nd_tdesc %src[%c0, %c1] {mode = vc}
+ : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
+
+ // CHECK: xegpu.create_nd_tdesc {{.*}} {mode = #xegpu<mode_kind vc>}
+ // CHECK-SAME: memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
+ %2 = xegpu.create_nd_tdesc %src[2, 4] {mode = vc}
+ : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
+
+ return
+}
+
+// CHECK-LABEL: func @test_create_tdesc_vc({{.*}}) {
+func.func @test_create_tdesc_vc(%src: ui64, %offsets : vector<16 x index>) {
+ // CHECK: xegpu.create_tdesc {{.*}} {chunk_size_per_lane = 2 : i64, mode = #xegpu<mode_kind vc>}
+ // CHECK-SAME: ui64, vector<16xindex> -> !xegpu.tensor_desc<16x2xf32, #xegpu.tdesc_attr<memory_scope = slm, #xegpu.scattered>>
+ %1 = xegpu.create_tdesc %src, %offsets {mode = vc, chunk_size_per_lane = 2}
+ : ui64, vector<16 x index> -> !xegpu.tensor_desc<16x2xf32, #xegpu.tdesc_attr<memory_scope = slm, #xegpu.scattered>>
+ return
+}
+
+// CHECK-LABEL: func @test_load_nd_vc({{.*}}) {
+func.func @test_load_nd_vc(%src: memref<24x32xf16>, %x : index, %y : index) {
+ // CHECK: xegpu.create_nd_tdesc {{.*}} {mode = #xegpu<mode_kind vc>}
+ // CHECK-SAME: memref<24x32xf16> -> !xegpu.tensor_desc<8x16xf16>
+ %1 = xegpu.create_nd_tdesc %src[%x, %y] {mode = vc}
+ : memref<24x32xf16> -> !xegpu.tensor_desc<8x16xf16>
+
+ // CHECK: xegpu.load_nd {{.*}} {l1_hint = #xegpu<cache_kind cached>, l2_hint = #xegpu<cache_kind uncached>, mode = #xegpu<mode_kind vc>, vnni_axis = 0 : i64}
+ // CHECK-SAME: !xegpu.tensor_desc<8x16xf16> -> vector<4x16x2xf16>
+ %2 = xegpu.load_nd %1 {mode = vc, vnni_axis = 0, l1_hint = cached, l2_hint = uncached} : !xegpu.tensor_desc<8x16xf16> -> vector<4x16x2xf16>
+ return
+}
+
+// CHECK-LABEL: func @test_store_nd_vc({{.*}}) {
+func.func @test_store_nd_vc(%src: memref<24x32xf16>, %dst: memref<24x32xf16>) {
+ %c0 = arith.constant 2 : index
+ %c1 = arith.constant 4 : index
+
+ // CHECK: xegpu.create_nd_tdesc {{.*}} {mode = #xegpu<mode_kind vc>}
+ // CHECK-SAME: memref<24x32xf16> -> !xegpu.tensor_desc<8x16xf16>
+ %1 = xegpu.create_nd_tdesc %src[%c0, %c1] {mode = vc}
+ : memref<24x32xf16> -> !xegpu.tensor_desc<8x16xf16>
+
+ // CHECK: xegpu.create_nd_tdesc {{.*}} {mode = #xegpu<mode_kind vc>}
+ // CHECK-SAME: memref<24x32xf16> -> !xegpu.tensor_desc<8x16xf16>
+ %2 = xegpu.create_nd_tdesc %dst[%c0, %c1] {mode = vc}
+ : memref<24x32xf16> -> !xegpu.tensor_desc<8x16xf16>
+
+ // CHECK: xegpu.load_nd {{.*}} {l1_hint = #xegpu<cache_kind cached>, l2_hint = #xegpu<cache_kind uncached>, mode = #xegpu<mode_kind vc>}
+ // CHECK-SAME: !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+ %3 = xegpu.load_nd %1 {mode=vc, l1_hint = cached, l2_hint = uncached}: !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+
+ // CHECK: xegpu.store_nd {{%[0-9], %[0-9]}} {l1_hint = #xegpu<cache_kind write_back>, l2_hint = #xegpu<cache_kind uncached>, mode = #xegpu<mode_kind vc>}
+ // CHECK-SAME: vector<8x16xf16>, !xegpu.tensor_desc<8x16xf16>
+ xegpu.store_nd %3, %2 {mode = vc, l1_hint = write_back, l2_hint = uncached}: vector<8x16xf16>, !xegpu.tensor_desc<8x16xf16>
+ return
+}
+
+// CHECK-LABEL: func @test_dpas_vc({{.*}}) {
+func.func @test_dpas_vc(%a : vector<8x8x2xf16>, %b: vector<8x16x2xf16>) {
+ // CHECK: xegpu.dpas {{.*}} {mode = #xegpu<mode_kind vc>}
+ // CHECK-SAME: vector<8x8x2xf16>, vector<8x16x2xf16> -> vector<8x16xf32>
+ %1 = xegpu.dpas %a, %b {mode = vc}: vector<8x8x2xf16>, vector<8x16x2xf16> -> vector<8x16xf32>
+ return
+}
+
+// CHECK-LABEL: func @test_update_nd_offset_vc({{.*}}) {
+func.func @test_update_nd_offset_vc(%src: memref<24x32xf32>) {
+ %c0 = arith.constant 2 : index
+ %c1 = arith.constant 4 : index
+
+ // CHECK: xegpu.create_nd_tdesc {{.*}} {mode = #xegpu<mode_kind vc>}
+ // CHECK-SAME: memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
+ %1 = xegpu.create_nd_tdesc %src[%c0, %c1] {mode = vc}
+ : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
+
+ // CHECK: xegpu.load_nd {{%[0-9]}} {l1_hint = #xegpu<cache_kind cached>, l2_hint = #xegpu<cache_kind uncached>, mode = #xegpu<mode_kind vc>}
+ // CHECK-SAME: !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32>
+ %2 = xegpu.load_nd %1 {mode = vc, l1_hint = cached, l2_hint = uncached}: !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32>
+
+ // CHECK: xegpu.update_nd_offset {{%[0-9]}}, [{{%c[0-9], %c[0-9]}}] {mode = #xegpu<mode_kind vc>}
+ // CHECK-SAME: !xegpu.tensor_desc<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
+ %3 = xegpu.update_nd_offset %1, [%c0, %c1] {mode = vc}: !xegpu.tensor_desc<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
+
+ return
+}
+
+// CHECK-LABEL: func @test_prefetch_nd_vc({{.*}}) {
+func.func @test_prefetch_nd_vc(%src: memref<24x32xf16>, %x : index, %y : index) {
+ // CHECK: xegpu.create_nd_tdesc {{.*}} {mode = #xegpu<mode_kind vc>}
+ // CHECK-SAME: memref<24x32xf16> -> !xegpu.tensor_desc<8x16xf16>
+ %1 = xegpu.create_nd_tdesc %src[%x, %y] {mode = vc} : memref<24x32xf16> -> !xegpu.tensor_desc<8x16xf16>
+ // CHECK: xegpu.prefetch_nd {{%[0-9]}} {l1_hint = #xegpu<cache_kind cached>, l2_hint = #xegpu<cache_kind uncached>, mode = #xegpu<mode_kind vc>}
+ // CHECK-SAME: !xegpu.tensor_desc<8x16xf16>
+ xegpu.prefetch_nd %1 {mode = vc, l1_hint = cached, l2_hint = uncached}: !xegpu.tensor_desc<8x16xf16>
+ return
+}
diff --git a/mlir/test/Dialect/XeGPU/IR/atomic_rmw.mlir b/mlir/test/Dialect/XeGPU/IR/atomic_rmw.mlir
new file mode 100644
index 00000000000000..f80df161a543ac
--- /dev/null
+++ b/mlir/test/Dialect/XeGPU/IR/atomic_rmw.mlir
@@ -0,0 +1,43 @@
+// RUN: mlir-opt %s | FileCheck %s
+// Verify the printed output can be parsed.
+// RUN: mlir-opt %s | mlir-opt | FileCheck %s
+// Verify the generic form can be parsed.
+// RUN: mlir-opt -mlir-print-op-generic %s | mlir-opt | FileCheck %s
+
+// CHECK-LABEL: func @test_atomic_rmw({{.*}}) {
+func.func @test_atomic_rmw(%src: ui64, %offsets : vector<16 x index>, %value : vector<16xf32>, %mask : vector<16xi1>) {
+ %1 = xegpu.create_tdesc %src, %offsets {mode=vc}: ui64, vector<16 x index> -> !xegpu.tensor_desc<16xf32, #xegpu.scattered>
+
+ // CHECK: xegpu.atomic_rmw
+ // CHECK-SAME: !xegpu.tensor_desc<16xf32, #xegpu.scattered>, vector<16xi1>, vector<16xf32>
+ xegpu.atomic_rmw #xegpu<atomic_rmw_kind addf> %1, %mask, %value {mode=vc}
+ : !xegpu.tensor_desc<16xf32, #xegpu.scattered>, vector<16xi1>, vector<16xf32> -> vector<16xf32>
+
+ return
+}
+
+// CHECK-LABEL: func @test_atomic_rmw_0({{.*}}) {
+func.func @test_atomic_rmw_0(%src: ui64, %offsets : vector<16 x index>, %value : vector<16x2xf32>, %mask : vector<16xi1>) {
+ %1 = xegpu.create_tdesc %src, %offsets {chunk_size_per_lane = 2, mode=vc}
+ : ui64, vector<16 x index> -> !xegpu.tensor_desc<16x2xf32, #xegpu.scattered>
+
+ // CHECK: xegpu.atomic_rmw
+ // CHECK-SAME: tensor_desc<16x2xf32, #xegpu.scattered>, vector<16xi1>, vector<16x2xf32> -> vector<16x2xf32>
+ xegpu.atomic_rmw mulf %1, %mask, %value {mode=vc}
+ : !xegpu.tensor_desc<16x2xf32, #xegpu.scattered>, vector<16xi1>, vector<16x2xf32> -> vector<16x2xf32>
+
+ return
+}
+
+// CHECK-LABEL: func @test_atomic_rmw_1({{.*}}) {
+func.func @test_atomic_rmw_1(%src: ui64, %offsets : vector<16 x index>, %value : vector<16x2xi32>, %mask : vector<16xi1>) {
+ %1 = xegpu.create_tdesc %src, %offsets {chunk_size_per_lane = 2, mode=vc}
+ : ui64, vector<16 x index> -> !xegpu.tensor_desc<16x2xi32, #xegpu.scattered>
+
+ // CHECK: xegpu.atomic_rmw
+ // CHECK-SAME: !xegpu.tensor_desc<16x2xi32, #xegpu.scattered>, vector<16xi1>, vector<16x2xi32> -> vector<16x2xf32>
+ xegpu.atomic_rmw andi %1, %mask, %value {mode=vc}
+ : !xegpu.tensor_desc<16x2xi32, #xegpu.scattered>, vector<16xi1>, vector<16x2xi32> -> vector<16x2xf32>
+
+ return
+}
diff --git a/mlir/test/Dialect/XeGPU/IR/atomic_rmw_vc.mlir b/mlir/test/Dialect/XeGPU/IR/atomic_rmw_vc.mlir
new file mode 100644
index 00000000000000..0f7229a02aa180
--- /dev/null
+++ b/mlir/test/Dialect/XeGPU/IR/atomic_rmw_vc.mlir
@@ -0,0 +1,38 @@
+// RUN: mlir-opt %s | FileCheck %s
+// Verify the printed output can be parsed.
+// RUN: mlir-opt %s | mlir-opt | FileCheck %s
+// Verify the generic form can be parsed.
+// RUN: mlir-opt -mlir-print-op-generic %s | mlir-opt | FileCheck %s
+
+// CHECK-LABEL: func @test_atomic_rmw({{.*}}) {
+func.func @test_atomic_rmw(%src: ui64, %offsets : vector<16 x index>, %value : vector<16x1xf32>, %mask : vector<16xi1>) {
+ %1 = xegpu.create_tdesc %src, %offsets {mode = vc} : ui64, vector<16 x index> -> !xegpu.tensor_desc<16xf32, #xegpu.scattered>
+
+ // CHECK: xegpu.atomic_rmw
+ // CHECK-SAME: !xegpu.tensor_desc<16xf32, #xegpu.scattered>, vector<16xi1>, vector<16x1xf32>
+ xegpu.atomic_rmw addf %1, %mask, %value {mode = vc} : !xegpu.tensor_desc<16xf32, #xegpu.scattered>, vector<16xi1>, vector<16x1xf32> -> vector<16x1xf32>
+
+ return
+}
+
+// CHECK-LABEL: func @test_atomic_rmw_0({{.*}}) {
+func.func @test_atomic_rmw_0(%src: ui64, %offsets : vector<16 x index>, %value : vector<16x2xf32>, %mask : vector<16xi1>) {
+ %1 = xegpu.create_tdesc %src, %offsets {mode = vc, chunk_size_per_lane = 2}: ui64, vector<16 x index> -> !xegpu.tensor_desc<16x2xf32, #xegpu.scattered>
+
+ // CHECK: xegpu.atomic_rmw
+ // CHECK-SAME: !xegpu.tensor_desc<16x2xf32, #xegpu.scattered>, vector<16xi1>, vector<16x2xf32>
+ xegpu.atomic_rmw mulf %1, %mask, %value {mode = vc} : !xegpu.tensor_desc<16x2xf32, #xegpu.scattered>, vector<16xi1>, vector<16x2xf32> -> vector<16x2xf32>
+
+ return
+}
+
+// CHECK-LABEL: func @test_atomic_rmw_1({{.*}}) {
+func.func @test_atomic_rmw_1(%src: ui64, %offsets : vector<16 x index>, %value : vector<16x2xi32>, %mask : vector<16xi1>) {
+ %1 = xegpu.create_tdesc %src, %offsets {mode = vc, chunk_size_per_lane = 2}: ui64, vector<16 x index> -> !xegpu.tensor_desc<16x2xi32, #xegpu.scattered>
+
+ // CHECK: xegpu.atomic_rmw
+ // CHECK-SAME: !xegpu.tensor_desc<16x2xi32, #xegpu.scattered>, vector<16xi1>, vector<16x2xi32>
+ xegpu.atomic_rmw andi %1, %mask, %value {mode = vc} : !xegpu.tensor_desc<16x2xi32, #xegpu.scattered>, vector<16xi1>, vector<16x2xi32> -> vector<16x2xf32>
+
+ return
+}
diff --git a/mlir/test/Dialect/XeGPU/IR/barrier_ops.mlir b/mlir/test/Dialect/XeGPU/IR/barrier_ops.mlir
new file mode 100644
index 00000000000000..a1abc9e171bcaf
--- /dev/null
+++ b/mlir/test/Dialect/XeGPU/IR/barrier_ops.mlir
@@ -0,0 +1,54 @@
+// RUN: mlir-opt %s | FileCheck %s
+// Verify the printed output can be parsed.
+// RUN: mlir-opt %s | mlir-opt | FileCheck %s
+// Verify the generic form can be parsed.
+// RUN: mlir-opt -mlir-print-op-generic %s | mlir-opt | FileCheck %s
+
+// CHECK-LABEL: func @alloc_nbarrier({{.*}}) {
+func.func @alloc_nbarrier() {
+ // CHECK: xegpu.alloc_nbarrier
+ xegpu.alloc_nbarrier 8
+ return
+}
+
+// CHECK-LABEL: func @create_nbarrier({{.*}}) {
+func.func @create_nbarrier() {
+ %nbarrier_id = arith.constant 1 : i8
+ %nbarrier_role = arith.constant 0 : i8
+ // CHECK: xegpu.create_nbarrier
+ // CHECK-SAME: {num_consumers = 32 : i8, num_producers = 32 : i8}
+ // CHECK-SAME: (i8, i8) -> !xegpu.nbarrier
+ %nbarrier = xegpu.create_nbarrier %nbarrier_id, %nbarrier_role {num_producers = 32 :i8 , num_consumers = 32 : i8}
+ : (i8, i8) -> !xegpu.nbarrier
+ return
+}
+
+// CHECK-LABEL: func @nbarrier_arrive({{.*}}) {
+func.func @nbarrier_arrive(%nbarrier : !xegpu.nbarrier) {
+ // CHECK: xegpu.nbarrier_arrive
+ // CHECK-SAME: !xegpu.nbarrier
+ xegpu.nbarrier_arrive %nbarrier : !xegpu.nbarrier
+ return
+}
+
+// CHECK-LABEL: func @nbarrier_wait({{.*}}) {
+func.func @nbarrier_wait(%nbarrier : !xegpu.nbarrier) {
+ // CHECK: xegpu.nbarrier_wait
+ // CHECK-SAME: !xegpu.nbarrier
+ xegpu.nbarrier_wait %nbarrier : !xegpu.nbarrier
+ return
+}
+
+// CHECK-LABEL: func @compile_hint({{.*}}) {
+func.func @compile_hint() {
+ // CHECK: xegpu.compile_hint
+ xegpu.compile_hint
+ return
+}
+
+// CHECK-LABEL: func @mfence({{.*}}) {
+func.func @mfence() {
+ // CHECK: xegpu.mfence {fence_op = "none", fence_scope = "local", memory_kind = "ugm"}
+ xegpu.mfence {memory_kind = "ugm" , fence_op = "none", fence_scope = "local"}
+ return
+}
diff --git a/mlir/test/Dialect/XeGPU/IR/create_nd_tdesc.mlir b/mlir/test/Dialect/XeGPU/IR/create_nd_tdesc.mlir
new file mode 100644
index 00000000000000..cebf59f12939da
--- /dev/null
+++ b/mlir/test/Dialect/XeGPU/IR/create_nd_tdesc.mlir
@@ -0,0 +1,111 @@
+// RUN: mlir-opt %s | FileCheck %s
+// Verify the printed output can be parsed.
+// RUN: mlir-opt %s | mlir-opt | FileCheck %s
+// Verify the generic form can be parsed.
+// RUN: mlir-opt -mlir-print-op-generic %s | mlir-opt | FileCheck %s
+
+#sg_map_fp16 = #xegpu.sg_map<wi_layout = [2, 8], wi_data = [1, 2]>
+
+func.func @test_create_nd_tdesc_0(%src: memref<24x32xf16>) {
+ %c0 = arith.constant 2 : index
+ %c1 = arith.constant 4 : index
+
+ // CHECK: xegpu.create_nd_tdesc
+ // CHECK-SAME: memref<24x32xf16> -> !xegpu.tensor_desc<8x16xf16, #xegpu.sg_map<wi_layout = [2, 8], wi_data = [1, 2]>>
+ %1 = xegpu.create_nd_tdesc %src[%c0, %c1]
+ : memref<24x32xf16> -> !xegpu.tensor_desc<8x16xf16, #sg_map_fp16>
+
+ // CHECK: xegpu.create_nd_tdesc
+ // CHECK-SAME: memref<24x32xf16> -> !xegpu.tensor_desc<8x16xf16, #xegpu.sg_map<wi_layout = [2, 8], wi_data = [1, 2]>>
+ %2 = xegpu.create_nd_tdesc %src[2, 4]
+ : memref<24x32xf16> -> !xegpu.tensor_desc<8x16xf16, #sg_map_fp16>
+
+ return
+}
+
+// CHECK-LABEL: func @test_create_nd_tdesc_1({{.*}}) {
+func.func @test_create_nd_tdesc_1(%src: memref<24x32xf16>, %x : index, %y : index) {
+ // CHECK: xegpu.create_nd_tdesc
+ // CHECK-SAME: memref<24x32xf16> -> !xegpu.tensor_desc<8x16xf16, #xegpu.sg_map<wi_layout = [2, 8], wi_data = [1, 2]>>
+ %1 = xegpu.create_nd_tdesc %src[%x, %y]
+ : memref<24x32xf16> -> !xegpu.tensor_desc<8x16xf16, #sg_map_fp16>
+ return
+}
+
+// CHECK-LABEL: func @test_create_nd_tdesc_2({{.*}}) {
+func.func @test_create_nd_tdesc_2(%src: ui64, %w : index, %h : index, %x : index, %y : index) {
+ %c1 = arith.constant 1 : index
+ // CHECK: xegpu.create_nd_tdesc
+ // CHECK-SAME: ui64 -> !xegpu.tensor_desc<8x16xf16, #xegpu.sg_map<wi_layout = [2, 8], wi_data = [1, 2]>>
+ %1 = xegpu.create_nd_tdesc %src[%x, %y], [%h, %w], [%w, %c1] : ui64 -> !xegpu.tensor_desc<8x16xf16, #sg_map_fp16>
+ return
+}
+
+// CHECK-LABEL: func @test_create_nd_tdesc_3({{.*}}) {
+func.func @test_create_nd_tdesc_3(%src: memref<?x?xf16>, %w : index, %h : index, %x : index, %y : index) {
+ %c1 = arith.constant 1 : index
+ // CHECK: xegpu.create_nd_tdesc
+ // CHECK-SAME: memref<?x?xf16> -> !xegpu.tensor_desc<8x16xf16, #xegpu.sg_map<wi_layout = [2, 8], wi_data = [1, 2]>>
+ %1 = xegpu.create_nd_tdesc %src[%x, %y], [%h, %w], [%w, %c1] : memref<?x?xf16> -> !xegpu.tensor_desc<8x16xf16, #sg_map_fp16>
+ return
+}
+
+
+// CHECK-LABEL: func @test_create_nd_tdesc_4({{.*}}) {
+func.func @test_create_nd_tdesc_4(%src: memref<?x?xf16>, %w : index, %h : index, %x : index, %y : index) {
+ %c1 = arith.constant 1 : index
+ // CHECK: xegpu.create_nd_tdesc
+ // CHECK-SAME: memref<?x?xf16> -> !xegpu.tensor_desc<8x16xf16, #xegpu.sg_map<wi_layout = [2, 8], wi_data = [1, 2]>>
+ %1 = xegpu.create_nd_tdesc %src[%x, %y], [%h, %w], [%w, %c1]
+ : memref<?x?xf16> -> !xegpu.tensor_desc<8x16xf16, #sg_map_fp16>
+ return
+}
+
+// CHECK-LABEL: func @test_create_nd_tdesc_5({{.*}}) {
+func.func @test_create_nd_tdesc_5(%src: memref<?x?xf16>, %w : index, %h : index, %x : index, %y : index) {
+ %c1 = arith.constant 1 : index
+ // CHECK: xegpu.create_nd_tdesc
+ // CHECK-SAME: memref<?x?xf16> -> !xegpu.tensor_desc<8x16xf16, #xegpu.tdesc_attr<memory_scope = slm, map = <wi_layout = [2, 8], wi_data = [1, 2]>>>
+ %1 = xegpu.create_nd_tdesc %src[%x, %y], [%h, %w], [%w, %c1]
+ : memref<?x?xf16> -> !xegpu.tensor_desc<8x16xf16, #xegpu.tdesc_attr<memory_scope = slm, map = #sg_map_fp16>>
+ return
+}
+
+// CHECK-LABEL: func @test_create_nd_tdesc_6({{.*}}) {
+func.func @test_create_nd_tdesc_6(%src: memref<?x?xf16>, %w : index, %h : index, %x : index, %y : index) {
+ %c1 = arith.constant 1 : index
+ // CHECK: xegpu.create_nd_tdesc
+ // CHECK-SAME: memref<?x?xf16> -> !xegpu.tensor_desc<8x16xf16, #xegpu.tdesc_attr<memory_scope = slm, map = <wi_layout = [2, 8], wi_data = [1, 2]>>>
+ %1 = xegpu.create_nd_tdesc %src[%x, %y], [%h, %w], [%w, %c1]
+ : memref<?x?xf16> -> !xegpu.tensor_desc<8x16xf16, #xegpu.tdesc_attr<memory_scope = slm, map = #sg_map_fp16>>
+ return
+}
+
+// CHECK-LABEL: func @test_create_nd_tdesc_7({{.*}}) {
+func.func @test_create_nd_tdesc_7(%src: memref<1024xf16>, %offset : index) {
+ // CHECK: xegpu.create_nd_tdesc
+ // CHECK-SAME: memref<1024xf16> -> !xegpu.tensor_desc<16xf16, #xegpu.sg_map<wi_layout = [2, 8], wi_data = [1, 2]>>
+ %1 = xegpu.create_nd_tdesc %src[%offset] : memref<1024xf16> -> !xegpu.tensor_desc<16xf16, #sg_map_fp16>
+ return
+}
+
+
+// CHECK-LABEL: func @test_create_nd_tdesc_8({{.*}}) {
+func.func @test_create_nd_tdesc_8(%src: memref<?x?xf16>, %w : index, %h : index, %x : index) {
+ %c1 = arith.constant 1 : index
+ // CHECK: xegpu.create_nd_tdesc
+ // CHECK-SAME: memref<?x?xf16> -> !xegpu.tensor_desc<8x16xf16, #xegpu.tdesc_attr<memory_scope = slm, map = <wi_layout = [2, 8], wi_data = [1, 2]>>>
+ %1 = xegpu.create_nd_tdesc %src[8, %x], [%h, %w], [%w, %c1]
+ : memref<?x?xf16> -> !xegpu.tensor_desc<8x16xf16, #xegpu.tdesc_attr<memory_scope = slm, map = #sg_map_fp16>>
+ return
+}
+
+// CHECK-LABEL: func @test_create_nd_tdesc_9({{.*}}) {
+func.func @test_create_nd_tdesc_9(%src: memref<?x?xf16>, %w : index, %h : index, %x : index) {
+ %c1 = arith.constant 1 : index
+ // CHECK: xegpu.create_nd_tdesc
+ // CHECK-SAME: memref<?x?xf16> -> !xegpu.tensor_desc<64x128xf16, #xegpu.tdesc_attr<memory_scope = slm, map = <wi_layout = [2, 8], wi_data = [1, 2]>>>
+ %1 = xegpu.create_nd_tdesc %src[8, %x], [%h, %w], [%w, %c1] : memref<?x?xf16>
+ -> !xegpu.tensor_desc<64x128xf16, #xegpu.tdesc_attr<memory_scope = slm, map = #sg_map_fp16>>
+ return
+}
diff --git a/mlir/test/Dialect/XeGPU/IR/create_nd_tdesc_vc.mlir b/mlir/test/Dialect/XeGPU/IR/create_nd_tdesc_vc.mlir
new file mode 100644
index 00000000000000..a21bf792fe0792
--- /dev/null
+++ b/mlir/test/Dialect/XeGPU/IR/create_nd_tdesc_vc.mlir
@@ -0,0 +1,115 @@
+// RUN: mlir-opt %s | FileCheck %s
+// Verify the printed output can be parsed.
+// RUN: mlir-opt %s | mlir-opt | FileCheck %s
+// Verify the generic form can be parsed.
+// RUN: mlir-opt -mlir-print-op-generic %s | mlir-opt | FileCheck %s
+
+// ----- SIMD -----
+// CHECK-LABEL: func @test_create_nd_tdesc_vc_0({{.*}}) {
+func.func @test_create_nd_tdesc_vc_0(%src: memref<24x32xf32>) {
+ %c0 = arith.constant 2 : index
+ %c1 = arith.constant 4 : index
+
+ // CHECK: xegpu.create_nd_tdesc
+ // CHECK-SAME: memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
+ %1 = xegpu.create_nd_tdesc %src[%c0, %c1] {mode = vc}
+ : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
+
+ // CHECK: xegpu.create_nd_tdesc
+ // CHECK-SAME: memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
+ %2 = xegpu.create_nd_tdesc %src[2, 4] {mode = vc}
+ : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
+
+ return
+}
+
+// CHECK-LABEL: func @test_create_nd_tdesc_vc_1({{.*}}) {
+func.func @test_create_nd_tdesc_vc_1(%src: memref<24x32xf32>, %x : index, %y : index) {
+ // CHECK: xegpu.create_nd_tdesc
+ // CHECK-SAME: %arg0[%arg1, %arg2]
+ // CHECK-SAME: memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
+ %1 = xegpu.create_nd_tdesc %src[%x, %y] {mode = vc}
+ : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
+ return
+}
+
+// CHECK-LABEL: func @test_create_nd_tdesc_vc_2({{.*}}) {
+func.func @test_create_nd_tdesc_vc_2(%src: ui64, %w : index, %h : index, %x : index, %y : index) {
+ %c1 = arith.constant 1 : index
+ // CHECK: xegpu.create_nd_tdesc
+ // CHECK-SAME: %arg0[%arg3, %arg4], [%arg2, %arg1], [%arg1, %c1]
+ // CHECK-SAME: ui64 -> !xegpu.tensor_desc<8x16xf32>
+ %1 = xegpu.create_nd_tdesc %src[%x, %y], [%h, %w], [%w, %c1] {mode = vc} : ui64 -> !xegpu.tensor_desc<8x16xf32>
+ return
+}
+
+// CHECK-LABEL: func @test_create_nd_tdesc_vc_3({{.*}}) {
+func.func @test_create_nd_tdesc_vc_3(%src: memref<?x?xf32>, %w : index, %h : index, %x : index, %y : index) {
+ %c1 = arith.constant 1 : index
+ // CHECK: xegpu.create_nd_tdesc
+ // CHECK-SAME: %arg0[%arg3, %arg4], [%arg2, %arg1], [%arg1, %c1]
+ // CHECK-SAME: memref<?x?xf32> -> !xegpu.tensor_desc<8x16xf32>
+ %1 = xegpu.create_nd_tdesc %src[%x, %y], [%h, %w], [%w, %c1] {mode = vc} : memref<?x?xf32> -> !xegpu.tensor_desc<8x16xf32>
+ return
+}
+
+
+// CHECK-LABEL: func @test_create_nd_tdesc_vc_4({{.*}}) {
+func.func @test_create_nd_tdesc_vc_4(%src: memref<?x?xf32>, %w : index, %h : index, %x : index, %y : index) {
+ %c1 = arith.constant 1 : index
+ // CHECK: xegpu.create_nd_tdesc
+ // CHECK-SAME: %arg0[%arg3, %arg4], [%arg2, %arg1], [%arg1, %c1]
+ // CHECK-SAME: memref<?x?xf32> -> !xegpu.tensor_desc<8x16xf32>
+ %1 = xegpu.create_nd_tdesc %src[%x, %y], [%h, %w], [%w, %c1] {mode = vc} : memref<?x?xf32> -> !xegpu.tensor_desc<8x16xf32>
+ return
+}
+
+// CHECK-LABEL: func @test_create_nd_tdesc_vc_5({{.*}}) {
+func.func @test_create_nd_tdesc_vc_5(%src: memref<?x?xf32>, %w : index, %h : index, %x : index, %y : index) {
+ %c1 = arith.constant 1 : index
+ // CHECK: xegpu.create_nd_tdesc
+ // CHECK-SAME: %arg0[%arg3, %arg4], [%arg2, %arg1], [%arg1, %c1]
+ // CHECK-SAME: memref<?x?xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.tdesc_attr<memory_scope = slm>>
+ %1 = xegpu.create_nd_tdesc %src[%x, %y], [%h, %w], [%w, %c1] {mode = vc}
+ : memref<?x?xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.tdesc_attr<memory_scope = slm>>
+ return
+}
+
+// CHECK-LABEL: func @test_create_nd_tdesc_vc_6({{.*}}) {
+func.func @test_create_nd_tdesc_vc_6(%src: memref<?x?xf32>, %w : index, %h : index, %x : index, %y : index) {
+ %c1 = arith.constant 1 : index
+ // CHECK: xegpu.create_nd_tdesc
+ // CHECK-SAME: %arg0[%arg3, %arg4], [%arg2, %arg1], [%arg1, %c1]
+ // CHECK-SAME: memref<?x?xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.tdesc_attr<memory_scope = slm>>
+ %1 = xegpu.create_nd_tdesc %src[%x, %y], [%h, %w], [%w, %c1] {mode = vc}
+ : memref<?x?xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.tdesc_attr<memory_scope = slm>>
+ return
+}
+
+
+// CHECK-LABEL: func @test_create_nd_tdesc_vc_7({{.*}}) {
+func.func @test_create_nd_tdesc_vc_7(%src: memref<1024xf32>, %offset : index) {
+ // CHECK: xegpu.create_nd_tdesc
+ // CHECK-SAME: memref<1024xf32> -> !xegpu.tensor_desc<16xf32>
+ %1 = xegpu.create_nd_tdesc %src[%offset] {mode = vc} : memref<1024xf32> -> !xegpu.tensor_desc<16xf32>
+ return
+}
+
+
+// CHECK-LABEL: func @test_create_nd_tdesc_vc_8({{.*}}) {
+func.func @test_create_nd_tdesc_vc_8(%src: memref<?x?xf32>, %w : index, %h : index, %x : index) {
+ %c1 = arith.constant 1 : index
+ // CHECK: xegpu.create_nd_tdesc
+ // CHECK-SAME: memref<?x?xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.tdesc_attr<memory_scope = slm>>
+ %1 = xegpu.create_nd_tdesc %src[8, %x], [%h, %w], [%w, %c1] {mode = vc}
+ : memref<?x?xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.tdesc_attr<memory_scope = slm>>
+ return
+}
+
+// CHECK-LABEL: func @test_create_nd_tdesc_vc_9({{.*}}) {
+func.func @test_create_nd_tdesc_vc_9(%src: memref<8x32xf32>) {
+ // CHECK: xegpu.create_nd_tdesc
+ // CHECK-SAME: memref<8x32xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.tdesc_attr<memory_scope = slm, array_length = 2>>
+ %1 = xegpu.create_nd_tdesc %src[0, 0] {mode = vc} : memref<8x32xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.tdesc_attr<memory_scope = slm, array_length = 2>>
+ return
+}
diff --git a/mlir/test/Dialect/XeGPU/IR/create_tdesc.mlir b/mlir/test/Dialect/XeGPU/IR/create_tdesc.mlir
new file mode 100644
index 00000000000000..8fb5ac824ddb27
--- /dev/null
+++ b/mlir/test/Dialect/XeGPU/IR/create_tdesc.mlir
@@ -0,0 +1,11 @@
+// RUN: mlir-opt %s | FileCheck %s
+// Verify the printed output can be parsed.
+// RUN: mlir-opt %s | mlir-opt | FileCheck %s
+// Verify the generic form can be parsed.
+// RUN: mlir-opt -mlir-print-op-generic %s | mlir-opt | FileCheck %s
+
+// CHECK-LABEL: func @test_create_tdesc_vc({{.*}}) {
+func.func @test_create_tdesc_vc(%src: ui64, %offsets : vector<16 x index>) {
+ %1 = xegpu.create_tdesc %src, %offsets {mode=vc} : ui64, vector<16 x index> -> !xegpu.tensor_desc<16xf32, #xegpu.scattered>
+ return
+}
\ No newline at end of file
diff --git a/mlir/test/Dialect/XeGPU/IR/create_tdesc_vc.mlir b/mlir/test/Dialect/XeGPU/IR/create_tdesc_vc.mlir
new file mode 100644
index 00000000000000..245d862e302a7c
--- /dev/null
+++ b/mlir/test/Dialect/XeGPU/IR/create_tdesc_vc.mlir
@@ -0,0 +1,51 @@
+// RUN: mlir-opt %s | FileCheck %s
+// Verify the printed output can be parsed.
+// RUN: mlir-opt %s | mlir-opt | FileCheck %s
+// Verify the generic form can be parsed.
+// RUN: mlir-opt -mlir-print-op-generic %s | mlir-opt | FileCheck %s
+
+
+// CHECK-LABEL: func @test_create_tdesc_vc({{.*}}) {
+func.func @test_create_tdesc_vc(%src: ui64, %offsets : vector<16 x index>) {
+ // CHECK: xegpu.create_tdesc %{{arg[0-9]}}, %{{arg[0-9]}} {mode = #xegpu<mode_kind vc>}
+ // CHECK-SAME: ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scattered>
+ %1 = xegpu.create_tdesc %src, %offsets {mode = vc}: ui64, vector<16 x index> -> !xegpu.tensor_desc<16xf32, #xegpu.scattered>
+ return
+}
+
+// CHECK-LABEL: func @test_create_tdesc_vc_2({{.*}}) {
+func.func @test_create_tdesc_vc_2(%src: ui64, %offsets : vector<16 x index>) {
+ // CHECK: xegpu.create_tdesc %{{arg[0-9]}}, %{{arg[0-9]}} {mode = #xegpu<mode_kind vc>}
+ // CHECK-SAME: ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.tdesc_attr<memory_scope = slm, #xegpu.scattered>>
+ %1 = xegpu.create_tdesc %src, %offsets {mode = vc} : ui64, vector<16 x index>
+ -> !xegpu.tensor_desc<16xf32, #xegpu.tdesc_attr<memory_scope = slm, #xegpu.scattered>>
+ return
+}
+
+// CHECK-LABEL: func @test_create_tdesc_vc_3({{.*}}) {
+func.func @test_create_tdesc_vc_3(%src: ui64, %offsets : vector<16 x index>) {
+ // CHECK: xegpu.create_tdesc %{{arg[0-9]}}, %{{arg[0-9]}} {chunk_size_per_lane = 8 : i64, mode = #xegpu<mode_kind vc>}
+ // CHECK-SAME: ui64, vector<16xindex> -> !xegpu.tensor_desc<16x8xf32, #xegpu.scattered>
+ %1 = xegpu.create_tdesc %src, %offsets {mode = vc, chunk_size_per_lane = 8}
+ : ui64, vector<16 x index> -> !xegpu.tensor_desc<16x8xf32, #xegpu.scattered>
+ return
+}
+
+// CHECK-LABEL: func @test_create_tdesc_vc_4({{.*}}) {
+func.func @test_create_tdesc_vc_4(%src: ui64, %offsets : vector<16 x index>) {
+ // CHECK: xegpu.create_tdesc %{{arg[0-9]}}, %{{arg[0-9]}} {chunk_size_per_lane = 2 : i64, mode = #xegpu<mode_kind vc>}
+ // CHECK-SAME: ui64, vector<16xindex> -> !xegpu.tensor_desc<16x2xf32, #xegpu.tdesc_attr<memory_scope = slm, #xegpu.scattered>>
+ %1 = xegpu.create_tdesc %src, %offsets {mode = vc, chunk_size_per_lane = 2}
+ : ui64, vector<16 x index> -> !xegpu.tensor_desc<16x2xf32, #xegpu.tdesc_attr<memory_scope = slm, #xegpu.scattered>>
+ return
+}
+
+
+// CHECK-LABEL: func @test_create_tdesc_vc_5({{.*}}) {
+func.func @test_create_tdesc_vc_5(%src: memref<?xf32>, %offsets : vector<16 x index>) {
+ // CHECK: xegpu.create_tdesc %{{arg[0-9]}}, %{{arg[0-9]}} {chunk_size_per_lane = 2 : i64, mode = #xegpu<mode_kind vc>}
+ // CHECK-SAME: memref<?xf32>, vector<16xindex> -> !xegpu.tensor_desc<16x2xf32, #xegpu.tdesc_attr<memory_scope = slm, #xegpu.scattered>>
+ %1 = xegpu.create_tdesc %src, %offsets {mode = vc, chunk_size_per_lane = 2}
+ : memref<?xf32>, vector<16 x index> -> !xegpu.tensor_desc<16x2xf32, #xegpu.tdesc_attr<memory_scope = slm, #xegpu.scattered>>
+ return
+}
diff --git a/mlir/test/Dialect/XeGPU/IR/invalid_vc.mlir b/mlir/test/Dialect/XeGPU/IR/invalid_vc.mlir
new file mode 100644
index 00000000000000..4a92fa77c5815e
--- /dev/null
+++ b/mlir/test/Dialect/XeGPU/IR/invalid_vc.mlir
@@ -0,0 +1,70 @@
+// RUN: mlir-opt -allow-unregistered-dialect %s -split-input-file -verify-diagnostics
+
+// -----
+func.func @test_create_nd_tdesc_vc_1(%src: memref<24xf32>) {
+ %c0 = arith.constant 2 : index
+ %c1 = arith.constant 4 : index
+
+ // expected-error at +1 {{Expecting the rank of shape, strides, offsets and memref type should match with each other}}
+ %1 = xegpu.create_nd_tdesc %src[%c0, %c1] {mode = vc} : memref<24xf32> -> !xegpu.tensor_desc<8x16xf32>
+ return
+}
+
+// -----
+func.func @test_create_nd_tdesc_vc_3(%input: memref<?xf32>) {
+ %c0 = arith.constant 2 : index
+ %c1 = arith.constant 4 : index
+
+ %c8 = arith.constant 8 : index
+ %c16 = arith.constant 16 : index
+
+ // expected-error at +1 {{Expecting the rank of shape, strides, offsets and memref type should match with each other}}
+ %1 = xegpu.create_nd_tdesc %input[%c0, %c1], [%c8, %c16], [%c16, %c1] {mode = vc} : memref<?xf32> -> !xegpu.tensor_desc<8x16xf32>
+ return
+}
+
+
+// -----
+func.func @test_create_nd_tdesc_vc_4(%input: memref<?x?xf32>) {
+ %c1 = arith.constant 2 : index
+ %c8 = arith.constant 8 : index
+
+ // expected-error at +1 {{Expecting the rank of shape, strides, offsets and memref type should match with each other}}
+ %1 = xegpu.create_nd_tdesc %input[%c1], [%c8], [%c1] {mode = vc}
+ : memref<?x?xf32> -> !xegpu.tensor_desc<8x16xf32>
+ return
+}
+
+// -----
+func.func @test_create_nd_tdesc_vc_5(%input: memref<24x32x64xf32>) {
+ %c1 = arith.constant 2 : index
+ %c8 = arith.constant 8 : index
+
+ // expected-error at +1 {{operand #0 must be 1D/2D memref}}
+ %1 = xegpu.create_nd_tdesc %input[%c1, %c1, %c8] {mode = vc}
+ : memref<24x32x64xf32> -> !xegpu.tensor_desc<8x16x8xf32>
+ return
+}
+
+// -----
+func.func @test_create_tdesc(%src: ui64, %offsets : vector<16x8xindex>) {
+ // expected-error at +1 {{operand #1 must be vector of index values of ranks 1}}
+ %1 = xegpu.create_tdesc %src, %offsets {mode = vc}
+ : ui64, vector<16x8xindex> -> !xegpu.tensor_desc<16x8xf32, #xegpu.scattered>
+ return
+}
+
+// -----
+func.func @test_load_gather(%src: ui64, %offsets : vector<16xindex>) {
+ %0 = arith.constant dense<1>: vector<16x8xi1>
+ // CHECK: xegpu.create_tdesc
+ // CHECK-SAME: {mode = vc, chunk_size_per_lane = 8}
+ // CHECK-SAME: ui64, vector<16xindex> -> !xegpu.tensor_desc<16x8xf32, #xegpu.scattered>
+ %1 = xegpu.create_tdesc %src, %offsets {mode = vc, chunk_size_per_lane = 8}
+ : ui64, vector<16xindex> -> !xegpu.tensor_desc<16x8xf16, #xegpu.scattered>
+
+ // expected-error at +1 {{Result shape doesn't match TensorDesc shape.}}
+ %2 = xegpu.load %1, %0 {mode = vc, vnni_axis = 0, l1_hint = cached, l2_hint = uncached}
+ : !xegpu.tensor_desc<16x8xf16, #xegpu.scattered>, vector<16x8xi1> -> vector<8x8x4xf16>
+ return
+}
diff --git a/mlir/test/Dialect/XeGPU/IR/load_gather_vc.mlir b/mlir/test/Dialect/XeGPU/IR/load_gather_vc.mlir
new file mode 100644
index 00000000000000..a3cb890483e634
--- /dev/null
+++ b/mlir/test/Dialect/XeGPU/IR/load_gather_vc.mlir
@@ -0,0 +1,50 @@
+// RUN: mlir-opt %s | FileCheck %s
+// Verify the printed output can be parsed.
+// RUN: mlir-opt %s | mlir-opt | FileCheck %s
+// Verify the generic form can be parsed.
+// RUN: mlir-opt -mlir-print-op-generic %s | mlir-opt | FileCheck %s
+
+
+// CHECK-LABEL: func @test_load_gather_vc({{.*}}) {
+func.func @test_load_gather_vc(%src: ui64, %offsets : vector<16xindex>) {
+ %0 = arith.constant dense<1>: vector<16xi1>
+ // CHECK: xegpu.create_tdesc %{{arg[0-9]}}, %{{arg[0-9]}} {mode = #xegpu<mode_kind vc>}
+ // CHECK-SAME: ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scattered>
+ %1 = xegpu.create_tdesc %src, %offsets {mode = vc}: ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scattered>
+
+ // CHECK: xegpu.load %{{[0-9]}}, %{{.*}} {l1_hint = #xegpu<cache_kind cached>, l2_hint = #xegpu<cache_kind uncached>, mode = #xegpu<mode_kind vc>}
+ // CHECK-SAME: !xegpu.tensor_desc<16xf32, #xegpu.scattered>, vector<16xi1> -> vector<16xf32>
+ %2 = xegpu.load %1, %0 {mode = vc, l1_hint = cached, l2_hint = uncached}
+ : !xegpu.tensor_desc<16xf32, #xegpu.scattered>, vector<16xi1> -> vector<16xf32>
+ return
+}
+
+// CHECK-LABEL: func @test_load_gather_vc_2({{.*}}) {
+func.func @test_load_gather_vc_2(%src: ui64, %offsets : vector<16xindex>) {
+ %0 = arith.constant dense<1>: vector<16x8xi1>
+ // CHECK: xegpu.create_tdesc %{{arg[0-9]}}, %{{arg[0-9]}} {chunk_size_per_lane = 8 : i64, mode = #xegpu<mode_kind vc>}
+ // CHECK-SAME: ui64, vector<16xindex> -> !xegpu.tensor_desc<16x8xf32, #xegpu.scattered>
+ %1 = xegpu.create_tdesc %src, %offsets {mode = vc, chunk_size_per_lane = 8}
+ : ui64, vector<16xindex> -> !xegpu.tensor_desc<16x8xf32, #xegpu.scattered>
+
+ // CHECK: xegpu.load %{{[0-9]}}, %{{.*}} {l1_hint = #xegpu<cache_kind cached>, l2_hint = #xegpu<cache_kind uncached>, mode = #xegpu<mode_kind vc>, transpose = array<i64: 1, 0>}
+ // CHECK-SAME: !xegpu.tensor_desc<16x8xf32, #xegpu.scattered>, vector<16x8xi1> -> vector<8x16xf32>
+ %2 = xegpu.load %1, %0 {mode = vc, transpose = [1, 0], l1_hint = cached, l2_hint = uncached}
+ : !xegpu.tensor_desc<16x8xf32, #xegpu.scattered>, vector<16x8xi1> -> vector<8x16xf32>
+ return
+}
+
+// CHECK-LABEL: func @test_load_gather_vc_3({{.*}}) {
+func.func @test_load_gather_vc_3(%src: ui64, %offsets : vector<16xindex>) {
+ %0 = arith.constant dense<1>: vector<16xi1>
+ // CHECK: xegpu.create_tdesc %{{arg[0-9]}}, %{{arg[0-9]}} {mode = #xegpu<mode_kind vc>}
+ // CHECK-SAME: ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scattered>
+ %1 = xegpu.create_tdesc %src, %offsets {mode = vc, chunk_size_per_lane = 1}
+ : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scattered>
+
+ // CHECK: xegpu.load %{{[0-9]}}, %{{.*}} {l1_hint = #xegpu<cache_kind cached>, l2_hint = #xegpu<cache_kind uncached>, mode = #xegpu<mode_kind vc>}
+ // CHECK-SAME: !xegpu.tensor_desc<16xf32, #xegpu.scattered>, vector<16xi1> -> vector<16xf32>
+ %2 = xegpu.load %1, %0 {mode = vc, l1_hint = cached, l2_hint = uncached}
+ : !xegpu.tensor_desc<16xf32, #xegpu.scattered>, vector<16xi1> -> vector<16xf32>
+ return
+}
diff --git a/mlir/test/Dialect/XeGPU/IR/load_nd.mlir b/mlir/test/Dialect/XeGPU/IR/load_nd.mlir
new file mode 100644
index 00000000000000..0644565c3f002e
--- /dev/null
+++ b/mlir/test/Dialect/XeGPU/IR/load_nd.mlir
@@ -0,0 +1,164 @@
+// RUN: mlir-opt %s | FileCheck %s
+// Verify the printed output can be parsed.
+// RUN: mlir-opt %s | mlir-opt | FileCheck %s
+// Verify the generic form can be parsed.
+// RUN: mlir-opt -mlir-print-op-generic %s | mlir-opt | FileCheck %s
+
+#sg_map_fp16_a = #xegpu.sg_map<wi_layout = [2, 8], wi_data = [1, 2]>
+#sg_map_fp16_b = #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>
+#sg_map_fp16_c = #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>
+#sg_map_fp16_d = #xegpu.sg_map<wi_layout = [2, 8], wi_data = [1, 2]>
+// CHECK-LABEL: func @test_load_nd_fp16({{.*}}) {
+func.func @test_load_nd_fp16(%A: memref<24x32xf16>, %B : memref<24x32xf16>, %C : memref<24x32xf16>) {
+ %c0 = arith.constant 2 : index
+ %c1 = arith.constant 4 : index
+
+ // CHECK: xegpu.create_nd_tdesc %{{arg[0-9]}}[%{{c[0-9]+}}, %{{c[0-9]+}}] : memref<24x32xf16>
+ // CHECK-SAME: -> !xegpu.tensor_desc<8x16xf16, #xegpu.sg_map<wi_layout = [2, 8], wi_data = [1, 2]>>
+ %1 = xegpu.create_nd_tdesc %A[%c0, %c1]
+ : memref<24x32xf16> -> !xegpu.tensor_desc<8x16xf16, #sg_map_fp16_a>
+
+ // CHECK: xegpu.load_nd %{{[0-9]}} {vnni_axis = 1 : i64}
+ // CHECK-SAME: !xegpu.tensor_desc<8x16xf16, #xegpu.sg_map<wi_layout = [2, 8], wi_data = [1, 2]>> -> vector<4x1x2xf16>
+ %2 = xegpu.load_nd %1 {vnni_axis = 1} : !xegpu.tensor_desc<8x16xf16, #sg_map_fp16_a> -> vector<4x1x2xf16>
+
+ // CHECK: xegpu.create_nd_tdesc %{{arg[0-9]}}[%{{c[0-9]+}}, %{{c[0-9]+}}] : memref<24x32xf16>
+ // CHECK-SAME: -> !xegpu.tensor_desc<16x16xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ %3 = xegpu.create_nd_tdesc %B[%c0, %c1]
+ : memref<24x32xf16> -> !xegpu.tensor_desc<16x16xf16, #sg_map_fp16_b>
+
+ // CHECK: xegpu.load_nd %{{[0-9]}} {vnni_axis = 0 : i64}
+ // CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>> -> vector<8x1x2xf16>
+ %4 = xegpu.load_nd %3 {vnni_axis = 0} : !xegpu.tensor_desc<16x16xf16, #sg_map_fp16_b> -> vector<8x1x2xf16>
+
+ // CHECK: xegpu.create_nd_tdesc %{{arg[0-9]}}[%{{c[0-9]+}}, %{{c[0-9]+}}] : memref<24x32xf16>
+ // CHECK-SAME: -> !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ %5 = xegpu.create_nd_tdesc %C[%c0, %c1]
+ : memref<24x32xf16> -> !xegpu.tensor_desc<8x16xf32, #sg_map_fp16_c>
+
+ // CHECK: xegpu.load_nd %{{[0-9]}} : !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>> -> vector<8x1xf32>
+ %6 = xegpu.load_nd %5 : !xegpu.tensor_desc<8x16xf32, #sg_map_fp16_c> -> vector<8x1xf32>
+
+ // CHECK: xegpu.create_nd_tdesc %{{arg[0-9]}}[%{{c[0-9]+}}, %{{c[0-9]+}}] : memref<24x32xf16>
+ // CHECK-SAME: -> !xegpu.tensor_desc<8x16xf16, #xegpu.sg_map<wi_layout = [2, 8], wi_data = [1, 2]>>
+ %7 = xegpu.create_nd_tdesc %A[%c0, %c1]
+ : memref<24x32xf16> -> !xegpu.tensor_desc<8x16xf16, #sg_map_fp16_d>
+ // CHECK: xegpu.load_nd %{{[0-9]}} {vnni_axis = 1 : i64}
+ // CHECK-SAME: !xegpu.tensor_desc<8x16xf16, #xegpu.sg_map<wi_layout = [2, 8], wi_data = [1, 2]>> -> vector<4x1x2xf16>
+ %8 = xegpu.load_nd %7 {vnni_axis = 1} : !xegpu.tensor_desc<8x16xf16, #sg_map_fp16_d> -> vector<4x1x2xf16>
+
+ return
+}
+
+#sg_map_bf16_a = #xegpu.sg_map<wi_layout = [2, 8], wi_data = [1, 2]>
+#sg_map_bf16_b = #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>
+#sg_map_bf16_c = #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>
+// CHECK-LABEL: func @test_load_nd_bf16({{.*}}) {
+func.func @test_load_nd_bf16(%A: memref<24x32xbf16>, %B : memref<24x32xbf16>, %C : memref<24x32xbf16>) {
+ %c0 = arith.constant 2 : index
+ %c1 = arith.constant 4 : index
+
+ // CHECK: xegpu.create_nd_tdesc %{{arg[0-9]}}[%{{c[0-9]+}}, %{{c[0-9]+}}] : memref<24x32xbf16>
+ // CHECK-SAME: -> !xegpu.tensor_desc<8x16xbf16, #xegpu.sg_map<wi_layout = [2, 8], wi_data = [1, 2]>>
+ %1 = xegpu.create_nd_tdesc %A[%c0, %c1] : memref<24x32xbf16> -> !xegpu.tensor_desc<8x16xbf16, #sg_map_bf16_a>
+
+ // CHECK: xegpu.load_nd %{{[0-9]}} {vnni_axis = 1 : i64}
+ // CHECK-SAME: !xegpu.tensor_desc<8x16xbf16, #xegpu.sg_map<wi_layout = [2, 8], wi_data = [1, 2]>> -> vector<4x1x2xbf16>
+ %2 = xegpu.load_nd %1 {vnni_axis = 1} : !xegpu.tensor_desc<8x16xbf16, #sg_map_bf16_a> -> vector<4x1x2xbf16>
+
+ // CHECK: xegpu.create_nd_tdesc %{{arg[0-9]}}[%{{c[0-9]+}}, %{{c[0-9]+}}] : memref<24x32xbf16>
+ // CHECK-SAME: -> !xegpu.tensor_desc<16x16xbf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ %3 = xegpu.create_nd_tdesc %B[%c0, %c1] : memref<24x32xbf16> -> !xegpu.tensor_desc<16x16xbf16, #sg_map_bf16_b>
+
+ // CHECK: xegpu.load_nd %{{[0-9]}} {vnni_axis = 0 : i64}
+ // CHECK-SAME: !xegpu.tensor_desc<16x16xbf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>> -> vector<8x1x2xbf16>
+ %4 = xegpu.load_nd %3 {vnni_axis = 0} : !xegpu.tensor_desc<16x16xbf16, #sg_map_bf16_b> -> vector<8x1x2xbf16>
+
+ // CHECK: xegpu.create_nd_tdesc %{{arg[0-9]}}[%{{c[0-9]+}}, %{{c[0-9]+}}] : memref<24x32xbf16>
+ // CHECK-SAME: -> !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ %5 = xegpu.create_nd_tdesc %C[%c0, %c1] : memref<24x32xbf16> -> !xegpu.tensor_desc<8x16xf32, #sg_map_fp16_c>
+
+ // CHECK: xegpu.load_nd %{{[0-9]}} : !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>> -> vector<8x1xf32>
+ %6 = xegpu.load_nd %5 : !xegpu.tensor_desc<8x16xf32, #sg_map_bf16_c> -> vector<8x1xf32>
+
+ return
+}
+
+#sg_map_i8_a = #xegpu.sg_map<wi_layout = [2, 8], wi_data = [1, 4]>
+#sg_map_i8_b = #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>
+#sg_map_i8_c = #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>
+// CHECK-LABEL: func @test_load_nd_i8({{.*}}) {
+func.func @test_load_nd_i8(%A: memref<64x64xi8>, %B : memref<64x64xi8>, %C : memref<64x64xi8>) {
+ %c0 = arith.constant 2 : index
+ %c1 = arith.constant 4 : index
+
+ // CHECK: xegpu.create_nd_tdesc %{{arg[0-9]}}[%{{c[0-9]+}}, %{{c[0-9]+}}] : memref<64x64xi8>
+ // CHECK-SAME: -> !xegpu.tensor_desc<8x32xi8, #xegpu.sg_map<wi_layout = [2, 8], wi_data = [1, 4]>>
+ %1 = xegpu.create_nd_tdesc %A[%c0, %c1] : memref<64x64xi8> -> !xegpu.tensor_desc<8x32xi8, #sg_map_i8_a>
+
+ // CHECK: xegpu.load_nd %{{[0-9]}} {vnni_axis = 1 : i64}
+ // CHECK-SAME: !xegpu.tensor_desc<8x32xi8, #xegpu.sg_map<wi_layout = [2, 8], wi_data = [1, 4]>> -> vector<4x1x4xi8>
+ %2 = xegpu.load_nd %1 {vnni_axis = 1} : !xegpu.tensor_desc<8x32xi8, #sg_map_i8_a> -> vector<4x1x4xi8>
+
+ // CHECK: xegpu.create_nd_tdesc %{{arg[0-9]}}[%{{c[0-9]+}}, %{{c[0-9]+}}] : memref<64x64xi8>
+ // CHECK-SAME: -> !xegpu.tensor_desc<32x16xi8, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ %3 = xegpu.create_nd_tdesc %B[%c0, %c1] : memref<64x64xi8> -> !xegpu.tensor_desc<32x16xi8, #sg_map_i8_b>
+
+ // CHECK: xegpu.load_nd %{{[0-9]}} {vnni_axis = 0 : i64}
+ // CHECK-SAME: !xegpu.tensor_desc<32x16xi8, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>> -> vector<8x1x4xi8>
+ %4 = xegpu.load_nd %3 {vnni_axis = 0} : !xegpu.tensor_desc<32x16xi8, #sg_map_i8_b> -> vector<8x1x4xi8>
+
+ // CHECK: xegpu.create_nd_tdesc %{{arg[0-9]}}[%{{c[0-9]+}}, %{{c[0-9]+}}] : memref<64x64xi8>
+ // CHECK-SAME: -> !xegpu.tensor_desc<8x16xi32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ %5 = xegpu.create_nd_tdesc %C[%c0, %c1] : memref<64x64xi8> -> !xegpu.tensor_desc<8x16xi32, #sg_map_i8_c>
+
+ // CHECK: xegpu.load_nd %{{[0-9]}}
+ // CHECK-SAME: !xegpu.tensor_desc<8x16xi32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>> -> vector<8x1xi32>
+ %6 = xegpu.load_nd %5 : !xegpu.tensor_desc<8x16xi32, #sg_map_i8_c> -> vector<8x1xi32>
+
+ return
+}
+
+#sg_map_f64_a = #xegpu.sg_map<wi_layout = [2, 8], wi_data = [1, 1]>
+#sg_map_f64_b = #xegpu.sg_map<wi_layout = [2, 8], wi_data = [1, 1]>
+#sg_map_f64_c = #xegpu.sg_map<wi_layout = [2, 8], wi_data = [1, 1]>
+// CHECK-LABEL: func @test_load_nd_f64({{.*}}) {
+func.func @test_load_nd_f64(%A: memref<64x64xf64>, %B : memref<64x64xf64>, %C : memref<64x64xf64>) {
+ %c0 = arith.constant 2 : index
+ %c1 = arith.constant 4 : index
+
+ // CHECK: xegpu.create_nd_tdesc
+ // CHECK-SAME: memref<64x64xf64>
+ // CHECK-SAME: -> !xegpu.tensor_desc<4x8xf64, #xegpu.sg_map<wi_layout = [2, 8], wi_data = [1, 1]>>
+ %1 = xegpu.create_nd_tdesc %A[%c0, %c1]
+ : memref<64x64xf64> -> !xegpu.tensor_desc<4x8xf64, #sg_map_f64_a>
+
+ // CHECK: xegpu.load_nd
+ // CHECK-SAME: !xegpu.tensor_desc<4x8xf64, #xegpu.sg_map<wi_layout = [2, 8], wi_data = [1, 1]>>
+ // CHECK-SAME: -> vector<2x1xf64>
+ %2 = xegpu.load_nd %1 : !xegpu.tensor_desc<4x8xf64, #sg_map_f64_a> -> vector<2x1xf64>
+
+ // CHECK: xegpu.create_nd_tdesc
+ // CHECK-SAME: memref<64x64xf64>
+ // CHECK-SAME: -> !xegpu.tensor_desc<8x8xf64, #xegpu.sg_map<wi_layout = [2, 8], wi_data = [1, 1]>>
+ %3 = xegpu.create_nd_tdesc %B[%c0, %c1]
+ : memref<64x64xf64> -> !xegpu.tensor_desc<8x8xf64, #sg_map_f64_b>
+
+ // CHECK: xegpu.load_nd
+ // CHECK-SAME: !xegpu.tensor_desc<8x8xf64, #xegpu.sg_map<wi_layout = [2, 8], wi_data = [1, 1]>>
+ // CHECK-SAME: -> vector<4x1xf64>
+ %4 = xegpu.load_nd %3 : !xegpu.tensor_desc<8x8xf64, #sg_map_f64_b> -> vector<4x1xf64>
+
+ // CHECK: xegpu.create_nd_tdesc
+ // CHECK-SAME: memref<64x64xf64>
+ // CHECK-SAME: -> !xegpu.tensor_desc<4x8xf64, #xegpu.sg_map<wi_layout = [2, 8], wi_data = [1, 1]>>
+ %5 = xegpu.create_nd_tdesc %C[%c0, %c1]
+ : memref<64x64xf64> -> !xegpu.tensor_desc<4x8xf64, #sg_map_f64_c>
+
+ // CHECK: xegpu.load_nd
+ // CHECK-SAME: !xegpu.tensor_desc<4x8xf64, #xegpu.sg_map<wi_layout = [2, 8], wi_data = [1, 1]>>
+ // CHECK-SAME: -> vector<2x1xf64>
+ %6 = xegpu.load_nd %5 : !xegpu.tensor_desc<4x8xf64, #sg_map_f64_c> -> vector<2x1xf64>
+
+ return
+}
diff --git a/mlir/test/Dialect/XeGPU/IR/load_nd_vc.mlir b/mlir/test/Dialect/XeGPU/IR/load_nd_vc.mlir
new file mode 100644
index 00000000000000..78980b551c0677
--- /dev/null
+++ b/mlir/test/Dialect/XeGPU/IR/load_nd_vc.mlir
@@ -0,0 +1,69 @@
+// RUN: mlir-opt %s | FileCheck %s
+// Verify the printed output can be parsed.
+// RUN: mlir-opt %s | mlir-opt | FileCheck %s
+// Verify the generic form can be parsed.
+// RUN: mlir-opt -mlir-print-op-generic %s | mlir-opt | FileCheck %s
+
+// -- SIMD ---
+// CHECK-LABEL: func @test_load_nd_simd_f32({{.*}}) {
+func.func @test_load_nd_simd_f32(%src: memref<24x32xf32>) {
+ %c0 = arith.constant 2 : index
+ %c1 = arith.constant 4 : index
+
+ // CHECK: xegpu.create_nd_tdesc %{{arg[0-9]}}[%{{c[0-9]+}}, %{{c[0-9]+}}]
+ // CHECK-SAME: {mode = #xegpu<mode_kind vc>} : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
+ %1 = xegpu.create_nd_tdesc %src[%c0, %c1] {mode = vc}
+ : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
+
+ // CHECK: xegpu.load_nd %{{[0-9]}}
+ // CHECK-SAME: {mode = #xegpu<mode_kind vc>} : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32>
+ %2 = xegpu.load_nd %1 {mode = vc} : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32>
+
+ // CHECK: xegpu.load_nd %{{[0-9]}}
+ // CHECK-SAME: {l1_hint = #xegpu<cache_kind cached>, l2_hint = #xegpu<cache_kind uncached>, l3_hint = #xegpu<cache_kind streaming>, mode = #xegpu<mode_kind vc>, transpose = array<i64: 1, 0>}
+ // CHECK-SAME: !xegpu.tensor_desc<8x16xf32> -> vector<16x8xf32>
+ %3 = xegpu.load_nd %1 {mode= vc, transpose = [1, 0], l1_hint = cached, l2_hint = uncached, l3_hint=streaming} : !xegpu.tensor_desc<8x16xf32> -> vector<16x8xf32>
+ return
+}
+
+// CHECK-LABEL: func @test_load_nd_simd_f16({{.*}}) {
+func.func @test_load_nd_simd_f16(%src: memref<24x32xf16>, %x : index, %y : index) {
+ // CHECK: xegpu.create_nd_tdesc %{{arg[0-9]}}[%{{arg[0-9]}}, %{{arg[0-9]}}]
+ // CHECK-SAME: {mode = #xegpu<mode_kind vc>} : memref<24x32xf16> -> !xegpu.tensor_desc<8x16xf16>
+ %1 = xegpu.create_nd_tdesc %src[%x, %y] {mode = vc} : memref<24x32xf16> -> !xegpu.tensor_desc<8x16xf16>
+
+ // CHECK: xegpu.load_nd %{{[0-9]+}}
+ // CHECK-SAME: {l1_hint = #xegpu<cache_kind cached>, l2_hint = #xegpu<cache_kind uncached>, mode = #xegpu<mode_kind vc>, vnni_axis = 0 : i64}
+ // CHECK-SAME: !xegpu.tensor_desc<8x16xf16> -> vector<4x16x2xf16>
+ %2 = xegpu.load_nd %1 {mode = vc, vnni_axis = 0, l1_hint = cached, l2_hint = uncached} : !xegpu.tensor_desc<8x16xf16> -> vector<4x16x2xf16>
+ return
+}
+
+// CHECK-LABEL: func @test_load_nd_simd_bf16({{.*}}) {
+func.func @test_load_nd_simd_bf16(%src: ui64, %w : index, %h : index, %x : index, %y : index) {
+ %c1 = arith.constant 1 : index
+ // CHECK: xegpu.create_nd_tdesc %{{arg[0-9]}}[%{{arg[0-9]}}, %{{arg[0-9]}}], [%{{arg[0-9]}}, %{{arg[0-9]}}], [%{{arg[0-9]}}, %{{c[0-9]}}]
+ // CHECK-SAME: {mode = #xegpu<mode_kind vc>} : ui64 -> !xegpu.tensor_desc<8x16xbf16>
+ %1 = xegpu.create_nd_tdesc %src[%x, %y], [%h, %w], [%w, %c1] {mode = vc} : ui64 -> !xegpu.tensor_desc<8x16xbf16>
+ // CHECK: xegpu.load_nd %{{[0-9]}}
+ // CHECK-SAME: {l1_hint = #xegpu<cache_kind cached>, l2_hint = #xegpu<cache_kind uncached>, mode = #xegpu<mode_kind vc>, vnni_axis = 1 : i64}
+ // CHECK-SAME: !xegpu.tensor_desc<8x16xbf16> -> vector<8x8x2xbf16>
+ %2 = xegpu.load_nd %1 {mode=vc, vnni_axis = 1, l1_hint = cached, l2_hint = uncached} : !xegpu.tensor_desc<8x16xbf16> -> vector<8x8x2xbf16>
+
+ return
+}
+
+// CHECK-LABEL: func @test_load_nd_block_array_simd_f16({{.*}}) {
+func.func @test_load_nd_block_array_simd_f16(%src: memref<8x32xf16>) {
+ // CHECK: xegpu.create_nd_tdesc %{{arg[0-9]}}[0, 0] {mode = #xegpu<mode_kind vc>}
+ // CHECK-SAME: memref<8x32xf16> -> !xegpu.tensor_desc<8x16xf16, #xegpu.tdesc_attr<array_length = 2>>
+ %1 = xegpu.create_nd_tdesc %src[0, 0] {mode = vc}
+ : memref<8x32xf16> -> !xegpu.tensor_desc<8x16xf16, #xegpu.tdesc_attr<array_length = 2>>
+
+ // CHECK: xegpu.load_nd %{{[0-9]}}
+ // CHECK-SAME: {l1_hint = #xegpu<cache_kind cached>, l2_hint = #xegpu<cache_kind uncached>, mode = #xegpu<mode_kind vc>}
+ // CHECK-SAME: !xegpu.tensor_desc<8x16xf16, #xegpu.tdesc_attr<array_length = 2>> -> vector<2x8x16xf16>
+ %2 = xegpu.load_nd %1 {mode = vc, l1_hint = cached, l2_hint = uncached}
+ : !xegpu.tensor_desc<8x16xf16, #xegpu.tdesc_attr<array_length = 2>> -> vector<2x8x16xf16>
+ return
+}
diff --git a/mlir/test/Dialect/XeGPU/IR/prefetch_nd_vc.mlir b/mlir/test/Dialect/XeGPU/IR/prefetch_nd_vc.mlir
new file mode 100644
index 00000000000000..6e2cb4de4ce1d4
--- /dev/null
+++ b/mlir/test/Dialect/XeGPU/IR/prefetch_nd_vc.mlir
@@ -0,0 +1,62 @@
+// RUN: mlir-opt %s | FileCheck %s
+// Verify the printed output can be parsed.
+// RUN: mlir-opt %s | mlir-opt | FileCheck %s
+// Verify the generic form can be parsed.
+// RUN: mlir-opt -mlir-print-op-generic %s | mlir-opt | FileCheck %s
+// CHECK-LABEL: func @test_prefetch_nd_tdesc_vc_0({{.*}}) {
+func.func @test_prefetch_nd_tdesc_vc_0(%src: memref<24x32xf32>) {
+ %c0 = arith.constant 2 : index
+ %c1 = arith.constant 4 : index
+
+ // CHECK: xegpu.create_nd_tdesc %{{arg[0-9]}}[%{{c[0-9]}}, %{{c[0-9]}}] {mode = #xegpu<mode_kind vc>}
+ // CHECK-SAME: memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
+ %1 = xegpu.create_nd_tdesc %src[%c0, %c1] {mode = vc} : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
+
+ // CHECK: xegpu.prefetch_nd %{{[0-9]}} {mode = #xegpu<mode_kind vc>} : !xegpu.tensor_desc<8x16xf32>
+ xegpu.prefetch_nd %1 {mode = vc} : !xegpu.tensor_desc<8x16xf32>
+
+ return
+}
+
+// CHECK-LABEL: func @test_prefetch_nd_tdesc_vc_1({{.*}}) {
+func.func @test_prefetch_nd_tdesc_vc_1(%src: memref<24x32xf16>, %x : index, %y : index) {
+ // CHECK: xegpu.create_nd_tdesc %{{arg[0-9]}}[%{{arg[0-9]}}, %{{arg[0-9]}}]
+ // CHECK-SAME: {mode = #xegpu<mode_kind vc>}
+ // CHECK-SAME: memref<24x32xf16> -> !xegpu.tensor_desc<8x16xf16>
+ %1 = xegpu.create_nd_tdesc %src[%x, %y] {mode = vc} : memref<24x32xf16> -> !xegpu.tensor_desc<8x16xf16>
+
+ // CHECK: xegpu.prefetch_nd %{{[0-9]}}
+ // CHECK-SAME: {l1_hint = #xegpu<cache_kind cached>, l2_hint = #xegpu<cache_kind uncached>, mode = #xegpu<mode_kind vc>}
+ // CHECK-SAME: !xegpu.tensor_desc<8x16xf16>
+ xegpu.prefetch_nd %1 {mode = vc, l1_hint = cached, l2_hint = uncached}: !xegpu.tensor_desc<8x16xf16>
+ return
+}
+
+
+// CHECK-LABEL: func @test_prefetch_nd_tdesc_vc_i8({{.*}}) {
+func.func @test_prefetch_nd_tdesc_vc_i8(%src: memref<24x32xi8>) {
+ %c0 = arith.constant 2 : index
+ %c1 = arith.constant 4 : index
+
+ // CHECK: xegpu.create_nd_tdesc %{{arg[0-9]}}[%{{c[0-9]}}, %{{c[0-9]}}] {mode = #xegpu<mode_kind vc>}
+ // CHECK-SAME: memref<24x32xi8> -> !xegpu.tensor_desc<8x16xi8>
+ %1 = xegpu.create_nd_tdesc %src[%c0, %c1] {mode = vc} : memref<24x32xi8> -> !xegpu.tensor_desc<8x16xi8>
+
+ // CHECK: xegpu.prefetch_nd %{{[0-9]}} {mode = #xegpu<mode_kind vc>} : !xegpu.tensor_desc<8x16xi8>
+ xegpu.prefetch_nd %1 {mode = vc} : !xegpu.tensor_desc<8x16xi8>
+
+ return
+}
+
+// CHECK-LABEL: func @test_prefetch_nd_tdesc_vc_bf16({{.*}}) {
+func.func @test_prefetch_nd_tdesc_vc_bf16(%src: memref<24x32xbf16>, %x : index, %y : index) {
+ // CHECK: xegpu.create_nd_tdesc %{{arg[0-9]}}[%{{arg[0-9]}}, %{{arg[0-9]}}]
+ // CHECK-SAME: {mode = #xegpu<mode_kind vc>} : memref<24x32xbf16> -> !xegpu.tensor_desc<8x16xbf16>
+ %1 = xegpu.create_nd_tdesc %src[%x, %y] {mode = vc}
+ : memref<24x32xbf16> -> !xegpu.tensor_desc<8x16xbf16>
+ // CHECK: xegpu.prefetch_nd %{{[0-9]}}
+ // CHECK-SAME: {l1_hint = #xegpu<cache_kind uncached>, l2_hint = #xegpu<cache_kind cached>, mode = #xegpu<mode_kind vc>}
+ // CHECK-SAME: !xegpu.tensor_desc<8x16xbf16>
+ xegpu.prefetch_nd %1 {mode = vc, l1_hint = uncached, l2_hint = cached}: !xegpu.tensor_desc<8x16xbf16>
+ return
+}
diff --git a/mlir/test/Dialect/XeGPU/IR/simple_gemm.mlir b/mlir/test/Dialect/XeGPU/IR/simple_gemm.mlir
new file mode 100644
index 00000000000000..ff6f31c77064af
--- /dev/null
+++ b/mlir/test/Dialect/XeGPU/IR/simple_gemm.mlir
@@ -0,0 +1,71 @@
+// RUN: mlir-opt %s | FileCheck %s
+// Verify the printed output can be parsed.
+// RUN: mlir-opt %s | mlir-opt | FileCheck %s
+// Verify the generic form can be parsed.
+// RUN: mlir-opt -mlir-print-op-generic %s | mlir-opt | FileCheck %s
+
+// ---- BF16 ------
+
+#sg_map_fp16_a = #xegpu.sg_map<wi_layout = [2, 8], wi_data = [1, 2]>
+#sg_map_fp16_b = #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>
+#sg_map_fp16_c = #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>
+// CHECK-LABEL: func @test_gemm_bf16({{.*}}) {
+func.func @test_gemm_bf16(%a : memref<1024x1024xbf16>, %b: memref<1024x1024xbf16>, %c: memref<1024x1024xf32>) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c8 = arith.constant 8 : index
+ %c16 = arith.constant 16 : index
+ %c1024 = arith.constant 1024 : index
+
+ %c0_1 = arith.constant 0 : i32
+ %c1_1 = arith.constant 1 : i32
+
+
+ scf.for %i= %c0 to %c1024 step %c8 {
+ scf.for %j= %c0 to %c1024 step %c16 {
+ // CHECK: xegpu.create_nd_tdesc
+ // CHECK-SAME: memref<1024x1024xbf16>
+ // CHECK-SAME: -> !xegpu.tensor_desc<8x16xbf16, #xegpu.sg_map<wi_layout = [2, 8], wi_data = [1, 2]>>
+ %1 = xegpu.create_nd_tdesc %a[%i, %c0] : memref<1024x1024xbf16> -> !xegpu.tensor_desc<8x16xbf16, #sg_map_fp16_a>
+
+ // CHECK: xegpu.create_nd_tdesc
+ // CHECK-SAME: memref<1024x1024xbf16>
+ // CHECK-SAME: -> !xegpu.tensor_desc<16x16xbf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ %2 = xegpu.create_nd_tdesc %b[%c0, %j] : memref<1024x1024xbf16> -> !xegpu.tensor_desc<16x16xbf16, #sg_map_fp16_b>
+
+ %3 = arith.constant dense<0.0> : vector<8x1xf32>
+
+ %tmp0, %tmp1, %result = scf.for %k= %c0 to %c1024 step %c16 iter_args(%subA = %1, %subB = %2, %subC = %3)
+ -> (!xegpu.tensor_desc<8x16xbf16, #sg_map_fp16_a>, !xegpu.tensor_desc<16x16xbf16, #sg_map_fp16_b>, vector<8x1xf32>) {
+ // CHECK: xegpu.load_nd
+ // CHECK-SAME: vector<4x1x2xbf16>
+ %4 = xegpu.load_nd %subA {vnni_axis = 1} : !xegpu.tensor_desc<8x16xbf16, #sg_map_fp16_a> -> vector<4x1x2xbf16>
+
+ // CHECK: xegpu.load_nd
+ // CHECK-SAME: vector<8x1x2xbf16>
+ %5 = xegpu.load_nd %subB {vnni_axis = 0} : !xegpu.tensor_desc<16x16xbf16, #sg_map_fp16_b> -> vector<8x1x2xbf16>
+
+ // CHECK: xegpu.dpas
+ // CHECK-SAME: vector<4x1x2xbf16>, vector<8x1x2xbf16>, vector<8x1xf32> -> vector<8x1xf32>
+ %6 = xegpu.dpas %4, %5, %subC : vector<4x1x2xbf16>, vector<8x1x2xbf16>, vector<8x1xf32> -> vector<8x1xf32>
+
+ %7 = xegpu.update_nd_offset %subA, [%c0, %c16] : !xegpu.tensor_desc<8x16xbf16, #sg_map_fp16_a>
+ -> !xegpu.tensor_desc<8x16xbf16, #sg_map_fp16_a>
+
+ %8 = xegpu.update_nd_offset %subB, [%c16, %c0] : !xegpu.tensor_desc<16x16xbf16, #sg_map_fp16_b>
+ -> !xegpu.tensor_desc<16x16xbf16, #sg_map_fp16_b>
+
+ scf.yield %7, %8, %6: !xegpu.tensor_desc<8x16xbf16, #sg_map_fp16_a>, !xegpu.tensor_desc<16x16xbf16, #sg_map_fp16_b>, vector<8x1xf32>
+ }
+
+ // CHECK: xegpu.create_nd_tdesc
+ // CHECK-SAME: memref<1024x1024xf32>
+ %9 = xegpu.create_nd_tdesc %c[%i, %j] : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32, #sg_map_fp16_c>
+
+ // CHECK: xegpu.store_nd
+ // CHECK-SAME: vector<8x1xf32>
+ xegpu.store_nd %result, %9 : vector<8x1xf32>, !xegpu.tensor_desc<8x16xf32, #sg_map_fp16_c>
+ }
+ }
+ return
+}
diff --git a/mlir/test/Dialect/XeGPU/IR/simple_gemm_vc.mlir b/mlir/test/Dialect/XeGPU/IR/simple_gemm_vc.mlir
new file mode 100644
index 00000000000000..794a6b6f1afb9c
--- /dev/null
+++ b/mlir/test/Dialect/XeGPU/IR/simple_gemm_vc.mlir
@@ -0,0 +1,65 @@
+// RUN: mlir-opt %s | FileCheck %s
+// Verify the printed output can be parsed.
+// RUN: mlir-opt %s | mlir-opt | FileCheck %s
+// Verify the generic form can be parsed.
+// RUN: mlir-opt -mlir-print-op-generic %s | mlir-opt | FileCheck %s
+
+// ---- BF16 VC ------
+
+// CHECK-LABEL: func @test_gemm_vc_bf16({{.*}}) {
+func.func @test_gemm_vc_bf16(%a : memref<1024x1024xbf16>, %b: memref<1024x1024xbf16>, %c: memref<1024x1024xf32>) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c8 = arith.constant 8 : index
+ %c16 = arith.constant 16 : index
+ %c1024 = arith.constant 1024 : index
+
+ %c0_1 = arith.constant 0 : i32
+ %c1_1 = arith.constant 1 : i32
+
+
+ scf.for %i= %c0 to %c1024 step %c8 {
+ scf.for %j= %c0 to %c1024 step %c16 {
+ // CHECK: xegpu.create_nd_tdesc
+ // CHECK-SAME: memref<1024x1024xbf16> -> !xegpu.tensor_desc<8x16xbf16>
+ %1 = xegpu.create_nd_tdesc %a[%i, %c0] {mode = vc} : memref<1024x1024xbf16> -> !xegpu.tensor_desc<8x16xbf16>
+
+ // CHECK: xegpu.create_nd_tdesc
+ // CHECK-SAME: memref<1024x1024xbf16> -> !xegpu.tensor_desc<16x16xbf16>
+ %2 = xegpu.create_nd_tdesc %b[%c0, %j] {mode = vc} : memref<1024x1024xbf16> -> !xegpu.tensor_desc<16x16xbf16>
+
+ %3 = arith.constant dense<0.0> : vector<8x16xf32>
+
+ %tmp0, %tmp1, %result = scf.for %k= %c0 to %c1024 step %c16
+ iter_args(%subA = %1, %subB = %2, %subC = %3)
+ -> (!xegpu.tensor_desc<8x16xbf16>, !xegpu.tensor_desc<16x16xbf16>, vector<8x16xf32>) {
+ // CHECK: xegpu.load_nd
+ // CHECK-SAME: !xegpu.tensor_desc<8x16xbf16> -> vector<8x8x2xbf16>
+ %4 = xegpu.load_nd %subA {mode = vc, vnni_axis = 1} : !xegpu.tensor_desc<8x16xbf16> -> vector<8x8x2xbf16>
+
+ // CHECK: xegpu.load_nd
+ // CHECK-SAME: !xegpu.tensor_desc<16x16xbf16> -> vector<8x16x2xbf16>
+ %5 = xegpu.load_nd %subB {mode = vc, vnni_axis = 0} : !xegpu.tensor_desc<16x16xbf16> -> vector<8x16x2xbf16>
+
+ // CHECK: xegpu.dpas
+ // CHECK-SAME: vector<8x8x2xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32>
+ %6 = xegpu.dpas %4, %5, %subC {mode = vc} : vector<8x8x2xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32>
+
+ %7 = xegpu.update_nd_offset %subA, [%c0, %c16] {mode = vc} : !xegpu.tensor_desc<8x16xbf16> -> !xegpu.tensor_desc<8x16xbf16>
+
+ %8 = xegpu.update_nd_offset %subB, [%c16, %c0] {mode = vc} : !xegpu.tensor_desc<16x16xbf16> -> !xegpu.tensor_desc<16x16xbf16>
+
+ scf.yield %7, %8, %6: !xegpu.tensor_desc<8x16xbf16>, !xegpu.tensor_desc<16x16xbf16>, vector<8x16xf32>
+ }
+
+ // CHECK: xegpu.create_nd_tdesc
+ // CHECK-SAME: memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32>
+ %9 = xegpu.create_nd_tdesc %c[%i, %j] {mode = vc} : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32>
+
+ // CHECK: xegpu.store_nd
+ // CHECK-SAME: vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
+ xegpu.store_nd %result, %9 {mode = vc}: vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
+ }
+ }
+ return
+}
diff --git a/mlir/test/Dialect/XeGPU/IR/store_nd_vc.mlir b/mlir/test/Dialect/XeGPU/IR/store_nd_vc.mlir
new file mode 100644
index 00000000000000..170b3a9fe81474
--- /dev/null
+++ b/mlir/test/Dialect/XeGPU/IR/store_nd_vc.mlir
@@ -0,0 +1,83 @@
+// RUN: mlir-opt %s | FileCheck %s
+// Verify the printed output can be parsed.
+// RUN: mlir-opt %s | mlir-opt | FileCheck %s
+// Verify the generic form can be parsed.
+// RUN: mlir-opt -mlir-print-op-generic %s | mlir-opt | FileCheck %s
+
+// CHECK-LABEL: func @test_store_nd_vc_bf16({{.*}}) {
+func.func @test_store_nd_vc_bf16(%src: memref<24x32xbf16>, %dst: memref<24x32xbf16>) {
+ %c0 = arith.constant 2 : index
+ %c1 = arith.constant 4 : index
+
+ // CHECK: xegpu.create_nd_tdesc %{{arg[0-9]}}[%{{c[0-9]}}, %{{c[0-9]}}] {mode = #xegpu<mode_kind vc>}
+ // CHECK-SAME: memref<24x32xbf16> -> !xegpu.tensor_desc<8x16xbf16>
+ %1 = xegpu.create_nd_tdesc %src[%c0, %c1] {mode = vc} : memref<24x32xbf16> -> !xegpu.tensor_desc<8x16xbf16>
+
+ // CHECK: xegpu.create_nd_tdesc %{{arg[0-9]}}[%{{c[0-9]}}, %{{c[0-9]}}] {mode = #xegpu<mode_kind vc>}
+ // CHECK-SAME: memref<24x32xbf16> -> !xegpu.tensor_desc<8x16xbf16>
+ %2 = xegpu.create_nd_tdesc %dst[%c0, %c1] {mode = vc} : memref<24x32xbf16> -> !xegpu.tensor_desc<8x16xbf16>
+
+ // CHECK: xegpu.load_nd %{{[0-9]}}
+ // CHECK-SAME: {l1_hint = #xegpu<cache_kind cached>, l2_hint = #xegpu<cache_kind uncached>, mode = #xegpu<mode_kind vc>}
+ // CHECK-SAME: !xegpu.tensor_desc<8x16xbf16> -> vector<8x16xbf16>
+ %3 = xegpu.load_nd %1 {mode = vc, l1_hint = cached, l2_hint = uncached}: !xegpu.tensor_desc<8x16xbf16> -> vector<8x16xbf16>
+
+ // CHECK: xegpu.store_nd %{{[0-9]}}, %{{[0-9]}}
+ // CHECK-SAME: {l1_hint = #xegpu<cache_kind write_back>, l2_hint = #xegpu<cache_kind uncached>, mode = #xegpu<mode_kind vc>}
+ // CHECK-SAME: vector<8x16xbf16>, !xegpu.tensor_desc<8x16xbf16>
+ xegpu.store_nd %3, %2 {mode = vc, l1_hint = write_back, l2_hint = uncached}: vector<8x16xbf16>, !xegpu.tensor_desc<8x16xbf16>
+ return
+}
+
+// CHECK-LABEL: func @test_store_nd_vc_f64({{.*}}) {
+func.func @test_store_nd_vc_f64(%src: memref<24x32xf64>, %dst: memref<24x32xf64>) {
+ %c0 = arith.constant 2 : index
+ %c1 = arith.constant 4 : index
+
+ // CHECK: xegpu.create_nd_tdesc %{{arg[0-9]}}[%{{c[0-9]}}, %{{c[0-9]}}] {mode = #xegpu<mode_kind vc>}
+ // CHECK-SAME: memref<24x32xf64> -> !xegpu.tensor_desc<8x16xf64>
+ %1 = xegpu.create_nd_tdesc %src[%c0, %c1] {mode = vc} : memref<24x32xf64> -> !xegpu.tensor_desc<8x16xf64>
+
+ // CHECK: xegpu.create_nd_tdesc %{{arg[0-9]}}[%{{c[0-9]}}, %{{c[0-9]}}] {mode = #xegpu<mode_kind vc>}
+ // CHECK-SAME: memref<24x32xf64> -> !xegpu.tensor_desc<8x16xf64>
+ %2 = xegpu.create_nd_tdesc %dst[%c0, %c1] {mode = vc}
+ : memref<24x32xf64> -> !xegpu.tensor_desc<8x16xf64>
+
+ // CHECK: xegpu.load_nd %{{[0-9]}}
+ // CHECK-SAME: {l1_hint = #xegpu<cache_kind cached>, l2_hint = #xegpu<cache_kind uncached>, mode = #xegpu<mode_kind vc>}
+ // CHECK-SAME: !xegpu.tensor_desc<8x16xf64> -> vector<8x16xf64>
+ %3 = xegpu.load_nd %1 {mode = vc, l1_hint = cached, l2_hint = uncached}: !xegpu.tensor_desc<8x16xf64> -> vector<8x16xf64>
+
+ // CHECK: xegpu.store_nd %{{[0-9]}}, %{{[0-9]}}
+ // CHECK-SAME: {l1_hint = #xegpu<cache_kind write_back>, l2_hint = #xegpu<cache_kind uncached>, mode = #xegpu<mode_kind vc>}
+ // CHECK-SAME: vector<8x16xf64>, !xegpu.tensor_desc<8x16xf64>
+ xegpu.store_nd %3, %2 {mode = vc, l1_hint = write_back, l2_hint = uncached}: vector<8x16xf64>, !xegpu.tensor_desc<8x16xf64>
+ return
+}
+
+// CHECK-LABEL: func @test_store_nd_vc_i8({{.*}}) {
+func.func @test_store_nd_vc_i8(%src: memref<24x32xi8>, %dst: memref<24x32xi8>) {
+ %c0 = arith.constant 2 : index
+ %c1 = arith.constant 4 : index
+
+ // CHECK: xegpu.create_nd_tdesc %{{arg[0-9]}}[%{{c[0-9]}}, %{{c[0-9]}}] {mode = #xegpu<mode_kind vc>}
+ // CHECK-SAME: memref<24x32xi8> -> !xegpu.tensor_desc<8x16xi8>
+ %1 = xegpu.create_nd_tdesc %src[%c0, %c1] {mode = vc}
+ : memref<24x32xi8> -> !xegpu.tensor_desc<8x16xi8>
+
+ // CHECK: xegpu.create_nd_tdesc %{{arg[0-9]}}[%{{c[0-9]}}, %{{c[0-9]}}] {mode = #xegpu<mode_kind vc>}
+ // CHECK-SAME: memref<24x32xi8> -> !xegpu.tensor_desc<8x16xi8>
+ %2 = xegpu.create_nd_tdesc %dst[%c0, %c1] {mode = vc}
+ : memref<24x32xi8> -> !xegpu.tensor_desc<8x16xi8>
+
+ // CHECK: xegpu.load_nd %{{[0-9]}}
+ // CHECK-SAME: {l1_hint = #xegpu<cache_kind cached>, l2_hint = #xegpu<cache_kind uncached>, mode = #xegpu<mode_kind vc>}
+ // CHECK-SAME: !xegpu.tensor_desc<8x16xi8> -> vector<8x16xi8>
+ %3 = xegpu.load_nd %1 {mode = vc, l1_hint = cached, l2_hint = uncached}: !xegpu.tensor_desc<8x16xi8> -> vector<8x16xi8>
+
+ // CHECK: xegpu.store_nd %{{[0-9]}}, %{{[0-9]}}
+ // CHECK-SAME: {l1_hint = #xegpu<cache_kind write_back>, l2_hint = #xegpu<cache_kind uncached>, mode = #xegpu<mode_kind vc>}
+ // CHECK-SAME: vector<8x16xi8>, !xegpu.tensor_desc<8x16xi8>
+ xegpu.store_nd %3, %2 {mode = vc, l1_hint = write_back, l2_hint = uncached}: vector<8x16xi8>, !xegpu.tensor_desc<8x16xi8>
+ return
+}
diff --git a/mlir/test/Dialect/XeGPU/IR/store_scatter.mlir b/mlir/test/Dialect/XeGPU/IR/store_scatter.mlir
new file mode 100644
index 00000000000000..6d98ac3950c31f
--- /dev/null
+++ b/mlir/test/Dialect/XeGPU/IR/store_scatter.mlir
@@ -0,0 +1,29 @@
+// RUN: mlir-opt %s | FileCheck %s
+// Verify the printed output can be parsed.
+// RUN: mlir-opt %s | mlir-opt | FileCheck %s
+// Verify the generic form can be parsed.
+// RUN: mlir-opt -mlir-print-op-generic %s | mlir-opt | FileCheck %s
+
+// CHECK-LABEL: func @test_store_scatter({{.*}}) {
+func.func @test_store_scatter(%src: ui64, %offsets : vector<16xindex>, %dst: ui64) {
+ %0 = arith.constant dense<true>: vector<16xi1>
+ // CHECK: xegpu.create_tdesc %{{arg[0-9]}}, %{{arg[0-9]}} {mode = #xegpu<mode_kind vc>}
+ // CHECK-SAME: ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scattered>
+ %1 = xegpu.create_tdesc %src, %offsets {mode = vc}
+ : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scattered>
+
+ // CHECK: xegpu.create_tdesc %{{arg[0-9]}}, %{{arg[0-9]}} {mode = #xegpu<mode_kind vc>}
+ // CHECK-SAME: ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scattered>
+ %2 = xegpu.create_tdesc %dst, %offsets {mode = vc}
+ : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scattered>
+
+ // CHECK: xegpu.load %{{[0-9]}}, %{{.*}} {l1_hint = #xegpu<cache_kind cached>, l2_hint = #xegpu<cache_kind uncached>, mode = #xegpu<mode_kind vc>}
+ // CHECK-SAME: !xegpu.tensor_desc<16xf32, #xegpu.scattered>, vector<16xi1> -> vector<16xf32>
+ %3 = xegpu.load %1, %0 {mode = vc, l1_hint = cached, l2_hint = uncached}
+ : !xegpu.tensor_desc<16xf32, #xegpu.scattered>, vector<16xi1> -> vector<16xf32>
+ // CHECK: xegpu.store %{{[0-9]}}, %{{[0-9]}}, %{{.*}} {l1_hint = #xegpu<cache_kind write_back>, l2_hint = #xegpu<cache_kind uncached>, mode = #xegpu<mode_kind vc>}
+ // CHECK-SAME: vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scattered>, vector<16xi1>
+ xegpu.store %3, %2, %0 {mode = vc, l1_hint = write_back, l2_hint = uncached}
+ : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scattered>, vector<16xi1>
+ return
+}
diff --git a/mlir/test/Dialect/XeGPU/IR/store_scatter_vc.mlir b/mlir/test/Dialect/XeGPU/IR/store_scatter_vc.mlir
new file mode 100644
index 00000000000000..c1a51712e70037
--- /dev/null
+++ b/mlir/test/Dialect/XeGPU/IR/store_scatter_vc.mlir
@@ -0,0 +1,29 @@
+// RUN: mlir-opt %s | FileCheck %s
+// Verify the printed output can be parsed.
+// RUN: mlir-opt %s | mlir-opt | FileCheck %s
+// Verify the generic form can be parsed.
+// RUN: mlir-opt -mlir-print-op-generic %s | mlir-opt | FileCheck %s
+
+// CHECK-LABEL: func @test_store_scatter_vc({{.*}}) {
+func.func @test_store_scatter_vc(%src: ui64, %offsets : vector<16 x index>, %dst: ui64) {
+ %0 = arith.constant dense<1>: vector<16xi1>
+ // CHECK: xegpu.create_tdesc %{{arg[0-9]}}, %{{arg[0-9]}} {mode = #xegpu<mode_kind vc>}
+ // CHECK-SAME: ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scattered>
+ %1 = xegpu.create_tdesc %src, %offsets {mode = vc}
+ : ui64, vector<16 x index> -> !xegpu.tensor_desc<16xf32, #xegpu.scattered>
+
+ // CHECK: xegpu.create_tdesc %{{arg[0-9]}}, %{{arg[0-9]}} {mode = #xegpu<mode_kind vc>}
+ // CHECK-SAME: ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scattered>
+ %2 = xegpu.create_tdesc %dst, %offsets {mode = vc}
+ : ui64, vector<16 x index> -> !xegpu.tensor_desc<16xf32, #xegpu.scattered>
+
+ // CHECK: xegpu.load %{{[0-9]}}, %{{.*}} {l1_hint = #xegpu<cache_kind cached>, l2_hint = #xegpu<cache_kind uncached>, mode = #xegpu<mode_kind vc>}
+ // CHECK-SAME: !xegpu.tensor_desc<16xf32, #xegpu.scattered>, vector<16xi1> -> vector<16xf32>
+ %3 = xegpu.load %1, %0 {mode = vc, l1_hint = cached, l2_hint = uncached}
+ : !xegpu.tensor_desc<16xf32, #xegpu.scattered>, vector<16xi1> -> vector<16xf32>
+ // CHECK: xegpu.store %{{[0-9]}}, %{{[0-9]}}, %{{.*}} {l1_hint = #xegpu<cache_kind write_back>, l2_hint = #xegpu<cache_kind uncached>, mode = #xegpu<mode_kind vc>}
+ // CHECK-SAME: vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scattered>, vector<16xi1>
+ xegpu.store %3, %2, %0 {mode = vc, l1_hint = write_back, l2_hint = uncached}
+ : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scattered>, vector<16xi1>
+ return
+}
diff --git a/mlir/test/Dialect/XeGPU/IR/update_nd_offset.mlir b/mlir/test/Dialect/XeGPU/IR/update_nd_offset.mlir
new file mode 100644
index 00000000000000..1b97be77a2d79f
--- /dev/null
+++ b/mlir/test/Dialect/XeGPU/IR/update_nd_offset.mlir
@@ -0,0 +1,27 @@
+// RUN: mlir-opt %s | FileCheck %s
+// Verify the printed output can be parsed.
+// RUN: mlir-opt %s | mlir-opt | FileCheck %s
+// Verify the generic form can be parsed.
+// RUN: mlir-opt -mlir-print-op-generic %s | mlir-opt | FileCheck %s
+// CHECK-LABEL: func @test_update_nd_offset_vc_0({{.*}}) {
+func.func @test_update_nd_offset_vc_0(%src: memref<24x32xf32>) {
+ %c0 = arith.constant 2 : index
+ %c1 = arith.constant 4 : index
+
+ // CHECK: xegpu.create_nd_tdesc %{{arg[0-9]}}[%{{c[0-9]}}, %{{c[0-9]}}]
+ // CHECK-SAME: {mode = #xegpu<mode_kind vc>} : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
+ %1 = xegpu.create_nd_tdesc %src[%c0, %c1] {mode = vc}
+ : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
+
+ // CHECK: xegpu.load_nd %{{[0-9]}}
+ // CHECK-SAME: {l1_hint = #xegpu<cache_kind cached>, l2_hint = #xegpu<cache_kind uncached>, mode = #xegpu<mode_kind vc>}
+ // CHECK-SAME: !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32>
+ %2 = xegpu.load_nd %1 {mode = vc, l1_hint = cached, l2_hint = uncached}
+ : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32>
+
+ // CHECK: xegpu.update_nd_offset %{{[0-9]}}, [%{{c[0-9]}}, %{{c[0-9]}}] {mode = #xegpu<mode_kind vc>}
+ // CHECK-SAME: !xegpu.tensor_desc<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
+ %3 = xegpu.update_nd_offset %1, [%c0, %c1] {mode = vc} : !xegpu.tensor_desc<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
+
+ return
+}
diff --git a/mlir/test/Dialect/XeGPU/IR/update_offset_vc.mlir b/mlir/test/Dialect/XeGPU/IR/update_offset_vc.mlir
new file mode 100644
index 00000000000000..05b0092d2379b7
--- /dev/null
+++ b/mlir/test/Dialect/XeGPU/IR/update_offset_vc.mlir
@@ -0,0 +1,29 @@
+// RUN: mlir-opt %s | FileCheck %s
+// Verify the printed output can be parsed.
+// RUN: mlir-opt %s | mlir-opt | FileCheck %s
+// Verify the generic form can be parsed.
+// RUN: mlir-opt -mlir-print-op-generic %s | mlir-opt | FileCheck %s
+
+// CHECK-LABEL: func @test_update_offset_VC({{.*}}) {
+func.func @test_update_offset_VC(%src: ui64, %offsets : vector<16 x index>) {
+ %0 = arith.constant dense<1>: vector<16xi1>
+ // CHECK: xegpu.create_tdesc %{{arg[0-9]}}, %{{arg[0-9]}} {mode = #xegpu<mode_kind vc>}
+ // CHECK-SAME: ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scattered>
+ %1 = xegpu.create_tdesc %src, %offsets {mode = vc}
+ : ui64, vector<16 x index> -> !xegpu.tensor_desc<16xf32, #xegpu.scattered>
+
+ // CHECK: xegpu.load %{{[0-9]}}, %{{.*}} {l1_hint = #xegpu<cache_kind cached>, l2_hint = #xegpu<cache_kind uncached>, mode = #xegpu<mode_kind vc>}
+ // CHECK-SAME: !xegpu.tensor_desc<16xf32, #xegpu.scattered>, vector<16xi1> -> vector<16xf32>
+ %2 = xegpu.load %1, %0 {mode = vc, l1_hint = cached, l2_hint = uncached}
+ : !xegpu.tensor_desc<16xf32, #xegpu.scattered>, vector<16xi1> -> vector<16xf32>
+
+ %3 = arith.constant dense<16>: vector<16 x index>
+ %4 = arith.addi %offsets, %3: vector<16 x index>
+
+ // CHECK: xegpu.update_offset %{{[0-9]}}, %{{[0-9]}} {mode = #xegpu<mode_kind vc>}
+ // CHECK-SAME: !xegpu.tensor_desc<16xf32, #xegpu.scattered>, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scattered>
+ %5 = xegpu.update_offset %1, %4 {mode = vc}
+ : !xegpu.tensor_desc<16xf32, #xegpu.scattered>, vector<16 x index> -> !xegpu.tensor_desc<16xf32, #xegpu.scattered>
+
+ return
+}
>From 9cac285ed21833ac88773809816515156d7fcb89 Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Thu, 18 Jan 2024 10:15:30 -0600
Subject: [PATCH 2/2] update testcases
---
mlir/test/Dialect/XeGPU/IR/atomic_rmw.mlir | 43 -------------------
mlir/test/Dialect/XeGPU/IR/atomic_rmw_vc.mlir | 12 ++++--
.../Dialect/XeGPU/IR/create_nd_tdesc.mlir | 22 +++++-----
.../Dialect/XeGPU/IR/create_nd_tdesc_vc.mlir | 31 ++++++-------
mlir/test/Dialect/XeGPU/IR/create_tdesc.mlir | 11 -----
mlir/test/Dialect/XeGPU/IR/simple_gemm.mlir | 32 +++++++-------
.../test/Dialect/XeGPU/IR/simple_gemm_vc.mlir | 18 +++++---
mlir/test/Dialect/XeGPU/IR/store_scatter.mlir | 29 -------------
8 files changed, 60 insertions(+), 138 deletions(-)
delete mode 100644 mlir/test/Dialect/XeGPU/IR/atomic_rmw.mlir
delete mode 100644 mlir/test/Dialect/XeGPU/IR/create_tdesc.mlir
delete mode 100644 mlir/test/Dialect/XeGPU/IR/store_scatter.mlir
diff --git a/mlir/test/Dialect/XeGPU/IR/atomic_rmw.mlir b/mlir/test/Dialect/XeGPU/IR/atomic_rmw.mlir
deleted file mode 100644
index f80df161a543ac..00000000000000
--- a/mlir/test/Dialect/XeGPU/IR/atomic_rmw.mlir
+++ /dev/null
@@ -1,43 +0,0 @@
-// RUN: mlir-opt %s | FileCheck %s
-// Verify the printed output can be parsed.
-// RUN: mlir-opt %s | mlir-opt | FileCheck %s
-// Verify the generic form can be parsed.
-// RUN: mlir-opt -mlir-print-op-generic %s | mlir-opt | FileCheck %s
-
-// CHECK-LABEL: func @test_atomic_rmw({{.*}}) {
-func.func @test_atomic_rmw(%src: ui64, %offsets : vector<16 x index>, %value : vector<16xf32>, %mask : vector<16xi1>) {
- %1 = xegpu.create_tdesc %src, %offsets {mode=vc}: ui64, vector<16 x index> -> !xegpu.tensor_desc<16xf32, #xegpu.scattered>
-
- // CHECK: xegpu.atomic_rmw
- // CHECK-SAME: !xegpu.tensor_desc<16xf32, #xegpu.scattered>, vector<16xi1>, vector<16xf32>
- xegpu.atomic_rmw #xegpu<atomic_rmw_kind addf> %1, %mask, %value {mode=vc}
- : !xegpu.tensor_desc<16xf32, #xegpu.scattered>, vector<16xi1>, vector<16xf32> -> vector<16xf32>
-
- return
-}
-
-// CHECK-LABEL: func @test_atomic_rmw_0({{.*}}) {
-func.func @test_atomic_rmw_0(%src: ui64, %offsets : vector<16 x index>, %value : vector<16x2xf32>, %mask : vector<16xi1>) {
- %1 = xegpu.create_tdesc %src, %offsets {chunk_size_per_lane = 2, mode=vc}
- : ui64, vector<16 x index> -> !xegpu.tensor_desc<16x2xf32, #xegpu.scattered>
-
- // CHECK: xegpu.atomic_rmw
- // CHECK-SAME: tensor_desc<16x2xf32, #xegpu.scattered>, vector<16xi1>, vector<16x2xf32> -> vector<16x2xf32>
- xegpu.atomic_rmw mulf %1, %mask, %value {mode=vc}
- : !xegpu.tensor_desc<16x2xf32, #xegpu.scattered>, vector<16xi1>, vector<16x2xf32> -> vector<16x2xf32>
-
- return
-}
-
-// CHECK-LABEL: func @test_atomic_rmw_1({{.*}}) {
-func.func @test_atomic_rmw_1(%src: ui64, %offsets : vector<16 x index>, %value : vector<16x2xi32>, %mask : vector<16xi1>) {
- %1 = xegpu.create_tdesc %src, %offsets {chunk_size_per_lane = 2, mode=vc}
- : ui64, vector<16 x index> -> !xegpu.tensor_desc<16x2xi32, #xegpu.scattered>
-
- // CHECK: xegpu.atomic_rmw
- // CHECK-SAME: !xegpu.tensor_desc<16x2xi32, #xegpu.scattered>, vector<16xi1>, vector<16x2xi32> -> vector<16x2xf32>
- xegpu.atomic_rmw andi %1, %mask, %value {mode=vc}
- : !xegpu.tensor_desc<16x2xi32, #xegpu.scattered>, vector<16xi1>, vector<16x2xi32> -> vector<16x2xf32>
-
- return
-}
diff --git a/mlir/test/Dialect/XeGPU/IR/atomic_rmw_vc.mlir b/mlir/test/Dialect/XeGPU/IR/atomic_rmw_vc.mlir
index 0f7229a02aa180..90df2a7c80ac5a 100644
--- a/mlir/test/Dialect/XeGPU/IR/atomic_rmw_vc.mlir
+++ b/mlir/test/Dialect/XeGPU/IR/atomic_rmw_vc.mlir
@@ -6,9 +6,11 @@
// CHECK-LABEL: func @test_atomic_rmw({{.*}}) {
func.func @test_atomic_rmw(%src: ui64, %offsets : vector<16 x index>, %value : vector<16x1xf32>, %mask : vector<16xi1>) {
+ // CHECK: xegpu.create_tdesc %{{arg[0-9]}}, %{{arg[0-9]}} {mode = #xegpu<mode_kind vc>}
+ // CHECK-SAME: ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scattered>
%1 = xegpu.create_tdesc %src, %offsets {mode = vc} : ui64, vector<16 x index> -> !xegpu.tensor_desc<16xf32, #xegpu.scattered>
- // CHECK: xegpu.atomic_rmw
+ // CHECK: xegpu.atomic_rmw addf %{{[0-9]}}, %{{arg[0-9]}}, %{{arg[0-9]}} {mode = #xegpu<mode_kind vc>}
// CHECK-SAME: !xegpu.tensor_desc<16xf32, #xegpu.scattered>, vector<16xi1>, vector<16x1xf32>
xegpu.atomic_rmw addf %1, %mask, %value {mode = vc} : !xegpu.tensor_desc<16xf32, #xegpu.scattered>, vector<16xi1>, vector<16x1xf32> -> vector<16x1xf32>
@@ -17,9 +19,11 @@ func.func @test_atomic_rmw(%src: ui64, %offsets : vector<16 x index>, %value : v
// CHECK-LABEL: func @test_atomic_rmw_0({{.*}}) {
func.func @test_atomic_rmw_0(%src: ui64, %offsets : vector<16 x index>, %value : vector<16x2xf32>, %mask : vector<16xi1>) {
+ // CHECK: xegpu.create_tdesc %{{arg[0-9]}}, %{{arg[0-9]}} {chunk_size_per_lane = 2 : i64, mode = #xegpu<mode_kind vc>}
+ // CHECK-SAME: ui64, vector<16xindex> -> !xegpu.tensor_desc<16x2xf32, #xegpu.scattered>
%1 = xegpu.create_tdesc %src, %offsets {mode = vc, chunk_size_per_lane = 2}: ui64, vector<16 x index> -> !xegpu.tensor_desc<16x2xf32, #xegpu.scattered>
- // CHECK: xegpu.atomic_rmw
+ // CHECK: xegpu.atomic_rmw mulf %{{[0-9]}}, %{{arg[0-9]}}, %{{arg[0-9]}} {mode = #xegpu<mode_kind vc>}
// CHECK-SAME: !xegpu.tensor_desc<16x2xf32, #xegpu.scattered>, vector<16xi1>, vector<16x2xf32>
xegpu.atomic_rmw mulf %1, %mask, %value {mode = vc} : !xegpu.tensor_desc<16x2xf32, #xegpu.scattered>, vector<16xi1>, vector<16x2xf32> -> vector<16x2xf32>
@@ -28,9 +32,11 @@ func.func @test_atomic_rmw_0(%src: ui64, %offsets : vector<16 x index>, %value :
// CHECK-LABEL: func @test_atomic_rmw_1({{.*}}) {
func.func @test_atomic_rmw_1(%src: ui64, %offsets : vector<16 x index>, %value : vector<16x2xi32>, %mask : vector<16xi1>) {
+ // CHECK: xegpu.create_tdesc %{{arg[0-9]}}, %{{arg[0-9]}} {chunk_size_per_lane = 2 : i64, mode = #xegpu<mode_kind vc>}
+ // CHECK-SAME: ui64, vector<16xindex> -> !xegpu.tensor_desc<16x2xi32, #xegpu.scattered>
%1 = xegpu.create_tdesc %src, %offsets {mode = vc, chunk_size_per_lane = 2}: ui64, vector<16 x index> -> !xegpu.tensor_desc<16x2xi32, #xegpu.scattered>
- // CHECK: xegpu.atomic_rmw
+ // CHECK: xegpu.atomic_rmw andi %{{[0-9]}}, %{{arg[0-9]}}, %{{arg[0-9]}} {mode = #xegpu<mode_kind vc>}
// CHECK-SAME: !xegpu.tensor_desc<16x2xi32, #xegpu.scattered>, vector<16xi1>, vector<16x2xi32>
xegpu.atomic_rmw andi %1, %mask, %value {mode = vc} : !xegpu.tensor_desc<16x2xi32, #xegpu.scattered>, vector<16xi1>, vector<16x2xi32> -> vector<16x2xf32>
diff --git a/mlir/test/Dialect/XeGPU/IR/create_nd_tdesc.mlir b/mlir/test/Dialect/XeGPU/IR/create_nd_tdesc.mlir
index cebf59f12939da..8284d730d4089c 100644
--- a/mlir/test/Dialect/XeGPU/IR/create_nd_tdesc.mlir
+++ b/mlir/test/Dialect/XeGPU/IR/create_nd_tdesc.mlir
@@ -10,12 +10,12 @@ func.func @test_create_nd_tdesc_0(%src: memref<24x32xf16>) {
%c0 = arith.constant 2 : index
%c1 = arith.constant 4 : index
- // CHECK: xegpu.create_nd_tdesc
+ // CHECK: xegpu.create_nd_tdesc %{{arg[0-9]}}[%{{c[0-9]}}, %{{c[0-9]}}]
// CHECK-SAME: memref<24x32xf16> -> !xegpu.tensor_desc<8x16xf16, #xegpu.sg_map<wi_layout = [2, 8], wi_data = [1, 2]>>
%1 = xegpu.create_nd_tdesc %src[%c0, %c1]
: memref<24x32xf16> -> !xegpu.tensor_desc<8x16xf16, #sg_map_fp16>
- // CHECK: xegpu.create_nd_tdesc
+ // CHECK: xegpu.create_nd_tdesc %{{arg[0-9]}}[2, 4]
// CHECK-SAME: memref<24x32xf16> -> !xegpu.tensor_desc<8x16xf16, #xegpu.sg_map<wi_layout = [2, 8], wi_data = [1, 2]>>
%2 = xegpu.create_nd_tdesc %src[2, 4]
: memref<24x32xf16> -> !xegpu.tensor_desc<8x16xf16, #sg_map_fp16>
@@ -25,7 +25,7 @@ func.func @test_create_nd_tdesc_0(%src: memref<24x32xf16>) {
// CHECK-LABEL: func @test_create_nd_tdesc_1({{.*}}) {
func.func @test_create_nd_tdesc_1(%src: memref<24x32xf16>, %x : index, %y : index) {
- // CHECK: xegpu.create_nd_tdesc
+ // CHECK: xegpu.create_nd_tdesc %{{arg[0-9]}}[%{{arg[0-9]}}, %{{arg[0-9]}}]
// CHECK-SAME: memref<24x32xf16> -> !xegpu.tensor_desc<8x16xf16, #xegpu.sg_map<wi_layout = [2, 8], wi_data = [1, 2]>>
%1 = xegpu.create_nd_tdesc %src[%x, %y]
: memref<24x32xf16> -> !xegpu.tensor_desc<8x16xf16, #sg_map_fp16>
@@ -35,7 +35,7 @@ func.func @test_create_nd_tdesc_1(%src: memref<24x32xf16>, %x : index, %y : inde
// CHECK-LABEL: func @test_create_nd_tdesc_2({{.*}}) {
func.func @test_create_nd_tdesc_2(%src: ui64, %w : index, %h : index, %x : index, %y : index) {
%c1 = arith.constant 1 : index
- // CHECK: xegpu.create_nd_tdesc
+ // CHECK: xegpu.create_nd_tdesc %{{arg[0-9]}}[%{{arg[0-9]}}, %{{arg[0-9]}}], [%{{arg[0-9]}}, %{{arg[0-9]}}], [%{{arg[0-9]}}, %{{c[0-9]}}]
// CHECK-SAME: ui64 -> !xegpu.tensor_desc<8x16xf16, #xegpu.sg_map<wi_layout = [2, 8], wi_data = [1, 2]>>
%1 = xegpu.create_nd_tdesc %src[%x, %y], [%h, %w], [%w, %c1] : ui64 -> !xegpu.tensor_desc<8x16xf16, #sg_map_fp16>
return
@@ -44,7 +44,7 @@ func.func @test_create_nd_tdesc_2(%src: ui64, %w : index, %h : index, %x : index
// CHECK-LABEL: func @test_create_nd_tdesc_3({{.*}}) {
func.func @test_create_nd_tdesc_3(%src: memref<?x?xf16>, %w : index, %h : index, %x : index, %y : index) {
%c1 = arith.constant 1 : index
- // CHECK: xegpu.create_nd_tdesc
+ // CHECK: xegpu.create_nd_tdesc %{{arg[0-9]}}[%{{arg[0-9]}}, %{{arg[0-9]}}], [%{{arg[0-9]}}, %{{arg[0-9]}}], [%{{arg[0-9]}}, %{{c[0-9]}}]
// CHECK-SAME: memref<?x?xf16> -> !xegpu.tensor_desc<8x16xf16, #xegpu.sg_map<wi_layout = [2, 8], wi_data = [1, 2]>>
%1 = xegpu.create_nd_tdesc %src[%x, %y], [%h, %w], [%w, %c1] : memref<?x?xf16> -> !xegpu.tensor_desc<8x16xf16, #sg_map_fp16>
return
@@ -54,7 +54,7 @@ func.func @test_create_nd_tdesc_3(%src: memref<?x?xf16>, %w : index, %h : index,
// CHECK-LABEL: func @test_create_nd_tdesc_4({{.*}}) {
func.func @test_create_nd_tdesc_4(%src: memref<?x?xf16>, %w : index, %h : index, %x : index, %y : index) {
%c1 = arith.constant 1 : index
- // CHECK: xegpu.create_nd_tdesc
+ // CHECK: xegpu.create_nd_tdesc %{{arg[0-9]}}[%{{arg[0-9]}}, %{{arg[0-9]}}], [%{{arg[0-9]}}, %{{arg[0-9]}}], [%{{arg[0-9]}}, %{{c[0-9]}}]
// CHECK-SAME: memref<?x?xf16> -> !xegpu.tensor_desc<8x16xf16, #xegpu.sg_map<wi_layout = [2, 8], wi_data = [1, 2]>>
%1 = xegpu.create_nd_tdesc %src[%x, %y], [%h, %w], [%w, %c1]
: memref<?x?xf16> -> !xegpu.tensor_desc<8x16xf16, #sg_map_fp16>
@@ -64,7 +64,7 @@ func.func @test_create_nd_tdesc_4(%src: memref<?x?xf16>, %w : index, %h : index,
// CHECK-LABEL: func @test_create_nd_tdesc_5({{.*}}) {
func.func @test_create_nd_tdesc_5(%src: memref<?x?xf16>, %w : index, %h : index, %x : index, %y : index) {
%c1 = arith.constant 1 : index
- // CHECK: xegpu.create_nd_tdesc
+ // CHECK: xegpu.create_nd_tdesc %{{arg[0-9]}}[%{{arg[0-9]}}, %{{arg[0-9]}}], [%{{arg[0-9]}}, %{{arg[0-9]}}], [%{{arg[0-9]}}, %{{c[0-9]}}]
// CHECK-SAME: memref<?x?xf16> -> !xegpu.tensor_desc<8x16xf16, #xegpu.tdesc_attr<memory_scope = slm, map = <wi_layout = [2, 8], wi_data = [1, 2]>>>
%1 = xegpu.create_nd_tdesc %src[%x, %y], [%h, %w], [%w, %c1]
: memref<?x?xf16> -> !xegpu.tensor_desc<8x16xf16, #xegpu.tdesc_attr<memory_scope = slm, map = #sg_map_fp16>>
@@ -74,7 +74,7 @@ func.func @test_create_nd_tdesc_5(%src: memref<?x?xf16>, %w : index, %h : index,
// CHECK-LABEL: func @test_create_nd_tdesc_6({{.*}}) {
func.func @test_create_nd_tdesc_6(%src: memref<?x?xf16>, %w : index, %h : index, %x : index, %y : index) {
%c1 = arith.constant 1 : index
- // CHECK: xegpu.create_nd_tdesc
+ // CHECK: xegpu.create_nd_tdesc %{{arg[0-9]}}[%{{arg[0-9]}}, %{{arg[0-9]}}], [%{{arg[0-9]}}, %{{arg[0-9]}}], [%{{arg[0-9]}}, %{{c[0-9]}}]
// CHECK-SAME: memref<?x?xf16> -> !xegpu.tensor_desc<8x16xf16, #xegpu.tdesc_attr<memory_scope = slm, map = <wi_layout = [2, 8], wi_data = [1, 2]>>>
%1 = xegpu.create_nd_tdesc %src[%x, %y], [%h, %w], [%w, %c1]
: memref<?x?xf16> -> !xegpu.tensor_desc<8x16xf16, #xegpu.tdesc_attr<memory_scope = slm, map = #sg_map_fp16>>
@@ -83,7 +83,7 @@ func.func @test_create_nd_tdesc_6(%src: memref<?x?xf16>, %w : index, %h : index,
// CHECK-LABEL: func @test_create_nd_tdesc_7({{.*}}) {
func.func @test_create_nd_tdesc_7(%src: memref<1024xf16>, %offset : index) {
- // CHECK: xegpu.create_nd_tdesc
+ // CHECK: xegpu.create_nd_tdesc %{{arg[0-9]}}[%{{arg[0-9]}}]
// CHECK-SAME: memref<1024xf16> -> !xegpu.tensor_desc<16xf16, #xegpu.sg_map<wi_layout = [2, 8], wi_data = [1, 2]>>
%1 = xegpu.create_nd_tdesc %src[%offset] : memref<1024xf16> -> !xegpu.tensor_desc<16xf16, #sg_map_fp16>
return
@@ -93,7 +93,7 @@ func.func @test_create_nd_tdesc_7(%src: memref<1024xf16>, %offset : index) {
// CHECK-LABEL: func @test_create_nd_tdesc_8({{.*}}) {
func.func @test_create_nd_tdesc_8(%src: memref<?x?xf16>, %w : index, %h : index, %x : index) {
%c1 = arith.constant 1 : index
- // CHECK: xegpu.create_nd_tdesc
+ // CHECK: xegpu.create_nd_tdesc %{{arg[0-9]}}[8, %{{arg[0-9]}}], [%{{arg[0-9]}}, %{{arg[0-9]}}], [%{{arg[0-9]}}, %c1]
// CHECK-SAME: memref<?x?xf16> -> !xegpu.tensor_desc<8x16xf16, #xegpu.tdesc_attr<memory_scope = slm, map = <wi_layout = [2, 8], wi_data = [1, 2]>>>
%1 = xegpu.create_nd_tdesc %src[8, %x], [%h, %w], [%w, %c1]
: memref<?x?xf16> -> !xegpu.tensor_desc<8x16xf16, #xegpu.tdesc_attr<memory_scope = slm, map = #sg_map_fp16>>
@@ -103,7 +103,7 @@ func.func @test_create_nd_tdesc_8(%src: memref<?x?xf16>, %w : index, %h : index,
// CHECK-LABEL: func @test_create_nd_tdesc_9({{.*}}) {
func.func @test_create_nd_tdesc_9(%src: memref<?x?xf16>, %w : index, %h : index, %x : index) {
%c1 = arith.constant 1 : index
- // CHECK: xegpu.create_nd_tdesc
+ // CHECK: xegpu.create_nd_tdesc %{{arg[0-9]}}[8, %{{arg[0-9]}}], [%{{arg[0-9]}}, %{{arg[0-9]}}], [%{{arg[0-9]}}, %c1]
// CHECK-SAME: memref<?x?xf16> -> !xegpu.tensor_desc<64x128xf16, #xegpu.tdesc_attr<memory_scope = slm, map = <wi_layout = [2, 8], wi_data = [1, 2]>>>
%1 = xegpu.create_nd_tdesc %src[8, %x], [%h, %w], [%w, %c1] : memref<?x?xf16>
-> !xegpu.tensor_desc<64x128xf16, #xegpu.tdesc_attr<memory_scope = slm, map = #sg_map_fp16>>
diff --git a/mlir/test/Dialect/XeGPU/IR/create_nd_tdesc_vc.mlir b/mlir/test/Dialect/XeGPU/IR/create_nd_tdesc_vc.mlir
index a21bf792fe0792..34cd66c9c69a4e 100644
--- a/mlir/test/Dialect/XeGPU/IR/create_nd_tdesc_vc.mlir
+++ b/mlir/test/Dialect/XeGPU/IR/create_nd_tdesc_vc.mlir
@@ -10,12 +10,12 @@ func.func @test_create_nd_tdesc_vc_0(%src: memref<24x32xf32>) {
%c0 = arith.constant 2 : index
%c1 = arith.constant 4 : index
- // CHECK: xegpu.create_nd_tdesc
+ // CHECK: xegpu.create_nd_tdesc %{{arg[0-9]}}[%{{c[0-9]}}, %{{c[0-9]}}] {mode = #xegpu<mode_kind vc>}
// CHECK-SAME: memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
%1 = xegpu.create_nd_tdesc %src[%c0, %c1] {mode = vc}
: memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
- // CHECK: xegpu.create_nd_tdesc
+ // CHECK: xegpu.create_nd_tdesc %{{arg[0-9]}}[2, 4] {mode = #xegpu<mode_kind vc>}
// CHECK-SAME: memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
%2 = xegpu.create_nd_tdesc %src[2, 4] {mode = vc}
: memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
@@ -25,19 +25,16 @@ func.func @test_create_nd_tdesc_vc_0(%src: memref<24x32xf32>) {
// CHECK-LABEL: func @test_create_nd_tdesc_vc_1({{.*}}) {
func.func @test_create_nd_tdesc_vc_1(%src: memref<24x32xf32>, %x : index, %y : index) {
- // CHECK: xegpu.create_nd_tdesc
- // CHECK-SAME: %arg0[%arg1, %arg2]
+ // CHECK: xegpu.create_nd_tdesc %{{arg[0-9]}}[%{{arg[0-9]}}, %{{arg[0-9]}}] {mode = #xegpu<mode_kind vc>}
// CHECK-SAME: memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
- %1 = xegpu.create_nd_tdesc %src[%x, %y] {mode = vc}
- : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
+ %1 = xegpu.create_nd_tdesc %src[%x, %y] {mode = vc} : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
return
}
// CHECK-LABEL: func @test_create_nd_tdesc_vc_2({{.*}}) {
func.func @test_create_nd_tdesc_vc_2(%src: ui64, %w : index, %h : index, %x : index, %y : index) {
%c1 = arith.constant 1 : index
- // CHECK: xegpu.create_nd_tdesc
- // CHECK-SAME: %arg0[%arg3, %arg4], [%arg2, %arg1], [%arg1, %c1]
+ // CHECK: xegpu.create_nd_tdesc %{{arg[0-9]}}[%{{arg[0-9]}}, %{{arg[0-9]}}], [%{{arg[0-9]}}, %{{arg[0-9]}}], [%{{arg[0-9]}}, %{{c[0-9]}}] {mode = #xegpu<mode_kind vc>}
// CHECK-SAME: ui64 -> !xegpu.tensor_desc<8x16xf32>
%1 = xegpu.create_nd_tdesc %src[%x, %y], [%h, %w], [%w, %c1] {mode = vc} : ui64 -> !xegpu.tensor_desc<8x16xf32>
return
@@ -46,8 +43,7 @@ func.func @test_create_nd_tdesc_vc_2(%src: ui64, %w : index, %h : index, %x : in
// CHECK-LABEL: func @test_create_nd_tdesc_vc_3({{.*}}) {
func.func @test_create_nd_tdesc_vc_3(%src: memref<?x?xf32>, %w : index, %h : index, %x : index, %y : index) {
%c1 = arith.constant 1 : index
- // CHECK: xegpu.create_nd_tdesc
- // CHECK-SAME: %arg0[%arg3, %arg4], [%arg2, %arg1], [%arg1, %c1]
+ // CHECK: xegpu.create_nd_tdesc %{{arg[0-9]}}[%{{arg[0-9]}}, %{{arg[0-9]}}], [%{{arg[0-9]}}, %{{arg[0-9]}}], [%{{arg[0-9]}}, %{{c[0-9]}}] {mode = #xegpu<mode_kind vc>}
// CHECK-SAME: memref<?x?xf32> -> !xegpu.tensor_desc<8x16xf32>
%1 = xegpu.create_nd_tdesc %src[%x, %y], [%h, %w], [%w, %c1] {mode = vc} : memref<?x?xf32> -> !xegpu.tensor_desc<8x16xf32>
return
@@ -57,8 +53,7 @@ func.func @test_create_nd_tdesc_vc_3(%src: memref<?x?xf32>, %w : index, %h : ind
// CHECK-LABEL: func @test_create_nd_tdesc_vc_4({{.*}}) {
func.func @test_create_nd_tdesc_vc_4(%src: memref<?x?xf32>, %w : index, %h : index, %x : index, %y : index) {
%c1 = arith.constant 1 : index
- // CHECK: xegpu.create_nd_tdesc
- // CHECK-SAME: %arg0[%arg3, %arg4], [%arg2, %arg1], [%arg1, %c1]
+ // CHECK: xegpu.create_nd_tdesc %{{arg[0-9]}}[%{{arg[0-9]}}, %{{arg[0-9]}}], [%{{arg[0-9]}}, %{{arg[0-9]}}], [%{{arg[0-9]}}, %{{c[0-9]}}] {mode = #xegpu<mode_kind vc>}
// CHECK-SAME: memref<?x?xf32> -> !xegpu.tensor_desc<8x16xf32>
%1 = xegpu.create_nd_tdesc %src[%x, %y], [%h, %w], [%w, %c1] {mode = vc} : memref<?x?xf32> -> !xegpu.tensor_desc<8x16xf32>
return
@@ -67,8 +62,7 @@ func.func @test_create_nd_tdesc_vc_4(%src: memref<?x?xf32>, %w : index, %h : ind
// CHECK-LABEL: func @test_create_nd_tdesc_vc_5({{.*}}) {
func.func @test_create_nd_tdesc_vc_5(%src: memref<?x?xf32>, %w : index, %h : index, %x : index, %y : index) {
%c1 = arith.constant 1 : index
- // CHECK: xegpu.create_nd_tdesc
- // CHECK-SAME: %arg0[%arg3, %arg4], [%arg2, %arg1], [%arg1, %c1]
+ // CHECK: xegpu.create_nd_tdesc %{{arg[0-9]}}[%{{arg[0-9]}}, %{{arg[0-9]}}], [%{{arg[0-9]}}, %{{arg[0-9]}}], [%{{arg[0-9]}}, %{{c[0-9]}}] {mode = #xegpu<mode_kind vc>}
// CHECK-SAME: memref<?x?xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.tdesc_attr<memory_scope = slm>>
%1 = xegpu.create_nd_tdesc %src[%x, %y], [%h, %w], [%w, %c1] {mode = vc}
: memref<?x?xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.tdesc_attr<memory_scope = slm>>
@@ -78,8 +72,7 @@ func.func @test_create_nd_tdesc_vc_5(%src: memref<?x?xf32>, %w : index, %h : ind
// CHECK-LABEL: func @test_create_nd_tdesc_vc_6({{.*}}) {
func.func @test_create_nd_tdesc_vc_6(%src: memref<?x?xf32>, %w : index, %h : index, %x : index, %y : index) {
%c1 = arith.constant 1 : index
- // CHECK: xegpu.create_nd_tdesc
- // CHECK-SAME: %arg0[%arg3, %arg4], [%arg2, %arg1], [%arg1, %c1]
+ // CHECK: xegpu.create_nd_tdesc %{{arg[0-9]}}[%{{arg[0-9]}}, %{{arg[0-9]}}], [%{{arg[0-9]}}, %{{arg[0-9]}}], [%{{arg[0-9]}}, %{{c[0-9]}}] {mode = #xegpu<mode_kind vc>}
// CHECK-SAME: memref<?x?xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.tdesc_attr<memory_scope = slm>>
%1 = xegpu.create_nd_tdesc %src[%x, %y], [%h, %w], [%w, %c1] {mode = vc}
: memref<?x?xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.tdesc_attr<memory_scope = slm>>
@@ -89,7 +82,7 @@ func.func @test_create_nd_tdesc_vc_6(%src: memref<?x?xf32>, %w : index, %h : ind
// CHECK-LABEL: func @test_create_nd_tdesc_vc_7({{.*}}) {
func.func @test_create_nd_tdesc_vc_7(%src: memref<1024xf32>, %offset : index) {
- // CHECK: xegpu.create_nd_tdesc
+ // CHECK: xegpu.create_nd_tdesc %{{arg[0-9]}}[%{{arg[0-9]}}] {mode = #xegpu<mode_kind vc>}
// CHECK-SAME: memref<1024xf32> -> !xegpu.tensor_desc<16xf32>
%1 = xegpu.create_nd_tdesc %src[%offset] {mode = vc} : memref<1024xf32> -> !xegpu.tensor_desc<16xf32>
return
@@ -99,7 +92,7 @@ func.func @test_create_nd_tdesc_vc_7(%src: memref<1024xf32>, %offset : index) {
// CHECK-LABEL: func @test_create_nd_tdesc_vc_8({{.*}}) {
func.func @test_create_nd_tdesc_vc_8(%src: memref<?x?xf32>, %w : index, %h : index, %x : index) {
%c1 = arith.constant 1 : index
- // CHECK: xegpu.create_nd_tdesc
+ // CHECK: xegpu.create_nd_tdesc %{{arg[0-9]}}[8, %{{arg[0-9]}}], [%{{arg[0-9]}}, %{{arg[0-9]}}], [%{{arg[0-9]}}, %c1] {mode = #xegpu<mode_kind vc>}
// CHECK-SAME: memref<?x?xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.tdesc_attr<memory_scope = slm>>
%1 = xegpu.create_nd_tdesc %src[8, %x], [%h, %w], [%w, %c1] {mode = vc}
: memref<?x?xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.tdesc_attr<memory_scope = slm>>
@@ -108,7 +101,7 @@ func.func @test_create_nd_tdesc_vc_8(%src: memref<?x?xf32>, %w : index, %h : ind
// CHECK-LABEL: func @test_create_nd_tdesc_vc_9({{.*}}) {
func.func @test_create_nd_tdesc_vc_9(%src: memref<8x32xf32>) {
- // CHECK: xegpu.create_nd_tdesc
+ // CHECK: xegpu.create_nd_tdesc %{{arg[0-9]}}[0, 0]
// CHECK-SAME: memref<8x32xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.tdesc_attr<memory_scope = slm, array_length = 2>>
%1 = xegpu.create_nd_tdesc %src[0, 0] {mode = vc} : memref<8x32xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.tdesc_attr<memory_scope = slm, array_length = 2>>
return
diff --git a/mlir/test/Dialect/XeGPU/IR/create_tdesc.mlir b/mlir/test/Dialect/XeGPU/IR/create_tdesc.mlir
deleted file mode 100644
index 8fb5ac824ddb27..00000000000000
--- a/mlir/test/Dialect/XeGPU/IR/create_tdesc.mlir
+++ /dev/null
@@ -1,11 +0,0 @@
-// RUN: mlir-opt %s | FileCheck %s
-// Verify the printed output can be parsed.
-// RUN: mlir-opt %s | mlir-opt | FileCheck %s
-// Verify the generic form can be parsed.
-// RUN: mlir-opt -mlir-print-op-generic %s | mlir-opt | FileCheck %s
-
-// CHECK-LABEL: func @test_create_tdesc_vc({{.*}}) {
-func.func @test_create_tdesc_vc(%src: ui64, %offsets : vector<16 x index>) {
- %1 = xegpu.create_tdesc %src, %offsets {mode=vc} : ui64, vector<16 x index> -> !xegpu.tensor_desc<16xf32, #xegpu.scattered>
- return
-}
\ No newline at end of file
diff --git a/mlir/test/Dialect/XeGPU/IR/simple_gemm.mlir b/mlir/test/Dialect/XeGPU/IR/simple_gemm.mlir
index ff6f31c77064af..8df22fb78996a5 100644
--- a/mlir/test/Dialect/XeGPU/IR/simple_gemm.mlir
+++ b/mlir/test/Dialect/XeGPU/IR/simple_gemm.mlir
@@ -23,12 +23,12 @@ func.func @test_gemm_bf16(%a : memref<1024x1024xbf16>, %b: memref<1024x1024xbf16
scf.for %i= %c0 to %c1024 step %c8 {
scf.for %j= %c0 to %c1024 step %c16 {
- // CHECK: xegpu.create_nd_tdesc
+ // CHECK: xegpu.create_nd_tdesc %{{arg[0-9]}}[%{{arg[0-9]}}, %{{c[0-9]}}]
// CHECK-SAME: memref<1024x1024xbf16>
// CHECK-SAME: -> !xegpu.tensor_desc<8x16xbf16, #xegpu.sg_map<wi_layout = [2, 8], wi_data = [1, 2]>>
%1 = xegpu.create_nd_tdesc %a[%i, %c0] : memref<1024x1024xbf16> -> !xegpu.tensor_desc<8x16xbf16, #sg_map_fp16_a>
- // CHECK: xegpu.create_nd_tdesc
+ // CHECK: xegpu.create_nd_tdesc %{{arg[0-9]}}[%{{c[0-9]}}, %{{arg[0-9]}}]
// CHECK-SAME: memref<1024x1024xbf16>
// CHECK-SAME: -> !xegpu.tensor_desc<16x16xbf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
%2 = xegpu.create_nd_tdesc %b[%c0, %j] : memref<1024x1024xbf16> -> !xegpu.tensor_desc<16x16xbf16, #sg_map_fp16_b>
@@ -37,33 +37,35 @@ func.func @test_gemm_bf16(%a : memref<1024x1024xbf16>, %b: memref<1024x1024xbf16
%tmp0, %tmp1, %result = scf.for %k= %c0 to %c1024 step %c16 iter_args(%subA = %1, %subB = %2, %subC = %3)
-> (!xegpu.tensor_desc<8x16xbf16, #sg_map_fp16_a>, !xegpu.tensor_desc<16x16xbf16, #sg_map_fp16_b>, vector<8x1xf32>) {
- // CHECK: xegpu.load_nd
- // CHECK-SAME: vector<4x1x2xbf16>
+ // CHECK: xegpu.load_nd %{{arg[0-9]}} {vnni_axis = 1 : i64}
+ // CHECK-SAME: !xegpu.tensor_desc<8x16xbf16, #xegpu.sg_map<wi_layout = [2, 8], wi_data = [1, 2]>> -> vector<4x1x2xbf16>
%4 = xegpu.load_nd %subA {vnni_axis = 1} : !xegpu.tensor_desc<8x16xbf16, #sg_map_fp16_a> -> vector<4x1x2xbf16>
- // CHECK: xegpu.load_nd
- // CHECK-SAME: vector<8x1x2xbf16>
+ // CHECK: xegpu.load_nd %{{arg[0-9]}} {vnni_axis = 0 : i64}
+ // CHECK-SAME: !xegpu.tensor_desc<16x16xbf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>> -> vector<8x1x2xbf16>
%5 = xegpu.load_nd %subB {vnni_axis = 0} : !xegpu.tensor_desc<16x16xbf16, #sg_map_fp16_b> -> vector<8x1x2xbf16>
- // CHECK: xegpu.dpas
+ // CHECK: xegpu.dpas %{{[0-9]}}, %{{[0-9]}}, %{{arg[0-9]}}
// CHECK-SAME: vector<4x1x2xbf16>, vector<8x1x2xbf16>, vector<8x1xf32> -> vector<8x1xf32>
%6 = xegpu.dpas %4, %5, %subC : vector<4x1x2xbf16>, vector<8x1x2xbf16>, vector<8x1xf32> -> vector<8x1xf32>
- %7 = xegpu.update_nd_offset %subA, [%c0, %c16] : !xegpu.tensor_desc<8x16xbf16, #sg_map_fp16_a>
- -> !xegpu.tensor_desc<8x16xbf16, #sg_map_fp16_a>
+ // CHECK: xegpu.update_nd_offset %{{arg[0-9]}}, [%{{c[0-9]}}, %{{c[0-9]+}}]
+ // CHECK-SAME: !xegpu.tensor_desc<8x16xbf16, #xegpu.sg_map<wi_layout = [2, 8], wi_data = [1, 2]>> -> !xegpu.tensor_desc<8x16xbf16, #xegpu.sg_map<wi_layout = [2, 8], wi_data = [1, 2]>>
+ %7 = xegpu.update_nd_offset %subA, [%c0, %c16] : !xegpu.tensor_desc<8x16xbf16, #sg_map_fp16_a> -> !xegpu.tensor_desc<8x16xbf16, #sg_map_fp16_a>
- %8 = xegpu.update_nd_offset %subB, [%c16, %c0] : !xegpu.tensor_desc<16x16xbf16, #sg_map_fp16_b>
- -> !xegpu.tensor_desc<16x16xbf16, #sg_map_fp16_b>
+ // CHECK: xegpu.update_nd_offset %{{arg[0-9]}}, [%{{c[0-9]+}}, %{{c[0-9]}}]
+ // CHECK-SAME: !xegpu.tensor_desc<16x16xbf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>> -> !xegpu.tensor_desc<16x16xbf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ %8 = xegpu.update_nd_offset %subB, [%c16, %c0] : !xegpu.tensor_desc<16x16xbf16, #sg_map_fp16_b> -> !xegpu.tensor_desc<16x16xbf16, #sg_map_fp16_b>
scf.yield %7, %8, %6: !xegpu.tensor_desc<8x16xbf16, #sg_map_fp16_a>, !xegpu.tensor_desc<16x16xbf16, #sg_map_fp16_b>, vector<8x1xf32>
}
- // CHECK: xegpu.create_nd_tdesc
- // CHECK-SAME: memref<1024x1024xf32>
+ // CHECK: xegpu.create_nd_tdesc %{{arg[0-9]}}[{{%arg[0-9]}}, %{{arg[0-9]}}]
+ // CHECK-SAME: memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
%9 = xegpu.create_nd_tdesc %c[%i, %j] : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32, #sg_map_fp16_c>
- // CHECK: xegpu.store_nd
- // CHECK-SAME: vector<8x1xf32>
+ // CHECK: xegpu.store_nd %{{[0-9]#2}}, %{{[0-9]}}
+ // CHECK-SAME: vector<8x1xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
xegpu.store_nd %result, %9 : vector<8x1xf32>, !xegpu.tensor_desc<8x16xf32, #sg_map_fp16_c>
}
}
diff --git a/mlir/test/Dialect/XeGPU/IR/simple_gemm_vc.mlir b/mlir/test/Dialect/XeGPU/IR/simple_gemm_vc.mlir
index 794a6b6f1afb9c..62b972ad189fde 100644
--- a/mlir/test/Dialect/XeGPU/IR/simple_gemm_vc.mlir
+++ b/mlir/test/Dialect/XeGPU/IR/simple_gemm_vc.mlir
@@ -20,11 +20,11 @@ func.func @test_gemm_vc_bf16(%a : memref<1024x1024xbf16>, %b: memref<1024x1024xb
scf.for %i= %c0 to %c1024 step %c8 {
scf.for %j= %c0 to %c1024 step %c16 {
- // CHECK: xegpu.create_nd_tdesc
+ // CHECK: xegpu.create_nd_tdesc %{{arg[0-9]}}[%{{arg[0-9]}}, %{{c[0-9]}}] {mode = #xegpu<mode_kind vc>}
// CHECK-SAME: memref<1024x1024xbf16> -> !xegpu.tensor_desc<8x16xbf16>
%1 = xegpu.create_nd_tdesc %a[%i, %c0] {mode = vc} : memref<1024x1024xbf16> -> !xegpu.tensor_desc<8x16xbf16>
- // CHECK: xegpu.create_nd_tdesc
+ // CHECK: xegpu.create_nd_tdesc %{{arg[0-9]}}[%{{c[0-9]}}, %{{arg[0-9]}}] {mode = #xegpu<mode_kind vc>}
// CHECK-SAME: memref<1024x1024xbf16> -> !xegpu.tensor_desc<16x16xbf16>
%2 = xegpu.create_nd_tdesc %b[%c0, %j] {mode = vc} : memref<1024x1024xbf16> -> !xegpu.tensor_desc<16x16xbf16>
@@ -33,30 +33,34 @@ func.func @test_gemm_vc_bf16(%a : memref<1024x1024xbf16>, %b: memref<1024x1024xb
%tmp0, %tmp1, %result = scf.for %k= %c0 to %c1024 step %c16
iter_args(%subA = %1, %subB = %2, %subC = %3)
-> (!xegpu.tensor_desc<8x16xbf16>, !xegpu.tensor_desc<16x16xbf16>, vector<8x16xf32>) {
- // CHECK: xegpu.load_nd
+ // CHECK: xegpu.load_nd %{{arg[0-9]}} {mode = #xegpu<mode_kind vc>, vnni_axis = 1 : i64}
// CHECK-SAME: !xegpu.tensor_desc<8x16xbf16> -> vector<8x8x2xbf16>
%4 = xegpu.load_nd %subA {mode = vc, vnni_axis = 1} : !xegpu.tensor_desc<8x16xbf16> -> vector<8x8x2xbf16>
- // CHECK: xegpu.load_nd
+ // CHECK: xegpu.load_nd %{{arg[0-9]}} {mode = #xegpu<mode_kind vc>, vnni_axis = 0 : i64}
// CHECK-SAME: !xegpu.tensor_desc<16x16xbf16> -> vector<8x16x2xbf16>
%5 = xegpu.load_nd %subB {mode = vc, vnni_axis = 0} : !xegpu.tensor_desc<16x16xbf16> -> vector<8x16x2xbf16>
- // CHECK: xegpu.dpas
+ // CHECK: xegpu.dpas %{{[0-9]}}, %{{[0-9]}}, %{{arg[0-9]}} {mode = #xegpu<mode_kind vc>}
// CHECK-SAME: vector<8x8x2xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32>
%6 = xegpu.dpas %4, %5, %subC {mode = vc} : vector<8x8x2xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32>
+ // CHECK: xegpu.update_nd_offset %{{arg[0-9]}}, [%{{c[0-9]}}, %{{c[0-9]+}}] {mode = #xegpu<mode_kind vc>}
+ // CHECK-SAME: !xegpu.tensor_desc<8x16xbf16> -> !xegpu.tensor_desc<8x16xbf16>
%7 = xegpu.update_nd_offset %subA, [%c0, %c16] {mode = vc} : !xegpu.tensor_desc<8x16xbf16> -> !xegpu.tensor_desc<8x16xbf16>
+ // CHECK: xegpu.update_nd_offset %{{arg[0-9]}}, [%{{c[0-9]+}}, %{{c[0-9]}}] {mode = #xegpu<mode_kind vc>}
+ // CHECK-SAME: !xegpu.tensor_desc<16x16xbf16> -> !xegpu.tensor_desc<16x16xbf16>
%8 = xegpu.update_nd_offset %subB, [%c16, %c0] {mode = vc} : !xegpu.tensor_desc<16x16xbf16> -> !xegpu.tensor_desc<16x16xbf16>
scf.yield %7, %8, %6: !xegpu.tensor_desc<8x16xbf16>, !xegpu.tensor_desc<16x16xbf16>, vector<8x16xf32>
}
- // CHECK: xegpu.create_nd_tdesc
+ // CHECK: xegpu.create_nd_tdesc %{{arg[0-9]}}[{{%arg[0-9]}}, %{{arg[0-9]}}] {mode = #xegpu<mode_kind vc>}
// CHECK-SAME: memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32>
%9 = xegpu.create_nd_tdesc %c[%i, %j] {mode = vc} : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32>
- // CHECK: xegpu.store_nd
+ // CHECK: xegpu.store_nd %{{[0-9]#2}}, %{{[0-9]}} {mode = #xegpu<mode_kind vc>}
// CHECK-SAME: vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
xegpu.store_nd %result, %9 {mode = vc}: vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
}
diff --git a/mlir/test/Dialect/XeGPU/IR/store_scatter.mlir b/mlir/test/Dialect/XeGPU/IR/store_scatter.mlir
deleted file mode 100644
index 6d98ac3950c31f..00000000000000
--- a/mlir/test/Dialect/XeGPU/IR/store_scatter.mlir
+++ /dev/null
@@ -1,29 +0,0 @@
-// RUN: mlir-opt %s | FileCheck %s
-// Verify the printed output can be parsed.
-// RUN: mlir-opt %s | mlir-opt | FileCheck %s
-// Verify the generic form can be parsed.
-// RUN: mlir-opt -mlir-print-op-generic %s | mlir-opt | FileCheck %s
-
-// CHECK-LABEL: func @test_store_scatter({{.*}}) {
-func.func @test_store_scatter(%src: ui64, %offsets : vector<16xindex>, %dst: ui64) {
- %0 = arith.constant dense<true>: vector<16xi1>
- // CHECK: xegpu.create_tdesc %{{arg[0-9]}}, %{{arg[0-9]}} {mode = #xegpu<mode_kind vc>}
- // CHECK-SAME: ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scattered>
- %1 = xegpu.create_tdesc %src, %offsets {mode = vc}
- : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scattered>
-
- // CHECK: xegpu.create_tdesc %{{arg[0-9]}}, %{{arg[0-9]}} {mode = #xegpu<mode_kind vc>}
- // CHECK-SAME: ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scattered>
- %2 = xegpu.create_tdesc %dst, %offsets {mode = vc}
- : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scattered>
-
- // CHECK: xegpu.load %{{[0-9]}}, %{{.*}} {l1_hint = #xegpu<cache_kind cached>, l2_hint = #xegpu<cache_kind uncached>, mode = #xegpu<mode_kind vc>}
- // CHECK-SAME: !xegpu.tensor_desc<16xf32, #xegpu.scattered>, vector<16xi1> -> vector<16xf32>
- %3 = xegpu.load %1, %0 {mode = vc, l1_hint = cached, l2_hint = uncached}
- : !xegpu.tensor_desc<16xf32, #xegpu.scattered>, vector<16xi1> -> vector<16xf32>
- // CHECK: xegpu.store %{{[0-9]}}, %{{[0-9]}}, %{{.*}} {l1_hint = #xegpu<cache_kind write_back>, l2_hint = #xegpu<cache_kind uncached>, mode = #xegpu<mode_kind vc>}
- // CHECK-SAME: vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scattered>, vector<16xi1>
- xegpu.store %3, %2, %0 {mode = vc, l1_hint = write_back, l2_hint = uncached}
- : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scattered>, vector<16xi1>
- return
-}
More information about the Mlir-commits
mailing list