[Mlir-commits] [mlir] 1bf31c3 - [MLIR][XeGPU] Update XeGPU create_tdesc, update_offset, load, store and prefetch. (#154653)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Aug 22 14:09:51 PDT 2025
Author: Sang Ik Lee
Date: 2025-08-22T14:09:47-07:00
New Revision: 1bf31c3cd01e946103eddf08b5b52a1c6fad96a4
URL: https://github.com/llvm/llvm-project/commit/1bf31c3cd01e946103eddf08b5b52a1c6fad96a4
DIFF: https://github.com/llvm/llvm-project/commit/1bf31c3cd01e946103eddf08b5b52a1c6fad96a4.diff
LOG: [MLIR][XeGPU] Update XeGPU create_tdesc, update_offset, load, store and prefetch. (#154653)
This PR tightens some loose ends in some XeGPU op definitions.
Changes are backward compatible except for
- Enforcing previous implicit assumption of load/store/prefetch offsets
is required if source/dest is not a scatter tensor descriptor.
- Likewise, enforce offsets is not allowed if source/dest is a scatter
tensor descriptor.
- Additionally, allow i64, i32 and ui32 as source/dest for
load/store/prefetch. This matches behavior of tensor descriptor which
allows i64, i32 and ui32 base address in addition to ui64
- Explicitly state that create_tdesc and update_offset ops are not valid
in SIMT mode. create_tdesc and update_offset ops are still available for
subgroup level non SIMT mode.
- prefetch op adds attribute offset_align_byte to be used with integer
pointer source to enable address calculation with offsets.
New test cases are added for the new enforced checks.
Other minor implementation change:
XeGPU scatter tensor descriptor only allows 1D base memref. This was
check in op verify() method. Now moved to tablegen - ODS - definition.
Added:
Modified:
mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
mlir/test/Dialect/XeGPU/invalid.mlir
mlir/test/Dialect/XeGPU/ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index ab471a1f33ef9..f0b325cd1d593 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -70,28 +70,32 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface
future). Elements in the subview continuous in each dimension. It encodes the
following important information for supporting Intel hardware features:
- * source: an object representing (starting address/pointer of) a memory region.
+ Arguments:
+ - `source`: an object representing (starting address/pointer of) a memory region.
It can be either a memref object, or simply a pointer represented by uint64_t type.
For the case of dynamic memrefs or pointer, the shape and layout information of the
memory region should be explicitly passed via `shape` and `strides` parameters.
- * offsets: index values represents offsets from the "source" at the each dimension
+ - `offsets`: index values represents offsets from the "source" at the each dimension
at which the subview of the target memory will be created. It is encoded via
"offsets" and "const_offsets", such that it can accept various forms, such as,
operands (e.g., [%c0, %c]) and attributes (e.g., [2, 4]).
- * shape: the shape information of the memory region pointed by the "source". It is
+ - `shape`: the shape information of the memory region pointed by the "source". It is
typically encoded via the MemRefType of the source, e.g., memref<4096x4096xf16>.
But if "source" is simply a pointer represented as uint64_t type, or a memref
type without shape information e.g., memref<?x?xf16>, the shape information has
to be explicitly passed via the "shape" and "const_shape" arguments.
- * strides: the strides of the memory region pointed by the "source". Similar to shape,
+ - `strides`: the strides of the memory region pointed by the "source". Similar to shape,
it is typically encoded via the MemRefType of the source too. But if "source" is
simply a pointer represented as uint64_t type, or a memref type without shape
information e.g., memref<?x?xf16>, the strides information has to be explicitly
passed via the "strides" and "const_strides" argument.
+ Results:
+ - `res`: nd tensor descriptor
+
Example 1 (suppose the tensor shape inferred by the compiler is 8x16):
```mlir
%0 = memref.alloc() : memref<1024x1024xf32>
@@ -560,12 +564,17 @@ def XeGPU_CreateDescOp: XeGPU_Op<"create_tdesc", [Pure, ViewLikeOpInterface]> {
(scattered) subviews, allowing each work-item in a subgroup specifying their own offset.
It accepts the following parameters:
- * source: a 1D memref or pointer (uint64_t) represents the flattened memory object.
- * offsets: a vector containing offsets of each access point. Its size
+ Arguments:
+ - `source`: a 1D memref or pointer (i64, i32, ui64, ui32) represents the flattened
+ memory object.
+ - `offsets`: a vector containing offsets of each access point. Its size
is fixed to the hardware supportted subgroup size, e.g., 16 on PVC,
implying each element in the vector corresponds to a work-item (SIMT lane)
in the subgroup.
+ Results:
+ - `res`: scattered tensor descriptor
+
The first dimension of the result TensorDesc corresponds to work-items, so it should
match the dimension of offsets. It may also has a second dimension corresponding to
the chunk_size if the chunk size is larger than 1.
@@ -596,8 +605,8 @@ def XeGPU_CreateDescOp: XeGPU_Op<"create_tdesc", [Pure, ViewLikeOpInterface]> {
```
}];
- let arguments = (ins XeGPU_BaseAddrType: $source,
- XeGPU_OffsetType: $offsets);
+ let arguments = (ins XeGPU_GatherScatterBaseAddrType:$source,
+ XeGPU_OffsetType:$offsets);
let results = (outs XeGPU_TensorDesc:$TensorDesc);
let builders = [
@@ -655,6 +664,18 @@ def XeGPU_PrefetchOp : XeGPU_Op<"prefetch", []> {
As compared to prefetch_nd, which works on non-scattered TensorDesc,
it works on scattered TensorDesc instead.
+ Arguments:
+ - `source`: represents the memory region to be loaded from, which can be either a
+ tensor_desc or a 1D memref or pointer (ui64, ui32, i64 or i32).
+ In case of tensor_desc, offsets come from the producer create_tdesc op.
+ tensor_desc cannot be used in SIMT mode.
+ - `offsets`: represents offsets from source. required if `source` in not a TensorDescType.
+ offsets is a vector of `index` type and vector length is either the subgroup size
+ or 1 in SIMT mode. scalar offset is also valid for SIMT mode.
+ - `l1_hint`, `l2_hint`, `l3_hint`: are optional cache hints for each level of cache.
+ - `offset_align_byte`: required if `source` is a pointer. If `source` is not a pointer,
+ it is not allowed. Represents the alignment in bytes of each offset in offsets.
+
Example 1:
```mlir
xegpu.prefetch %tdesc {l1_hint = #xegpu.cache_hint<cached>,
@@ -666,7 +687,7 @@ def XeGPU_PrefetchOp : XeGPU_Op<"prefetch", []> {
Example 2:
A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc.
It combines "create scattered TensorTdesc" and "prefetch with scattered TensorTdesc".
- The source operand could be a raw pointer (uint64_t).
+ The source operand could be a raw pointer (ui64, ui32, i64, i32).
Please refer to create_tdesc for the restriction of memref.
```mlir
%a = memref.alloc() : memref<1024xf32>
@@ -677,13 +698,33 @@ def XeGPU_PrefetchOp : XeGPU_Op<"prefetch", []> {
: memref<1024xf32>, vector<4xindex>
```
+ Example 3 (SIMT mode):
+ SIMT mode only accepts the offsets variant.
+ ```mlir
+ xegpu.prefetch %0[%1] {l1_hint = #xegpu.cache_hint<cached>,
+ l2_hint = #xegpu.cache_hint<cached>,
+ l3_hint = #xegpu.cache_hint<cached>}
+ : memref<256xf32>, vector<1xindex>
+ ```
+
+ Example 4 (SIMT mode):
+ SIMT mode only accepts the offsets variant.
+ ```mlir
+ xegpu.prefetch %0[%1] {l1_hint = #xegpu.cache_hint<cached>,
+ l2_hint = #xegpu.cache_hint<cached>,
+ l3_hint = #xegpu.cache_hint<cached>,
+ offset_align_byte = 2}
+ : i64, vector<1xindex>
+ ```
+
}];
- let arguments = (ins XeGPU_GatherScatterSourceType: $source,
- Optional<XeGPU_OffsetType>: $offsets,
- OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
- OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint,
- OptionalAttr<XeGPU_CacheHintAttr>: $l3_hint);
+ let arguments = (ins XeGPU_GatherScatterSourceType:$source,
+ Optional<AnyTypeOf<[XeGPU_OffsetType, Index]>>:$offsets,
+ OptionalAttr<XeGPU_CacheHintAttr>:$l1_hint,
+ OptionalAttr<XeGPU_CacheHintAttr>:$l2_hint,
+ OptionalAttr<XeGPU_CacheHintAttr>:$l3_hint,
+ OptionalAttr<I64Attr>:$offset_align_byte);
let extraClassDeclaration = extraBaseClassDeclaration # [{
Type getSourceType() {
@@ -731,8 +772,26 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> {
The mask operand masks out memory access so that it is safe to pass out-of-boundary
addresses/offsets as long as they are masked. It applies to slots of SIMD lanes.
- In SIMT mode, the result vector represents the data to be loaded by each work-item.
- Each work-item recieves a `chunk_size` number of elements.
+ In SIMT mode, the result is a 1D vector that represents the data to be loaded by
+ each work-item. If size is not 1, size should be equal to the chunk size,
+
+ Arguments:
+ - `source`: represents the memory region to be loaded from, which can be either a
+ tensor_desc or a 1D memref or pointer (ui64, ui32, i64 or i32).
+ In case of tensor_desc, offsets come from the producer create_tdesc op.
+ tensor_desc cannot be used in SIMT mode.
+ - `offsets`: represents offsets from source. required if `source` in not a TensorDescType.
+ offsets is a vector of `index` type and vector length is either the subgroup size
+ or 1 in SIMT mode. scalar offset is also valid for SIMT mode.
+ - `mask`: is a vector of `i1` type, which is used to mask out the memory access.
+ mask is a vector of size equal to the subgroup size, or 1 in SIMT mode.
+ scalar mask is also valid for SIMT mode.
+ - `chunk_size`: (optional) represents contiguous number of elements to load from per work item.
+ - `l1_hint`, `l2_hint`, `l3_hint`: are optional cache hints for each level of cache.
+
+ Results:
+ - `res`: represents loaded data
+
Example 1:
```mlir
@@ -752,19 +811,10 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> {
vector<16xi1> -> vector<16x8xf32>
```
- Example 3 (SIMT mode):
- ```mlir
- %2 = xegpu.load %1, %0 <{l1_hint = #xegpu.cache_hint<cached>,
- l2_hint = #xegpu.cache_hint<uncached>,
- l3_hint = #xegpu.cache_hint<uncached>}>
- : !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<memory_space=global, chunk_size=8>>
- vector<16xi1> -> vector<8xf32>
- ```
-
- Example 4:
+ Example 3:
A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc.
It combines "create scattered TensorTdesc" and "load with scattered TensorTdesc".
- The source operand could be a raw pointer (uint64_t). Please refer to create_tdesc
+ The source operand could be a raw pointer (ui64, ui32, i64, i32). Please refer to create_tdesc
for the restriction of memref.
```mlir
%a = memref.alloc() : memref<1024xf32>
@@ -776,16 +826,25 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> {
: memref<1024xf32>, vector<16xi1>, vector<16xindex> -> vector<16xf32>
```
+ Example 4 (SIMT mode):
+ SIMT mode only accepts the offsets variant. chunk_size can be inferred from result
+ type. In this example, chunk_size is 8.
+ ```mlir
+ %2 = xegpu.load %1[%2], %0 <{l1_hint = #xegpu.cache_hint<cached>,
+ l2_hint = #xegpu.cache_hint<uncached>,
+ l3_hint = #xegpu.cache_hint<uncached>}>
+ : memref<128xf32>, vector<1xindex>, vector<1xi1> -> vector<8xf32>
+ ```
+
}];
- let arguments = (ins XeGPU_GatherScatterSourceType: $source,
- Optional<XeGPU_OffsetType>: $offsets,
- XeGPU_MaskType: $mask,
- OptionalAttr<I64Attr>: $chunk_size,
- OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
- OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint,
- OptionalAttr<XeGPU_CacheHintAttr>: $l3_hint);
- let results = (outs XeGPU_ValueType: $value);
+ let arguments = (ins XeGPU_GatherScatterSourceType:$source,
+ Optional<AnyTypeOf<[XeGPU_OffsetType, Index]>>:$offsets,
+ AnyTypeOf<[XeGPU_MaskType, I1]>:$mask, OptionalAttr<I64Attr>:$chunk_size,
+ OptionalAttr<XeGPU_CacheHintAttr>:$l1_hint,
+ OptionalAttr<XeGPU_CacheHintAttr>:$l2_hint,
+ OptionalAttr<XeGPU_CacheHintAttr>:$l3_hint);
+ let results = (outs AnyTypeOf<[XeGPU_ValueType, XeGPU_ScalarType]>:$value);
let extraClassDeclaration = extraBaseClassDeclaration # [{
@@ -838,15 +897,31 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> {
def XeGPU_StoreScatterOp : XeGPU_Op<"store", [MemoryEffects<[MemWrite]>]> {
let summary = "store data to scattered memory locations.";
- let description = [{ It (aka. store) stores data to scattered memory locations. The value is
+ let description =
+ [{ It (aka. store) stores data to scattered memory locations. The value is
typically a 1D vector. But when the chunk size of the TensorDesc is larger than 1, it will be
a 2D vector instead. For the later case, dim-1 of the value correspods to the simd lanes
and the dim-0 of the value corresponds to the chunk size stored per lane. So `store_scatter`
has transpose effect, which is similar to `load_gather`. Therefore, a transpose attribute is
introduced on purpose, making sure users are aware of this implicit transformation.
- In SIMT mode, the input vector represents the data to be stored by each work-item.
- Each work-item stores a `chunk_size` number of elements.
+ In SIMT mode, the result is a 1D vector that represents the data to be stored by
+ each work-item. If size is not 1, size should be equal to the chunk size.
+
+ Arguments:
+ - `value`: represents the data to be stored.
+ - `dest`: represents the memory region to be stored to, which can be either a
+ tensor_desc or a 1D memref or pointer (ui64, ui32, i64 or i32).
+ In case of tensor_desc, offsets come from the producer create_tdesc op.
+ tensor_desc cannot be used in SIMT mode.
+ - `offsets`: represents offsets from dest. required if `source` in not a TensorDescType.
+ offsets is a vector of `index` type and vector length is either the subgroup size
+ or 1 in SIMT mode. scalar offset is also valid for SIMT mode.
+ - `mask`: is a vector of `i1` type, which is used to mask out the memory access.
+ mask is a vector of size equal to the subgroup size, or 1 in SIMT mode.
+ scalar mask is also valid for SIMT mode.
+ - `chunk_size`: (optional) represents contiguous number of elements to store to per work item.
+ - `l1_hint`, `l2_hint`, `l3_hint`: are optional cache hints for each level of cache.
Example 1:
```mlir
@@ -864,15 +939,7 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [MemoryEffects<[MemWrite]>]> {
: vector<16x8xf32>, !xegpu.tensor_desc<16x8xf32, #xegpu.scattered_tdesc_attr<chunk_size=8>>, vector<16xi1>
```
- Example 3 (SIMT mode):
- ```mlir
- xegpu.store %0, %1, %2 <{l1_hint = #xegpu.cache_hint<uncached>,
- l2_hint = #xegpu.cache_hint<write_back>,
- l3_hint = #xegpu.cache_hint<write_through>}>
- : vector<8xf32>, !xegpu.tensor_desc<16x8xf32, #xegpu.scattered_tdesc_attr<chunk_size=8>> vector<16xi1>
- ```
-
- Example 4:
+ Example 3:
A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc.
It combines "create scattered TensorTdesc" and "store with scattered TensorTdesc".
The dest operand could be a raw pointer (uint64_t).
@@ -888,19 +955,27 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [MemoryEffects<[MemWrite]>]> {
: memref<1024xf32>, vector<16xi1>, vector<16xindex> -> vector<16xf32>
```
+ Example 4 (SIMT mode):
+ SIMT mode only accepts the offsets variant. chunk_size can be inferred from value
+ type. In this example, chunk_size is 8.
+ ```mlir
+ xegpu.store %0, %1[%2], %3 <{l1_hint = #xegpu.cache_hint<uncached>,
+ l2_hint = #xegpu.cache_hint<write_back>,
+ l3_hint = #xegpu.cache_hint<write_through>}>
+ : vector<8xf32>, memref<256xf32>, vector<1xindex>, vector<1xi1>
+ ```
+
}];
- let arguments = (ins
- XeGPU_ValueType: $value,
- XeGPU_GatherScatterSourceType: $dest,
- Optional<XeGPU_OffsetType>: $offsets,
- XeGPU_MaskType: $mask,
- OptionalAttr<I64Attr>: $chunk_size,
- OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
- OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint,
- OptionalAttr<XeGPU_CacheHintAttr>: $l3_hint);
+ let arguments = (ins AnyTypeOf<[XeGPU_ValueType, XeGPU_ScalarType]>:$value,
+ XeGPU_GatherScatterSourceType:$dest,
+ Optional<AnyTypeOf<[XeGPU_OffsetType, Index]>>:$offsets,
+ AnyTypeOf<[XeGPU_MaskType, I1]>:$mask, OptionalAttr<I64Attr>:$chunk_size,
+ OptionalAttr<XeGPU_CacheHintAttr>:$l1_hint,
+ OptionalAttr<XeGPU_CacheHintAttr>:$l2_hint,
+ OptionalAttr<XeGPU_CacheHintAttr>:$l3_hint);
- let extraClassDeclaration = extraBaseClassDeclaration # [{
+ let extraClassDeclaration = extraBaseClassDeclaration#[{
Type getDestType() {
return getDest().getType();
}
@@ -916,6 +991,11 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [MemoryEffects<[MemWrite]>]> {
return dyn_cast<xegpu::TensorDescType>(getDestType());
}
+ mlir::Type getElementType() {
+ auto type = getValue().getType();
+ return getElementTypeOrSelf(type);
+ }
+
VectorType getValueType() {
return llvm::dyn_cast<VectorType>(getValue().getType());
}
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
index f8b371db498e8..84902b2039643 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
@@ -16,13 +16,17 @@ include "mlir/IR/BuiltinTypes.td"
def XeGPU_IntType: AnyTypeOf<[I1, I8, I16, I32, I64, SI1, SI8, SI16, SI32, SI64, UI1, UI8, UI16, UI32, UI64]>;
def XeGPU_FloatType: AnyTypeOf<[F16, F32, F64, BF16, TF32]>;
def XeGPU_ScalarType: AnyTypeOf<[XeGPU_IntType, XeGPU_FloatType]>;
-def XeGPU_BaseAddrType: AnyTypeOf<[Non0RankedMemRefOf<[XeGPU_ScalarType]>, UI64, UI32, I64, I32]>;
+def XeGPU_PointerType : AnyTypeOf<[UI64, UI32, I64, I32]>;
+def XeGPU_BaseAddrType
+ : AnyTypeOf<[Non0RankedMemRefOf<[XeGPU_ScalarType]>, XeGPU_PointerType]>;
def XeGPU_DpasOprType: FixedVectorOfRankAndType<[1, 2, 3], [XeGPU_ScalarType]>;
def XeGPU_DpasResType: FixedVectorOfRankAndType<[1, 2], [XeGPU_ScalarType]>;
def XeGPU_OffsetType: FixedVectorOfNonZeroRankOf<[Index]>;
def XeGPU_MaskType: FixedVectorOfNonZeroRankOf<[I1]>;
def XeGPU_ValueType: FixedVectorOfNonZeroRankOf<[XeGPU_ScalarType]>;
def XeGPU_VectorType: VectorOfRankAndType<[1,2,3,4,5,6], [XeGPU_ScalarType]>;
+def XeGPU_GatherScatterBaseAddrType
+ : AnyTypeOf<[MemRefRankOf<[XeGPU_ScalarType], [1]>, XeGPU_PointerType]>;
// common base class for types in XeGPU dialect
class XeGPUTypeDef<string name, string typeMnemonic, list<Trait> traits = [],
@@ -189,7 +193,8 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
let genVerifyDecl = 1;
}
-def XeGPU_GatherScatterSourceType : AnyTypeOf<[XeGPU_TensorDesc,Non0RankedMemRefOf<[XeGPU_ScalarType]>, UI64]>;
+def XeGPU_GatherScatterSourceType
+ : AnyTypeOf<[XeGPU_TensorDesc, XeGPU_GatherScatterBaseAddrType]>;
def XeGPU_Nbarrier: XeGPUTypeDef<"Nbarrier", "nbarrier", [], "mlir::Type"> {
let summary = "!xegpu.nbarrier a custom XeGPU type representing a barrier.";
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index c8d180b973f05..7036996a68e0d 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -58,13 +58,6 @@ static SmallVector<int64_t> getShapeOf(Type type) {
return shape;
}
-static int64_t getRankOf(Value val) {
- auto type = val.getType();
- if (auto ty = llvm::dyn_cast<ShapedType>(type))
- return ty.getRank();
- return 0;
-}
-
static bool isReadHintOrNone(const CachePolicyAttr &attr) {
if (!attr)
return true;
@@ -89,13 +82,18 @@ isValidGatherScatterParams(Type maskTy, VectorType valueTy,
if (!tdescTy.isScattered())
return emitError() << "Expects a scattered TensorDesc.";
- if (!valueTy)
- return emitError() << "Expecting a vector type result.";
+ auto chunkSize = tdescTy.getChunkSizeAsInt();
+ if (!valueTy) {
+ if (chunkSize > 1)
+ return emitError() << "Expecting chunk size == 1 for scalar result";
+ if (dyn_cast<VectorType>(maskTy))
+ return emitError() << "Expecting a vector type result.";
+ return success();
+ }
auto maskShape = getShapeOf(maskTy);
auto valueShape = getShapeOf(valueTy);
auto tdescShape = getShapeOf(tdescTy);
- auto chunkSize = tdescTy.getChunkSizeAsInt();
if (valueTy.getElementType() != tdescTy.getElementType())
return emitError()
@@ -124,22 +122,37 @@ isValidGatherScatterParams(Type maskTy, VectorType valueTy,
}
static LogicalResult
-isValidGatherScatterBufferParams(Type maskTy, VectorType valueTy,
- int64_t chunkSize,
+isValidGatherScatterBufferParams(Type offsetsTy, Type maskTy,
+ VectorType valueTy, int64_t chunkSize,
function_ref<InFlightDiagnostic()> emitError) {
- if (!valueTy)
- return emitError() << "Expecting a vector type result.";
+ auto maskVecTy = dyn_cast<VectorType>(maskTy);
+ auto offsetsVecTy = dyn_cast<VectorType>(offsetsTy);
+ if (!valueTy) {
+ if (chunkSize > 1)
+ return emitError() << "Expecting chunk size == 1 for scalar result";
+ if (maskVecTy || offsetsVecTy)
+ return emitError() << "Expecting scalar mask and offsets.";
+ else if (maskVecTy && offsetsVecTy)
+ return emitError() << "Expecting a vector type result.";
+ return success();
+ }
+ auto valueSize = valueTy.getNumElements();
+ // SIMT mode with scalar mask and offsets.
+ if (!maskVecTy && !offsetsVecTy) {
+ if (valueSize != chunkSize)
+ return emitError() << "value elements must match chunk size "
+ << chunkSize;
+ return success();
+ }
auto maskShape = getShapeOf(maskTy);
auto valueShape = getShapeOf(valueTy);
- auto maskVecTy = dyn_cast<VectorType>(maskTy);
if (!maskVecTy)
return emitError() << "Expecting a vector type mask.";
int64_t maskSize = maskVecTy.getNumElements();
- auto valueSize = valueTy.getNumElements();
if (chunkSize > 1) {
if ((valueTy.getRank() == 1) && (valueSize != chunkSize))
return emitError() << "value elements must match chunk size "
@@ -149,8 +162,9 @@ isValidGatherScatterBufferParams(Type maskTy, VectorType valueTy,
return emitError()
<< "Mask should match value except the chunk size dim.";
}
-
llvm::SmallVector<int64_t> expectedMaskShape(valueShape);
+ if (maskSize == 1)
+ return success();
if (chunkSize > 1)
expectedMaskShape.pop_back();
if (expectedMaskShape != maskShape)
@@ -685,10 +699,6 @@ void CreateDescOp::build(OpBuilder &builder, OperationState &state,
LogicalResult CreateDescOp::verify() {
auto tdescTy = getTensorDescType();
- if (getRankOf(getSource()) > 1)
- return emitOpError(
- "Expecting the source is a 1D memref or pointer (uint64_t).");
-
if (!tdescTy.isScattered())
return emitOpError("Expects a scattered TensorDesc.\n");
@@ -723,13 +733,15 @@ LogicalResult CreateDescOp::verify() {
LogicalResult PrefetchOp::verify() {
auto tdescTy = getTensorDescType();
+ if (!tdescTy && !getOffsets())
+ return emitOpError("Expects offsets.");
+
+ if (tdescTy && getOffsets())
+ return emitOpError("offsets not allowed.");
+
if (tdescTy && !tdescTy.isScattered())
return emitOpError("Expects a scattered TensorDesc.");
- if (!tdescTy && getRankOf(getSource()) > 1)
- return emitOpError(
- "Expecting the source is a 1D memref or pointer (uint64_t).");
-
if (!isReadHintOrNone(getL1HintAttr()))
return emitOpError("invalid l1_hint: ") << getL1HintAttr();
@@ -739,6 +751,13 @@ LogicalResult PrefetchOp::verify() {
if (!isReadHintOrNone(getL3HintAttr()))
return emitOpError("invalid l3_hint: ") << getL3HintAttr();
+ auto srcTy = getSourceType();
+ if (srcTy.isInteger() && !getOffsetAlignByteAttr())
+ return emitOpError("offset_align_byte is required with integer source.");
+
+ if (getOffsetAlignByteAttr() && !srcTy.isInteger())
+ return emitOpError("offset_align_byte only allowed with integer source.");
+
return success();
}
@@ -746,7 +765,8 @@ void PrefetchOp::build(OpBuilder &builder, OperationState &state, Value source,
xegpu::CachePolicyAttr l1_hint,
xegpu::CachePolicyAttr l2_hint,
xegpu::CachePolicyAttr l3_hint) {
- build(builder, state, source, Value(), l1_hint, l2_hint, l3_hint);
+ build(builder, state, source, Value(), l1_hint, l2_hint, l3_hint,
+ IntegerAttr{});
}
//===----------------------------------------------------------------------===//
@@ -757,13 +777,15 @@ LogicalResult LoadGatherOp::verify() {
auto maskTy = getMaskType();
auto valueTy = getValueType();
+ if (!tdescTy && !getOffsets())
+ return emitOpError("Expects offsets.");
+
+ if (tdescTy && getOffsets())
+ return emitOpError("offsets not allowed.");
+
if (tdescTy && !tdescTy.isScattered())
return emitOpError("Expects a scattered TensorDesc.");
- if (!tdescTy && getRankOf(getSource()) > 1)
- return emitOpError(
- "Expecting the source is a 1D memref or pointer (uint64_t).");
-
if (!isReadHintOrNone(getL1HintAttr()))
return emitOpError("invalid l1_hint: ") << getL1HintAttr();
@@ -780,10 +802,11 @@ LogicalResult LoadGatherOp::verify() {
uint64_t chunkSize = static_cast<int64_t>(getChunkSize().value_or(1));
auto memTy = dyn_cast<MemRefType>(srcTy);
- if (memTy && (valueTy.getElementType() != memTy.getElementType()))
+ if (memTy && (getElementType() != memTy.getElementType()))
return emitError() << "Value should have the same element type as MemRef.";
- return isValidGatherScatterBufferParams(maskTy, valueTy, chunkSize,
+ auto offsetsTy = getOffsets().getType();
+ return isValidGatherScatterBufferParams(offsetsTy, maskTy, valueTy, chunkSize,
[&]() { return emitOpError(); });
}
@@ -804,13 +827,15 @@ LogicalResult StoreScatterOp::verify() {
auto maskTy = getMaskType();
auto valueTy = getValueType();
+ if (!tdescTy && !getOffsets())
+ return emitOpError("Expects offsets.");
+
+ if (tdescTy && getOffsets())
+ return emitOpError("offsets not allowed.");
+
if (tdescTy && !tdescTy.isScattered())
return emitOpError("Expects a scattered TensorDesc.");
- if (!tdescTy && getRankOf(getDest()) > 1)
- return emitOpError(
- "Expecting the dest is a 1D memref or pointer (uint64_t).");
-
if (!isWriteHintOrNone(getL1HintAttr()))
return emitOpError("invalid l1_hint: ") << getL1HintAttr();
@@ -828,10 +853,11 @@ LogicalResult StoreScatterOp::verify() {
uint64_t chunkSize = static_cast<int64_t>(getChunkSize().value_or(1));
auto memTy = dyn_cast<MemRefType>(destTy);
- if (memTy && (valueTy.getElementType() != memTy.getElementType()))
+ if (memTy && (getElementType() != memTy.getElementType()))
return emitError() << "Value should have the same element type as MemRef.";
- return isValidGatherScatterBufferParams(maskTy, valueTy, chunkSize,
+ auto offsetsTy = getOffsets().getType();
+ return isValidGatherScatterBufferParams(offsetsTy, maskTy, valueTy, chunkSize,
[&]() { return emitOpError(); });
}
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index 93a5a055b08c6..228ef69d9a478 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -387,11 +387,44 @@ func.func @load_gather_vc_3(%src: ui64) {
// -----
func.func @prefetch_offset_wi_1(%src: memref<4x4xf32>) {
%offsets = arith.constant dense<[0]> : vector<1xindex>
- // expected-error at +1 {{Expecting the source is a 1D memref or pointer}}
+ // expected-error at +1 {{op operand #0 must be TensorDesc describing regions of interested data}}
xegpu.prefetch %src[%offsets]: memref<4x4xf32>, vector<1xindex>
return
}
+// -----
+func.func @prefetch_offset_wi_2(%src: memref<16xf32>) {
+ %offsets = arith.constant dense<[0]> : vector<1xindex>
+ %1 = xegpu.create_tdesc %src, %offsets : memref<16xf32>, vector<1xindex>
+ -> !xegpu.tensor_desc<1x3xf32, #xegpu.scatter_tdesc_attr<chunk_size = 3>>
+ // expected-error at +1 {{offsets not allowed}}
+ xegpu.prefetch %1[%offsets]: !xegpu.tensor_desc<1x3xf32, #xegpu.scatter_tdesc_attr<chunk_size = 3>>, vector<1xindex>
+ return
+}
+
+// -----
+func.func @prefetch_offset_wi_3(%src: memref<16xf32>) {
+ // expected-error at +1 {{Expects offsets}}
+ xegpu.prefetch %src: memref<16xf32>
+ return
+}
+
+// -----
+func.func @prefetch_offset_wi_4(%src: memref<16xf32>) {
+ %offsets = arith.constant dense<[0]> : vector<1xindex>
+ // expected-error at +1 {{offset_align_byte only allowed with integer source.}}
+ xegpu.prefetch %src[%offsets] <{offset_align_byte = 4}>: memref<16xf32>, vector<1xindex>
+ return
+}
+
+// -----
+func.func @prefetch_offset_wi_5(%src: i64) {
+ %offsets = arith.constant dense<[0]> : vector<1xindex>
+ // expected-error at +1 {{offset_align_byte is required with integer source.}}
+ xegpu.prefetch %src[%offsets] : i64, vector<1xindex>
+ return
+}
+
// -----
func.func @load_gather_offset_sg(%src: memref<?xf16>) {
%offsets = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
@@ -428,12 +461,50 @@ func.func @store_scatter_offset_wi_2(%src: memref<4x4xf16>) {
%val = arith.constant dense<2.9>: vector<4xf16>
%offsets = arith.constant dense<[0]> : vector<1xindex>
%mask = arith.constant dense<1>: vector<1xi1>
- // expected-error at +1 {{Expecting the dest is a 1D memref or pointer}}
+ // expected-error at +1 {{op operand #1 must be TensorDesc describing regions of interested data}}
xegpu.store %val, %src[%offsets], %mask
: vector<4xf16>, memref<4x4xf16>, vector<1xindex>, vector<1xi1>
return
}
+// -----
+func.func @store_scatter_offset_wi_3(%src: memref<16xf16>) {
+ %val = arith.constant dense<2.9>: vector<1xf16>
+ %mask = arith.constant dense<1>: vector<1xi1>
+ // expected-error at +1 {{Expects offsets}}
+ xegpu.store %val, %src, %mask
+ : vector<1xf16>, memref<16xf16>, vector<1xi1>
+ return
+}
+
+// -----
+func.func @store_scatter_offset_wi_4(%src: !xegpu.tensor_desc<1x1xf32, #xegpu.scatter_tdesc_attr<>>) {
+ %val = arith.constant dense<2.9>: vector<1xf16>
+ %offsets = arith.constant dense<[0]> : vector<1xindex>
+ %mask = arith.constant dense<1>: vector<1xi1>
+ // expected-error at +1 {{offsets not allowed}}
+ xegpu.store %val, %src[%offsets], %mask
+ : vector<1xf16>, !xegpu.tensor_desc<1x1xf32, #xegpu.scatter_tdesc_attr<>>, vector<1xindex>, vector<1xi1>
+ return
+}
+
+// -----
+func.func @load_gather_offset_wi_4(%src: !xegpu.tensor_desc<1x2xf16, #xegpu.scatter_tdesc_attr<>>) {
+ %mask = arith.constant dense<1>: vector<1xi1>
+ %offsets = arith.constant dense<[0]> : vector<1xindex>
+ // expected-error at +1 {{offsets not allowed}}
+ %2 = xegpu.load %src[%offsets], %mask <{chunk_size = 2}> : !xegpu.tensor_desc<1x2xf16, #xegpu.scatter_tdesc_attr<>>, vector<1xindex>, vector<1xi1> -> vector<2xf16>
+ return
+}
+
+// -----
+func.func @load_gather_offset_wi_3(%src: ui64) {
+ %mask = arith.constant dense<1>: vector<1xi1>
+ // expected-error at +1 {{Expects offsets}}
+ %2 = xegpu.load %src, %mask <{chunk_size = 2}> : ui64, vector<1xi1> -> vector<2xf16>
+ return
+}
+
// -----
func.func @load_gather_offset_wi_2(%src: ui64) {
%mask = arith.constant dense<1>: vector<1xi1>
@@ -447,7 +518,7 @@ func.func @load_gather_offset_wi_2(%src: ui64) {
func.func @load_gather_offset_wi_1(%src: memref<4x4xf32>) {
%mask = arith.constant dense<1>: vector<1xi1>
%offsets = arith.constant dense<[0]> : vector<1xindex>
- // expected-error at +1 {{Expecting the source is a 1D memref or pointer}}
+ // expected-error at +1 {{op operand #0 must be TensorDesc describing regions of interested data}}
%2 = xegpu.load %src[%offsets], %mask <{chunk_size = 2}> : memref<4x4xf32>, vector<1xindex>, vector<1xi1> -> vector<2xf32>
return
}
diff --git a/mlir/test/Dialect/XeGPU/ops.mlir b/mlir/test/Dialect/XeGPU/ops.mlir
index 35342eca1354c..bb379024a34d7 100644
--- a/mlir/test/Dialect/XeGPU/ops.mlir
+++ b/mlir/test/Dialect/XeGPU/ops.mlir
@@ -508,6 +508,34 @@ gpu.func @simt_load_3(%src: ui64) {
gpu.return
}
+// CHECK: gpu.func @simt_load_4(%[[arg0:.*]]: memref<256xf16>, %[[arg1:.*]]: vector<1xindex>, %[[arg2:.*]]: vector<1xi1>) {
+gpu.func @simt_load_4(%arg0: memref<256xf16>, %arg1: vector<1xindex>, %arg2: vector<1xi1>) {
+ // CHECK: %0 = xegpu.load %[[arg0]][%[[arg1]]], %[[arg2]] <{chunk_size = 8 : i64}> : memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<8xf16>
+ %0 = xegpu.load %arg0[%arg1], %arg2 <{chunk_size = 8 : i64}> : memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<8xf16>
+ gpu.return
+}
+
+// CHECK: gpu.func @simt_load_5(%[[arg0:.*]]: memref<256xf16>, %[[arg1:.*]]: vector<1xindex>, %[[arg2:.*]]: vector<1xi1>) {
+gpu.func @simt_load_5(%arg0: memref<256xf16>, %arg1: vector<1xindex>, %arg2: vector<1xi1>) {
+ // CHECK: %0 = xegpu.load %[[arg0]][%[[arg1]]], %[[arg2]] : memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<1xf16>
+ %0 = xegpu.load %arg0[%arg1], %arg2 : memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<1xf16>
+ gpu.return
+}
+
+// CHECK: gpu.func @simt_load_6(%[[arg0:.*]]: memref<256xf16>, %[[arg1:.*]]: index, %[[arg2:.*]]: i1) {
+gpu.func @simt_load_6(%arg0: memref<256xf16>, %arg1: index, %arg2: i1) {
+ // CHECK: %0 = xegpu.load %[[arg0]][%[[arg1]]], %[[arg2]] <{chunk_size = 8 : i64}> : memref<256xf16>, index, i1 -> vector<8xf16>
+ %0 = xegpu.load %arg0[%arg1], %arg2 <{chunk_size = 8 : i64}> : memref<256xf16>, index, i1 -> vector<8xf16>
+ gpu.return
+}
+
+// CHECK: gpu.func @simt_load_7(%[[arg0:.*]]: memref<256xf16>, %[[arg1:.*]]: index, %[[arg2:.*]]: i1) {
+gpu.func @simt_load_7(%arg0: memref<256xf16>, %arg1: index, %arg2: i1) {
+ // CHECK: %0 = xegpu.load %[[arg0]][%[[arg1]]], %[[arg2]] : memref<256xf16>, index, i1 -> f16
+ %0 = xegpu.load %arg0[%arg1], %arg2 : memref<256xf16>, index, i1 -> f16
+ gpu.return
+}
+
// CHECK: gpu.func @subgroup_load_4(%[[arg0:.*]]: ui64) {
gpu.func @subgroup_load_4(%src: ui64) {
//CHECK: %[[cst:.*]] = arith.constant dense<{{.*}}> : vector<2x4xindex>
@@ -621,6 +649,34 @@ gpu.func @simt_store_3(%src: ui64) {
gpu.return
}
+// CHECK: gpu.func @simt_store_4(%[[arg0:.*]]: vector<8xf16>, %[[arg1:.*]]: memref<256xf16>, %[[arg2:.*]]: vector<1xindex>, %[[arg3:.*]]: vector<1xi1>) {
+gpu.func @simt_store_4(%arg0: vector<8xf16>, %arg1: memref<256xf16>, %arg2: vector<1xindex>, %arg3: vector<1xi1>) {
+ // CHECK: xegpu.store %[[arg0]], %[[arg1]][%[[arg2]]], %[[arg3]] <{chunk_size = 8 : i64}> : vector<8xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
+ xegpu.store %arg0, %arg1[%arg2], %arg3 <{chunk_size = 8 : i64}> : vector<8xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
+ gpu.return
+}
+
+// CHECK: gpu.func @simt_store_5(%[[arg0:.*]]: vector<8xf16>, %[[arg1:.*]]: memref<256xf16>, %[[arg2:.*]]: index, %[[arg3:.*]]: i1) {
+gpu.func @simt_store_5(%arg0: vector<8xf16>, %arg1: memref<256xf16>, %arg2: index, %arg3: i1) {
+ // CHECK: xegpu.store %[[arg0]], %[[arg1]][%[[arg2]]], %[[arg3]] <{chunk_size = 8 : i64}> : vector<8xf16>, memref<256xf16>, index, i1
+ xegpu.store %arg0, %arg1[%arg2], %arg3 <{chunk_size = 8 : i64}> : vector<8xf16>, memref<256xf16>, index, i1
+ gpu.return
+}
+
+// CHECK: gpu.func @simt_store_6(%[[arg0:.*]]: vector<1xf16>, %[[arg1:.*]]: memref<256xf16>, %[[arg2:.*]]: vector<1xindex>, %[[arg3:.*]]: vector<1xi1>) {
+gpu.func @simt_store_6(%arg0: vector<1xf16>, %arg1: memref<256xf16>, %arg2: vector<1xindex>, %arg3: vector<1xi1>) {
+ // CHECK: xegpu.store %[[arg0]], %[[arg1]][%[[arg2]]], %[[arg3]] : vector<1xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
+ xegpu.store %arg0, %arg1[%arg2], %arg3 : vector<1xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
+ gpu.return
+}
+
+// CHECK: gpu.func @simt_store_7(%[[arg0:.*]]: f16, %[[arg1:.*]]: memref<256xf16>, %[[arg2:.*]]: index, %[[arg3:.*]]: i1) {
+gpu.func @simt_store_7(%arg0: f16, %arg1: memref<256xf16>, %arg2: index, %arg3: i1) {
+ // CHECK: xegpu.store %[[arg0]], %[[arg1]][%[[arg2]]], %[[arg3]] : f16, memref<256xf16>, index, i1
+ xegpu.store %arg0, %arg1[%arg2], %arg3 : f16, memref<256xf16>, index, i1
+ gpu.return
+}
+
// CHECK: gpu.func @subgroup_store_4(%[[arg0:.*]]: ui64) {
gpu.func @subgroup_store_4(%src: ui64) {
//CHECK: %[[cst:.*]] = arith.constant dense<{{.*}}> : vector<2x4xindex>
@@ -662,8 +718,8 @@ gpu.func @prefetch(%src: ui64) {
gpu.func @prefetch_offset(%src: ui64) {
//CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
%0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
- // CHECK: xegpu.prefetch %[[arg0]][%cst] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : ui64, vector<4xindex>
- xegpu.prefetch %src[%0] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>: ui64, vector<4xindex>
+ // CHECK: xegpu.prefetch %[[arg0]][%cst] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, offset_align_byte = 2 : i64}> : ui64, vector<4xindex>
+ xegpu.prefetch %src[%0] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, offset_align_byte = 2}>: ui64, vector<4xindex>
gpu.return
}
More information about the Mlir-commits
mailing list