[Mlir-commits] [mlir] [MLIR][Dialect] Add XeVM dialect (PR #144811)

Sang Ik Lee llvmlistbot at llvm.org
Wed Jun 18 15:51:51 PDT 2025


https://github.com/silee2 created https://github.com/llvm/llvm-project/pull/144811

None

>From 75bc544d5ff42cefe7dceddf021450881689ce9b Mon Sep 17 00:00:00 2001
From: "Lee, Sang Ik" <sang.ik.lee at intel.com>
Date: Wed, 18 Jun 2025 22:48:04 +0000
Subject: [PATCH] Add XeVM dialect.

---
 .../mlir/Dialect/LLVMIR/CMakeLists.txt        |  10 +
 .../include/mlir/Dialect/LLVMIR/XeVMDialect.h |  28 +
 mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td   | 550 ++++++++++++++++++
 mlir/include/mlir/InitAllDialects.h           |   4 +-
 mlir/lib/Dialect/LLVMIR/CMakeLists.txt        |  22 +
 mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp    | 381 ++++++++++++
 mlir/test/Dialect/LLVMIR/invalid.mlir         |  45 +-
 mlir/test/Dialect/LLVMIR/xevm.mlir            |  76 +++
 mlir/test/lib/Dialect/GPU/CMakeLists.txt      |   1 +
 9 files changed, 1114 insertions(+), 3 deletions(-)
 create mode 100644 mlir/include/mlir/Dialect/LLVMIR/XeVMDialect.h
 create mode 100644 mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td
 create mode 100644 mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp
 create mode 100644 mlir/test/Dialect/LLVMIR/xevm.mlir

diff --git a/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt b/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt
index 9c5bbae1022f7..cfad07e57021f 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt
@@ -87,3 +87,13 @@ mlir_tablegen(VCIXConversions.inc -gen-llvmir-conversions)
 mlir_tablegen(VCIXOpsAttributes.h.inc -gen-attrdef-decls -attrdefs-dialect=vcix)
 mlir_tablegen(VCIXOpsAttributes.cpp.inc -gen-attrdef-defs -attrdefs-dialect=vcix)
 add_public_tablegen_target(MLIRVCIXConversionsIncGen)
+
+add_mlir_dialect(XeVMOps xevm)
+add_mlir_doc(XeVMOps XeVMDialect Dialects/ -gen-dialect-doc -dialect=xevm)
+set(LLVM_TARGET_DEFINITIONS XeVMOps.td)
+mlir_tablegen(XeVMConversions.inc -gen-llvmir-conversions)
+mlir_tablegen(XeVMOpsEnums.h.inc -gen-enum-decls)
+mlir_tablegen(XeVMOpsEnums.cpp.inc -gen-enum-defs)
+mlir_tablegen(XeVMOpsAttributes.h.inc -gen-attrdef-decls -attrdefs-dialect=xevm)
+mlir_tablegen(XeVMOpsAttributes.cpp.inc -gen-attrdef-defs -attrdefs-dialect=xevm)
+add_public_tablegen_target(MLIRXeVMConversionsIncGen)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/XeVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/XeVMDialect.h
new file mode 100644
index 0000000000000..a83d4248c862c
--- /dev/null
+++ b/mlir/include/mlir/Dialect/LLVMIR/XeVMDialect.h
@@ -0,0 +1,28 @@
+//===-- XeVMDialect.h - MLIR XeVM target definitions ------------*- C++ -*-===//
+//
+// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_LLVMIR_XEVMDIALECT_H_
+#define MLIR_DIALECT_LLVMIR_XEVMDIALECT_H_
+
+#include "mlir/Bytecode/BytecodeOpInterface.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/Target/LLVMIR/ModuleTranslation.h"
+
+#include <mlir/Dialect/LLVMIR/XeVMOpsEnums.h.inc>
+
+#define GET_ATTRDEF_CLASSES
+#include <mlir/Dialect/LLVMIR/XeVMOpsAttributes.h.inc>
+
+#define GET_OP_CLASSES
+#include <mlir/Dialect/LLVMIR/XeVMOps.h.inc>
+
+#include <mlir/Dialect/LLVMIR/XeVMOpsDialect.h.inc>
+
+#endif /* MLIR_DIALECT_LLVMIR_XEVMDIALECT_H_ */
diff --git a/mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td
new file mode 100644
index 0000000000000..9525c4a731efa
--- /dev/null
+++ b/mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td
@@ -0,0 +1,550 @@
+//===-- XeVMOps.td - XeVM dialect definition ---------------*- tablegen -*-===//
+//
+// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#ifndef XEVMIR_OPS
+#define XEVMIR_OPS
+
+include "mlir/Dialect/GPU/IR/CompilationAttrInterfaces.td"
+include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
+
+include "mlir/IR/OpBase.td"
+include "mlir/IR/EnumAttr.td"
+
+def XeVM_Dialect : Dialect {
+  let name = "xevm";
+  let cppNamespace = "::mlir::xevm";
+  let dependentDialects = ["LLVM::LLVMDialect"];
+
+  let extraClassDeclaration = [{
+    /// Get the name for the attribute used to specify cache control
+    /// decorations.
+    static constexpr ::llvm::StringRef getCacheControlsAttrName() {
+      return ::llvm::StringLiteral("xevm.DecorationCacheControl");
+    }
+  }];
+
+  let useDefaultAttributePrinterParser = 1;
+}
+
+class XeVM_Attr<string attrName, string attrMnemonic, list<Trait> traits = []>
+    : AttrDef<XeVM_Dialect, attrName, traits> {
+  let mnemonic = attrMnemonic;
+}
+
+class XeVM_Op<string mnemonic, list<Trait> traits = []>
+    : Op<XeVM_Dialect, mnemonic, traits> {
+
+  code extraBaseClassDeclaration = [{
+    void printProperties(::mlir::MLIRContext *ctx,
+            ::mlir::OpAsmPrinter &p, const Properties &prop,
+            ::mlir::ArrayRef<::llvm::StringRef> elidedProps) {
+      Attribute propAttr = getPropertiesAsAttr(ctx, prop);
+      if (propAttr)
+        p << "<" << propAttr << ">";
+    }
+
+    static ::mlir::ParseResult parseProperties(::mlir::OpAsmParser &parser,
+                                     ::mlir::OperationState &result) {
+      if (mlir::succeeded(parser.parseOptionalLess())) {
+        if (parser.parseAttribute(result.propertiesAttr) || parser.parseGreater())
+          return failure();
+      }
+      return success();
+    }
+
+  }];
+}
+
+def XeVM_ElemType : AnyTypeOf<[AnyI8, AnyI16, AnyI32, F32, TF32, F16, BF16]>;
+
+def LoadCacheControlDefault : I32EnumAttrCase<"DEFAULT", 0, "Default">;
+def LoadCacheControl_L1uc_L2uc_L3uc
+    : I32EnumAttrCase<"L1UC_L2UC_L3UC", 1, "L1uc_L2uc_L3uc">;
+def LoadCacheControl_L1uc_L2uc_L3c
+    : I32EnumAttrCase<"L1UC_L2UC_L3C", 2, "L1uc_L2uc_L3c">;
+def LoadCacheControl_L1uc_L2c_L3uc
+    : I32EnumAttrCase<"L1UC_L2C_L3UC", 3, "L1uc_L2c_L3uc">;
+def LoadCacheControl_L1uc_L2c_L3c
+    : I32EnumAttrCase<"L1UC_L2C_L3C", 4, "L1uc_L2c_L3c">;
+def LoadCacheControl_L1c_L2uc_L3uc
+    : I32EnumAttrCase<"L1C_L2UC_L3UC", 5, "L1c_L2uc_L3uc">;
+def LoadCacheControl_L1c_L2uc_L3c
+    : I32EnumAttrCase<"L1C_L2UC_L3C", 6, "L1c_L2uc_L3c">;
+def LoadCacheControl_L1c_L2c_L3uc
+    : I32EnumAttrCase<"L1C_L2C_L3UC", 7, "L1c_L2c_L3uc">;
+def LoadCacheControl_L1c_L2c_L3c
+    : I32EnumAttrCase<"L1C_L2C_L3C", 8, "L1c_L2c_L3c">;
+def LoadCacheControl_L1s_L2uc_L3uc
+    : I32EnumAttrCase<"L1S_L2UC_L3UC", 9, "L1s_L2uc_L3uc">;
+def LoadCacheControl_L1s_L2uc_L3c
+    : I32EnumAttrCase<"L1S_L2UC_L3C", 10, "L1s_L2uc_L3c">;
+def LoadCacheControl_L1s_L2c_L3uc
+    : I32EnumAttrCase<"L1S_L2C_L3UC", 11, "L1s_L2c_L3uc">;
+def LoadCacheControl_L1s_L2c_L3c
+    : I32EnumAttrCase<"L1S_L2C_L3C", 12, "L1s_L2c_L3c">;
+def LoadCacheControlInvalidateRead
+    : I32EnumAttrCase<"INVALIDATE_READ", 13, "ir">;
+
+def XeVM_LoadCacheControl
+    : I32EnumAttr<
+          "LoadCacheControl", "XeVM load ops cache control",
+          [LoadCacheControlDefault, LoadCacheControl_L1uc_L2uc_L3uc,
+           LoadCacheControl_L1uc_L2uc_L3c, LoadCacheControl_L1uc_L2c_L3uc,
+           LoadCacheControl_L1uc_L2c_L3c, LoadCacheControl_L1c_L2uc_L3uc,
+           LoadCacheControl_L1c_L2uc_L3c, LoadCacheControl_L1c_L2c_L3uc,
+           LoadCacheControl_L1c_L2c_L3c, LoadCacheControl_L1s_L2uc_L3uc,
+           LoadCacheControl_L1s_L2uc_L3c, LoadCacheControl_L1s_L2c_L3uc,
+           LoadCacheControl_L1s_L2c_L3c, LoadCacheControlInvalidateRead]> {
+  let cppNamespace = "::mlir::xevm";
+  let genSpecializedAttr = 0;
+}
+
+def XeVM_LoadCacheControlAttr
+    : EnumAttr<XeVM_Dialect, XeVM_LoadCacheControl, "load_cache_control"> {
+  let summary = [{Describe the cache settings for load operators}];
+  let assemblyFormat = "`<` $value `>`";
+}
+
+def StoreCacheControlDefault : I32EnumAttrCase<"DEFAULT", 0, "Default">;
+def StoreCacheControl_L1uc_L2uc_L3uc
+    : I32EnumAttrCase<"L1UC_L2UC_L3UC", 1, "L1uc_L2uc_L3uc">;
+def StoreCacheControl_L1uc_L2uc_L3wb
+    : I32EnumAttrCase<"L1UC_L2UC_L3WB", 2, "L1uc_L2uc_L3wb">;
+def StoreCacheControl_L1uc_L2wb_L3uc
+    : I32EnumAttrCase<"L1UC_L2WB_L3UC", 3, "L1uc_L2wb_L3uc">;
+def StoreCacheControl_L1uc_L2wb_L3wb
+    : I32EnumAttrCase<"L1UC_L2WB_L3WB", 4, "L1uc_L2wb_L3wb">;
+def StoreCacheControl_L1wt_L2uc_L3uc
+    : I32EnumAttrCase<"L1WT_L2UC_L3UC", 5, "L1wt_L2uc_L3uc">;
+def StoreCacheControl_L1wt_L2uc_L3wb
+    : I32EnumAttrCase<"L1WT_L2UC_L3WB", 6, "L1wt_L2uc_L3wb">;
+def StoreCacheControl_L1wt_L2wb_L3uc
+    : I32EnumAttrCase<"L1WT_L2WB_L3UC", 7, "L1wt_L2wb_L3uc">;
+def StoreCacheControl_L1wt_L2wb_L3wb
+    : I32EnumAttrCase<"L1WT_L2WB_L3WB", 8, "L1wt_L2wb_L3wb">;
+def StoreCacheControl_L1s_L2uc_L3uc
+    : I32EnumAttrCase<"L1S_L2UC_L3UC", 9, "L1s_L2uc_L3uc">;
+def StoreCacheControl_L1s_L2uc_L3wb
+    : I32EnumAttrCase<"L1S_L2UC_L3WB", 10, "L1s_L2uc_L3wb">;
+def StoreCacheControl_L1s_L2wb_L3uc
+    : I32EnumAttrCase<"L1S_L2WB_L3UC", 11, "L1s_L2wb_L3uc">;
+def StoreCacheControl_L1s_L2wb_L3wb
+    : I32EnumAttrCase<"L1S_L2WB_L3WB", 12, "L1s_L2wb_L3wb">;
+def StoreCacheControl_L1wb_L2uc_L3uc
+    : I32EnumAttrCase<"L1WB_L2UC_L3UC", 13, "L1wb_L2uc_L3uc">;
+def StoreCacheControl_L1wb_L2wb_L3uc
+    : I32EnumAttrCase<"L1WB_L2WB_L3UC", 14, "L1wb_L2wb_L3uc">;
+def StoreCacheControl_L1wb_L2uc_L3wb
+    : I32EnumAttrCase<"L1WB_L2UC_L3WB", 15, "L1wb_L2uc_L3wb">;
+
+def XeVM_StoreCacheControl
+    : I32EnumAttr<
+          "StoreCacheControl", "XeVM store ops cache control",
+          [StoreCacheControlDefault, StoreCacheControl_L1uc_L2uc_L3uc,
+           StoreCacheControl_L1uc_L2uc_L3wb, StoreCacheControl_L1uc_L2wb_L3uc,
+           StoreCacheControl_L1uc_L2wb_L3wb, StoreCacheControl_L1wt_L2uc_L3uc,
+           StoreCacheControl_L1wt_L2uc_L3wb, StoreCacheControl_L1wt_L2wb_L3uc,
+           StoreCacheControl_L1wt_L2wb_L3wb, StoreCacheControl_L1s_L2uc_L3uc,
+           StoreCacheControl_L1s_L2uc_L3wb, StoreCacheControl_L1s_L2wb_L3uc,
+           StoreCacheControl_L1s_L2wb_L3wb, StoreCacheControl_L1wb_L2uc_L3uc,
+           StoreCacheControl_L1wb_L2wb_L3uc,
+           StoreCacheControl_L1wb_L2uc_L3wb]> {
+  let cppNamespace = "::mlir::xevm";
+  let genSpecializedAttr = 0;
+}
+
+def XeVM_StoreCacheControlAttr
+    : EnumAttr<XeVM_Dialect, XeVM_StoreCacheControl, "store_cache_control"> {
+  let summary = [{Describe the cache settings for store operators}];
+  let assemblyFormat = "`<` $value `>`";
+}
+
+def XeVM_BlockLoad2dOp
+    : XeVM_Op<"blockload2d">,
+      Results<(outs FixedVectorOfRankAndType<[1], [XeVM_ElemType]>:$res)>,
+      Arguments<(ins Arg<LLVM_AnyPointer, "", [MemRead]>:$ptr, I32:$base_width,
+          I32:$base_height, I32:$base_pitch, I32:$x, I32:$y,
+          I32Attr:$elem_size_in_bits, I32Attr:$tile_width, I32Attr:$tile_height,
+          I32Attr:$v_blocks, I1Attr:$transpose, I1Attr:$pack_register,
+          OptionalAttr<XeVM_LoadCacheControlAttr>:$cache_control)> {
+
+  let summary = "2D block load";
+
+  let description = [{
+    The `xevm.blockload2d` operation loads a two dimensional matrix tile
+    from a base matrix residing in memory. The parameters are:
+      $ptr - the base address of the base matrix containing the tile to load
+      $base_width, $base_height, $base_pitch - the shape of the base matrix.
+      pitch is the physical stride between the first columns of the current row
+      and the subsequent row. All units are in bytes.
+      $x, $y, $tile_width, $tile_height - the starting offsets and shape of
+      the tile to load in number of elements.
+      $elem_size_in_bits - the size in bits of the matrix element type
+        - 32 for f32, tf32
+        - 16 for f16, int16, bf16
+        - 8 for int8
+      $v_blocks - number of consecutive tiles in innermost dimension direction to load
+      $transpose - transpose the tile in registers (useful for 32 bit element type)
+      $pack_register - pack element types narrower than register bit width. [M, N] => [M/factor, N, factor] where factor is register_size_in_bits / elem_size_in_bits
+      $cache_control - an enumerator that sets the cache behaviour
+
+    Notes:
+      - the $transpose and $pack_register parameters are mutual exclusive
+      - transposing the tile loaded is used for A matrix in backward path or used for the B matrix operand
+        (D = C + A * B), where A has row-major layout and B should have column-major layout in memory.
+      - if the tile loaded contains out of bound elements of the matrix, they are filled with 0.
+
+    Example:
+    ```mlir
+      %base_width_a = arith.constant 32 : i32
+      %base_height_a = arith.constant 8 : i32
+      %base_pitch_a = arith.constant 32 : i32
+      %x = arith.constant 0 : i32
+      %y = arith.constant 0 : i32
+      %loaded_a = xevm.blockload2d %src, %base_width_a, %base_height_a, %base_pitch_a, %x, %y <{elem_size_in_bits=16 : i32, tile_width=16 : i32, tile_height=8 : i32, v_blocks=1 : i32, transpose=false : i32, pack_register=false, cache_control=#xevm.load_cache_control<Default>}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi16>
+    ```
+  }];
+
+  let assemblyFormat = [{
+    operands prop-dict attr-dict `:` functional-type(operands, results)
+  }];
+
+  let extraClassDeclaration = extraBaseClassDeclaration#[{
+    /// Get cache control or return default if not set.
+    ::mlir::xevm::LoadCacheControl getCacheControlOrDefault() {
+      if(getCacheControl())
+        return *getCacheControl();
+      return ::mlir::xevm::LoadCacheControl::DEFAULT;
+    }
+  }];
+
+  let hasVerifier = 1;
+}
+
+def XeVM_BlockStore2dOp
+    : XeVM_Op<"blockstore2d">,
+      Arguments<(ins Arg<LLVM_AnyPointer, "", [MemWrite]>:$ptr, I32:$base_width,
+          I32:$base_height, I32:$base_pitch, I32:$x, I32:$y,
+          I32Attr:$elem_size_in_bits, I32Attr:$tile_width, I32Attr:$tile_height,
+          FixedVectorOfRankAndType<[1], [XeVM_ElemType]>:$stored_val,
+          OptionalAttr<XeVM_StoreCacheControlAttr>:$cache_control)> {
+
+  let summary = "2D block store";
+
+  let description = [{
+    The `xevm.blockstore2d` operation stores a two dimensional tile into a
+    larger matrix residing in memory. The parameters are:
+      $ptr - the base address of the target matrix where to store the tile
+      $base_width, $base_height, $base_pitch - the shape of the target matrix. pitch is the
+      physical stride between the first columns of the current row and the subsequent row.
+      All units are in bytes.
+      $x, $y, $tile_width, $tile_height - the starting offsets and shape of the tile to store
+      in number of elements.
+      $elem_size_in_bits - the size in bits of the matrix element
+        - 32 for f32, tf32
+        - 16 for f16, int16, bf16
+        - 8 for int8
+      $cache_control - an enumerator that sets the cache behaviour
+      $stored_val - the tile to store
+
+    Example:
+    ```mlir
+      %base_width_c = arith.constant 64 : i32
+      %base_height_c = arith.constant 8 : i32
+      %base_pitch_c = arith.constant 64 : i32
+      %x = arith.constant 0 : i32
+      %y = arith.constant 0 : i32
+      xevm.blockstore2d %dst, %base_width_c, %base_height_c, %base_pitch_c, %x, %y, %src <{elem_size_in_bits=32 : i32, tile_width=16 : i32, tile_height=8 : i32, cache_control=#xevm.load_cache_control<Default>}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi32>)
+    ```
+  }];
+
+  let assemblyFormat = [{
+    operands prop-dict attr-dict `:` `(` type(operands) `)`
+  }];
+
+  let extraClassDeclaration = extraBaseClassDeclaration#[{
+    /// Get cache control or return default if not set.
+    ::mlir::xevm::StoreCacheControl getCacheControlOrDefault() {
+      if(getCacheControl())
+        return *getCacheControl();
+      return ::mlir::xevm::StoreCacheControl::DEFAULT;
+    }
+
+    /// Default value for v_blocks is 1.
+    constexpr uint32_t getVBlocks() {
+      return 1;
+    }
+  }];
+
+  let hasVerifier = 1;
+}
+
+def MemScopeLane : I32EnumAttrCase<"LANE", 0, "lane">;
+def MemScopeSg : I32EnumAttrCase<"SUBGROUP", 1, "subgroup">;
+def MemScopeWg : I32EnumAttrCase<"WORKGROUP", 2, "workgroup">;
+def MemScopeCluster : I32EnumAttrCase<"CLUSTER", 3, "cluster">;
+def MemScopeDevice : I32EnumAttrCase<"DEVICE", 4, "device">;
+def MemScopeSystem : I32EnumAttrCase<"SYSTEM", 5, "system">;
+
+def XeVM_MemScope
+    : I32EnumAttr<"MemScope", "XeVM memory scope",
+                  [MemScopeLane, MemScopeSg, MemScopeWg, MemScopeCluster,
+                   MemScopeDevice, MemScopeSystem]> {
+  let genSpecializedAttr = 0;
+  let cppNamespace = "::mlir::xevm";
+}
+def XeVM_MemScopeAttr : EnumAttr<XeVM_Dialect, XeVM_MemScope, "mem_scope"> {
+  let summary = [{Describe memory scopes}];
+  let assemblyFormat = "`<` $value `>`";
+}
+
+def AddrSpacePrivate : I32EnumAttrCase<"PRIVATE", 0, "private">;
+def AddrSpaceGlobal : I32EnumAttrCase<"GLOBAL", 1, "global">;
+def AddrSpaceConstant : I32EnumAttrCase<"CONSTANT", 2, "constant">;
+def AddrSpaceShared : I32EnumAttrCase<"SHARED", 3, "shared">;
+def AddrSpaceGeneric : I32EnumAttrCase<"GENERIC", 4, "generic">;
+
+def XeVM_AddrSpace
+    : I32EnumAttr<"AddrSpace", "Address spaces",
+                  [AddrSpacePrivate, AddrSpaceGlobal, AddrSpaceConstant,
+                   AddrSpaceShared, AddrSpaceGeneric]> {
+  let genSpecializedAttr = 0;
+  let cppNamespace = "mlir::xevm";
+}
+def XeVM_AddrSpaceAttr : EnumAttr<XeVM_Dialect, XeVM_AddrSpace, "addr_space"> {
+  let summary = [{Describe address spaces}];
+  let assemblyFormat = "`<` $value `>`";
+}
+
+def XeVM_MemfenceOp
+    : XeVM_Op<"memfence">,
+      Arguments<(ins XeVM_MemScopeAttr:$scope,
+          DefaultValuedAttr<XeVM_AddrSpaceAttr,
+                            "mlir::xevm::AddrSpace::GENERIC">:$addrspace)> {
+  let summary = "Work-item's memory fence.";
+  let description = [{
+    This operation ensures that all prior memory accesses of this
+    work-item to `addrspace` are visible to all other work-items in `scope`.
+    Parameters description:
+    $scope - specify the memory scope at which all other work-items should observe
+                memory operations prior to the fence.
+    $addrspace - specify the address space of work-item's memory accesses
+                to be affected by the fence.
+  }];
+  let assemblyFormat = [{prop-dict  attr-dict}];
+
+  let extraClassDeclaration = extraBaseClassDeclaration#[{
+  }];
+}
+
+def XeVM_PrefetchOp
+    : XeVM_Op<"prefetch">,
+      Arguments<(ins Arg<LLVM_AnyPointer, "", [MemRead]>:$ptr,
+          XeVM_AddrSpaceAttr:$addrspace,
+          OptionalAttr<XeVM_LoadCacheControlAttr>:$cache_control)> {
+  let summary = "Prefetch data into a cache subsystem.";
+  let description = [{
+    Work-item issues a prefetch from global memory to cache:
+    $ptr - memory pointer.
+    $addrspace - address space of a pointer, must be generic or global.
+    $cache_control - specify caching options
+  }];
+  let assemblyFormat = [{
+    operands prop-dict attr-dict `:` `(` type(operands) `)`
+  }];
+
+  let extraClassDeclaration = extraBaseClassDeclaration#[{
+    /// Get cache control or return default if not set.
+    ::mlir::xevm::LoadCacheControl getCacheControlOrDefault() {
+      if(getCacheControl())
+        return *getCacheControl();
+      return ::mlir::xevm::LoadCacheControl::DEFAULT;
+    }
+  }];
+
+  let hasVerifier = 1;
+}
+
+def XeVM_BlockPrefetch2dOp
+    : XeVM_Op<"blockprefetch2d">,
+      Arguments<(ins Arg<LLVM_AnyPointer, "", [MemRead]>:$ptr, I32:$base_width,
+          I32:$base_height, I32:$base_pitch, I32:$x, I32:$y,
+          I32Attr:$elem_size_in_bits, I32Attr:$tile_width, I32Attr:$tile_height,
+          I32Attr:$v_blocks,
+          OptionalAttr<XeVM_LoadCacheControlAttr>:$cache_control)> {
+
+  let summary = "2D block prefetch";
+
+  let description = [{
+    The `xevm.blockprefetch2d` operation prefetches a two dimensional tile
+    from a larger base matrix residing in memory. The parameters are:
+      $ptr - the base address of the base matrix containing the tile to prefetch
+      $base_width, $base_height, $base_pitch - the shape of the base matrix.
+      pitch is the physical stride between the first columns of the current row
+      and the subsequent row. All units are in bytes.
+      $x, $y, $tile_width, $tile_height - the starting offsets and shape of tile
+      to prefetch in number of elements.
+      $elem_size_in_bits - the size in bits of the matrix element
+      - 32 for f32, bf32
+      - 16 for f16, int16, bf16
+      - 8 for int8, int4, int2
+      $v_blocks - number of tiles in innermost dimension direction to prefetch
+      $cache_control - an enumerator that sets the cache behaviour
+
+    Example:
+    ```mlir
+      xevm.blockprefetch2d %ptr, %base_width, %base_height, %base_pitch, %x, %y <{elem_size_in_bits=8 : i32, tile_width=32 : i32, tile_height=8 : i32, v_blocks=1 : i32, cache_control=#xevm.load_cache_control<L1uc_L2uc_L3uc>}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32)
+    ```
+  }];
+
+  let assemblyFormat = [{
+    operands prop-dict attr-dict `:` `(` type(operands) `)`
+  }];
+
+  let extraClassDeclaration = extraBaseClassDeclaration#[{
+    /// Get cache control or return default if not set.
+    ::mlir::xevm::LoadCacheControl getCacheControlOrDefault() {
+      if(getCacheControl())
+        return *getCacheControl();
+      return ::mlir::xevm::LoadCacheControl::DEFAULT;
+    }
+  }];
+
+  let hasVerifier = 1;
+}
+
+def XeVM_MatrixElemType
+    : AnyTypeOf<[AnyI8, AnyI16, AnyI32, F32, TF32, F16, BF16]>;
+
+/// Enum attribute of the different element types.
+def XeVM_ET_BF16 : I32EnumAttrCase<"BF16", 8, "bf16">;
+def XeVM_ET_F16 : I32EnumAttrCase<"F16", 9, "f16">;
+def XeVM_ET_S8 : I32EnumAttrCase<"S8", 10, "s8">;
+def XeVM_ET_U8 : I32EnumAttrCase<"U8", 11, "u8">;
+def XeVM_ET_S4 : I32EnumAttrCase<"S4", 12, "s4">;
+def XeVM_ET_U4 : I32EnumAttrCase<"U4", 13, "u4">;
+def XeVM_ET_TF32 : I32EnumAttrCase<"TF32", 14, "tf32">;
+def XeVM_ET_F32 : I32EnumAttrCase<"F32", 15, "f32">;
+def XeVM_ET_S32 : I32EnumAttrCase<"S32", 16, "s32">;
+
+def XeVM_ElemTypeAttr : I32EnumAttr<"ElemType", "XeVM element type",
+                                    [XeVM_ET_BF16, XeVM_ET_F16, XeVM_ET_S8,
+                                     XeVM_ET_U8, XeVM_ET_S4, XeVM_ET_U4,
+                                     XeVM_ET_TF32, XeVM_ET_F32, XeVM_ET_S32]> {
+  let cppNamespace = "::mlir::xevm";
+}
+
+def XeVM_MMAShapeAttr : XeVM_Attr<"MMAShape", "mma_shape"> {
+  let description = [{
+    MMA operation is represented as D=AxB+C, where
+      A has the shape MxK.
+      B has the shape KxN.
+      D and C havethe shape MxN.
+    This attribute encodes the shape of all matrices that participate in MMA.
+  }];
+  let parameters = (ins "int":$m, "int":$n, "int":$k);
+  let assemblyFormat = "`<` struct(params) `>`";
+}
+
+def XeVM_MMATypesAttr : XeVM_Attr<"MMATypes", "mma_types"> {
+  let parameters = (ins "xevm::ElemType":$d, "xevm::ElemType":$a,
+      "xevm::ElemType":$b, OptionalParameter<"xevm::ElemType">:$c);
+  let assemblyFormat = "`<` struct(params) `>`";
+}
+
+def XeVM_MMAOp
+    : XeVM_Op<"mma">,
+      Results<(outs FixedVectorOfRankAndType<[1], [XeVM_MatrixElemType]>:$d)>,
+      Arguments<(ins FixedVectorOfRankAndType<[1], [XeVM_MatrixElemType]>:$a,
+          FixedVectorOfRankAndType<[1], [XeVM_MatrixElemType]>:$b,
+          Optional<FixedVectorOfRankAndType<[1], [XeVM_MatrixElemType]>>:$c,
+          XeVM_MMAShapeAttr:$shape, XeVM_MMATypesAttr:$types)> {
+
+  let summary = "Subgroup matrix multiply-add";
+
+  let description = [{
+    The `xevm.mma` is a cooperative operation where all threads/lanes in
+    a subgroup participates and carries out matrix multiplication plus accumulation:
+
+      D = C + A x B
+
+      where the A, B, C input matrices and the result D have shapes:
+        D : MxN
+        C : MxN
+        A : MxK
+        B : KxN
+
+    Parameters:
+      `a` - vector of matrix A elements.
+      `b` - vector of matrix B elements.
+      `c` - (optional) vector of matrix C elements.
+      `shape` - the shape of the matrices, specified as `M`, `N`, and `K` values.
+      `types` - the data types of the matrices, specified as `D`, `A`, `B`, and optionally `C`.
+
+    Example:
+    ```mlir
+      %d = xevm.mma %a, %b, %c { shape=<m=8, n=16, k=16>, types=<d=f32, a=f16, b=f16, c=f32> } : (vector<8xi16>, vector<8xi32>, vector<8xf32>) -> vector<8xf32>
+    ```
+  }];
+
+  let assemblyFormat = [{
+    $a `,` $b (`,` $c^)? ` `
+    `{`
+      `shape` `=` $shape `,`
+      `types` `=` $types
+    `}` attr-dict `:` functional-type(operands, results)
+  }];
+
+  let hasVerifier = 1;
+}
+
+//===----------------------------------------------------------------------===//
+// XeVM target attribute.
+//===----------------------------------------------------------------------===//
+
+def XeVM_TargetAttr : XeVM_Attr<"XeVMTarget", "target"> {
+  let description = [{
+    GPU target attribute for controlling compilation of Intel GPU targets. All
+    parameters decay into default values if not present.
+
+    Examples:
+
+    1. Target with default values.
+    ```
+      gpu.module @mymodule [#xevm.target] attributes {...} {
+        ...
+      }
+    ```
+  }];
+  let parameters =
+      (ins DefaultValuedParameter<"int", "2",
+                                  "Optimization level to apply.">:$O,
+          StringRefParameter<"Target triple.",
+                             "\"spirv64-unknown-unknown\"">:$triple,
+          StringRefParameter<"Target chip.", "\"bmg\"">:$chip,
+          OptionalParameter<"::mlir::DictionaryAttr",
+                            "Target specific flags.">:$flags,
+          OptionalParameter<"::mlir::ArrayAttr",
+                            "Files to link to the LLVM module.">:$linkFiles);
+  let assemblyFormat = [{
+    (`<` struct($O, $triple, $chip, $flags, $linkFiles)^ `>`)?
+  }];
+  let builders = [AttrBuilder<
+      (ins CArg<"int", "2">:$optLevel,
+          CArg<"::llvm::StringRef", "\"spirv64-unknown-unknown\"">:$triple,
+          CArg<"::llvm::StringRef", "\"bmg\"">:$chip,
+          CArg<"::mlir::DictionaryAttr", "nullptr">:$targetFlags,
+          CArg<"::mlir::ArrayAttr", "nullptr">:$linkFiles),
+      [{
+      return Base::get($_ctxt, optLevel, triple, chip, targetFlags, linkFiles);
+    }]>];
+  let skipDefaultBuilders = 1;
+  let genVerifyDecl = 1;
+}
+
+#endif // XEVMIR_OPS
diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index 261b0e00bdf86..c6fcf1a0d510b 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -46,6 +46,7 @@
 #include "mlir/Dialect/LLVMIR/NVVMDialect.h"
 #include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
 #include "mlir/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.h"
+#include "mlir/Dialect/LLVMIR/XeVMDialect.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/Linalg/Transforms/AllInterfaces.h"
 #include "mlir/Dialect/Linalg/Transforms/RuntimeOpVerification.h"
@@ -152,7 +153,8 @@ inline void registerAllDialects(DialectRegistry &registry) {
                   ub::UBDialect,
                   vector::VectorDialect,
                   x86vector::X86VectorDialect,
-                  xegpu::XeGPUDialect>();
+                  xegpu::XeGPUDialect,
+                  xevm::XeVMDialect>();
   // clang-format on
 
   // Register all external models.
diff --git a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt
index d83fd3800eb91..67081ca61e6e5 100644
--- a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt
+++ b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt
@@ -110,3 +110,25 @@ add_mlir_dialect_library(MLIRVCIXDialect
   MLIRLLVMDialect
   MLIRSideEffectInterfaces
   )
+
+add_mlir_dialect_library(MLIRXeVMDialect
+  IR/XeVMDialect.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/LLVMIR
+
+  DEPENDS
+  MLIRGPUCompilationAttrInterfacesIncGen
+  MLIRXeVMOpsIncGen
+  MLIRXeVMConversionsIncGen
+  intrinsics_gen
+
+  LINK_COMPONENTS
+  AsmParser
+  Core
+
+  LINK_LIBS PUBLIC
+  MLIRIR
+  MLIRLLVMDialect
+  MLIRSideEffectInterfaces
+)
diff --git a/mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp
new file mode 100644
index 0000000000000..afb14666f06be
--- /dev/null
+++ b/mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp
@@ -0,0 +1,381 @@
+//===-- XeVMDialect.cpp - XeVM dialect registration -------------*- C++ -*-===//
+//
+// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/LLVMIR/XeVMDialect.h"
+#include "mlir/Dialect/GPU/IR/CompilationInterfaces.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/IR/DialectImplementation.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/FileSystem.h"
+#include "llvm/Support/MathExtras.h"
+
+using namespace mlir;
+using namespace mlir::xevm;
+
+#include <mlir/Dialect/LLVMIR/XeVMOpsDialect.cpp.inc>
+#include <mlir/Dialect/LLVMIR/XeVMOpsEnums.cpp.inc>
+
+namespace {
+constexpr uint32_t subgroupSize = 16;
+
+template <typename Op>
+LogicalResult verifyMatrixInput(Op op) {
+  static_assert(llvm::is_one_of<Op, BlockLoad2dOp, BlockStore2dOp,
+                                BlockPrefetch2dOp>::value,
+                "Unexpected template parameter");
+
+  std::optional<int64_t> width = getConstantIntValue(op.getBaseWidth());
+  std::optional<int64_t> pitch = getConstantIntValue(op.getBasePitch());
+  if (pitch && width && *pitch < *width)
+    return op->emitOpError(
+        "4th operand (base pitch) should be >= 2nd operand (base width)");
+
+  uint32_t elemSize = op.getElemSizeInBits();
+  if (elemSize < 8 || !llvm::isPowerOf2_32(elemSize) || elemSize > 32)
+    return op->emitOpError("expecting 'elem_size_in_bits' to be 8, 16, or 32");
+
+  uint32_t tileHeight = op.getTileHeight();
+  if (tileHeight > 32 || !llvm::isPowerOf2_32(tileHeight))
+    return op->emitOpError("expecting tile_height to be 1, 2, 4, 8, 16, or 32");
+
+  uint32_t vBlocks = op.getVBlocks();
+  if (vBlocks > 8 || !llvm::isPowerOf2_32(vBlocks))
+    return op->emitOpError("expecting v_blocks to be 1, 2, 4, or 8");
+
+  return success();
+}
+
+LogicalResult verify2DBlockLoadRestriction(BlockLoad2dOp op) {
+  VectorType resTy = op.getRes().getType();
+  if (!resTy.getElementType().isIntOrFloat())
+    return op.emitOpError()
+           << "expecting result element type to be int or float";
+  unsigned resElemTySize = resTy.getElementType().getIntOrFloatBitWidth();
+  unsigned resSize = resTy.getNumElements() * resElemTySize;
+  unsigned expectedSize = op.getElemSizeInBits() * op.getTileHeight() *
+                          op.getTileWidth() * op.getVBlocks() / subgroupSize;
+  if (resSize != expectedSize)
+    return op.emitOpError() << "result size of " << resSize
+                            << " bits does not match the expected size of "
+                            << expectedSize << " bits";
+
+  if (op.getTranspose() && op.getPackRegister())
+    return op.emitOpError(
+        "transpose and vnni_transform are mutually exclusive");
+
+  if (!op.getTranspose() && !op.getPackRegister()) {
+    uint32_t tileHeight = op.getTileHeight();
+    if (tileHeight < 1 || tileHeight > 32)
+      return op.emitOpError("expecting tile_height to be between 1 and 32");
+
+    uint32_t tileWidth = op.getTileWidth();
+    uint32_t vBlocks = op.getVBlocks();
+    switch (op.getElemSizeInBits()) {
+    case 8:
+      if (tileWidth < 4 || tileWidth > 64)
+        return op.emitOpError("expecting tile_width to be between 4 and 64");
+      if (vBlocks != 1 && vBlocks != 2 && vBlocks != 4)
+        return op.emitOpError("expecting v_blocks to be 1, 2, or 4");
+      if (tileWidth * vBlocks > 64)
+        return op.emitOpError(
+            "tile_width * v_blocks should be less than or equal "
+            "to 64 for 8 bit elements");
+      break;
+    case 16:
+      if (tileWidth < 2 || tileWidth > 32)
+        return op.emitOpError("expecting tile_width to be between 2 and 32");
+      if (vBlocks != 1 && vBlocks != 2 && vBlocks != 4)
+        return op.emitOpError("expecting v_blocks to be 1, 2, or 4");
+      if (tileWidth * vBlocks > 32)
+        return op.emitOpError(
+            "tile_width * v_blocks should be less than or equal "
+            "to 32 for 16 bit elements");
+      break;
+    case 32:
+      if (tileWidth < 1 || tileWidth > 16)
+        return op.emitOpError("expecting tile_width to be between 1 and 16");
+      if (vBlocks != 1 && vBlocks != 2)
+        return op.emitOpError("expecting v_blocks to be 1 or 2");
+      if (tileWidth * vBlocks > 16)
+        return op.emitOpError(
+            "tile_width * v_blocks should be less than or equal "
+            "to 16 for 32 bit elements");
+      break;
+    case 64:
+      if (tileWidth < 1 || tileWidth > 8)
+        return op.emitOpError("expecting tile_width to be between 1 and 8");
+      if (vBlocks != 1)
+        return op.emitOpError("expecting v_blocks to be 1");
+      break;
+    default:
+      return op.emitOpError(
+          "expecting elem_size_in_bits to be 8, 16, 32, or 64");
+    }
+
+    return success();
+  }
+
+  if (op.getTranspose()) {
+    assert(!op.getPackRegister() && "Expecting vnni_transform should be false");
+
+    uint32_t vBlocks = op.getVBlocks();
+    if (vBlocks != 1)
+      return op.emitOpError("expecting v_blocks to be 1");
+
+    uint32_t tileHeight = op.getTileHeight();
+    uint32_t tileWidth = op.getTileWidth();
+    switch (op.getElemSizeInBits()) {
+    case 32:
+      if (tileHeight < 1 || tileHeight > 32)
+        return op.emitOpError("expecting tile_height to be between 1 and 32");
+      if (tileWidth < 1 || tileWidth > 8)
+        return op.emitOpError("expecting tile_width to be between 1 and 8");
+      break;
+    case 64:
+      if (tileHeight != 8)
+        return op.emitOpError(
+            "expecting tile_height to be 8 for 64 bit elements");
+      if (tileWidth != 1 && tileWidth != 2 && tileWidth != 4)
+        return op.emitOpError("expecting tile_width to be 1, 2, or 4");
+      break;
+    default:
+      return op.emitOpError("transpose is only supported for 32 and 64 bit "
+                            "elements");
+    }
+
+    return success();
+  }
+
+  assert(op.getPackRegister() && !op.getTranspose() &&
+         "Expecting vnni_transform should be true and transpose should be "
+         "false");
+
+  uint32_t vBlocks = op.getVBlocks();
+  if (vBlocks != 1 && vBlocks != 2 && vBlocks != 4)
+    return op.emitOpError("expecting v_blocks to be 1, 2, or 4");
+
+  uint32_t tileHeight = op.getTileHeight();
+  uint32_t tileWidth = op.getTileWidth();
+  switch (op.getElemSizeInBits()) {
+  case 8:
+    if (tileHeight < 4 || tileHeight > 32)
+      return op.emitOpError("expecting tile_height to be between 4 and 32");
+    if (tileWidth < 4 || tileWidth > 16)
+      return op.emitOpError("expecting tile_width to be between 4 and 16");
+    break;
+  case 16:
+    if (tileHeight < 2 || tileHeight > 32)
+      return op.emitOpError("expecting tile_height to be between 2 and 32");
+    if (tileWidth < 2 || tileWidth > 16)
+      return op.emitOpError("expecting tile_width to be between 2 and 16");
+    if (tileWidth * vBlocks > 32)
+      return op.emitOpError(
+          "tile_width * v_blocks should be less than or equal "
+          "to 32 for 16 bit elements");
+    break;
+  default:
+    return op.emitOpError("vnni_transform is only supported for 8 and 16 bit "
+                          "elements");
+  }
+
+  return success();
+}
+
+static LogicalResult verify2DBlockStoreRestriction(BlockStore2dOp op) {
+  uint32_t tileHeight = op.getTileHeight();
+  if (tileHeight < 1 || tileHeight > 8)
+    return op.emitOpError("expecting tile_height to be between 1 and 8");
+
+  uint32_t tileWidth = op.getTileWidth();
+  switch (op.getElemSizeInBits()) {
+  case 8:
+    if (tileWidth < 4 || tileWidth > 64)
+      return op.emitOpError("expecting tile_width to be between 4 and 64");
+    break;
+  case 16:
+    if (tileWidth < 2 || tileWidth > 32)
+      return op.emitOpError("expecting tile_width to be between 2 and 32");
+    break;
+  case 32:
+    if (tileWidth < 1 || tileWidth > 16)
+      return op.emitOpError("expecting tile_width to be between 1 and 16");
+    break;
+  case 64:
+    if (tileWidth < 1 || tileWidth > 8)
+      return op.emitOpError("expecting tile_width to be between 1 and 8");
+    break;
+  default:
+    return op.emitOpError("expecting elem_size_in_bits to be 8, 16, 32, or 64");
+  }
+
+  uint32_t vBlocks = op.getVBlocks();
+  if (vBlocks != 1)
+    return op.emitOpError("expecting v_blocks to be 1");
+  return success();
+}
+
+} // namespace
+
+LogicalResult BlockLoad2dOp::verify() {
+  if (verify2DBlockLoadRestriction(*this).failed())
+    return failure();
+
+  if (verifyMatrixInput(*this).failed())
+    return failure();
+
+  VectorType resTy = getRes().getType();
+  if (!resTy.getElementType().isIntOrFloat())
+    return emitOpError() << "expecting result element type to be int of float";
+  unsigned resElemTySize = resTy.getElementType().getIntOrFloatBitWidth();
+  if (getElemSizeInBits() == 32 || getPackRegister()) {
+    if (resElemTySize != 32)
+      return emitOpError() << "expecting result element type to be 32 bits";
+  }
+
+  uint32_t tileWidth = getTileWidth();
+  if (getPackRegister()) {
+    if (tileWidth != 16)
+      return emitOpError(
+          "tile_width when vnni_transform is true should be equal "
+          "to subgroup size (16 elements)");
+    return success();
+  }
+
+  return success();
+}
+
+LogicalResult BlockStore2dOp::verify() {
+  if (verify2DBlockStoreRestriction(*this).failed())
+    return failure();
+
+  if (verifyMatrixInput(*this).failed())
+    return failure();
+
+  uint32_t tileWidth = getTileWidth();
+  switch (getElemSizeInBits()) {
+  case 8:
+    if (tileWidth != 16 && tileWidth != 32)
+      return emitOpError("tile_width for 8 bit elements should be equal to "
+                         "16 or 32");
+    break;
+  case 16:
+    if (tileWidth != 16)
+      return emitOpError("tile_width for 16 bit elements should be equal "
+                         "to 16");
+    break;
+  case 32:
+    if (tileWidth != 16)
+      return emitOpError("tile_width for 32 bit elements should be equal "
+                         "to 16");
+    break;
+  default:
+    llvm_unreachable("unexpected element size");
+  }
+
+  return success();
+}
+
+LogicalResult BlockPrefetch2dOp::verify() {
+  if (verifyMatrixInput(*this).failed())
+    return failure();
+
+  uint32_t tileWidth = getTileWidth();
+  switch (getElemSizeInBits()) {
+  case 8:
+    if (tileWidth != 16 && tileWidth != 32)
+      return emitOpError("tile_width for 8 bit elements should be equal to "
+                         "16 or 32");
+    break;
+  case 16:
+    if (tileWidth != 16)
+      return emitOpError("tile_width for 16 bit elements should be equal "
+                         "to 16");
+    break;
+  case 32:
+    if (tileWidth != 8 && tileWidth != 16)
+      return emitOpError(
+          "tile_width for 32 bit elements should be equal to 8 or 16");
+    break;
+  default:
+    llvm_unreachable("unexpected element size");
+  }
+
+  return success();
+}
+
+LogicalResult MMAOp::verify() {
+  if (getC()) {
+    if (getResult().getType() != getC().getType())
+      return emitOpError("type of C operand must match result type");
+  }
+  return success();
+}
+
+LogicalResult PrefetchOp::verify() {
+  auto addrSpace = getAddrspace();
+  if (addrSpace != AddrSpace::GLOBAL && addrSpace != AddrSpace::GENERIC) {
+    return emitOpError("address space must be global or generic");
+  }
+  return success();
+}
+
+LogicalResult
+XeVMTargetAttr::verify(function_ref<InFlightDiagnostic()> emitError, int O,
+                       StringRef triple, StringRef chip, DictionaryAttr flags,
+                       ArrayAttr linkFiles) {
+  if (O < 0 || O > 3) {
+    emitError() << "The optimization level must be a number between 0 and 3.";
+    return failure();
+  }
+  if (triple.empty()) {
+    emitError() << "The target triple cannot be empty.";
+    return failure();
+  }
+  if (chip.empty()) {
+    emitError() << "The target chip cannot be empty.";
+    return failure();
+  }
+  if (linkFiles) {
+    for (Attribute fileAttr : linkFiles) {
+      if (auto fileStrAttr = llvm::dyn_cast<StringAttr>(fileAttr)) {
+        StringRef filePath = fileStrAttr.getValue();
+        if (filePath.empty()) {
+          emitError() << "File paths in linkFiles cannot be empty.";
+          return failure();
+        }
+        if (!llvm::sys::fs::exists(filePath)) {
+          emitError() << "File '" << filePath << "' does not exist.";
+          return failure();
+        }
+      }
+    }
+  }
+  return success();
+}
+
+void XeVMDialect::initialize() {
+  // NOLINTBEGIN
+  addOperations<
+#define GET_OP_LIST
+#include "mlir/Dialect/LLVMIR/XeVMOps.cpp.inc"
+      >();
+
+  addAttributes<
+#define GET_ATTRDEF_LIST
+#include "mlir/Dialect/LLVMIR/XeVMOpsAttributes.cpp.inc"
+      >();
+  // NOLINTEND
+  declarePromisedInterface<mlir::gpu::TargetAttrInterface,
+                           mlir::xevm::XeVMTargetAttr>();
+}
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/LLVMIR/XeVMOps.cpp.inc"
+
+#define GET_ATTRDEF_CLASSES
+#include "mlir/Dialect/LLVMIR/XeVMOpsAttributes.cpp.inc"
diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index 251ca716c7a7a..3a17926f4778c 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -1875,7 +1875,7 @@ llvm.mlir.global @bad_struct_array_init_elements() : !llvm.array<1x!llvm.struct<
   llvm.return %0 : !llvm.array<1x!llvm.struct<(i32, f32)>>
 }
 
-// ----
+// -----
 
 llvm.mlir.global internal constant @bad_array_attr_simple_type() : !llvm.array<2 x f64> {
   // expected-error at below {{'llvm.mlir.constant' op for array with an array attribute must have a struct element type}}
@@ -1883,10 +1883,51 @@ llvm.mlir.global internal constant @bad_array_attr_simple_type() : !llvm.array<2
   llvm.return %0 : !llvm.array<2 x f64>
 }
 
-// ----
+// -----
 
 llvm.func @inlineAsmMustTail(%arg0: i32, %arg1 : !llvm.ptr) {
   // expected-error at +1 {{op tail call kind 'musttail' is not supported}}
   %8 = llvm.inline_asm tail_call_kind = <musttail> "foo", "=r,=r,r" %arg0 : (i32) -> !llvm.struct<(i8, i8)>
   llvm.return
 }
+
+// -----
+
+llvm.func @invalid_xevm_prefetch(%arg0: !llvm.ptr<1>) {
+  // expected-error at +1 {{op address space must be global or generic}}
+  xevm.prefetch %arg0 <{addrspace = #xevm.addr_space<private>, cache_control = #xevm.load_cache_control<L1uc_L2uc_L3uc>}> : (!llvm.ptr<1>)
+  llvm.return
+}
+
+// -----
+
+llvm.func @invalid_xevm_mma(%loaded_c_casted: vector<4xf32>, %loaded_a: vector<8xi16>, %loaded_b_casted: vector<8xi32>) -> vector<8xf32> {
+  // expected-error at +1 {{op type of C operand must match result type}}
+  %c_result = xevm.mma %loaded_a, %loaded_b_casted, %loaded_c_casted {shape = <m = 8, n = 16, k = 16>, types = <d = f32, a = f16, b = f16, c = f32>} : (vector<8xi16>, vector<8xi32>, vector<4xf32>) -> vector<8xf32>
+  llvm.return %c_result : vector<8xf32>
+}
+
+// -----
+
+llvm.func @invalid_xevm_matrix_1(%c: !llvm.ptr<1>, %base_width_c: i32, %base_height_c: i32, %base_pitch_c: i32, %x: i32, %y: i32, %c_result_casted: vector<8xi32>) {
+  // expected-error at +1 {{op expecting tile_width to be between 1 and 8}}
+  xevm.blockstore2d %c, %base_width_c, %base_height_c, %base_pitch_c, %x, %y, %c_result_casted <{elem_size_in_bits=64 : i32, tile_width=16 : i32, tile_height=8 : i32, cache_control=#xevm.store_cache_control<Default>}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi32>)
+  llvm.return
+}
+
+// -----
+
+llvm.func @invalid_xevm_matrix_2(%c: !llvm.ptr<1>, %base_width_c: i32, %base_height_c: i32, %base_pitch_c: i32, %x: i32, %y: i32, %c_result_casted: vector<8xi32>) {
+  // expected-error at +1 {{op expecting elem_size_in_bits to be 8, 16, 32, or 64}}
+  xevm.blockstore2d %c, %base_width_c, %base_height_c, %base_pitch_c, %x, %y, %c_result_casted <{elem_size_in_bits=18 : i32, tile_width=16 : i32, tile_height=8 : i32, cache_control=#xevm.store_cache_control<Default>}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi32>)
+  llvm.return
+}
+
+// -----
+
+llvm.func @invalid_xevm_matrix_3(%a: !llvm.ptr<1>, %base_width_a: i32, %base_height_a: i32, %base_pitch_a: i32, %x: i32, %y: i32) -> vector<8xi16> {
+  // expected-error at +1 {{op result size of 128 bits does not match the expected size of 208 bits}}
+  %loaded_a = xevm.blockload2d %a, %base_width_a, %base_height_a, %base_pitch_a, %x, %y <{elem_size_in_bits=16 : i32, tile_width=26 : i32, tile_height=8 : i32, v_blocks=1 : i32, transpose=false, pack_register=false, cache_control=#xevm.load_cache_control<Default>}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi16>
+  llvm.return %loaded_a : vector<8xi16>
+}
+
diff --git a/mlir/test/Dialect/LLVMIR/xevm.mlir b/mlir/test/Dialect/LLVMIR/xevm.mlir
new file mode 100644
index 0000000000000..3ea6768345f6e
--- /dev/null
+++ b/mlir/test/Dialect/LLVMIR/xevm.mlir
@@ -0,0 +1,76 @@
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics | FileCheck %s
+
+// CHECK: func.func @blockload2d(%[[ARG0:.*]]: !llvm.ptr<1>, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32, %[[ARG3:.*]]: i32, %[[ARG4:.*]]: i32, %[[ARG5:.*]]: i32)
+func.func @blockload2d(%a: !llvm.ptr<1>, %base_width_a: i32, %base_height_a: i32, %base_pitch_a: i32, %x: i32, %y: i32) -> vector<8xi16> {
+  // CHECK: %[[VAR0:.*]] = xevm.blockload2d %[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG4]], %[[ARG5]]
+  // CHECK-DAG: elem_size_in_bits = 16 : i32
+  // CHECK-DAG: tile_width = 16 : i32
+  // CHECK-DAG: tile_height = 8 : i32
+  // CHECK-DAG: v_blocks = 1 : i32
+  // CHECK-DAG: transpose = false
+  // CHECK-DAG: pack_register = false
+  // CHECK-DAG: cache_control = #xevm.load_cache_control<Default>
+  // CHECK: (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi16>
+  %loaded_a = xevm.blockload2d %a, %base_width_a, %base_height_a, %base_pitch_a, %x, %y <{elem_size_in_bits=16 : i32, tile_width=16 : i32, tile_height=8 : i32, v_blocks=1 : i32, transpose=false, pack_register=false, cache_control=#xevm.load_cache_control<Default>}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi16>
+  return %loaded_a : vector<8xi16>
+}
+
+// -----
+// CHECK: func.func @blockstore2d(%[[ARG0:.*]]: !llvm.ptr<1>, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32, %[[ARG3:.*]]: i32, %[[ARG4:.*]]: i32, %[[ARG5:.*]]: i32, %[[ARG6:.*]]: vector<8xi32>)
+func.func @blockstore2d(%c: !llvm.ptr<1>, %base_width_c: i32, %base_height_c: i32, %base_pitch_c: i32, %x: i32, %y: i32, %c_result_casted: vector<8xi32>) {
+  // CHECK: xevm.blockstore2d %[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG4]], %[[ARG5]], %[[ARG6]] 
+  // CHECK-DAG: elem_size_in_bits = 32 : i32
+  // CHECK-DAG: tile_width = 16 : i32
+  // CHECK-DAG: tile_height = 8 : i32
+  // CHECK-DAG: cache_control = #xevm.store_cache_control<Default>
+  // CHECK: (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi32>)
+  xevm.blockstore2d %c, %base_width_c, %base_height_c, %base_pitch_c, %x, %y, %c_result_casted <{elem_size_in_bits=32 : i32, tile_width=16 : i32, tile_height=8 : i32, cache_control=#xevm.store_cache_control<Default>}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi32>)
+  return
+}
+
+// -----
+// CHECK: func.func @blockprefetch2d(%[[ARG0:.*]]: !llvm.ptr<1>, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32, %[[ARG3:.*]]: i32, %[[ARG4:.*]]: i32, %[[ARG5:.*]]: i32)
+func.func @blockprefetch2d(%ptr: !llvm.ptr<1>, %base_width: i32, %base_height: i32, %base_pitch: i32, %x: i32, %y: i32) {
+  // CHECK: xevm.blockprefetch2d %[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG4]], %[[ARG5]]
+  // CHECK-DAG: elem_size_in_bits = 8 : i32
+  // CHECK-DAG: tile_width = 32 : i32
+  // CHECK-DAG: tile_height = 8 : i32
+  // CHECK-DAG: v_blocks = 1 : i32
+  // CHECK-DAG: cache_control = #xevm.load_cache_control<L1uc_L2uc_L3uc>
+  // CHECK:  (!llvm.ptr<1>, i32, i32, i32, i32, i32)
+  xevm.blockprefetch2d %ptr, %base_width, %base_height, %base_pitch, %x, %y <{elem_size_in_bits=8 : i32, tile_width=32 : i32, tile_height=8 : i32, v_blocks=1 : i32, cache_control=#xevm.load_cache_control<L1uc_L2uc_L3uc>}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32)
+  return
+}
+
+// -----
+// CHECK: func.func @mma(%[[ARG0:.*]]: vector<8xf32>, %[[ARG1:.*]]: vector<8xi16>, %[[ARG2:.*]]: vector<8xi32>)
+func.func @mma(%loaded_c_casted: vector<8xf32>, %loaded_a: vector<8xi16>, %loaded_b_casted: vector<8xi32>) -> vector<8xf32> {
+  // CHECK: %0 = xevm.mma %[[ARG1]], %[[ARG2]], %[[ARG0]] {shape = <m = 8, n = 16, k = 16>, types = <d = f32, a = f16, b = f16, c = f32>} : (vector<8xi16>, vector<8xi32>, vector<8xf32>) -> vector<8xf32>
+  %c_result = xevm.mma %loaded_a, %loaded_b_casted, %loaded_c_casted { shape=<m=8, n=16, k=16>, types=<d=f32, a=f16, b=f16, c=f32> } : (vector<8xi16>, vector<8xi32>, vector<8xf32>) -> vector<8xf32>
+  return %c_result : vector<8xf32>
+}
+
+// -----
+func.func @memfence() {
+  // CHECK: xevm.memfence
+  // CHECK-DAG: addrspace = #xevm.addr_space<global>
+  // CHECK-DAG: scope = #xevm.mem_scope<workgroup>
+  xevm.memfence <{addrspace=#xevm.addr_space<global>, scope=#xevm.mem_scope<workgroup>}>
+  return
+}
+
+// -----
+// CHECK: func.func @prefetch(%[[ARG0:.*]]: !llvm.ptr<1>)
+func.func @prefetch(%ptr: !llvm.ptr<1>) {
+  // CHECK: xevm.prefetch %[[ARG0]]
+  // CHECK-DAG: addrspace = #xevm.addr_space<global>
+  // CHECK-DAG: cache_control = #xevm.load_cache_control<L1uc_L2uc_L3uc>
+  // CHECK: (!llvm.ptr<1>)
+  xevm.prefetch %ptr <{addrspace = #xevm.addr_space<global>, cache_control = #xevm.load_cache_control<L1uc_L2uc_L3uc>}> : (!llvm.ptr<1>)
+  return
+}
+
+// -----
+// CHECK: @xevm_module [#xevm.target<O = 3, chip = "pvc">] {
+gpu.module @xevm_module [#xevm.target<O = 3, chip = "pvc">]{
+}
diff --git a/mlir/test/lib/Dialect/GPU/CMakeLists.txt b/mlir/test/lib/Dialect/GPU/CMakeLists.txt
index 4ca5974ed5a49..418c884dc03b3 100644
--- a/mlir/test/lib/Dialect/GPU/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/GPU/CMakeLists.txt
@@ -29,6 +29,7 @@ set(LIBS
   MLIRTranslateLib
   MLIRVectorDialect
   MLIRVectorToLLVMPass
+  MLIRXeVMDialect
   )
 
 add_mlir_library(MLIRGPUTestPasses



More information about the Mlir-commits mailing list