[Mlir-commits] [mlir] b9b2661 - [MLIR][Dialect] Add XeVM dialect (#144811)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Jul 4 05:51:25 PDT 2025
Author: Sang Ik Lee
Date: 2025-07-04T13:51:21+01:00
New Revision: b9b2661f72ac5f9d4f23d9bb83131aa3d46020b9
URL: https://github.com/llvm/llvm-project/commit/b9b2661f72ac5f9d4f23d9bb83131aa3d46020b9
DIFF: https://github.com/llvm/llvm-project/commit/b9b2661f72ac5f9d4f23d9bb83131aa3d46020b9.diff
LOG: [MLIR][Dialect] Add XeVM dialect (#144811)
XeVM is a new dialect that is designed to exposes Intel GPU hardware
features in a future proof way.
RFC is here:
https://discourse.llvm.org/t/mlir-rfc-dialect-xevm-proposal-for-new-xevm-dialect/86955
In short, XeVM is the nvvm or rocdl for Intel GPU.
The RFC includes background and challenges that XeVM is designed to
solve.
And also lists plan for upstreaming at the end.
This PR is the first of a series and it covers dialect definition and op
tests only.
Co-authored-by: Artem Kroviakov <artem.kroviakov at intel.com>
Added:
mlir/include/mlir/Dialect/LLVMIR/XeVMDialect.h
mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td
mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp
mlir/test/Dialect/LLVMIR/xevm.mlir
Modified:
mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt
mlir/include/mlir/InitAllDialects.h
mlir/lib/Dialect/LLVMIR/CMakeLists.txt
mlir/test/Dialect/LLVMIR/invalid.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt b/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt
index 9c5bbae1022f7..cfad07e57021f 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt
@@ -87,3 +87,13 @@ mlir_tablegen(VCIXConversions.inc -gen-llvmir-conversions)
mlir_tablegen(VCIXOpsAttributes.h.inc -gen-attrdef-decls -attrdefs-dialect=vcix)
mlir_tablegen(VCIXOpsAttributes.cpp.inc -gen-attrdef-defs -attrdefs-dialect=vcix)
add_public_tablegen_target(MLIRVCIXConversionsIncGen)
+
+add_mlir_dialect(XeVMOps xevm)
+add_mlir_doc(XeVMOps XeVMDialect Dialects/ -gen-dialect-doc -dialect=xevm)
+set(LLVM_TARGET_DEFINITIONS XeVMOps.td)
+mlir_tablegen(XeVMConversions.inc -gen-llvmir-conversions)
+mlir_tablegen(XeVMOpsEnums.h.inc -gen-enum-decls)
+mlir_tablegen(XeVMOpsEnums.cpp.inc -gen-enum-defs)
+mlir_tablegen(XeVMOpsAttributes.h.inc -gen-attrdef-decls -attrdefs-dialect=xevm)
+mlir_tablegen(XeVMOpsAttributes.cpp.inc -gen-attrdef-defs -attrdefs-dialect=xevm)
+add_public_tablegen_target(MLIRXeVMConversionsIncGen)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/XeVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/XeVMDialect.h
new file mode 100644
index 0000000000000..85bc19ad59794
--- /dev/null
+++ b/mlir/include/mlir/Dialect/LLVMIR/XeVMDialect.h
@@ -0,0 +1,28 @@
+//===-- XeVMDialect.h - MLIR XeVM target definitions ------------*- C++ -*-===//
+//
+// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_LLVMIR_XEVMDIALECT_H_
+#define MLIR_DIALECT_LLVMIR_XEVMDIALECT_H_
+
+#include "mlir/Bytecode/BytecodeOpInterface.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/Target/LLVMIR/ModuleTranslation.h"
+
+#include "mlir/Dialect/LLVMIR/XeVMOpsEnums.h.inc"
+
+#define GET_ATTRDEF_CLASSES
+#include "mlir/Dialect/LLVMIR/XeVMOpsAttributes.h.inc"
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/LLVMIR/XeVMOps.h.inc"
+
+#include "mlir/Dialect/LLVMIR/XeVMOpsDialect.h.inc"
+
+#endif /* MLIR_DIALECT_LLVMIR_XEVMDIALECT_H_ */
diff --git a/mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td
new file mode 100644
index 0000000000000..b5e81d595d74c
--- /dev/null
+++ b/mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td
@@ -0,0 +1,561 @@
+//===-- XeVMOps.td - XeVM dialect definition ---------------*- tablegen -*-===//
+//
+// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#ifndef XEVMIR_OPS
+#define XEVMIR_OPS
+
+include "mlir/Dialect/GPU/IR/CompilationAttrInterfaces.td"
+include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
+
+include "mlir/IR/OpBase.td"
+include "mlir/IR/EnumAttr.td"
+
+def XeVM_Dialect : Dialect {
+ let name = "xevm";
+ let cppNamespace = "::mlir::xevm";
+ let summary = "The XeVM dialect that extends LLVM dialect and models Intel "
+ "GPU's hardware features.";
+ let description = [{
+ The XeVM dialect is extension to the LLVM dialect that models hardware
+ features of Intel GPUs. The dialect is designed to work with the Xe
+ architecture for Intel GPUs, supporting advanced operations like 2D block
+ loads, stores, prefetch and matrix multiply-add (MMA) operations.
+ }];
+ let dependentDialects = ["LLVM::LLVMDialect"];
+
+ let extraClassDeclaration = [{
+ /// Get the name for the attribute used to specify cache control
+ /// decorations.
+ static constexpr ::llvm::StringRef getCacheControlsAttrName() {
+ return ::llvm::StringLiteral("xevm.DecorationCacheControl");
+ }
+ }];
+
+ let useDefaultAttributePrinterParser = 1;
+}
+
+class XeVM_Attr<string attrName, string attrMnemonic, list<Trait> traits = []>
+ : AttrDef<XeVM_Dialect, attrName, traits> {
+ let mnemonic = attrMnemonic;
+}
+
+class XeVM_Op<string mnemonic, list<Trait> traits = []>
+ : LLVM_OpBase<XeVM_Dialect, mnemonic, traits> {
+
+ code extraBaseClassDeclaration = [{
+ void printProperties(::mlir::MLIRContext *ctx,
+ ::mlir::OpAsmPrinter &p, const Properties &prop,
+ ::mlir::ArrayRef<::llvm::StringRef> elidedProps) {
+ Attribute propAttr = getPropertiesAsAttr(ctx, prop);
+ if (propAttr)
+ p << "<" << propAttr << ">";
+ }
+
+ static ::mlir::ParseResult parseProperties(::mlir::OpAsmParser &parser,
+ ::mlir::OperationState &result) {
+ if (mlir::succeeded(parser.parseOptionalLess())) {
+ if (parser.parseAttribute(result.propertiesAttr) || parser.parseGreater())
+ return failure();
+ }
+ return success();
+ }
+
+ }];
+}
+
+def XeVM_ElemType : AnyTypeOf<[AnyI8, AnyI16, AnyI32, F32, TF32, F16, BF16]>;
+
+//===----------------------------------------------------------------------===//
+// XeVM Load Cache Control
+// L1, L2, L3 - cache levels
+// uc - uncached
+// c - cached
+// s - streaming
+// ir - invalidated after read
+//===----------------------------------------------------------------------===//
+
+def LoadCacheControl_L1uc_L2uc_L3uc
+ : I32EnumAttrCase<"L1UC_L2UC_L3UC", 1, "L1uc_L2uc_L3uc">;
+def LoadCacheControl_L1uc_L2uc_L3c
+ : I32EnumAttrCase<"L1UC_L2UC_L3C", 2, "L1uc_L2uc_L3c">;
+def LoadCacheControl_L1uc_L2c_L3uc
+ : I32EnumAttrCase<"L1UC_L2C_L3UC", 3, "L1uc_L2c_L3uc">;
+def LoadCacheControl_L1uc_L2c_L3c
+ : I32EnumAttrCase<"L1UC_L2C_L3C", 4, "L1uc_L2c_L3c">;
+def LoadCacheControl_L1c_L2uc_L3uc
+ : I32EnumAttrCase<"L1C_L2UC_L3UC", 5, "L1c_L2uc_L3uc">;
+def LoadCacheControl_L1c_L2uc_L3c
+ : I32EnumAttrCase<"L1C_L2UC_L3C", 6, "L1c_L2uc_L3c">;
+def LoadCacheControl_L1c_L2c_L3uc
+ : I32EnumAttrCase<"L1C_L2C_L3UC", 7, "L1c_L2c_L3uc">;
+def LoadCacheControl_L1c_L2c_L3c
+ : I32EnumAttrCase<"L1C_L2C_L3C", 8, "L1c_L2c_L3c">;
+def LoadCacheControl_L1s_L2uc_L3uc
+ : I32EnumAttrCase<"L1S_L2UC_L3UC", 9, "L1s_L2uc_L3uc">;
+def LoadCacheControl_L1s_L2uc_L3c
+ : I32EnumAttrCase<"L1S_L2UC_L3C", 10, "L1s_L2uc_L3c">;
+def LoadCacheControl_L1s_L2c_L3uc
+ : I32EnumAttrCase<"L1S_L2C_L3UC", 11, "L1s_L2c_L3uc">;
+def LoadCacheControl_L1s_L2c_L3c
+ : I32EnumAttrCase<"L1S_L2C_L3C", 12, "L1s_L2c_L3c">;
+def LoadCacheControlInvalidateRead
+ : I32EnumAttrCase<"INVALIDATE_READ", 13, "ir">;
+
+def XeVM_LoadCacheControl
+ : I32EnumAttr<
+ "LoadCacheControl", "XeVM load ops cache control",
+ [LoadCacheControl_L1uc_L2uc_L3uc, LoadCacheControl_L1uc_L2uc_L3c,
+ LoadCacheControl_L1uc_L2c_L3uc, LoadCacheControl_L1uc_L2c_L3c,
+ LoadCacheControl_L1c_L2uc_L3uc, LoadCacheControl_L1c_L2uc_L3c,
+ LoadCacheControl_L1c_L2c_L3uc, LoadCacheControl_L1c_L2c_L3c,
+ LoadCacheControl_L1s_L2uc_L3uc, LoadCacheControl_L1s_L2uc_L3c,
+ LoadCacheControl_L1s_L2c_L3uc, LoadCacheControl_L1s_L2c_L3c,
+ LoadCacheControlInvalidateRead]> {
+ let cppNamespace = "::mlir::xevm";
+ let genSpecializedAttr = 0;
+}
+
+def XeVM_LoadCacheControlAttr
+ : EnumAttr<XeVM_Dialect, XeVM_LoadCacheControl, "load_cache_control"> {
+ let summary = [{Describe the cache settings for load operators}];
+ let assemblyFormat = "`<` $value `>`";
+}
+
+//===----------------------------------------------------------------------===//
+// XeVM Store Cache Control
+// L1, L2, L3 - cache levels
+// uc - uncached
+// wb - write-back
+// wt - write-through
+// s - streaming
+//===----------------------------------------------------------------------===//
+
+def StoreCacheControl_L1uc_L2uc_L3uc
+ : I32EnumAttrCase<"L1UC_L2UC_L3UC", 1, "L1uc_L2uc_L3uc">;
+def StoreCacheControl_L1uc_L2uc_L3wb
+ : I32EnumAttrCase<"L1UC_L2UC_L3WB", 2, "L1uc_L2uc_L3wb">;
+def StoreCacheControl_L1uc_L2wb_L3uc
+ : I32EnumAttrCase<"L1UC_L2WB_L3UC", 3, "L1uc_L2wb_L3uc">;
+def StoreCacheControl_L1uc_L2wb_L3wb
+ : I32EnumAttrCase<"L1UC_L2WB_L3WB", 4, "L1uc_L2wb_L3wb">;
+def StoreCacheControl_L1wt_L2uc_L3uc
+ : I32EnumAttrCase<"L1WT_L2UC_L3UC", 5, "L1wt_L2uc_L3uc">;
+def StoreCacheControl_L1wt_L2uc_L3wb
+ : I32EnumAttrCase<"L1WT_L2UC_L3WB", 6, "L1wt_L2uc_L3wb">;
+def StoreCacheControl_L1wt_L2wb_L3uc
+ : I32EnumAttrCase<"L1WT_L2WB_L3UC", 7, "L1wt_L2wb_L3uc">;
+def StoreCacheControl_L1wt_L2wb_L3wb
+ : I32EnumAttrCase<"L1WT_L2WB_L3WB", 8, "L1wt_L2wb_L3wb">;
+def StoreCacheControl_L1s_L2uc_L3uc
+ : I32EnumAttrCase<"L1S_L2UC_L3UC", 9, "L1s_L2uc_L3uc">;
+def StoreCacheControl_L1s_L2uc_L3wb
+ : I32EnumAttrCase<"L1S_L2UC_L3WB", 10, "L1s_L2uc_L3wb">;
+def StoreCacheControl_L1s_L2wb_L3uc
+ : I32EnumAttrCase<"L1S_L2WB_L3UC", 11, "L1s_L2wb_L3uc">;
+def StoreCacheControl_L1s_L2wb_L3wb
+ : I32EnumAttrCase<"L1S_L2WB_L3WB", 12, "L1s_L2wb_L3wb">;
+def StoreCacheControl_L1wb_L2uc_L3uc
+ : I32EnumAttrCase<"L1WB_L2UC_L3UC", 13, "L1wb_L2uc_L3uc">;
+def StoreCacheControl_L1wb_L2wb_L3uc
+ : I32EnumAttrCase<"L1WB_L2WB_L3UC", 14, "L1wb_L2wb_L3uc">;
+def StoreCacheControl_L1wb_L2uc_L3wb
+ : I32EnumAttrCase<"L1WB_L2UC_L3WB", 15, "L1wb_L2uc_L3wb">;
+
+def XeVM_StoreCacheControl
+ : I32EnumAttr<
+ "StoreCacheControl", "XeVM store ops cache control",
+ [StoreCacheControl_L1uc_L2uc_L3uc, StoreCacheControl_L1uc_L2uc_L3wb,
+ StoreCacheControl_L1uc_L2wb_L3uc, StoreCacheControl_L1uc_L2wb_L3wb,
+ StoreCacheControl_L1wt_L2uc_L3uc, StoreCacheControl_L1wt_L2uc_L3wb,
+ StoreCacheControl_L1wt_L2wb_L3uc, StoreCacheControl_L1wt_L2wb_L3wb,
+ StoreCacheControl_L1s_L2uc_L3uc, StoreCacheControl_L1s_L2uc_L3wb,
+ StoreCacheControl_L1s_L2wb_L3uc, StoreCacheControl_L1s_L2wb_L3wb,
+ StoreCacheControl_L1wb_L2uc_L3uc, StoreCacheControl_L1wb_L2wb_L3uc,
+ StoreCacheControl_L1wb_L2uc_L3wb]> {
+ let cppNamespace = "::mlir::xevm";
+ let genSpecializedAttr = 0;
+}
+
+def XeVM_StoreCacheControlAttr
+ : EnumAttr<XeVM_Dialect, XeVM_StoreCacheControl, "store_cache_control"> {
+ let summary = [{Describe the cache settings for store operators}];
+ let assemblyFormat = "`<` $value `>`";
+}
+
+def XeVM_BlockLoad2dOp
+ : XeVM_Op<"blockload2d">,
+ Results<(outs FixedVectorOfRankAndType<[1], [XeVM_ElemType]>:$res)>,
+ Arguments<(ins Arg<LLVM_AnyPointer, "", [MemRead]>:$ptr, I32:$base_width,
+ I32:$base_height, I32:$base_pitch, I32:$x, I32:$y,
+ I32Attr:$elem_size_in_bits, I32Attr:$tile_width, I32Attr:$tile_height,
+ I32Attr:$v_blocks, I1Attr:$transpose, I1Attr:$pack_register,
+ OptionalAttr<XeVM_LoadCacheControlAttr>:$cache_control)> {
+
+ let summary = "2D block load";
+
+ let description = [{
+ The `xevm.blockload2d` operation loads a two dimensional matrix tile
+ from a base matrix residing in global memory. The parameters are:
+ $ptr - the base address of the base matrix containing the tile to load
+ $base_width - the width of the base matrix in number of bytes.
+ $base_height - the number of rows in the base matrix
+ $base_pitch - the physical stride between the first columns of the current
+ row and the subsequent row in number of bytes.
+ $x, $y, $tile_width, $tile_height - the starting offsets and shape of
+ the tile to load in number of elements.
+ $elem_size_in_bits - the size in bits of the matrix element type
+ - 32 for f32, tf32
+ - 16 for f16, int16, bf16
+ - 8 for int8
+ $v_blocks - number of consecutive tiles in innermost dimension direction to load
+ $transpose - transpose the tile in registers (useful for 32 bit element type)
+ $pack_register - pack element types narrower than register bit width.
+ [M, N] => [M/factor, N, factor] where factor is register_size_in_bits / elem_size_in_bits
+ $cache_control - an enumerator that sets the cache behaviour
+
+ Notes:
+ - the $transpose and $pack_register parameters are mutual exclusive
+ - transposing the tile loaded is used for A matrix in backward path or used for the B matrix operand
+ (D = C + A * B), where A has row-major layout and B should have column-major layout in memory.
+ - if the tile loaded contains out of bound elements of the matrix, they are filled with 0.
+
+ Example:
+ ```mlir
+ %base_width_a = arith.constant 32 : i32
+ %base_height_a = arith.constant 8 : i32
+ %base_pitch_a = arith.constant 32 : i32
+ %x = arith.constant 0 : i32
+ %y = arith.constant 0 : i32
+ %loaded_a = xevm.blockload2d %src, %base_width_a, %base_height_a, %base_pitch_a, %x, %y
+ <{elem_size_in_bits=16 : i32, tile_width=16 : i32, tile_height=8 : i32,
+ v_blocks=1 : i32, transpose=false : i32, pack_register=false,
+ cache_control=#xevm.load_cache_control<Default>}>
+ : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi16>
+ ```
+ }];
+
+ let assemblyFormat = [{
+ operands prop-dict attr-dict `:` functional-type(operands, results)
+ }];
+
+ let extraClassDeclaration = extraBaseClassDeclaration#[{
+ }];
+
+ let hasVerifier = 1;
+}
+
+def XeVM_BlockStore2dOp
+ : XeVM_Op<"blockstore2d">,
+ Arguments<(ins Arg<LLVM_AnyPointer, "", [MemWrite]>:$ptr, I32:$base_width,
+ I32:$base_height, I32:$base_pitch, I32:$x, I32:$y,
+ I32Attr:$elem_size_in_bits, I32Attr:$tile_width, I32Attr:$tile_height,
+ FixedVectorOfRankAndType<[1], [XeVM_ElemType]>:$stored_val,
+ OptionalAttr<XeVM_StoreCacheControlAttr>:$cache_control)> {
+
+ let summary = "2D block store";
+
+ let description = [{
+ The `xevm.blockstore2d` operation stores a two dimensional tile into a
+ larger matrix residing in global memory. The parameters are:
+ $ptr - the base address of the target matrix where to store the tile
+ $base_width - the width of the base matrix in number of bytes.
+ $base_height - the number of rows in the base matrix
+ $base_pitch - the physical stride between the first columns of the current
+ row and the subsequent row in number of bytes.
+ $x, $y, $tile_width, $tile_height - the starting offsets and shape of the tile to store
+ in number of elements.
+ $elem_size_in_bits - the size in bits of the matrix element
+ - 32 for f32, tf32
+ - 16 for f16, int16, bf16
+ - 8 for int8
+ $cache_control - an enumerator that sets the cache behaviour
+ $stored_val - the tile to store
+
+ Example:
+ ```mlir
+ %base_width_c = arith.constant 64 : i32
+ %base_height_c = arith.constant 8 : i32
+ %base_pitch_c = arith.constant 64 : i32
+ %x = arith.constant 0 : i32
+ %y = arith.constant 0 : i32
+ xevm.blockstore2d %dst, %base_width_c, %base_height_c, %base_pitch_c, %x, %y, %src
+ <{elem_size_in_bits=32 : i32, tile_width=16 : i32, tile_height=8 : i32,
+ cache_control=#xevm.load_cache_control<Default>}>
+ : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi32>)
+ ```
+ }];
+
+ let assemblyFormat = [{
+ operands prop-dict attr-dict `:` `(` type(operands) `)`
+ }];
+
+ let extraClassDeclaration = extraBaseClassDeclaration#[{
+ /// Default value for v_blocks is 1.
+ constexpr uint32_t getVBlocks() {
+ return 1;
+ }
+ }];
+
+ let hasVerifier = 1;
+}
+
+def MemScopeLane : I32EnumAttrCase<"LANE", 0, "lane">;
+def MemScopeSubgroup : I32EnumAttrCase<"SUBGROUP", 1, "subgroup">;
+def MemScopeWorkgroup : I32EnumAttrCase<"WORKGROUP", 2, "workgroup">;
+def MemScopeCluster : I32EnumAttrCase<"CLUSTER", 3, "cluster">;
+def MemScopeDevice : I32EnumAttrCase<"DEVICE", 4, "device">;
+def MemScopeSystem : I32EnumAttrCase<"SYSTEM", 5, "system">;
+
+def XeVM_MemScope
+ : I32EnumAttr<"MemScope", "XeVM memory scope",
+ [MemScopeLane, MemScopeSubgroup, MemScopeWorkgroup,
+ MemScopeCluster, MemScopeDevice, MemScopeSystem]> {
+ let genSpecializedAttr = 0;
+ let cppNamespace = "::mlir::xevm";
+}
+def XeVM_MemScopeAttr : EnumAttr<XeVM_Dialect, XeVM_MemScope, "mem_scope"> {
+ let summary = [{Describe memory scopes}];
+ let assemblyFormat = "`<` $value `>`";
+}
+
+def AddrSpacePrivate : I32EnumAttrCase<"PRIVATE", 0, "private">;
+def AddrSpaceGlobal : I32EnumAttrCase<"GLOBAL", 1, "global">;
+def AddrSpaceConstant : I32EnumAttrCase<"CONSTANT", 2, "constant">;
+def AddrSpaceShared : I32EnumAttrCase<"SHARED", 3, "shared">;
+def AddrSpaceGeneric : I32EnumAttrCase<"GENERIC", 4, "generic">;
+
+def XeVM_AddrSpace
+ : I32EnumAttr<"AddrSpace", "Address spaces",
+ [AddrSpacePrivate, AddrSpaceGlobal, AddrSpaceConstant,
+ AddrSpaceShared, AddrSpaceGeneric]> {
+ let genSpecializedAttr = 0;
+ let cppNamespace = "mlir::xevm";
+}
+def XeVM_AddrSpaceAttr : EnumAttr<XeVM_Dialect, XeVM_AddrSpace, "addr_space"> {
+ let summary = [{Describe address spaces}];
+ let assemblyFormat = "`<` $value `>`";
+}
+
+def XeVM_MemfenceOp
+ : XeVM_Op<"memfence">,
+ Arguments<(ins XeVM_MemScopeAttr:$scope,
+ DefaultValuedAttr<XeVM_AddrSpaceAttr,
+ "mlir::xevm::AddrSpace::GENERIC">:$addrspace)> {
+ let summary = "Work-item's memory fence.";
+ let description = [{
+ This operation ensures that all prior memory accesses of this
+ work-item to `addrspace` are visible to all other work-items in `scope`.
+ Parameters description:
+ $scope - specify the memory scope at which all other work-items should observe
+ memory operations prior to the fence.
+ $addrspace - specify the address space of work-item's memory accesses
+ to be affected by the fence.
+ }];
+ let assemblyFormat = [{prop-dict attr-dict}];
+
+ let extraClassDeclaration = extraBaseClassDeclaration#[{
+ }];
+}
+
+def XeVM_PrefetchOp
+ : XeVM_Op<"prefetch">,
+ Arguments<(ins Arg<AnyTypeOf<[LLVM_PointerInAddressSpace<1>,
+ LLVM_PointerInAddressSpace<4>]>>:$ptr,
+ OptionalAttr<XeVM_LoadCacheControlAttr>:$cache_control)> {
+ let summary = "Prefetch data into a cache subsystem.";
+ let description = [{
+ Work-item issues a prefetch from global memory to cache:
+ $ptr - LLVM pointer with address space. Address space must be 1 (global)
+ or 4 (generic)
+ $cache_control - specify caching options
+ }];
+ let assemblyFormat = [{
+ operands prop-dict attr-dict `:` `(` type(operands) `)`
+ }];
+
+ let extraClassDeclaration = extraBaseClassDeclaration#[{
+ }];
+}
+
+def XeVM_BlockPrefetch2dOp
+ : XeVM_Op<"blockprefetch2d">,
+ Arguments<(ins Arg<LLVM_AnyPointer, "", [MemRead]>:$ptr, I32:$base_width,
+ I32:$base_height, I32:$base_pitch, I32:$x, I32:$y,
+ I32Attr:$elem_size_in_bits, I32Attr:$tile_width, I32Attr:$tile_height,
+ I32Attr:$v_blocks,
+ OptionalAttr<XeVM_LoadCacheControlAttr>:$cache_control)> {
+
+ let summary = "2D block prefetch";
+
+ let description = [{
+ The `xevm.blockprefetch2d` operation prefetches a two dimensional tile
+ from a larger base matrix residing in global memory. The parameters are:
+ $ptr - the base address of the base matrix containing the tile to prefetch
+ $base_width - the width of the base matrix in number of bytes.
+ $base_height - the number of rows in the base matrix
+ $base_pitch - the physical stride between the first columns of the current
+ row and the subsequent row in number of bytes.
+ $x, $y, $tile_width, $tile_height - the starting offsets and shape of tile
+ to prefetch in number of elements.
+ $elem_size_in_bits - the size in bits of the matrix element
+ - 32 for f32, bf32
+ - 16 for f16, int16, bf16
+ - 8 for int8, int4, int2
+ $v_blocks - number of tiles in innermost dimension direction to prefetch
+ $cache_control - an enumerator that sets the cache behaviour
+
+ Example:
+ ```mlir
+ xevm.blockprefetch2d %ptr, %base_width, %base_height, %base_pitch, %x, %y
+ <{elem_size_in_bits=8 : i32, tile_width=32 : i32, tile_height=8 : i32,
+ v_blocks=1 : i32, cache_control=#xevm.load_cache_control<L1uc_L2uc_L3uc>}>
+ : (!llvm.ptr<1>, i32, i32, i32, i32, i32)
+ ```
+ }];
+
+ let assemblyFormat = [{
+ operands prop-dict attr-dict `:` `(` type(operands) `)`
+ }];
+
+ let extraClassDeclaration = extraBaseClassDeclaration#[{
+ }];
+
+ let hasVerifier = 1;
+}
+
+def XeVM_MatrixElemType
+ : AnyTypeOf<[AnyI8, AnyI16, AnyI32, F32, TF32, F16, BF16]>;
+
+/// Enum attribute of the
diff erent element types.
+def XeVM_ET_BF16 : I32EnumAttrCase<"BF16", 8, "bf16">;
+def XeVM_ET_F16 : I32EnumAttrCase<"F16", 9, "f16">;
+def XeVM_ET_S8 : I32EnumAttrCase<"S8", 10, "s8">;
+def XeVM_ET_U8 : I32EnumAttrCase<"U8", 11, "u8">;
+def XeVM_ET_S4 : I32EnumAttrCase<"S4", 12, "s4">;
+def XeVM_ET_U4 : I32EnumAttrCase<"U4", 13, "u4">;
+def XeVM_ET_TF32 : I32EnumAttrCase<"TF32", 14, "tf32">;
+def XeVM_ET_F32 : I32EnumAttrCase<"F32", 15, "f32">;
+def XeVM_ET_S32 : I32EnumAttrCase<"S32", 16, "s32">;
+
+def XeVM_ElemTypeAttr : I32EnumAttr<"ElemType", "XeVM element type",
+ [XeVM_ET_BF16, XeVM_ET_F16, XeVM_ET_S8,
+ XeVM_ET_U8, XeVM_ET_S4, XeVM_ET_U4,
+ XeVM_ET_TF32, XeVM_ET_F32, XeVM_ET_S32]> {
+ let cppNamespace = "::mlir::xevm";
+}
+
+def XeVM_MMAShapeAttr : XeVM_Attr<"MMAShape", "mma_shape"> {
+ let description = [{
+ MMA operation is represented as D=AxB+C, where
+ A has the shape MxK.
+ B has the shape KxN.
+ D and C have the shape MxN.
+ This attribute encodes the shape of all matrices that participate in MMA.
+ }];
+ let parameters = (ins "int":$m, "int":$n, "int":$k);
+ let assemblyFormat = "`<` struct(params) `>`";
+}
+
+def XeVM_MMATypesAttr : XeVM_Attr<"MMATypes", "mma_types"> {
+ let parameters = (ins "xevm::ElemType":$d, "xevm::ElemType":$a,
+ "xevm::ElemType":$b, OptionalParameter<"xevm::ElemType">:$c);
+ let assemblyFormat = "`<` struct(params) `>`";
+}
+
+def XeVM_MMAOp
+ : XeVM_Op<"mma">,
+ Results<(outs FixedVectorOfRankAndType<[1], [XeVM_MatrixElemType]>:$d)>,
+ Arguments<(ins FixedVectorOfRankAndType<[1], [XeVM_MatrixElemType]>:$a,
+ FixedVectorOfRankAndType<[1], [XeVM_MatrixElemType]>:$b,
+ Optional<FixedVectorOfRankAndType<[1], [XeVM_MatrixElemType]>>:$c,
+ XeVM_MMAShapeAttr:$shape, XeVM_MMATypesAttr:$types)> {
+
+ let summary = "Subgroup matrix multiply-add";
+
+ let description = [{
+ The `xevm.mma` is a cooperative operation where all threads/lanes in
+ a subgroup participates and carries out matrix multiplication plus accumulation:
+
+ D = C + A x B
+
+ where the A, B, C input matrices and the result D have shapes:
+ D : MxN
+ C : MxN
+ A : MxK
+ B : KxN
+
+ Parameters:
+ `a` - vector of matrix A elements.
+ `b` - vector of matrix B elements.
+ `c` - (optional) vector of matrix C elements.
+ `shape` - the shape of the matrices, specified as `M`, `N`, and `K` values.
+ `types` - the data types of the matrices, specified as `D`, `A`, `B`, and optionally `C`.
+
+ Example:
+ ```mlir
+ %d = xevm.mma %a, %b, %c { shape=<m=8, n=16, k=16>, types=<d=f32, a=f16, b=f16, c=f32> }
+ : (vector<8xi16>, vector<8xi32>, vector<8xf32>) -> vector<8xf32>
+ ```
+ }];
+
+ let assemblyFormat = [{
+ $a `,` $b (`,` $c^)? ` `
+ `{`
+ `shape` `=` $shape `,`
+ `types` `=` $types
+ `}` attr-dict `:` functional-type(operands, results)
+ }];
+
+ let hasVerifier = 1;
+}
+
+//===----------------------------------------------------------------------===//
+// XeVM target attribute.
+//===----------------------------------------------------------------------===//
+
+def XeVM_TargetAttr : XeVM_Attr<"XeVMTarget", "target"> {
+ let description = [{
+ GPU target attribute for controlling compilation of Intel GPU targets. All
+ parameters decay into default values if not present.
+
+ Examples:
+
+ 1. Target with default values.
+ ```
+ gpu.module @mymodule [#xevm.target] attributes {...} {
+ ...
+ }
+ ```
+ }];
+ let parameters =
+ (ins DefaultValuedParameter<"int", "2",
+ "Optimization level to apply.">:$O,
+ StringRefParameter<"Target triple.",
+ "\"spirv64-unknown-unknown\"">:$triple,
+ StringRefParameter<"Target chip.", "\"bmg\"">:$chip,
+ OptionalParameter<"::mlir::DictionaryAttr",
+ "Target specific flags.">:$flags,
+ OptionalParameter<"::mlir::ArrayAttr",
+ "Files to link to the LLVM module.">:$linkFiles);
+ let assemblyFormat = [{
+ (`<` struct($O, $triple, $chip, $flags, $linkFiles)^ `>`)?
+ }];
+ let builders = [AttrBuilder<
+ (ins CArg<"int", "2">:$optLevel,
+ CArg<"::llvm::StringRef", "\"spirv64-unknown-unknown\"">:$triple,
+ CArg<"::llvm::StringRef", "\"bmg\"">:$chip,
+ CArg<"::mlir::DictionaryAttr", "nullptr">:$targetFlags,
+ CArg<"::mlir::ArrayAttr", "nullptr">:$linkFiles),
+ [{
+ return Base::get($_ctxt, optLevel, triple, chip, targetFlags, linkFiles);
+ }]>];
+ let skipDefaultBuilders = 1;
+ let genVerifyDecl = 1;
+}
+
+#endif // XEVMIR_OPS
diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index 261b0e00bdf86..c6fcf1a0d510b 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -46,6 +46,7 @@
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
#include "mlir/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.h"
+#include "mlir/Dialect/LLVMIR/XeVMDialect.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/AllInterfaces.h"
#include "mlir/Dialect/Linalg/Transforms/RuntimeOpVerification.h"
@@ -152,7 +153,8 @@ inline void registerAllDialects(DialectRegistry ®istry) {
ub::UBDialect,
vector::VectorDialect,
x86vector::X86VectorDialect,
- xegpu::XeGPUDialect>();
+ xegpu::XeGPUDialect,
+ xevm::XeVMDialect>();
// clang-format on
// Register all external models.
diff --git a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt
index d83fd3800eb91..67081ca61e6e5 100644
--- a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt
+++ b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt
@@ -110,3 +110,25 @@ add_mlir_dialect_library(MLIRVCIXDialect
MLIRLLVMDialect
MLIRSideEffectInterfaces
)
+
+add_mlir_dialect_library(MLIRXeVMDialect
+ IR/XeVMDialect.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/LLVMIR
+
+ DEPENDS
+ MLIRGPUCompilationAttrInterfacesIncGen
+ MLIRXeVMOpsIncGen
+ MLIRXeVMConversionsIncGen
+ intrinsics_gen
+
+ LINK_COMPONENTS
+ AsmParser
+ Core
+
+ LINK_LIBS PUBLIC
+ MLIRIR
+ MLIRLLVMDialect
+ MLIRSideEffectInterfaces
+)
diff --git a/mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp
new file mode 100644
index 0000000000000..9e497829ba723
--- /dev/null
+++ b/mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp
@@ -0,0 +1,366 @@
+//===-- XeVMDialect.cpp - XeVM dialect registration -------------*- C++ -*-===//
+//
+// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/LLVMIR/XeVMDialect.h"
+#include "mlir/Dialect/GPU/IR/CompilationInterfaces.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/IR/DialectImplementation.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/FileSystem.h"
+#include "llvm/Support/MathExtras.h"
+
+using namespace mlir;
+using namespace mlir::xevm;
+
+#include "mlir/Dialect/LLVMIR/XeVMOpsDialect.cpp.inc"
+#include "mlir/Dialect/LLVMIR/XeVMOpsEnums.cpp.inc"
+
+namespace {
+static constexpr uint32_t subgroupSize = 16;
+
+template <typename Op>
+LogicalResult verifyMatrixInput(Op op) {
+ static_assert(llvm::is_one_of<Op, BlockLoad2dOp, BlockStore2dOp,
+ BlockPrefetch2dOp>::value,
+ "Unexpected template parameter");
+
+ std::optional<int64_t> width = getConstantIntValue(op.getBaseWidth());
+ std::optional<int64_t> pitch = getConstantIntValue(op.getBasePitch());
+ if (pitch && width && *pitch < *width)
+ return op->emitOpError(
+ "4th operand (base pitch) should be >= 2nd operand (base width)");
+
+ uint32_t elemSize = op.getElemSizeInBits();
+ if (elemSize < 8 || !llvm::isPowerOf2_32(elemSize) || elemSize > 32)
+ return op->emitOpError("expecting 'elem_size_in_bits' to be 8, 16, or 32");
+
+ uint32_t tileHeight = op.getTileHeight();
+ if (tileHeight > 32 || !llvm::isPowerOf2_32(tileHeight))
+ return op->emitOpError("expecting tile_height to be 1, 2, 4, 8, 16, or 32");
+
+ uint32_t vBlocks = op.getVBlocks();
+ if (vBlocks > 8 || !llvm::isPowerOf2_32(vBlocks))
+ return op->emitOpError("expecting v_blocks to be 1, 2, 4, or 8");
+
+ return success();
+}
+
+LogicalResult verify2DBlockLoadRestriction(BlockLoad2dOp op) {
+ VectorType resTy = op.getRes().getType();
+ if (!resTy.getElementType().isIntOrFloat())
+ return op.emitOpError()
+ << "expecting result element type to be int or float";
+ unsigned resElemTySize = resTy.getElementType().getIntOrFloatBitWidth();
+ unsigned resSize = resTy.getNumElements() * resElemTySize;
+ unsigned expectedSize = op.getElemSizeInBits() * op.getTileHeight() *
+ op.getTileWidth() * op.getVBlocks() / subgroupSize;
+ if (resSize != expectedSize)
+ return op.emitOpError() << "result size of " << resSize
+ << " bits does not match the expected size of "
+ << expectedSize << " bits";
+
+ if (op.getTranspose() && op.getPackRegister())
+ return op.emitOpError("transpose and pack_register are mutually exclusive");
+
+ if (!op.getTranspose() && !op.getPackRegister()) {
+ uint32_t tileHeight = op.getTileHeight();
+ if (tileHeight < 1 || tileHeight > 32)
+ return op.emitOpError("expecting tile_height to be between 1 and 32");
+
+ uint32_t tileWidth = op.getTileWidth();
+ uint32_t vBlocks = op.getVBlocks();
+ switch (op.getElemSizeInBits()) {
+ case 8:
+ if (tileWidth < 4 || tileWidth > 64)
+ return op.emitOpError("expecting tile_width to be between 4 and 64");
+ if (vBlocks != 1 && vBlocks != 2 && vBlocks != 4)
+ return op.emitOpError("expecting v_blocks to be 1, 2, or 4");
+ if (tileWidth * vBlocks > 64)
+ return op.emitOpError(
+ "tile_width * v_blocks should be less than or equal "
+ "to 64 for 8 bit elements");
+ break;
+ case 16:
+ if (tileWidth < 2 || tileWidth > 32)
+ return op.emitOpError("expecting tile_width to be between 2 and 32");
+ if (vBlocks != 1 && vBlocks != 2 && vBlocks != 4)
+ return op.emitOpError("expecting v_blocks to be 1, 2, or 4");
+ if (tileWidth * vBlocks > 32)
+ return op.emitOpError(
+ "tile_width * v_blocks should be less than or equal "
+ "to 32 for 16 bit elements");
+ break;
+ case 32:
+ if (tileWidth < 1 || tileWidth > 16)
+ return op.emitOpError("expecting tile_width to be between 1 and 16");
+ if (vBlocks != 1 && vBlocks != 2)
+ return op.emitOpError("expecting v_blocks to be 1 or 2");
+ if (tileWidth * vBlocks > 16)
+ return op.emitOpError(
+ "tile_width * v_blocks should be less than or equal "
+ "to 16 for 32 bit elements");
+ break;
+ case 64:
+ if (tileWidth < 1 || tileWidth > 8)
+ return op.emitOpError("expecting tile_width to be between 1 and 8");
+ if (vBlocks != 1)
+ return op.emitOpError("expecting v_blocks to be 1");
+ break;
+ default:
+ return op.emitOpError(
+ "expecting elem_size_in_bits to be 8, 16, 32, or 64");
+ }
+
+ return success();
+ }
+
+ if (op.getTranspose()) {
+ assert(!op.getPackRegister() && "Expecting pack_register should be false");
+
+ uint32_t vBlocks = op.getVBlocks();
+ if (vBlocks != 1)
+ return op.emitOpError("expecting v_blocks to be 1");
+
+ uint32_t tileHeight = op.getTileHeight();
+ uint32_t tileWidth = op.getTileWidth();
+ switch (op.getElemSizeInBits()) {
+ case 32:
+ if (tileHeight < 1 || tileHeight > 32)
+ return op.emitOpError("expecting tile_height to be between 1 and 32");
+ if (tileWidth < 1 || tileWidth > 8)
+ return op.emitOpError("expecting tile_width to be between 1 and 8");
+ break;
+ case 64:
+ if (tileHeight != 8)
+ return op.emitOpError(
+ "expecting tile_height to be 8 for 64 bit elements");
+ if (tileWidth != 1 && tileWidth != 2 && tileWidth != 4)
+ return op.emitOpError("expecting tile_width to be 1, 2, or 4");
+ break;
+ default:
+ return op.emitOpError("transpose is only supported for 32 and 64 bit "
+ "elements");
+ }
+
+ return success();
+ }
+
+ assert(op.getPackRegister() && !op.getTranspose() &&
+ "Expecting pack_register should be true and transpose should be "
+ "false");
+
+ uint32_t vBlocks = op.getVBlocks();
+ if (vBlocks != 1 && vBlocks != 2 && vBlocks != 4)
+ return op.emitOpError("expecting v_blocks to be 1, 2, or 4");
+
+ uint32_t tileHeight = op.getTileHeight();
+ uint32_t tileWidth = op.getTileWidth();
+ switch (op.getElemSizeInBits()) {
+ case 8:
+ if (tileHeight < 4 || tileHeight > 32)
+ return op.emitOpError("expecting tile_height to be between 4 and 32");
+ if (tileWidth < 4 || tileWidth > 16)
+ return op.emitOpError("expecting tile_width to be between 4 and 16");
+ break;
+ case 16:
+ if (tileHeight < 2 || tileHeight > 32)
+ return op.emitOpError("expecting tile_height to be between 2 and 32");
+ if (tileWidth < 2 || tileWidth > 16)
+ return op.emitOpError("expecting tile_width to be between 2 and 16");
+ if (tileWidth * vBlocks > 32)
+ return op.emitOpError(
+ "tile_width * v_blocks should be less than or equal "
+ "to 32 for 16 bit elements");
+ break;
+ default:
+ return op.emitOpError("pack_register is only supported for 8 and 16 bit "
+ "elements");
+ }
+
+ return success();
+}
+
+static LogicalResult verify2DBlockStoreRestriction(BlockStore2dOp op) {
+ uint32_t tileHeight = op.getTileHeight();
+ if (tileHeight < 1 || tileHeight > 8)
+ return op.emitOpError("expecting tile_height to be between 1 and 8");
+
+ uint32_t tileWidth = op.getTileWidth();
+ switch (op.getElemSizeInBits()) {
+ case 8:
+ if (tileWidth < 4 || tileWidth > 64)
+ return op.emitOpError("expecting tile_width to be between 4 and 64");
+ break;
+ case 16:
+ if (tileWidth < 2 || tileWidth > 32)
+ return op.emitOpError("expecting tile_width to be between 2 and 32");
+ break;
+ case 32:
+ if (tileWidth < 1 || tileWidth > 16)
+ return op.emitOpError("expecting tile_width to be between 1 and 16");
+ break;
+ case 64:
+ if (tileWidth < 1 || tileWidth > 8)
+ return op.emitOpError("expecting tile_width to be between 1 and 8");
+ break;
+ default:
+ return op.emitOpError("expecting elem_size_in_bits to be 8, 16, 32, or 64");
+ }
+
+ uint32_t vBlocks = op.getVBlocks();
+ if (vBlocks != 1)
+ return op.emitOpError("expecting v_blocks to be 1");
+ return success();
+}
+
+} // namespace
+
+LogicalResult BlockLoad2dOp::verify() {
+ if (verify2DBlockLoadRestriction(*this).failed())
+ return failure();
+
+ if (verifyMatrixInput(*this).failed())
+ return failure();
+
+ VectorType resTy = getRes().getType();
+ if (!resTy.getElementType().isIntOrFloat())
+ return emitOpError() << "expecting result element type to be int of float";
+ unsigned resElemTySize = resTy.getElementType().getIntOrFloatBitWidth();
+ if (getElemSizeInBits() == 32 || getPackRegister()) {
+ if (resElemTySize != 32)
+ return emitOpError() << "expecting result element type to be 32 bits";
+ }
+
+ uint32_t tileWidth = getTileWidth();
+ if (getPackRegister()) {
+ if (tileWidth != 16)
+ return emitOpError(
+ "tile_width when pack_register is true should be equal "
+ "to subgroup size (16 elements)");
+ return success();
+ }
+
+ return success();
+}
+
+LogicalResult BlockStore2dOp::verify() {
+ if (verify2DBlockStoreRestriction(*this).failed())
+ return failure();
+
+ if (verifyMatrixInput(*this).failed())
+ return failure();
+
+ uint32_t tileWidth = getTileWidth();
+ switch (getElemSizeInBits()) {
+ case 8:
+ if (tileWidth != 16 && tileWidth != 32)
+ return emitOpError("tile_width for 8 bit elements should be equal to "
+ "16 or 32");
+ break;
+ case 16:
+ if (tileWidth != 16)
+ return emitOpError("tile_width for 16 bit elements should be equal "
+ "to 16");
+ break;
+ case 32:
+ if (tileWidth != 16)
+ return emitOpError("tile_width for 32 bit elements should be equal "
+ "to 16");
+ break;
+ default:
+ llvm_unreachable("unexpected element size");
+ }
+
+ return success();
+}
+
+LogicalResult BlockPrefetch2dOp::verify() {
+ if (verifyMatrixInput(*this).failed())
+ return failure();
+
+ uint32_t tileWidth = getTileWidth();
+ switch (getElemSizeInBits()) {
+ case 8:
+ if (tileWidth != 16 && tileWidth != 32)
+ return emitOpError("tile_width for 8 bit elements should be equal to "
+ "16 or 32");
+ break;
+ case 16:
+ if (tileWidth != 16)
+ return emitOpError("tile_width for 16 bit elements should be equal "
+ "to 16");
+ break;
+ case 32:
+ if (tileWidth != 8 && tileWidth != 16)
+ return emitOpError(
+ "tile_width for 32 bit elements should be equal to 8 or 16");
+ break;
+ default:
+ llvm_unreachable("unexpected element size");
+ }
+
+ return success();
+}
+
+LogicalResult MMAOp::verify() {
+ if (getC()) {
+ if (getResult().getType() != getC().getType())
+ return emitOpError("type of C operand must match result type");
+ }
+ return success();
+}
+
+LogicalResult
+XeVMTargetAttr::verify(function_ref<InFlightDiagnostic()> emitError, int O,
+ StringRef triple, StringRef chip, DictionaryAttr flags,
+ ArrayAttr linkFiles) {
+ if (O < 0 || O > 3) {
+ return emitError()
+ << "The optimization level must be a number between 0 and 3.";
+ }
+ if (triple.empty()) {
+ return emitError() << "The target triple cannot be empty.";
+ }
+ if (chip.empty()) {
+ return emitError() << "The target chip cannot be empty.";
+ }
+ if (linkFiles) {
+ for (Attribute fileAttr : linkFiles) {
+ if (auto fileStrAttr = llvm::dyn_cast<StringAttr>(fileAttr)) {
+ StringRef filePath = fileStrAttr.getValue();
+ if (filePath.empty()) {
+ return emitError() << "File paths in linkFiles cannot be empty.";
+ }
+ if (!llvm::sys::fs::exists(filePath)) {
+ return emitError() << "File '" << filePath << "' does not exist.";
+ }
+ }
+ }
+ }
+ return success();
+}
+
+void XeVMDialect::initialize() {
+ addOperations<
+#define GET_OP_LIST
+#include "mlir/Dialect/LLVMIR/XeVMOps.cpp.inc"
+ >();
+
+ addAttributes<
+#define GET_ATTRDEF_LIST
+#include "mlir/Dialect/LLVMIR/XeVMOpsAttributes.cpp.inc"
+ >();
+ declarePromisedInterface<mlir::gpu::TargetAttrInterface,
+ mlir::xevm::XeVMTargetAttr>();
+}
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/LLVMIR/XeVMOps.cpp.inc"
+
+#define GET_ATTRDEF_CLASSES
+#include "mlir/Dialect/LLVMIR/XeVMOpsAttributes.cpp.inc"
diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index 251ca716c7a7a..bd1106e304c60 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -1875,7 +1875,7 @@ llvm.mlir.global @bad_struct_array_init_elements() : !llvm.array<1x!llvm.struct<
llvm.return %0 : !llvm.array<1x!llvm.struct<(i32, f32)>>
}
-// ----
+// -----
llvm.mlir.global internal constant @bad_array_attr_simple_type() : !llvm.array<2 x f64> {
// expected-error at below {{'llvm.mlir.constant' op for array with an array attribute must have a struct element type}}
@@ -1883,10 +1883,51 @@ llvm.mlir.global internal constant @bad_array_attr_simple_type() : !llvm.array<2
llvm.return %0 : !llvm.array<2 x f64>
}
-// ----
+// -----
llvm.func @inlineAsmMustTail(%arg0: i32, %arg1 : !llvm.ptr) {
// expected-error at +1 {{op tail call kind 'musttail' is not supported}}
%8 = llvm.inline_asm tail_call_kind = <musttail> "foo", "=r,=r,r" %arg0 : (i32) -> !llvm.struct<(i8, i8)>
llvm.return
}
+
+// -----
+
+llvm.func @invalid_xevm_prefetch(%arg0: !llvm.ptr) {
+ // expected-error at +1 {{op operand #0 must be LLVM pointer in address space 1 or LLVM pointer in address space 4}}
+ xevm.prefetch %arg0 <{cache_control = #xevm.load_cache_control<L1uc_L2uc_L3uc>}> : (!llvm.ptr)
+ llvm.return
+}
+
+// -----
+
+llvm.func @invalid_xevm_mma(%loaded_c_casted: vector<4xf32>, %loaded_a: vector<8xi16>, %loaded_b_casted: vector<8xi32>) -> vector<8xf32> {
+ // expected-error at +1 {{op type of C operand must match result type}}
+ %c_result = xevm.mma %loaded_a, %loaded_b_casted, %loaded_c_casted {shape = <m = 8, n = 16, k = 16>, types = <d = f32, a = f16, b = f16, c = f32>} : (vector<8xi16>, vector<8xi32>, vector<4xf32>) -> vector<8xf32>
+ llvm.return %c_result : vector<8xf32>
+}
+
+// -----
+
+llvm.func @invalid_xevm_matrix_1(%c: !llvm.ptr<1>, %base_width_c: i32, %base_height_c: i32, %base_pitch_c: i32, %x: i32, %y: i32, %c_result_casted: vector<8xi32>) {
+ // expected-error at +1 {{op expecting tile_width to be between 1 and 8}}
+ xevm.blockstore2d %c, %base_width_c, %base_height_c, %base_pitch_c, %x, %y, %c_result_casted <{elem_size_in_bits=64 : i32, tile_width=16 : i32, tile_height=8 : i32}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi32>)
+ llvm.return
+}
+
+// -----
+
+llvm.func @invalid_xevm_matrix_2(%c: !llvm.ptr<1>, %base_width_c: i32, %base_height_c: i32, %base_pitch_c: i32, %x: i32, %y: i32, %c_result_casted: vector<8xi32>) {
+ // expected-error at +1 {{op expecting elem_size_in_bits to be 8, 16, 32, or 64}}
+ xevm.blockstore2d %c, %base_width_c, %base_height_c, %base_pitch_c, %x, %y, %c_result_casted <{elem_size_in_bits=18 : i32, tile_width=16 : i32, tile_height=8 : i32}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi32>)
+ llvm.return
+}
+
+// -----
+
+llvm.func @invalid_xevm_matrix_3(%a: !llvm.ptr<1>, %base_width_a: i32, %base_height_a: i32, %base_pitch_a: i32, %x: i32, %y: i32) -> vector<8xi16> {
+ // expected-error at +1 {{op result size of 128 bits does not match the expected size of 208 bits}}
+ %loaded_a = xevm.blockload2d %a, %base_width_a, %base_height_a, %base_pitch_a, %x, %y <{elem_size_in_bits=16 : i32, tile_width=26 : i32, tile_height=8 : i32, v_blocks=1 : i32, transpose=false, pack_register=false, cache_control=#xevm.load_cache_control<L1uc_L2uc_L3uc>}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi16>
+ llvm.return %loaded_a : vector<8xi16>
+}
+
diff --git a/mlir/test/Dialect/LLVMIR/xevm.mlir b/mlir/test/Dialect/LLVMIR/xevm.mlir
new file mode 100644
index 0000000000000..3dd5f872f898c
--- /dev/null
+++ b/mlir/test/Dialect/LLVMIR/xevm.mlir
@@ -0,0 +1,95 @@
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics | FileCheck %s
+
+// CHECK-LABEL: func.func @blockload2d(
+// CHECK-SAME: %[[ARG0:.*]]: !llvm.ptr<1>,
+// CHECK-SAME: %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32, %[[ARG3:.*]]: i32, %[[ARG4:.*]]: i32, %[[ARG5:.*]]: i32)
+func.func @blockload2d(%a: !llvm.ptr<1>, %base_width_a: i32, %base_height_a: i32,
+ %base_pitch_a: i32, %x: i32, %y: i32) -> vector<8xi16> {
+ // CHECK: %[[VAR0:.*]] = xevm.blockload2d %[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG4]], %[[ARG5]]
+ // CHECK-DAG: elem_size_in_bits = 16 : i32
+ // CHECK-DAG: tile_width = 16 : i32
+ // CHECK-DAG: tile_height = 8 : i32
+ // CHECK-DAG: v_blocks = 1 : i32
+ // CHECK-DAG: transpose = false
+ // CHECK-DAG: pack_register = false
+ // CHECK-DAG: cache_control = #xevm.load_cache_control<L1uc_L2uc_L3uc>
+ // CHECK: (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi16>
+ %loaded_a = xevm.blockload2d %a, %base_width_a, %base_height_a, %base_pitch_a, %x, %y
+ <{elem_size_in_bits=16 : i32, tile_width=16 : i32, tile_height=8 : i32, v_blocks=1 : i32,
+ transpose=false, pack_register=false, cache_control=#xevm.load_cache_control<L1uc_L2uc_L3uc>}>
+ : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi16>
+ return %loaded_a : vector<8xi16>
+}
+
+// -----
+// CHECK-LABEL: func.func @blockstore2d(
+// CHECK-SAME: %[[ARG0:.*]]: !llvm.ptr<1>,
+// CHECK-SAME: %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32, %[[ARG3:.*]]: i32, %[[ARG4:.*]]: i32, %[[ARG5:.*]]: i32,
+// CHECK-SAME: %[[ARG6:.*]]: vector<8xi32>)
+func.func @blockstore2d(%c: !llvm.ptr<1>, %base_width_c: i32, %base_height_c: i32,
+ %base_pitch_c: i32, %x: i32, %y: i32, %c_result_casted: vector<8xi32>) {
+ // CHECK: xevm.blockstore2d %[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG4]], %[[ARG5]], %[[ARG6]]
+ // CHECK-DAG: elem_size_in_bits = 32 : i32
+ // CHECK-DAG: tile_width = 16 : i32
+ // CHECK-DAG: tile_height = 8 : i32
+ // CHECK: (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi32>)
+ xevm.blockstore2d %c, %base_width_c, %base_height_c, %base_pitch_c, %x, %y, %c_result_casted
+ <{elem_size_in_bits=32 : i32, tile_width=16 : i32, tile_height=8 : i32}>
+ : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi32>)
+ return
+}
+
+// -----
+// CHECK-LABEL: func.func @blockprefetch2d(
+// CHECK-SAME: %[[ARG0:.*]]: !llvm.ptr<1>,
+// CHECK-SAME: %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32, %[[ARG3:.*]]: i32, %[[ARG4:.*]]: i32, %[[ARG5:.*]]: i32)
+func.func @blockprefetch2d(%ptr: !llvm.ptr<1>, %base_width: i32, %base_height: i32,
+ %base_pitch: i32, %x: i32, %y: i32) {
+ // CHECK: xevm.blockprefetch2d %[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG4]], %[[ARG5]]
+ // CHECK-DAG: elem_size_in_bits = 8 : i32
+ // CHECK-DAG: tile_width = 32 : i32
+ // CHECK-DAG: tile_height = 8 : i32
+ // CHECK-DAG: v_blocks = 1 : i32
+ // CHECK-DAG: cache_control = #xevm.load_cache_control<L1uc_L2uc_L3uc>
+ // CHECK: (!llvm.ptr<1>, i32, i32, i32, i32, i32)
+ xevm.blockprefetch2d %ptr, %base_width, %base_height, %base_pitch, %x, %y <{elem_size_in_bits=8 : i32,
+ tile_width=32 : i32, tile_height=8 : i32, v_blocks=1 : i32,
+ cache_control=#xevm.load_cache_control<L1uc_L2uc_L3uc>}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32)
+ return
+}
+
+// -----
+// CHECK-LABEL: func.func @mma(
+// CHECK-SAME: %[[ARG0:.*]]: vector<8xf32>, %[[ARG1:.*]]: vector<8xi16>, %[[ARG2:.*]]: vector<8xi32>)
+func.func @mma(%loaded_c_casted: vector<8xf32>, %loaded_a: vector<8xi16>, %loaded_b_casted: vector<8xi32>) -> vector<8xf32> {
+ // CHECK: %0 = xevm.mma %[[ARG1]], %[[ARG2]], %[[ARG0]] {shape = <m = 8, n = 16, k = 16>,
+ // CHECK-SAME: types = <d = f32, a = f16, b = f16, c = f32>} : (vector<8xi16>, vector<8xi32>, vector<8xf32>) -> vector<8xf32>
+ %c_result = xevm.mma %loaded_a, %loaded_b_casted, %loaded_c_casted { shape=<m=8, n=16, k=16>,
+ types=<d=f32, a=f16, b=f16, c=f32> } : (vector<8xi16>, vector<8xi32>, vector<8xf32>) -> vector<8xf32>
+ return %c_result : vector<8xf32>
+}
+
+// -----
+// CHECK-LABEL: func.func @memfence()
+func.func @memfence() {
+ // CHECK: xevm.memfence
+ // CHECK-DAG: addrspace = #xevm.addr_space<global>
+ // CHECK-DAG: scope = #xevm.mem_scope<workgroup>
+ xevm.memfence <{addrspace=#xevm.addr_space<global>, scope=#xevm.mem_scope<workgroup>}>
+ return
+}
+
+// -----
+// CHECK-LABEL: func.func @prefetch(
+// CHECK-SAME: %[[ARG0:.*]]: !llvm.ptr<1>)
+func.func @prefetch(%ptr: !llvm.ptr<1>) {
+ // CHECK: xevm.prefetch %[[ARG0]]
+ // CHECK-SAME: <{cache_control = #xevm.load_cache_control<L1uc_L2uc_L3uc>}> : (!llvm.ptr<1>)
+ xevm.prefetch %ptr <{cache_control = #xevm.load_cache_control<L1uc_L2uc_L3uc>}> : (!llvm.ptr<1>)
+ return
+}
+
+// -----
+// CHECK-LABEL: @xevm_module [#xevm.target<O = 3, chip = "pvc">] {
+gpu.module @xevm_module [#xevm.target<O = 3, chip = "pvc">]{
+}
More information about the Mlir-commits
mailing list