[Mlir-commits] [mlir] [MLIR][Dialect] Add XeVM dialect (PR #144811)
Sang Ik Lee
llvmlistbot at llvm.org
Thu Jun 26 13:25:29 PDT 2025
https://github.com/silee2 updated https://github.com/llvm/llvm-project/pull/144811
>From 75bc544d5ff42cefe7dceddf021450881689ce9b Mon Sep 17 00:00:00 2001
From: "Lee, Sang Ik" <sang.ik.lee at intel.com>
Date: Wed, 18 Jun 2025 22:48:04 +0000
Subject: [PATCH 1/8] Add XeVM dialect.
---
.../mlir/Dialect/LLVMIR/CMakeLists.txt | 10 +
.../include/mlir/Dialect/LLVMIR/XeVMDialect.h | 28 +
mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td | 550 ++++++++++++++++++
mlir/include/mlir/InitAllDialects.h | 4 +-
mlir/lib/Dialect/LLVMIR/CMakeLists.txt | 22 +
mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp | 381 ++++++++++++
mlir/test/Dialect/LLVMIR/invalid.mlir | 45 +-
mlir/test/Dialect/LLVMIR/xevm.mlir | 76 +++
mlir/test/lib/Dialect/GPU/CMakeLists.txt | 1 +
9 files changed, 1114 insertions(+), 3 deletions(-)
create mode 100644 mlir/include/mlir/Dialect/LLVMIR/XeVMDialect.h
create mode 100644 mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td
create mode 100644 mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp
create mode 100644 mlir/test/Dialect/LLVMIR/xevm.mlir
diff --git a/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt b/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt
index 9c5bbae1022f7..cfad07e57021f 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt
@@ -87,3 +87,13 @@ mlir_tablegen(VCIXConversions.inc -gen-llvmir-conversions)
mlir_tablegen(VCIXOpsAttributes.h.inc -gen-attrdef-decls -attrdefs-dialect=vcix)
mlir_tablegen(VCIXOpsAttributes.cpp.inc -gen-attrdef-defs -attrdefs-dialect=vcix)
add_public_tablegen_target(MLIRVCIXConversionsIncGen)
+
+add_mlir_dialect(XeVMOps xevm)
+add_mlir_doc(XeVMOps XeVMDialect Dialects/ -gen-dialect-doc -dialect=xevm)
+set(LLVM_TARGET_DEFINITIONS XeVMOps.td)
+mlir_tablegen(XeVMConversions.inc -gen-llvmir-conversions)
+mlir_tablegen(XeVMOpsEnums.h.inc -gen-enum-decls)
+mlir_tablegen(XeVMOpsEnums.cpp.inc -gen-enum-defs)
+mlir_tablegen(XeVMOpsAttributes.h.inc -gen-attrdef-decls -attrdefs-dialect=xevm)
+mlir_tablegen(XeVMOpsAttributes.cpp.inc -gen-attrdef-defs -attrdefs-dialect=xevm)
+add_public_tablegen_target(MLIRXeVMConversionsIncGen)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/XeVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/XeVMDialect.h
new file mode 100644
index 0000000000000..a83d4248c862c
--- /dev/null
+++ b/mlir/include/mlir/Dialect/LLVMIR/XeVMDialect.h
@@ -0,0 +1,28 @@
+//===-- XeVMDialect.h - MLIR XeVM target definitions ------------*- C++ -*-===//
+//
+// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_LLVMIR_XEVMDIALECT_H_
+#define MLIR_DIALECT_LLVMIR_XEVMDIALECT_H_
+
+#include "mlir/Bytecode/BytecodeOpInterface.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/Target/LLVMIR/ModuleTranslation.h"
+
+#include <mlir/Dialect/LLVMIR/XeVMOpsEnums.h.inc>
+
+#define GET_ATTRDEF_CLASSES
+#include <mlir/Dialect/LLVMIR/XeVMOpsAttributes.h.inc>
+
+#define GET_OP_CLASSES
+#include <mlir/Dialect/LLVMIR/XeVMOps.h.inc>
+
+#include <mlir/Dialect/LLVMIR/XeVMOpsDialect.h.inc>
+
+#endif /* MLIR_DIALECT_LLVMIR_XEVMDIALECT_H_ */
diff --git a/mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td
new file mode 100644
index 0000000000000..9525c4a731efa
--- /dev/null
+++ b/mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td
@@ -0,0 +1,550 @@
+//===-- XeVMOps.td - XeVM dialect definition ---------------*- tablegen -*-===//
+//
+// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#ifndef XEVMIR_OPS
+#define XEVMIR_OPS
+
+include "mlir/Dialect/GPU/IR/CompilationAttrInterfaces.td"
+include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
+
+include "mlir/IR/OpBase.td"
+include "mlir/IR/EnumAttr.td"
+
+def XeVM_Dialect : Dialect {
+ let name = "xevm";
+ let cppNamespace = "::mlir::xevm";
+ let dependentDialects = ["LLVM::LLVMDialect"];
+
+ let extraClassDeclaration = [{
+ /// Get the name for the attribute used to specify cache control
+ /// decorations.
+ static constexpr ::llvm::StringRef getCacheControlsAttrName() {
+ return ::llvm::StringLiteral("xevm.DecorationCacheControl");
+ }
+ }];
+
+ let useDefaultAttributePrinterParser = 1;
+}
+
+class XeVM_Attr<string attrName, string attrMnemonic, list<Trait> traits = []>
+ : AttrDef<XeVM_Dialect, attrName, traits> {
+ let mnemonic = attrMnemonic;
+}
+
+class XeVM_Op<string mnemonic, list<Trait> traits = []>
+ : Op<XeVM_Dialect, mnemonic, traits> {
+
+ code extraBaseClassDeclaration = [{
+ void printProperties(::mlir::MLIRContext *ctx,
+ ::mlir::OpAsmPrinter &p, const Properties &prop,
+ ::mlir::ArrayRef<::llvm::StringRef> elidedProps) {
+ Attribute propAttr = getPropertiesAsAttr(ctx, prop);
+ if (propAttr)
+ p << "<" << propAttr << ">";
+ }
+
+ static ::mlir::ParseResult parseProperties(::mlir::OpAsmParser &parser,
+ ::mlir::OperationState &result) {
+ if (mlir::succeeded(parser.parseOptionalLess())) {
+ if (parser.parseAttribute(result.propertiesAttr) || parser.parseGreater())
+ return failure();
+ }
+ return success();
+ }
+
+ }];
+}
+
+def XeVM_ElemType : AnyTypeOf<[AnyI8, AnyI16, AnyI32, F32, TF32, F16, BF16]>;
+
+def LoadCacheControlDefault : I32EnumAttrCase<"DEFAULT", 0, "Default">;
+def LoadCacheControl_L1uc_L2uc_L3uc
+ : I32EnumAttrCase<"L1UC_L2UC_L3UC", 1, "L1uc_L2uc_L3uc">;
+def LoadCacheControl_L1uc_L2uc_L3c
+ : I32EnumAttrCase<"L1UC_L2UC_L3C", 2, "L1uc_L2uc_L3c">;
+def LoadCacheControl_L1uc_L2c_L3uc
+ : I32EnumAttrCase<"L1UC_L2C_L3UC", 3, "L1uc_L2c_L3uc">;
+def LoadCacheControl_L1uc_L2c_L3c
+ : I32EnumAttrCase<"L1UC_L2C_L3C", 4, "L1uc_L2c_L3c">;
+def LoadCacheControl_L1c_L2uc_L3uc
+ : I32EnumAttrCase<"L1C_L2UC_L3UC", 5, "L1c_L2uc_L3uc">;
+def LoadCacheControl_L1c_L2uc_L3c
+ : I32EnumAttrCase<"L1C_L2UC_L3C", 6, "L1c_L2uc_L3c">;
+def LoadCacheControl_L1c_L2c_L3uc
+ : I32EnumAttrCase<"L1C_L2C_L3UC", 7, "L1c_L2c_L3uc">;
+def LoadCacheControl_L1c_L2c_L3c
+ : I32EnumAttrCase<"L1C_L2C_L3C", 8, "L1c_L2c_L3c">;
+def LoadCacheControl_L1s_L2uc_L3uc
+ : I32EnumAttrCase<"L1S_L2UC_L3UC", 9, "L1s_L2uc_L3uc">;
+def LoadCacheControl_L1s_L2uc_L3c
+ : I32EnumAttrCase<"L1S_L2UC_L3C", 10, "L1s_L2uc_L3c">;
+def LoadCacheControl_L1s_L2c_L3uc
+ : I32EnumAttrCase<"L1S_L2C_L3UC", 11, "L1s_L2c_L3uc">;
+def LoadCacheControl_L1s_L2c_L3c
+ : I32EnumAttrCase<"L1S_L2C_L3C", 12, "L1s_L2c_L3c">;
+def LoadCacheControlInvalidateRead
+ : I32EnumAttrCase<"INVALIDATE_READ", 13, "ir">;
+
+def XeVM_LoadCacheControl
+ : I32EnumAttr<
+ "LoadCacheControl", "XeVM load ops cache control",
+ [LoadCacheControlDefault, LoadCacheControl_L1uc_L2uc_L3uc,
+ LoadCacheControl_L1uc_L2uc_L3c, LoadCacheControl_L1uc_L2c_L3uc,
+ LoadCacheControl_L1uc_L2c_L3c, LoadCacheControl_L1c_L2uc_L3uc,
+ LoadCacheControl_L1c_L2uc_L3c, LoadCacheControl_L1c_L2c_L3uc,
+ LoadCacheControl_L1c_L2c_L3c, LoadCacheControl_L1s_L2uc_L3uc,
+ LoadCacheControl_L1s_L2uc_L3c, LoadCacheControl_L1s_L2c_L3uc,
+ LoadCacheControl_L1s_L2c_L3c, LoadCacheControlInvalidateRead]> {
+ let cppNamespace = "::mlir::xevm";
+ let genSpecializedAttr = 0;
+}
+
+def XeVM_LoadCacheControlAttr
+ : EnumAttr<XeVM_Dialect, XeVM_LoadCacheControl, "load_cache_control"> {
+ let summary = [{Describe the cache settings for load operators}];
+ let assemblyFormat = "`<` $value `>`";
+}
+
+def StoreCacheControlDefault : I32EnumAttrCase<"DEFAULT", 0, "Default">;
+def StoreCacheControl_L1uc_L2uc_L3uc
+ : I32EnumAttrCase<"L1UC_L2UC_L3UC", 1, "L1uc_L2uc_L3uc">;
+def StoreCacheControl_L1uc_L2uc_L3wb
+ : I32EnumAttrCase<"L1UC_L2UC_L3WB", 2, "L1uc_L2uc_L3wb">;
+def StoreCacheControl_L1uc_L2wb_L3uc
+ : I32EnumAttrCase<"L1UC_L2WB_L3UC", 3, "L1uc_L2wb_L3uc">;
+def StoreCacheControl_L1uc_L2wb_L3wb
+ : I32EnumAttrCase<"L1UC_L2WB_L3WB", 4, "L1uc_L2wb_L3wb">;
+def StoreCacheControl_L1wt_L2uc_L3uc
+ : I32EnumAttrCase<"L1WT_L2UC_L3UC", 5, "L1wt_L2uc_L3uc">;
+def StoreCacheControl_L1wt_L2uc_L3wb
+ : I32EnumAttrCase<"L1WT_L2UC_L3WB", 6, "L1wt_L2uc_L3wb">;
+def StoreCacheControl_L1wt_L2wb_L3uc
+ : I32EnumAttrCase<"L1WT_L2WB_L3UC", 7, "L1wt_L2wb_L3uc">;
+def StoreCacheControl_L1wt_L2wb_L3wb
+ : I32EnumAttrCase<"L1WT_L2WB_L3WB", 8, "L1wt_L2wb_L3wb">;
+def StoreCacheControl_L1s_L2uc_L3uc
+ : I32EnumAttrCase<"L1S_L2UC_L3UC", 9, "L1s_L2uc_L3uc">;
+def StoreCacheControl_L1s_L2uc_L3wb
+ : I32EnumAttrCase<"L1S_L2UC_L3WB", 10, "L1s_L2uc_L3wb">;
+def StoreCacheControl_L1s_L2wb_L3uc
+ : I32EnumAttrCase<"L1S_L2WB_L3UC", 11, "L1s_L2wb_L3uc">;
+def StoreCacheControl_L1s_L2wb_L3wb
+ : I32EnumAttrCase<"L1S_L2WB_L3WB", 12, "L1s_L2wb_L3wb">;
+def StoreCacheControl_L1wb_L2uc_L3uc
+ : I32EnumAttrCase<"L1WB_L2UC_L3UC", 13, "L1wb_L2uc_L3uc">;
+def StoreCacheControl_L1wb_L2wb_L3uc
+ : I32EnumAttrCase<"L1WB_L2WB_L3UC", 14, "L1wb_L2wb_L3uc">;
+def StoreCacheControl_L1wb_L2uc_L3wb
+ : I32EnumAttrCase<"L1WB_L2UC_L3WB", 15, "L1wb_L2uc_L3wb">;
+
+def XeVM_StoreCacheControl
+ : I32EnumAttr<
+ "StoreCacheControl", "XeVM store ops cache control",
+ [StoreCacheControlDefault, StoreCacheControl_L1uc_L2uc_L3uc,
+ StoreCacheControl_L1uc_L2uc_L3wb, StoreCacheControl_L1uc_L2wb_L3uc,
+ StoreCacheControl_L1uc_L2wb_L3wb, StoreCacheControl_L1wt_L2uc_L3uc,
+ StoreCacheControl_L1wt_L2uc_L3wb, StoreCacheControl_L1wt_L2wb_L3uc,
+ StoreCacheControl_L1wt_L2wb_L3wb, StoreCacheControl_L1s_L2uc_L3uc,
+ StoreCacheControl_L1s_L2uc_L3wb, StoreCacheControl_L1s_L2wb_L3uc,
+ StoreCacheControl_L1s_L2wb_L3wb, StoreCacheControl_L1wb_L2uc_L3uc,
+ StoreCacheControl_L1wb_L2wb_L3uc,
+ StoreCacheControl_L1wb_L2uc_L3wb]> {
+ let cppNamespace = "::mlir::xevm";
+ let genSpecializedAttr = 0;
+}
+
+def XeVM_StoreCacheControlAttr
+ : EnumAttr<XeVM_Dialect, XeVM_StoreCacheControl, "store_cache_control"> {
+ let summary = [{Describe the cache settings for store operators}];
+ let assemblyFormat = "`<` $value `>`";
+}
+
+def XeVM_BlockLoad2dOp
+ : XeVM_Op<"blockload2d">,
+ Results<(outs FixedVectorOfRankAndType<[1], [XeVM_ElemType]>:$res)>,
+ Arguments<(ins Arg<LLVM_AnyPointer, "", [MemRead]>:$ptr, I32:$base_width,
+ I32:$base_height, I32:$base_pitch, I32:$x, I32:$y,
+ I32Attr:$elem_size_in_bits, I32Attr:$tile_width, I32Attr:$tile_height,
+ I32Attr:$v_blocks, I1Attr:$transpose, I1Attr:$pack_register,
+ OptionalAttr<XeVM_LoadCacheControlAttr>:$cache_control)> {
+
+ let summary = "2D block load";
+
+ let description = [{
+ The `xevm.blockload2d` operation loads a two dimensional matrix tile
+ from a base matrix residing in memory. The parameters are:
+ $ptr - the base address of the base matrix containing the tile to load
+ $base_width, $base_height, $base_pitch - the shape of the base matrix.
+ pitch is the physical stride between the first columns of the current row
+ and the subsequent row. All units are in bytes.
+ $x, $y, $tile_width, $tile_height - the starting offsets and shape of
+ the tile to load in number of elements.
+ $elem_size_in_bits - the size in bits of the matrix element type
+ - 32 for f32, tf32
+ - 16 for f16, int16, bf16
+ - 8 for int8
+ $v_blocks - number of consecutive tiles in innermost dimension direction to load
+ $transpose - transpose the tile in registers (useful for 32 bit element type)
+ $pack_register - pack element types narrower than register bit width. [M, N] => [M/factor, N, factor] where factor is register_size_in_bits / elem_size_in_bits
+ $cache_control - an enumerator that sets the cache behaviour
+
+ Notes:
+ - the $transpose and $pack_register parameters are mutual exclusive
+ - transposing the tile loaded is used for A matrix in backward path or used for the B matrix operand
+ (D = C + A * B), where A has row-major layout and B should have column-major layout in memory.
+ - if the tile loaded contains out of bound elements of the matrix, they are filled with 0.
+
+ Example:
+ ```mlir
+ %base_width_a = arith.constant 32 : i32
+ %base_height_a = arith.constant 8 : i32
+ %base_pitch_a = arith.constant 32 : i32
+ %x = arith.constant 0 : i32
+ %y = arith.constant 0 : i32
+ %loaded_a = xevm.blockload2d %src, %base_width_a, %base_height_a, %base_pitch_a, %x, %y <{elem_size_in_bits=16 : i32, tile_width=16 : i32, tile_height=8 : i32, v_blocks=1 : i32, transpose=false : i32, pack_register=false, cache_control=#xevm.load_cache_control<Default>}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi16>
+ ```
+ }];
+
+ let assemblyFormat = [{
+ operands prop-dict attr-dict `:` functional-type(operands, results)
+ }];
+
+ let extraClassDeclaration = extraBaseClassDeclaration#[{
+ /// Get cache control or return default if not set.
+ ::mlir::xevm::LoadCacheControl getCacheControlOrDefault() {
+ if(getCacheControl())
+ return *getCacheControl();
+ return ::mlir::xevm::LoadCacheControl::DEFAULT;
+ }
+ }];
+
+ let hasVerifier = 1;
+}
+
+def XeVM_BlockStore2dOp
+ : XeVM_Op<"blockstore2d">,
+ Arguments<(ins Arg<LLVM_AnyPointer, "", [MemWrite]>:$ptr, I32:$base_width,
+ I32:$base_height, I32:$base_pitch, I32:$x, I32:$y,
+ I32Attr:$elem_size_in_bits, I32Attr:$tile_width, I32Attr:$tile_height,
+ FixedVectorOfRankAndType<[1], [XeVM_ElemType]>:$stored_val,
+ OptionalAttr<XeVM_StoreCacheControlAttr>:$cache_control)> {
+
+ let summary = "2D block store";
+
+ let description = [{
+ The `xevm.blockstore2d` operation stores a two dimensional tile into a
+ larger matrix residing in memory. The parameters are:
+ $ptr - the base address of the target matrix where to store the tile
+ $base_width, $base_height, $base_pitch - the shape of the target matrix. pitch is the
+ physical stride between the first columns of the current row and the subsequent row.
+ All units are in bytes.
+ $x, $y, $tile_width, $tile_height - the starting offsets and shape of the tile to store
+ in number of elements.
+ $elem_size_in_bits - the size in bits of the matrix element
+ - 32 for f32, tf32
+ - 16 for f16, int16, bf16
+ - 8 for int8
+ $cache_control - an enumerator that sets the cache behaviour
+ $stored_val - the tile to store
+
+ Example:
+ ```mlir
+ %base_width_c = arith.constant 64 : i32
+ %base_height_c = arith.constant 8 : i32
+ %base_pitch_c = arith.constant 64 : i32
+ %x = arith.constant 0 : i32
+ %y = arith.constant 0 : i32
+ xevm.blockstore2d %dst, %base_width_c, %base_height_c, %base_pitch_c, %x, %y, %src <{elem_size_in_bits=32 : i32, tile_width=16 : i32, tile_height=8 : i32, cache_control=#xevm.load_cache_control<Default>}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi32>)
+ ```
+ }];
+
+ let assemblyFormat = [{
+ operands prop-dict attr-dict `:` `(` type(operands) `)`
+ }];
+
+ let extraClassDeclaration = extraBaseClassDeclaration#[{
+ /// Get cache control or return default if not set.
+ ::mlir::xevm::StoreCacheControl getCacheControlOrDefault() {
+ if(getCacheControl())
+ return *getCacheControl();
+ return ::mlir::xevm::StoreCacheControl::DEFAULT;
+ }
+
+ /// Default value for v_blocks is 1.
+ constexpr uint32_t getVBlocks() {
+ return 1;
+ }
+ }];
+
+ let hasVerifier = 1;
+}
+
+def MemScopeLane : I32EnumAttrCase<"LANE", 0, "lane">;
+def MemScopeSg : I32EnumAttrCase<"SUBGROUP", 1, "subgroup">;
+def MemScopeWg : I32EnumAttrCase<"WORKGROUP", 2, "workgroup">;
+def MemScopeCluster : I32EnumAttrCase<"CLUSTER", 3, "cluster">;
+def MemScopeDevice : I32EnumAttrCase<"DEVICE", 4, "device">;
+def MemScopeSystem : I32EnumAttrCase<"SYSTEM", 5, "system">;
+
+def XeVM_MemScope
+ : I32EnumAttr<"MemScope", "XeVM memory scope",
+ [MemScopeLane, MemScopeSg, MemScopeWg, MemScopeCluster,
+ MemScopeDevice, MemScopeSystem]> {
+ let genSpecializedAttr = 0;
+ let cppNamespace = "::mlir::xevm";
+}
+def XeVM_MemScopeAttr : EnumAttr<XeVM_Dialect, XeVM_MemScope, "mem_scope"> {
+ let summary = [{Describe memory scopes}];
+ let assemblyFormat = "`<` $value `>`";
+}
+
+def AddrSpacePrivate : I32EnumAttrCase<"PRIVATE", 0, "private">;
+def AddrSpaceGlobal : I32EnumAttrCase<"GLOBAL", 1, "global">;
+def AddrSpaceConstant : I32EnumAttrCase<"CONSTANT", 2, "constant">;
+def AddrSpaceShared : I32EnumAttrCase<"SHARED", 3, "shared">;
+def AddrSpaceGeneric : I32EnumAttrCase<"GENERIC", 4, "generic">;
+
+def XeVM_AddrSpace
+ : I32EnumAttr<"AddrSpace", "Address spaces",
+ [AddrSpacePrivate, AddrSpaceGlobal, AddrSpaceConstant,
+ AddrSpaceShared, AddrSpaceGeneric]> {
+ let genSpecializedAttr = 0;
+ let cppNamespace = "mlir::xevm";
+}
+def XeVM_AddrSpaceAttr : EnumAttr<XeVM_Dialect, XeVM_AddrSpace, "addr_space"> {
+ let summary = [{Describe address spaces}];
+ let assemblyFormat = "`<` $value `>`";
+}
+
+def XeVM_MemfenceOp
+ : XeVM_Op<"memfence">,
+ Arguments<(ins XeVM_MemScopeAttr:$scope,
+ DefaultValuedAttr<XeVM_AddrSpaceAttr,
+ "mlir::xevm::AddrSpace::GENERIC">:$addrspace)> {
+ let summary = "Work-item's memory fence.";
+ let description = [{
+ This operation ensures that all prior memory accesses of this
+ work-item to `addrspace` are visible to all other work-items in `scope`.
+ Parameters description:
+ $scope - specify the memory scope at which all other work-items should observe
+ memory operations prior to the fence.
+ $addrspace - specify the address space of work-item's memory accesses
+ to be affected by the fence.
+ }];
+ let assemblyFormat = [{prop-dict attr-dict}];
+
+ let extraClassDeclaration = extraBaseClassDeclaration#[{
+ }];
+}
+
+def XeVM_PrefetchOp
+ : XeVM_Op<"prefetch">,
+ Arguments<(ins Arg<LLVM_AnyPointer, "", [MemRead]>:$ptr,
+ XeVM_AddrSpaceAttr:$addrspace,
+ OptionalAttr<XeVM_LoadCacheControlAttr>:$cache_control)> {
+ let summary = "Prefetch data into a cache subsystem.";
+ let description = [{
+ Work-item issues a prefetch from global memory to cache:
+ $ptr - memory pointer.
+ $addrspace - address space of a pointer, must be generic or global.
+ $cache_control - specify caching options
+ }];
+ let assemblyFormat = [{
+ operands prop-dict attr-dict `:` `(` type(operands) `)`
+ }];
+
+ let extraClassDeclaration = extraBaseClassDeclaration#[{
+ /// Get cache control or return default if not set.
+ ::mlir::xevm::LoadCacheControl getCacheControlOrDefault() {
+ if(getCacheControl())
+ return *getCacheControl();
+ return ::mlir::xevm::LoadCacheControl::DEFAULT;
+ }
+ }];
+
+ let hasVerifier = 1;
+}
+
+def XeVM_BlockPrefetch2dOp
+ : XeVM_Op<"blockprefetch2d">,
+ Arguments<(ins Arg<LLVM_AnyPointer, "", [MemRead]>:$ptr, I32:$base_width,
+ I32:$base_height, I32:$base_pitch, I32:$x, I32:$y,
+ I32Attr:$elem_size_in_bits, I32Attr:$tile_width, I32Attr:$tile_height,
+ I32Attr:$v_blocks,
+ OptionalAttr<XeVM_LoadCacheControlAttr>:$cache_control)> {
+
+ let summary = "2D block prefetch";
+
+ let description = [{
+ The `xevm.blockprefetch2d` operation prefetches a two dimensional tile
+ from a larger base matrix residing in memory. The parameters are:
+ $ptr - the base address of the base matrix containing the tile to prefetch
+ $base_width, $base_height, $base_pitch - the shape of the base matrix.
+ pitch is the physical stride between the first columns of the current row
+ and the subsequent row. All units are in bytes.
+ $x, $y, $tile_width, $tile_height - the starting offsets and shape of tile
+ to prefetch in number of elements.
+ $elem_size_in_bits - the size in bits of the matrix element
+ - 32 for f32, bf32
+ - 16 for f16, int16, bf16
+ - 8 for int8, int4, int2
+ $v_blocks - number of tiles in innermost dimension direction to prefetch
+ $cache_control - an enumerator that sets the cache behaviour
+
+ Example:
+ ```mlir
+ xevm.blockprefetch2d %ptr, %base_width, %base_height, %base_pitch, %x, %y <{elem_size_in_bits=8 : i32, tile_width=32 : i32, tile_height=8 : i32, v_blocks=1 : i32, cache_control=#xevm.load_cache_control<L1uc_L2uc_L3uc>}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32)
+ ```
+ }];
+
+ let assemblyFormat = [{
+ operands prop-dict attr-dict `:` `(` type(operands) `)`
+ }];
+
+ let extraClassDeclaration = extraBaseClassDeclaration#[{
+ /// Get cache control or return default if not set.
+ ::mlir::xevm::LoadCacheControl getCacheControlOrDefault() {
+ if(getCacheControl())
+ return *getCacheControl();
+ return ::mlir::xevm::LoadCacheControl::DEFAULT;
+ }
+ }];
+
+ let hasVerifier = 1;
+}
+
+def XeVM_MatrixElemType
+ : AnyTypeOf<[AnyI8, AnyI16, AnyI32, F32, TF32, F16, BF16]>;
+
+/// Enum attribute of the different element types.
+def XeVM_ET_BF16 : I32EnumAttrCase<"BF16", 8, "bf16">;
+def XeVM_ET_F16 : I32EnumAttrCase<"F16", 9, "f16">;
+def XeVM_ET_S8 : I32EnumAttrCase<"S8", 10, "s8">;
+def XeVM_ET_U8 : I32EnumAttrCase<"U8", 11, "u8">;
+def XeVM_ET_S4 : I32EnumAttrCase<"S4", 12, "s4">;
+def XeVM_ET_U4 : I32EnumAttrCase<"U4", 13, "u4">;
+def XeVM_ET_TF32 : I32EnumAttrCase<"TF32", 14, "tf32">;
+def XeVM_ET_F32 : I32EnumAttrCase<"F32", 15, "f32">;
+def XeVM_ET_S32 : I32EnumAttrCase<"S32", 16, "s32">;
+
+def XeVM_ElemTypeAttr : I32EnumAttr<"ElemType", "XeVM element type",
+ [XeVM_ET_BF16, XeVM_ET_F16, XeVM_ET_S8,
+ XeVM_ET_U8, XeVM_ET_S4, XeVM_ET_U4,
+ XeVM_ET_TF32, XeVM_ET_F32, XeVM_ET_S32]> {
+ let cppNamespace = "::mlir::xevm";
+}
+
+def XeVM_MMAShapeAttr : XeVM_Attr<"MMAShape", "mma_shape"> {
+ let description = [{
+ MMA operation is represented as D=AxB+C, where
+ A has the shape MxK.
+ B has the shape KxN.
+ D and C havethe shape MxN.
+ This attribute encodes the shape of all matrices that participate in MMA.
+ }];
+ let parameters = (ins "int":$m, "int":$n, "int":$k);
+ let assemblyFormat = "`<` struct(params) `>`";
+}
+
+def XeVM_MMATypesAttr : XeVM_Attr<"MMATypes", "mma_types"> {
+ let parameters = (ins "xevm::ElemType":$d, "xevm::ElemType":$a,
+ "xevm::ElemType":$b, OptionalParameter<"xevm::ElemType">:$c);
+ let assemblyFormat = "`<` struct(params) `>`";
+}
+
+def XeVM_MMAOp
+ : XeVM_Op<"mma">,
+ Results<(outs FixedVectorOfRankAndType<[1], [XeVM_MatrixElemType]>:$d)>,
+ Arguments<(ins FixedVectorOfRankAndType<[1], [XeVM_MatrixElemType]>:$a,
+ FixedVectorOfRankAndType<[1], [XeVM_MatrixElemType]>:$b,
+ Optional<FixedVectorOfRankAndType<[1], [XeVM_MatrixElemType]>>:$c,
+ XeVM_MMAShapeAttr:$shape, XeVM_MMATypesAttr:$types)> {
+
+ let summary = "Subgroup matrix multiply-add";
+
+ let description = [{
+ The `xevm.mma` is a cooperative operation where all threads/lanes in
+ a subgroup participates and carries out matrix multiplication plus accumulation:
+
+ D = C + A x B
+
+ where the A, B, C input matrices and the result D have shapes:
+ D : MxN
+ C : MxN
+ A : MxK
+ B : KxN
+
+ Parameters:
+ `a` - vector of matrix A elements.
+ `b` - vector of matrix B elements.
+ `c` - (optional) vector of matrix C elements.
+ `shape` - the shape of the matrices, specified as `M`, `N`, and `K` values.
+ `types` - the data types of the matrices, specified as `D`, `A`, `B`, and optionally `C`.
+
+ Example:
+ ```mlir
+ %d = xevm.mma %a, %b, %c { shape=<m=8, n=16, k=16>, types=<d=f32, a=f16, b=f16, c=f32> } : (vector<8xi16>, vector<8xi32>, vector<8xf32>) -> vector<8xf32>
+ ```
+ }];
+
+ let assemblyFormat = [{
+ $a `,` $b (`,` $c^)? ` `
+ `{`
+ `shape` `=` $shape `,`
+ `types` `=` $types
+ `}` attr-dict `:` functional-type(operands, results)
+ }];
+
+ let hasVerifier = 1;
+}
+
+//===----------------------------------------------------------------------===//
+// XeVM target attribute.
+//===----------------------------------------------------------------------===//
+
+def XeVM_TargetAttr : XeVM_Attr<"XeVMTarget", "target"> {
+ let description = [{
+ GPU target attribute for controlling compilation of Intel GPU targets. All
+ parameters decay into default values if not present.
+
+ Examples:
+
+ 1. Target with default values.
+ ```
+ gpu.module @mymodule [#xevm.target] attributes {...} {
+ ...
+ }
+ ```
+ }];
+ let parameters =
+ (ins DefaultValuedParameter<"int", "2",
+ "Optimization level to apply.">:$O,
+ StringRefParameter<"Target triple.",
+ "\"spirv64-unknown-unknown\"">:$triple,
+ StringRefParameter<"Target chip.", "\"bmg\"">:$chip,
+ OptionalParameter<"::mlir::DictionaryAttr",
+ "Target specific flags.">:$flags,
+ OptionalParameter<"::mlir::ArrayAttr",
+ "Files to link to the LLVM module.">:$linkFiles);
+ let assemblyFormat = [{
+ (`<` struct($O, $triple, $chip, $flags, $linkFiles)^ `>`)?
+ }];
+ let builders = [AttrBuilder<
+ (ins CArg<"int", "2">:$optLevel,
+ CArg<"::llvm::StringRef", "\"spirv64-unknown-unknown\"">:$triple,
+ CArg<"::llvm::StringRef", "\"bmg\"">:$chip,
+ CArg<"::mlir::DictionaryAttr", "nullptr">:$targetFlags,
+ CArg<"::mlir::ArrayAttr", "nullptr">:$linkFiles),
+ [{
+ return Base::get($_ctxt, optLevel, triple, chip, targetFlags, linkFiles);
+ }]>];
+ let skipDefaultBuilders = 1;
+ let genVerifyDecl = 1;
+}
+
+#endif // XEVMIR_OPS
diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index 261b0e00bdf86..c6fcf1a0d510b 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -46,6 +46,7 @@
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
#include "mlir/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.h"
+#include "mlir/Dialect/LLVMIR/XeVMDialect.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/AllInterfaces.h"
#include "mlir/Dialect/Linalg/Transforms/RuntimeOpVerification.h"
@@ -152,7 +153,8 @@ inline void registerAllDialects(DialectRegistry ®istry) {
ub::UBDialect,
vector::VectorDialect,
x86vector::X86VectorDialect,
- xegpu::XeGPUDialect>();
+ xegpu::XeGPUDialect,
+ xevm::XeVMDialect>();
// clang-format on
// Register all external models.
diff --git a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt
index d83fd3800eb91..67081ca61e6e5 100644
--- a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt
+++ b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt
@@ -110,3 +110,25 @@ add_mlir_dialect_library(MLIRVCIXDialect
MLIRLLVMDialect
MLIRSideEffectInterfaces
)
+
+add_mlir_dialect_library(MLIRXeVMDialect
+ IR/XeVMDialect.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/LLVMIR
+
+ DEPENDS
+ MLIRGPUCompilationAttrInterfacesIncGen
+ MLIRXeVMOpsIncGen
+ MLIRXeVMConversionsIncGen
+ intrinsics_gen
+
+ LINK_COMPONENTS
+ AsmParser
+ Core
+
+ LINK_LIBS PUBLIC
+ MLIRIR
+ MLIRLLVMDialect
+ MLIRSideEffectInterfaces
+)
diff --git a/mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp
new file mode 100644
index 0000000000000..afb14666f06be
--- /dev/null
+++ b/mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp
@@ -0,0 +1,381 @@
+//===-- XeVMDialect.cpp - XeVM dialect registration -------------*- C++ -*-===//
+//
+// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/LLVMIR/XeVMDialect.h"
+#include "mlir/Dialect/GPU/IR/CompilationInterfaces.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/IR/DialectImplementation.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/FileSystem.h"
+#include "llvm/Support/MathExtras.h"
+
+using namespace mlir;
+using namespace mlir::xevm;
+
+#include <mlir/Dialect/LLVMIR/XeVMOpsDialect.cpp.inc>
+#include <mlir/Dialect/LLVMIR/XeVMOpsEnums.cpp.inc>
+
+namespace {
+constexpr uint32_t subgroupSize = 16;
+
+template <typename Op>
+LogicalResult verifyMatrixInput(Op op) {
+ static_assert(llvm::is_one_of<Op, BlockLoad2dOp, BlockStore2dOp,
+ BlockPrefetch2dOp>::value,
+ "Unexpected template parameter");
+
+ std::optional<int64_t> width = getConstantIntValue(op.getBaseWidth());
+ std::optional<int64_t> pitch = getConstantIntValue(op.getBasePitch());
+ if (pitch && width && *pitch < *width)
+ return op->emitOpError(
+ "4th operand (base pitch) should be >= 2nd operand (base width)");
+
+ uint32_t elemSize = op.getElemSizeInBits();
+ if (elemSize < 8 || !llvm::isPowerOf2_32(elemSize) || elemSize > 32)
+ return op->emitOpError("expecting 'elem_size_in_bits' to be 8, 16, or 32");
+
+ uint32_t tileHeight = op.getTileHeight();
+ if (tileHeight > 32 || !llvm::isPowerOf2_32(tileHeight))
+ return op->emitOpError("expecting tile_height to be 1, 2, 4, 8, 16, or 32");
+
+ uint32_t vBlocks = op.getVBlocks();
+ if (vBlocks > 8 || !llvm::isPowerOf2_32(vBlocks))
+ return op->emitOpError("expecting v_blocks to be 1, 2, 4, or 8");
+
+ return success();
+}
+
+LogicalResult verify2DBlockLoadRestriction(BlockLoad2dOp op) {
+ VectorType resTy = op.getRes().getType();
+ if (!resTy.getElementType().isIntOrFloat())
+ return op.emitOpError()
+ << "expecting result element type to be int or float";
+ unsigned resElemTySize = resTy.getElementType().getIntOrFloatBitWidth();
+ unsigned resSize = resTy.getNumElements() * resElemTySize;
+ unsigned expectedSize = op.getElemSizeInBits() * op.getTileHeight() *
+ op.getTileWidth() * op.getVBlocks() / subgroupSize;
+ if (resSize != expectedSize)
+ return op.emitOpError() << "result size of " << resSize
+ << " bits does not match the expected size of "
+ << expectedSize << " bits";
+
+ if (op.getTranspose() && op.getPackRegister())
+ return op.emitOpError(
+ "transpose and vnni_transform are mutually exclusive");
+
+ if (!op.getTranspose() && !op.getPackRegister()) {
+ uint32_t tileHeight = op.getTileHeight();
+ if (tileHeight < 1 || tileHeight > 32)
+ return op.emitOpError("expecting tile_height to be between 1 and 32");
+
+ uint32_t tileWidth = op.getTileWidth();
+ uint32_t vBlocks = op.getVBlocks();
+ switch (op.getElemSizeInBits()) {
+ case 8:
+ if (tileWidth < 4 || tileWidth > 64)
+ return op.emitOpError("expecting tile_width to be between 4 and 64");
+ if (vBlocks != 1 && vBlocks != 2 && vBlocks != 4)
+ return op.emitOpError("expecting v_blocks to be 1, 2, or 4");
+ if (tileWidth * vBlocks > 64)
+ return op.emitOpError(
+ "tile_width * v_blocks should be less than or equal "
+ "to 64 for 8 bit elements");
+ break;
+ case 16:
+ if (tileWidth < 2 || tileWidth > 32)
+ return op.emitOpError("expecting tile_width to be between 2 and 32");
+ if (vBlocks != 1 && vBlocks != 2 && vBlocks != 4)
+ return op.emitOpError("expecting v_blocks to be 1, 2, or 4");
+ if (tileWidth * vBlocks > 32)
+ return op.emitOpError(
+ "tile_width * v_blocks should be less than or equal "
+ "to 32 for 16 bit elements");
+ break;
+ case 32:
+ if (tileWidth < 1 || tileWidth > 16)
+ return op.emitOpError("expecting tile_width to be between 1 and 16");
+ if (vBlocks != 1 && vBlocks != 2)
+ return op.emitOpError("expecting v_blocks to be 1 or 2");
+ if (tileWidth * vBlocks > 16)
+ return op.emitOpError(
+ "tile_width * v_blocks should be less than or equal "
+ "to 16 for 32 bit elements");
+ break;
+ case 64:
+ if (tileWidth < 1 || tileWidth > 8)
+ return op.emitOpError("expecting tile_width to be between 1 and 8");
+ if (vBlocks != 1)
+ return op.emitOpError("expecting v_blocks to be 1");
+ break;
+ default:
+ return op.emitOpError(
+ "expecting elem_size_in_bits to be 8, 16, 32, or 64");
+ }
+
+ return success();
+ }
+
+ if (op.getTranspose()) {
+ assert(!op.getPackRegister() && "Expecting vnni_transform should be false");
+
+ uint32_t vBlocks = op.getVBlocks();
+ if (vBlocks != 1)
+ return op.emitOpError("expecting v_blocks to be 1");
+
+ uint32_t tileHeight = op.getTileHeight();
+ uint32_t tileWidth = op.getTileWidth();
+ switch (op.getElemSizeInBits()) {
+ case 32:
+ if (tileHeight < 1 || tileHeight > 32)
+ return op.emitOpError("expecting tile_height to be between 1 and 32");
+ if (tileWidth < 1 || tileWidth > 8)
+ return op.emitOpError("expecting tile_width to be between 1 and 8");
+ break;
+ case 64:
+ if (tileHeight != 8)
+ return op.emitOpError(
+ "expecting tile_height to be 8 for 64 bit elements");
+ if (tileWidth != 1 && tileWidth != 2 && tileWidth != 4)
+ return op.emitOpError("expecting tile_width to be 1, 2, or 4");
+ break;
+ default:
+ return op.emitOpError("transpose is only supported for 32 and 64 bit "
+ "elements");
+ }
+
+ return success();
+ }
+
+ assert(op.getPackRegister() && !op.getTranspose() &&
+ "Expecting vnni_transform should be true and transpose should be "
+ "false");
+
+ uint32_t vBlocks = op.getVBlocks();
+ if (vBlocks != 1 && vBlocks != 2 && vBlocks != 4)
+ return op.emitOpError("expecting v_blocks to be 1, 2, or 4");
+
+ uint32_t tileHeight = op.getTileHeight();
+ uint32_t tileWidth = op.getTileWidth();
+ switch (op.getElemSizeInBits()) {
+ case 8:
+ if (tileHeight < 4 || tileHeight > 32)
+ return op.emitOpError("expecting tile_height to be between 4 and 32");
+ if (tileWidth < 4 || tileWidth > 16)
+ return op.emitOpError("expecting tile_width to be between 4 and 16");
+ break;
+ case 16:
+ if (tileHeight < 2 || tileHeight > 32)
+ return op.emitOpError("expecting tile_height to be between 2 and 32");
+ if (tileWidth < 2 || tileWidth > 16)
+ return op.emitOpError("expecting tile_width to be between 2 and 16");
+ if (tileWidth * vBlocks > 32)
+ return op.emitOpError(
+ "tile_width * v_blocks should be less than or equal "
+ "to 32 for 16 bit elements");
+ break;
+ default:
+ return op.emitOpError("vnni_transform is only supported for 8 and 16 bit "
+ "elements");
+ }
+
+ return success();
+}
+
+static LogicalResult verify2DBlockStoreRestriction(BlockStore2dOp op) {
+ uint32_t tileHeight = op.getTileHeight();
+ if (tileHeight < 1 || tileHeight > 8)
+ return op.emitOpError("expecting tile_height to be between 1 and 8");
+
+ uint32_t tileWidth = op.getTileWidth();
+ switch (op.getElemSizeInBits()) {
+ case 8:
+ if (tileWidth < 4 || tileWidth > 64)
+ return op.emitOpError("expecting tile_width to be between 4 and 64");
+ break;
+ case 16:
+ if (tileWidth < 2 || tileWidth > 32)
+ return op.emitOpError("expecting tile_width to be between 2 and 32");
+ break;
+ case 32:
+ if (tileWidth < 1 || tileWidth > 16)
+ return op.emitOpError("expecting tile_width to be between 1 and 16");
+ break;
+ case 64:
+ if (tileWidth < 1 || tileWidth > 8)
+ return op.emitOpError("expecting tile_width to be between 1 and 8");
+ break;
+ default:
+ return op.emitOpError("expecting elem_size_in_bits to be 8, 16, 32, or 64");
+ }
+
+ uint32_t vBlocks = op.getVBlocks();
+ if (vBlocks != 1)
+ return op.emitOpError("expecting v_blocks to be 1");
+ return success();
+}
+
+} // namespace
+
+LogicalResult BlockLoad2dOp::verify() {
+ if (verify2DBlockLoadRestriction(*this).failed())
+ return failure();
+
+ if (verifyMatrixInput(*this).failed())
+ return failure();
+
+ VectorType resTy = getRes().getType();
+ if (!resTy.getElementType().isIntOrFloat())
+ return emitOpError() << "expecting result element type to be int of float";
+ unsigned resElemTySize = resTy.getElementType().getIntOrFloatBitWidth();
+ if (getElemSizeInBits() == 32 || getPackRegister()) {
+ if (resElemTySize != 32)
+ return emitOpError() << "expecting result element type to be 32 bits";
+ }
+
+ uint32_t tileWidth = getTileWidth();
+ if (getPackRegister()) {
+ if (tileWidth != 16)
+ return emitOpError(
+ "tile_width when vnni_transform is true should be equal "
+ "to subgroup size (16 elements)");
+ return success();
+ }
+
+ return success();
+}
+
+LogicalResult BlockStore2dOp::verify() {
+ if (verify2DBlockStoreRestriction(*this).failed())
+ return failure();
+
+ if (verifyMatrixInput(*this).failed())
+ return failure();
+
+ uint32_t tileWidth = getTileWidth();
+ switch (getElemSizeInBits()) {
+ case 8:
+ if (tileWidth != 16 && tileWidth != 32)
+ return emitOpError("tile_width for 8 bit elements should be equal to "
+ "16 or 32");
+ break;
+ case 16:
+ if (tileWidth != 16)
+ return emitOpError("tile_width for 16 bit elements should be equal "
+ "to 16");
+ break;
+ case 32:
+ if (tileWidth != 16)
+ return emitOpError("tile_width for 32 bit elements should be equal "
+ "to 16");
+ break;
+ default:
+ llvm_unreachable("unexpected element size");
+ }
+
+ return success();
+}
+
+LogicalResult BlockPrefetch2dOp::verify() {
+ if (verifyMatrixInput(*this).failed())
+ return failure();
+
+ uint32_t tileWidth = getTileWidth();
+ switch (getElemSizeInBits()) {
+ case 8:
+ if (tileWidth != 16 && tileWidth != 32)
+ return emitOpError("tile_width for 8 bit elements should be equal to "
+ "16 or 32");
+ break;
+ case 16:
+ if (tileWidth != 16)
+ return emitOpError("tile_width for 16 bit elements should be equal "
+ "to 16");
+ break;
+ case 32:
+ if (tileWidth != 8 && tileWidth != 16)
+ return emitOpError(
+ "tile_width for 32 bit elements should be equal to 8 or 16");
+ break;
+ default:
+ llvm_unreachable("unexpected element size");
+ }
+
+ return success();
+}
+
+LogicalResult MMAOp::verify() {
+ if (getC()) {
+ if (getResult().getType() != getC().getType())
+ return emitOpError("type of C operand must match result type");
+ }
+ return success();
+}
+
+LogicalResult PrefetchOp::verify() {
+ auto addrSpace = getAddrspace();
+ if (addrSpace != AddrSpace::GLOBAL && addrSpace != AddrSpace::GENERIC) {
+ return emitOpError("address space must be global or generic");
+ }
+ return success();
+}
+
+LogicalResult
+XeVMTargetAttr::verify(function_ref<InFlightDiagnostic()> emitError, int O,
+ StringRef triple, StringRef chip, DictionaryAttr flags,
+ ArrayAttr linkFiles) {
+ if (O < 0 || O > 3) {
+ emitError() << "The optimization level must be a number between 0 and 3.";
+ return failure();
+ }
+ if (triple.empty()) {
+ emitError() << "The target triple cannot be empty.";
+ return failure();
+ }
+ if (chip.empty()) {
+ emitError() << "The target chip cannot be empty.";
+ return failure();
+ }
+ if (linkFiles) {
+ for (Attribute fileAttr : linkFiles) {
+ if (auto fileStrAttr = llvm::dyn_cast<StringAttr>(fileAttr)) {
+ StringRef filePath = fileStrAttr.getValue();
+ if (filePath.empty()) {
+ emitError() << "File paths in linkFiles cannot be empty.";
+ return failure();
+ }
+ if (!llvm::sys::fs::exists(filePath)) {
+ emitError() << "File '" << filePath << "' does not exist.";
+ return failure();
+ }
+ }
+ }
+ }
+ return success();
+}
+
+void XeVMDialect::initialize() {
+ // NOLINTBEGIN
+ addOperations<
+#define GET_OP_LIST
+#include "mlir/Dialect/LLVMIR/XeVMOps.cpp.inc"
+ >();
+
+ addAttributes<
+#define GET_ATTRDEF_LIST
+#include "mlir/Dialect/LLVMIR/XeVMOpsAttributes.cpp.inc"
+ >();
+ // NOLINTEND
+ declarePromisedInterface<mlir::gpu::TargetAttrInterface,
+ mlir::xevm::XeVMTargetAttr>();
+}
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/LLVMIR/XeVMOps.cpp.inc"
+
+#define GET_ATTRDEF_CLASSES
+#include "mlir/Dialect/LLVMIR/XeVMOpsAttributes.cpp.inc"
diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index 251ca716c7a7a..3a17926f4778c 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -1875,7 +1875,7 @@ llvm.mlir.global @bad_struct_array_init_elements() : !llvm.array<1x!llvm.struct<
llvm.return %0 : !llvm.array<1x!llvm.struct<(i32, f32)>>
}
-// ----
+// -----
llvm.mlir.global internal constant @bad_array_attr_simple_type() : !llvm.array<2 x f64> {
// expected-error at below {{'llvm.mlir.constant' op for array with an array attribute must have a struct element type}}
@@ -1883,10 +1883,51 @@ llvm.mlir.global internal constant @bad_array_attr_simple_type() : !llvm.array<2
llvm.return %0 : !llvm.array<2 x f64>
}
-// ----
+// -----
llvm.func @inlineAsmMustTail(%arg0: i32, %arg1 : !llvm.ptr) {
// expected-error at +1 {{op tail call kind 'musttail' is not supported}}
%8 = llvm.inline_asm tail_call_kind = <musttail> "foo", "=r,=r,r" %arg0 : (i32) -> !llvm.struct<(i8, i8)>
llvm.return
}
+
+// -----
+
+llvm.func @invalid_xevm_prefetch(%arg0: !llvm.ptr<1>) {
+ // expected-error at +1 {{op address space must be global or generic}}
+ xevm.prefetch %arg0 <{addrspace = #xevm.addr_space<private>, cache_control = #xevm.load_cache_control<L1uc_L2uc_L3uc>}> : (!llvm.ptr<1>)
+ llvm.return
+}
+
+// -----
+
+llvm.func @invalid_xevm_mma(%loaded_c_casted: vector<4xf32>, %loaded_a: vector<8xi16>, %loaded_b_casted: vector<8xi32>) -> vector<8xf32> {
+ // expected-error at +1 {{op type of C operand must match result type}}
+ %c_result = xevm.mma %loaded_a, %loaded_b_casted, %loaded_c_casted {shape = <m = 8, n = 16, k = 16>, types = <d = f32, a = f16, b = f16, c = f32>} : (vector<8xi16>, vector<8xi32>, vector<4xf32>) -> vector<8xf32>
+ llvm.return %c_result : vector<8xf32>
+}
+
+// -----
+
+llvm.func @invalid_xevm_matrix_1(%c: !llvm.ptr<1>, %base_width_c: i32, %base_height_c: i32, %base_pitch_c: i32, %x: i32, %y: i32, %c_result_casted: vector<8xi32>) {
+ // expected-error at +1 {{op expecting tile_width to be between 1 and 8}}
+ xevm.blockstore2d %c, %base_width_c, %base_height_c, %base_pitch_c, %x, %y, %c_result_casted <{elem_size_in_bits=64 : i32, tile_width=16 : i32, tile_height=8 : i32, cache_control=#xevm.store_cache_control<Default>}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi32>)
+ llvm.return
+}
+
+// -----
+
+llvm.func @invalid_xevm_matrix_2(%c: !llvm.ptr<1>, %base_width_c: i32, %base_height_c: i32, %base_pitch_c: i32, %x: i32, %y: i32, %c_result_casted: vector<8xi32>) {
+ // expected-error at +1 {{op expecting elem_size_in_bits to be 8, 16, 32, or 64}}
+ xevm.blockstore2d %c, %base_width_c, %base_height_c, %base_pitch_c, %x, %y, %c_result_casted <{elem_size_in_bits=18 : i32, tile_width=16 : i32, tile_height=8 : i32, cache_control=#xevm.store_cache_control<Default>}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi32>)
+ llvm.return
+}
+
+// -----
+
+llvm.func @invalid_xevm_matrix_3(%a: !llvm.ptr<1>, %base_width_a: i32, %base_height_a: i32, %base_pitch_a: i32, %x: i32, %y: i32) -> vector<8xi16> {
+ // expected-error at +1 {{op result size of 128 bits does not match the expected size of 208 bits}}
+ %loaded_a = xevm.blockload2d %a, %base_width_a, %base_height_a, %base_pitch_a, %x, %y <{elem_size_in_bits=16 : i32, tile_width=26 : i32, tile_height=8 : i32, v_blocks=1 : i32, transpose=false, pack_register=false, cache_control=#xevm.load_cache_control<Default>}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi16>
+ llvm.return %loaded_a : vector<8xi16>
+}
+
diff --git a/mlir/test/Dialect/LLVMIR/xevm.mlir b/mlir/test/Dialect/LLVMIR/xevm.mlir
new file mode 100644
index 0000000000000..3ea6768345f6e
--- /dev/null
+++ b/mlir/test/Dialect/LLVMIR/xevm.mlir
@@ -0,0 +1,76 @@
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics | FileCheck %s
+
+// CHECK: func.func @blockload2d(%[[ARG0:.*]]: !llvm.ptr<1>, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32, %[[ARG3:.*]]: i32, %[[ARG4:.*]]: i32, %[[ARG5:.*]]: i32)
+func.func @blockload2d(%a: !llvm.ptr<1>, %base_width_a: i32, %base_height_a: i32, %base_pitch_a: i32, %x: i32, %y: i32) -> vector<8xi16> {
+ // CHECK: %[[VAR0:.*]] = xevm.blockload2d %[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG4]], %[[ARG5]]
+ // CHECK-DAG: elem_size_in_bits = 16 : i32
+ // CHECK-DAG: tile_width = 16 : i32
+ // CHECK-DAG: tile_height = 8 : i32
+ // CHECK-DAG: v_blocks = 1 : i32
+ // CHECK-DAG: transpose = false
+ // CHECK-DAG: pack_register = false
+ // CHECK-DAG: cache_control = #xevm.load_cache_control<Default>
+ // CHECK: (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi16>
+ %loaded_a = xevm.blockload2d %a, %base_width_a, %base_height_a, %base_pitch_a, %x, %y <{elem_size_in_bits=16 : i32, tile_width=16 : i32, tile_height=8 : i32, v_blocks=1 : i32, transpose=false, pack_register=false, cache_control=#xevm.load_cache_control<Default>}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi16>
+ return %loaded_a : vector<8xi16>
+}
+
+// -----
+// CHECK: func.func @blockstore2d(%[[ARG0:.*]]: !llvm.ptr<1>, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32, %[[ARG3:.*]]: i32, %[[ARG4:.*]]: i32, %[[ARG5:.*]]: i32, %[[ARG6:.*]]: vector<8xi32>)
+func.func @blockstore2d(%c: !llvm.ptr<1>, %base_width_c: i32, %base_height_c: i32, %base_pitch_c: i32, %x: i32, %y: i32, %c_result_casted: vector<8xi32>) {
+ // CHECK: xevm.blockstore2d %[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG4]], %[[ARG5]], %[[ARG6]]
+ // CHECK-DAG: elem_size_in_bits = 32 : i32
+ // CHECK-DAG: tile_width = 16 : i32
+ // CHECK-DAG: tile_height = 8 : i32
+ // CHECK-DAG: cache_control = #xevm.store_cache_control<Default>
+ // CHECK: (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi32>)
+ xevm.blockstore2d %c, %base_width_c, %base_height_c, %base_pitch_c, %x, %y, %c_result_casted <{elem_size_in_bits=32 : i32, tile_width=16 : i32, tile_height=8 : i32, cache_control=#xevm.store_cache_control<Default>}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi32>)
+ return
+}
+
+// -----
+// CHECK: func.func @blockprefetch2d(%[[ARG0:.*]]: !llvm.ptr<1>, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32, %[[ARG3:.*]]: i32, %[[ARG4:.*]]: i32, %[[ARG5:.*]]: i32)
+func.func @blockprefetch2d(%ptr: !llvm.ptr<1>, %base_width: i32, %base_height: i32, %base_pitch: i32, %x: i32, %y: i32) {
+ // CHECK: xevm.blockprefetch2d %[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG4]], %[[ARG5]]
+ // CHECK-DAG: elem_size_in_bits = 8 : i32
+ // CHECK-DAG: tile_width = 32 : i32
+ // CHECK-DAG: tile_height = 8 : i32
+ // CHECK-DAG: v_blocks = 1 : i32
+ // CHECK-DAG: cache_control = #xevm.load_cache_control<L1uc_L2uc_L3uc>
+ // CHECK: (!llvm.ptr<1>, i32, i32, i32, i32, i32)
+ xevm.blockprefetch2d %ptr, %base_width, %base_height, %base_pitch, %x, %y <{elem_size_in_bits=8 : i32, tile_width=32 : i32, tile_height=8 : i32, v_blocks=1 : i32, cache_control=#xevm.load_cache_control<L1uc_L2uc_L3uc>}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32)
+ return
+}
+
+// -----
+// CHECK: func.func @mma(%[[ARG0:.*]]: vector<8xf32>, %[[ARG1:.*]]: vector<8xi16>, %[[ARG2:.*]]: vector<8xi32>)
+func.func @mma(%loaded_c_casted: vector<8xf32>, %loaded_a: vector<8xi16>, %loaded_b_casted: vector<8xi32>) -> vector<8xf32> {
+ // CHECK: %0 = xevm.mma %[[ARG1]], %[[ARG2]], %[[ARG0]] {shape = <m = 8, n = 16, k = 16>, types = <d = f32, a = f16, b = f16, c = f32>} : (vector<8xi16>, vector<8xi32>, vector<8xf32>) -> vector<8xf32>
+ %c_result = xevm.mma %loaded_a, %loaded_b_casted, %loaded_c_casted { shape=<m=8, n=16, k=16>, types=<d=f32, a=f16, b=f16, c=f32> } : (vector<8xi16>, vector<8xi32>, vector<8xf32>) -> vector<8xf32>
+ return %c_result : vector<8xf32>
+}
+
+// -----
+func.func @memfence() {
+ // CHECK: xevm.memfence
+ // CHECK-DAG: addrspace = #xevm.addr_space<global>
+ // CHECK-DAG: scope = #xevm.mem_scope<workgroup>
+ xevm.memfence <{addrspace=#xevm.addr_space<global>, scope=#xevm.mem_scope<workgroup>}>
+ return
+}
+
+// -----
+// CHECK: func.func @prefetch(%[[ARG0:.*]]: !llvm.ptr<1>)
+func.func @prefetch(%ptr: !llvm.ptr<1>) {
+ // CHECK: xevm.prefetch %[[ARG0]]
+ // CHECK-DAG: addrspace = #xevm.addr_space<global>
+ // CHECK-DAG: cache_control = #xevm.load_cache_control<L1uc_L2uc_L3uc>
+ // CHECK: (!llvm.ptr<1>)
+ xevm.prefetch %ptr <{addrspace = #xevm.addr_space<global>, cache_control = #xevm.load_cache_control<L1uc_L2uc_L3uc>}> : (!llvm.ptr<1>)
+ return
+}
+
+// -----
+// CHECK: @xevm_module [#xevm.target<O = 3, chip = "pvc">] {
+gpu.module @xevm_module [#xevm.target<O = 3, chip = "pvc">]{
+}
diff --git a/mlir/test/lib/Dialect/GPU/CMakeLists.txt b/mlir/test/lib/Dialect/GPU/CMakeLists.txt
index 4ca5974ed5a49..418c884dc03b3 100644
--- a/mlir/test/lib/Dialect/GPU/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/GPU/CMakeLists.txt
@@ -29,6 +29,7 @@ set(LIBS
MLIRTranslateLib
MLIRVectorDialect
MLIRVectorToLLVMPass
+ MLIRXeVMDialect
)
add_mlir_library(MLIRGPUTestPasses
>From b58a57726d9738ff196c7dfa4aa6a48053823e60 Mon Sep 17 00:00:00 2001
From: "Lee, Sang Ik" <sang.ik.lee at intel.com>
Date: Mon, 23 Jun 2025 20:24:26 +0000
Subject: [PATCH 2/8] Replace incorrect term vnni_transform with pack_register.
---
mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp | 10 +++++-----
1 file changed, 5 insertions(+), 5 deletions(-)
diff --git a/mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp
index afb14666f06be..de73fcbb8fdaf 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp
@@ -66,7 +66,7 @@ LogicalResult verify2DBlockLoadRestriction(BlockLoad2dOp op) {
if (op.getTranspose() && op.getPackRegister())
return op.emitOpError(
- "transpose and vnni_transform are mutually exclusive");
+ "transpose and pack_register are mutually exclusive");
if (!op.getTranspose() && !op.getPackRegister()) {
uint32_t tileHeight = op.getTileHeight();
@@ -121,7 +121,7 @@ LogicalResult verify2DBlockLoadRestriction(BlockLoad2dOp op) {
}
if (op.getTranspose()) {
- assert(!op.getPackRegister() && "Expecting vnni_transform should be false");
+ assert(!op.getPackRegister() && "Expecting pack_register should be false");
uint32_t vBlocks = op.getVBlocks();
if (vBlocks != 1)
@@ -152,7 +152,7 @@ LogicalResult verify2DBlockLoadRestriction(BlockLoad2dOp op) {
}
assert(op.getPackRegister() && !op.getTranspose() &&
- "Expecting vnni_transform should be true and transpose should be "
+ "Expecting pack_register should be true and transpose should be "
"false");
uint32_t vBlocks = op.getVBlocks();
@@ -179,7 +179,7 @@ LogicalResult verify2DBlockLoadRestriction(BlockLoad2dOp op) {
"to 32 for 16 bit elements");
break;
default:
- return op.emitOpError("vnni_transform is only supported for 8 and 16 bit "
+ return op.emitOpError("pack_register is only supported for 8 and 16 bit "
"elements");
}
@@ -241,7 +241,7 @@ LogicalResult BlockLoad2dOp::verify() {
if (getPackRegister()) {
if (tileWidth != 16)
return emitOpError(
- "tile_width when vnni_transform is true should be equal "
+ "tile_width when pack_register is true should be equal "
"to subgroup size (16 elements)");
return success();
}
>From 38bd98097a0cd485085970cfbcbccf4211df35f4 Mon Sep 17 00:00:00 2001
From: "Lee, Sang Ik" <sang.ik.lee at intel.com>
Date: Mon, 23 Jun 2025 20:31:26 +0000
Subject: [PATCH 3/8] Fix typo. Remove unused dpendency. Return emitError().
---
mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td | 2 +-
mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp | 15 +++++----------
mlir/test/lib/Dialect/GPU/CMakeLists.txt | 1 -
3 files changed, 6 insertions(+), 12 deletions(-)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td
index 9525c4a731efa..cc8650b8e4d22 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td
@@ -444,7 +444,7 @@ def XeVM_MMAShapeAttr : XeVM_Attr<"MMAShape", "mma_shape"> {
MMA operation is represented as D=AxB+C, where
A has the shape MxK.
B has the shape KxN.
- D and C havethe shape MxN.
+ D and C have the shape MxN.
This attribute encodes the shape of all matrices that participate in MMA.
}];
let parameters = (ins "int":$m, "int":$n, "int":$k);
diff --git a/mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp
index de73fcbb8fdaf..edb8dce84321c 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp
@@ -329,28 +329,23 @@ XeVMTargetAttr::verify(function_ref<InFlightDiagnostic()> emitError, int O,
StringRef triple, StringRef chip, DictionaryAttr flags,
ArrayAttr linkFiles) {
if (O < 0 || O > 3) {
- emitError() << "The optimization level must be a number between 0 and 3.";
- return failure();
+ return emitError() << "The optimization level must be a number between 0 and 3.";
}
if (triple.empty()) {
- emitError() << "The target triple cannot be empty.";
- return failure();
+ return emitError() << "The target triple cannot be empty.";
}
if (chip.empty()) {
- emitError() << "The target chip cannot be empty.";
- return failure();
+ return emitError() << "The target chip cannot be empty.";
}
if (linkFiles) {
for (Attribute fileAttr : linkFiles) {
if (auto fileStrAttr = llvm::dyn_cast<StringAttr>(fileAttr)) {
StringRef filePath = fileStrAttr.getValue();
if (filePath.empty()) {
- emitError() << "File paths in linkFiles cannot be empty.";
- return failure();
+ return emitError() << "File paths in linkFiles cannot be empty.";
}
if (!llvm::sys::fs::exists(filePath)) {
- emitError() << "File '" << filePath << "' does not exist.";
- return failure();
+ return emitError() << "File '" << filePath << "' does not exist.";
}
}
}
diff --git a/mlir/test/lib/Dialect/GPU/CMakeLists.txt b/mlir/test/lib/Dialect/GPU/CMakeLists.txt
index 418c884dc03b3..4ca5974ed5a49 100644
--- a/mlir/test/lib/Dialect/GPU/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/GPU/CMakeLists.txt
@@ -29,7 +29,6 @@ set(LIBS
MLIRTranslateLib
MLIRVectorDialect
MLIRVectorToLLVMPass
- MLIRXeVMDialect
)
add_mlir_library(MLIRGPUTestPasses
>From 81b17d819c43f39f99bdc90facdb5050cb7706e5 Mon Sep 17 00:00:00 2001
From: "Lee, Sang Ik" <sang.ik.lee at intel.com>
Date: Mon, 23 Jun 2025 20:43:44 +0000
Subject: [PATCH 4/8] Add line breaks in doc embedded examples.
---
mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td | 22 ++++++++++++++++-----
1 file changed, 17 insertions(+), 5 deletions(-)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td
index cc8650b8e4d22..8f8f822de2182 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td
@@ -190,7 +190,8 @@ def XeVM_BlockLoad2dOp
- 8 for int8
$v_blocks - number of consecutive tiles in innermost dimension direction to load
$transpose - transpose the tile in registers (useful for 32 bit element type)
- $pack_register - pack element types narrower than register bit width. [M, N] => [M/factor, N, factor] where factor is register_size_in_bits / elem_size_in_bits
+ $pack_register - pack element types narrower than register bit width.
+ [M, N] => [M/factor, N, factor] where factor is register_size_in_bits / elem_size_in_bits
$cache_control - an enumerator that sets the cache behaviour
Notes:
@@ -206,7 +207,11 @@ def XeVM_BlockLoad2dOp
%base_pitch_a = arith.constant 32 : i32
%x = arith.constant 0 : i32
%y = arith.constant 0 : i32
- %loaded_a = xevm.blockload2d %src, %base_width_a, %base_height_a, %base_pitch_a, %x, %y <{elem_size_in_bits=16 : i32, tile_width=16 : i32, tile_height=8 : i32, v_blocks=1 : i32, transpose=false : i32, pack_register=false, cache_control=#xevm.load_cache_control<Default>}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi16>
+ %loaded_a = xevm.blockload2d %src, %base_width_a, %base_height_a, %base_pitch_a, %x, %y
+ <{elem_size_in_bits=16 : i32, tile_width=16 : i32, tile_height=8 : i32,
+ v_blocks=1 : i32, transpose=false : i32, pack_register=false,
+ cache_control=#xevm.load_cache_control<Default>}>
+ : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi16>
```
}];
@@ -259,7 +264,10 @@ def XeVM_BlockStore2dOp
%base_pitch_c = arith.constant 64 : i32
%x = arith.constant 0 : i32
%y = arith.constant 0 : i32
- xevm.blockstore2d %dst, %base_width_c, %base_height_c, %base_pitch_c, %x, %y, %src <{elem_size_in_bits=32 : i32, tile_width=16 : i32, tile_height=8 : i32, cache_control=#xevm.load_cache_control<Default>}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi32>)
+ xevm.blockstore2d %dst, %base_width_c, %base_height_c, %base_pitch_c, %x, %y, %src
+ <{elem_size_in_bits=32 : i32, tile_width=16 : i32, tile_height=8 : i32,
+ cache_control=#xevm.load_cache_control<Default>}>
+ : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi32>)
```
}];
@@ -398,7 +406,10 @@ def XeVM_BlockPrefetch2dOp
Example:
```mlir
- xevm.blockprefetch2d %ptr, %base_width, %base_height, %base_pitch, %x, %y <{elem_size_in_bits=8 : i32, tile_width=32 : i32, tile_height=8 : i32, v_blocks=1 : i32, cache_control=#xevm.load_cache_control<L1uc_L2uc_L3uc>}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32)
+ xevm.blockprefetch2d %ptr, %base_width, %base_height, %base_pitch, %x, %y
+ <{elem_size_in_bits=8 : i32, tile_width=32 : i32, tile_height=8 : i32,
+ v_blocks=1 : i32, cache_control=#xevm.load_cache_control<L1uc_L2uc_L3uc>}>
+ : (!llvm.ptr<1>, i32, i32, i32, i32, i32)
```
}];
@@ -488,7 +499,8 @@ def XeVM_MMAOp
Example:
```mlir
- %d = xevm.mma %a, %b, %c { shape=<m=8, n=16, k=16>, types=<d=f32, a=f16, b=f16, c=f32> } : (vector<8xi16>, vector<8xi32>, vector<8xf32>) -> vector<8xf32>
+ %d = xevm.mma %a, %b, %c { shape=<m=8, n=16, k=16>, types=<d=f32, a=f16, b=f16, c=f32> }
+ : (vector<8xi16>, vector<8xi32>, vector<8xf32>) -> vector<8xf32>
```
}];
>From 3d521aab778cbe365cf2d6960c45b2c8ba4c89a0 Mon Sep 17 00:00:00 2001
From: "Lee, Sang Ik" <sang.ik.lee at intel.com>
Date: Mon, 23 Jun 2025 20:54:57 +0000
Subject: [PATCH 5/8] Add top level comment on enums used for cache control.
---
mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td | 20 ++++++++++++++++++++
1 file changed, 20 insertions(+)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td
index 8f8f822de2182..14fbb325cd4c5 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td
@@ -62,6 +62,16 @@ class XeVM_Op<string mnemonic, list<Trait> traits = []>
def XeVM_ElemType : AnyTypeOf<[AnyI8, AnyI16, AnyI32, F32, TF32, F16, BF16]>;
+//===----------------------------------------------------------------------===//
+// XeVM Load Cache Control
+// L1, L2, L3 - cache levels
+// uc - uncached
+// c - cached
+// s - streaming
+// ir - invalidated after read
+// Default - default cache behavior for L1, L2 and L3 cache
+//===----------------------------------------------------------------------===//
+
def LoadCacheControlDefault : I32EnumAttrCase<"DEFAULT", 0, "Default">;
def LoadCacheControl_L1uc_L2uc_L3uc
: I32EnumAttrCase<"L1UC_L2UC_L3UC", 1, "L1uc_L2uc_L3uc">;
@@ -110,6 +120,16 @@ def XeVM_LoadCacheControlAttr
let assemblyFormat = "`<` $value `>`";
}
+//===----------------------------------------------------------------------===//
+// XeVM Store Cache Control
+// L1, L2, L3 - cache levels
+// uc - uncached
+// wb - write-back
+// wt - write-through
+// s - streaming
+// Default - default cache behavior for L1, L2 and L3 cache
+//===----------------------------------------------------------------------===//
+
def StoreCacheControlDefault : I32EnumAttrCase<"DEFAULT", 0, "Default">;
def StoreCacheControl_L1uc_L2uc_L3uc
: I32EnumAttrCase<"L1UC_L2UC_L3UC", 1, "L1uc_L2uc_L3uc">;
>From c0f7bbe6925a41198681cd6bc0b8c04a97c86b9b Mon Sep 17 00:00:00 2001
From: "Lee, Sang Ik" <sang.ik.lee at intel.com>
Date: Mon, 23 Jun 2025 21:23:50 +0000
Subject: [PATCH 6/8] Add summary and description for XeVM dialect.
---
mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td | 7 +++++++
1 file changed, 7 insertions(+)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td
index 14fbb325cd4c5..f165b547a3078 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td
@@ -18,6 +18,13 @@ include "mlir/IR/EnumAttr.td"
def XeVM_Dialect : Dialect {
let name = "xevm";
let cppNamespace = "::mlir::xevm";
+ let summary = "The XeVM dialect that extends LLVM dialect and models Intel GPU's hardware features.";
+ let description = [{
+ The XeVM dialect is extension to the LLVM dialect that models hardware
+ features of Intel GPUs. The dialect is designed to work with the Xe
+ architecture for Intel GPUs, supporting advanced operations like 2D block
+ loads, stores, prefetch and matrix multiply-add (MMA) operations.
+ }];
let dependentDialects = ["LLVM::LLVMDialect"];
let extraClassDeclaration = [{
>From 477967ce51303f199f5c914912d6896014532a43 Mon Sep 17 00:00:00 2001
From: "Lee, Sang Ik" <sang.ik.lee at intel.com>
Date: Mon, 23 Jun 2025 21:24:41 +0000
Subject: [PATCH 7/8] Run clang-format.
---
mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td | 3 ++-
mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp | 6 +++---
2 files changed, 5 insertions(+), 4 deletions(-)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td
index f165b547a3078..87583dc2526bb 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td
@@ -18,7 +18,8 @@ include "mlir/IR/EnumAttr.td"
def XeVM_Dialect : Dialect {
let name = "xevm";
let cppNamespace = "::mlir::xevm";
- let summary = "The XeVM dialect that extends LLVM dialect and models Intel GPU's hardware features.";
+ let summary = "The XeVM dialect that extends LLVM dialect and models Intel "
+ "GPU's hardware features.";
let description = [{
The XeVM dialect is extension to the LLVM dialect that models hardware
features of Intel GPUs. The dialect is designed to work with the Xe
diff --git a/mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp
index edb8dce84321c..c3f363563b38e 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp
@@ -65,8 +65,7 @@ LogicalResult verify2DBlockLoadRestriction(BlockLoad2dOp op) {
<< expectedSize << " bits";
if (op.getTranspose() && op.getPackRegister())
- return op.emitOpError(
- "transpose and pack_register are mutually exclusive");
+ return op.emitOpError("transpose and pack_register are mutually exclusive");
if (!op.getTranspose() && !op.getPackRegister()) {
uint32_t tileHeight = op.getTileHeight();
@@ -329,7 +328,8 @@ XeVMTargetAttr::verify(function_ref<InFlightDiagnostic()> emitError, int O,
StringRef triple, StringRef chip, DictionaryAttr flags,
ArrayAttr linkFiles) {
if (O < 0 || O > 3) {
- return emitError() << "The optimization level must be a number between 0 and 3.";
+ return emitError()
+ << "The optimization level must be a number between 0 and 3.";
}
if (triple.empty()) {
return emitError() << "The target triple cannot be empty.";
>From 042fb9acf614377dec69a02bf0b5e1212c675360 Mon Sep 17 00:00:00 2001
From: "Lee, Sang Ik" <sang.ik.lee at intel.com>
Date: Tue, 24 Jun 2025 13:40:25 -0700
Subject: [PATCH 8/8] xevm.prefetch: use LLVM pointer address space
---
mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td | 5 ++---
mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp | 9 +++++----
mlir/test/Dialect/LLVMIR/invalid.mlir | 6 +++---
mlir/test/Dialect/LLVMIR/xevm.mlir | 3 +--
4 files changed, 11 insertions(+), 12 deletions(-)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td
index 87583dc2526bb..ca055670a9527 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td
@@ -381,13 +381,12 @@ def XeVM_MemfenceOp
def XeVM_PrefetchOp
: XeVM_Op<"prefetch">,
Arguments<(ins Arg<LLVM_AnyPointer, "", [MemRead]>:$ptr,
- XeVM_AddrSpaceAttr:$addrspace,
OptionalAttr<XeVM_LoadCacheControlAttr>:$cache_control)> {
let summary = "Prefetch data into a cache subsystem.";
let description = [{
Work-item issues a prefetch from global memory to cache:
- $ptr - memory pointer.
- $addrspace - address space of a pointer, must be generic or global.
+ $ptr - LLVM pointer with address space. Address space must be 1 (global)
+ or 4 (generic)
$cache_control - specify caching options
}];
let assemblyFormat = [{
diff --git a/mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp
index c3f363563b38e..d10fa5cdbc2f5 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp
@@ -316,10 +316,11 @@ LogicalResult MMAOp::verify() {
}
LogicalResult PrefetchOp::verify() {
- auto addrSpace = getAddrspace();
- if (addrSpace != AddrSpace::GLOBAL && addrSpace != AddrSpace::GENERIC) {
- return emitOpError("address space must be global or generic");
- }
+ auto ptrTy = mlir::dyn_cast<LLVM::LLVMPointerType>(getOperand().getType());
+ auto addrSpace = ptrTy.getAddressSpace();
+ if (addrSpace != 1 && addrSpace != 4)
+ return emitOpError(
+ "LLVM pointer type address space must be 1 (global) or 4 (generic)");
return success();
}
diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index 3a17926f4778c..174f925fea317 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -1893,9 +1893,9 @@ llvm.func @inlineAsmMustTail(%arg0: i32, %arg1 : !llvm.ptr) {
// -----
-llvm.func @invalid_xevm_prefetch(%arg0: !llvm.ptr<1>) {
- // expected-error at +1 {{op address space must be global or generic}}
- xevm.prefetch %arg0 <{addrspace = #xevm.addr_space<private>, cache_control = #xevm.load_cache_control<L1uc_L2uc_L3uc>}> : (!llvm.ptr<1>)
+llvm.func @invalid_xevm_prefetch(%arg0: !llvm.ptr) {
+ // expected-error at +1 {{LLVM pointer type address space must be 1 (global) or 4 (generic)}}
+ xevm.prefetch %arg0 <{cache_control = #xevm.load_cache_control<L1uc_L2uc_L3uc>}> : (!llvm.ptr)
llvm.return
}
diff --git a/mlir/test/Dialect/LLVMIR/xevm.mlir b/mlir/test/Dialect/LLVMIR/xevm.mlir
index 3ea6768345f6e..bf10bd45d58a0 100644
--- a/mlir/test/Dialect/LLVMIR/xevm.mlir
+++ b/mlir/test/Dialect/LLVMIR/xevm.mlir
@@ -63,10 +63,9 @@ func.func @memfence() {
// CHECK: func.func @prefetch(%[[ARG0:.*]]: !llvm.ptr<1>)
func.func @prefetch(%ptr: !llvm.ptr<1>) {
// CHECK: xevm.prefetch %[[ARG0]]
- // CHECK-DAG: addrspace = #xevm.addr_space<global>
// CHECK-DAG: cache_control = #xevm.load_cache_control<L1uc_L2uc_L3uc>
// CHECK: (!llvm.ptr<1>)
- xevm.prefetch %ptr <{addrspace = #xevm.addr_space<global>, cache_control = #xevm.load_cache_control<L1uc_L2uc_L3uc>}> : (!llvm.ptr<1>)
+ xevm.prefetch %ptr <{cache_control = #xevm.load_cache_control<L1uc_L2uc_L3uc>}> : (!llvm.ptr<1>)
return
}
More information about the Mlir-commits
mailing list