[Mlir-commits] [mlir] [MLIR][XeGPU] Allow load/store/prefetch uses [memref+offset] instead of tdesc (PR #150576)
Jianhui Li
llvmlistbot at llvm.org
Wed Jul 30 10:54:07 PDT 2025
https://github.com/Jianhui-Li updated https://github.com/llvm/llvm-project/pull/150576
>From 1373ffa1d836cb8401f5d24fca5c9283c2484d0e Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Thu, 17 Jul 2025 23:21:14 +0000
Subject: [PATCH 01/16] add optional offsets to nd load/store/prefetch
---
.../include/mlir/Dialect/XeGPU/IR/XeGPUOps.td | 70 +++++++++++++++++--
mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 32 +++++++--
mlir/test/Dialect/XeGPU/ops.mlir | 50 +++++++++++++
3 files changed, 140 insertions(+), 12 deletions(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index 81e25f7537cb0..e9f8437d7c102 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -29,9 +29,22 @@ class XeGPU_Op<string mnemonic, list<Trait> traits = []>:
void printProperties(::mlir::MLIRContext *ctx,
::mlir::OpAsmPrinter &p, const Properties &prop,
::mlir::ArrayRef<::llvm::StringRef> elidedProps) {
- Attribute propAttr = getPropertiesAsAttr(ctx, prop);
- if (propAttr)
- p << "<" << propAttr << ">";
+
+ DictionaryAttr propAttr = dyn_cast_if_present<mlir::DictionaryAttr>(getPropertiesAsAttr(ctx, prop));
+
+ // filter out the elidedProps from propAttr, and get the resultAttr
+ mlir::SmallVector<mlir::NamedAttribute> filteredAttrs;
+ if (propAttr) {
+ for (auto namedAttr : propAttr.getValue()) {
+ if (llvm::is_contained(elidedProps, namedAttr.getName().strref()))
+ continue;
+ filteredAttrs.push_back(namedAttr);
+ }
+ }
+
+ if (!filteredAttrs.empty()) {
+ p << "<" << DictionaryAttr::get(ctx, filteredAttrs) << ">";
+ }
}
static ::mlir::ParseResult parseProperties(::mlir::OpAsmParser &parser,
@@ -288,6 +301,8 @@ def XeGPU_PrefetchNdOp : XeGPU_Op<"prefetch_nd", []> {
}];
let arguments = (ins XeGPU_TensorDesc: $TensorDesc,
+ Variadic<Index>: $offsets,
+ OptionalAttr<DenseI64ArrayAttr>: $const_offsets,
OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint,
OptionalAttr<XeGPU_CacheHintAttr>: $l3_hint);
@@ -298,7 +313,18 @@ def XeGPU_PrefetchNdOp : XeGPU_Op<"prefetch_nd", []> {
}
}];
- let assemblyFormat = "$TensorDesc prop-dict attr-dict `:` qualified(type($TensorDesc))";
+ let assemblyFormat = [{
+ $TensorDesc ``
+ custom<OptionalDynamicIndexList>($offsets, $const_offsets)
+ prop-dict attr-dict `:` qualified(type($TensorDesc))
+ }];
+
+ let builders = [
+ OpBuilder<(ins "Value": $TensorDesc,
+ "xegpu::CachePolicyAttr": $l1_hint,
+ "xegpu::CachePolicyAttr": $l2_hint,
+ "xegpu::CachePolicyAttr": $l3_hint)>
+ ];
let hasVerifier = 1;
}
@@ -343,6 +369,8 @@ def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [
}];
let arguments = (ins XeGPU_TensorDesc: $TensorDesc,
+ Variadic<Index>: $offsets,
+ OptionalAttr<DenseI64ArrayAttr>: $const_offsets,
OptionalAttr<UnitAttr>: $packed,
OptionalAttr<DenseI64ArrayAttr>: $transpose,
OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
@@ -361,7 +389,20 @@ def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [
}
}];
- let assemblyFormat = "$TensorDesc prop-dict attr-dict `:` qualified(type($TensorDesc)) `->` type($value)";
+ let assemblyFormat = [{
+ $TensorDesc ``
+ custom<OptionalDynamicIndexList>($offsets, $const_offsets)
+ prop-dict attr-dict `:` qualified(type($TensorDesc)) `->` type($value)
+ }];
+
+ let builders = [
+ OpBuilder<(ins "Type": $value, "Value": $TensorDesc,
+ "UnitAttr": $packed, "DenseI64ArrayAttr": $transpose,
+ "xegpu::CachePolicyAttr": $l1_hint,
+ "xegpu::CachePolicyAttr": $l2_hint,
+ "xegpu::CachePolicyAttr": $l3_hint)>
+ ];
+
let hasVerifier = 1;
}
@@ -400,6 +441,8 @@ def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", [
let arguments = (ins XeGPU_ValueType: $value,
XeGPU_TensorDesc: $TensorDesc,
+ Variadic<Index>: $offsets,
+ OptionalAttr<DenseI64ArrayAttr>: $const_offsets,
OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint,
OptionalAttr<XeGPU_CacheHintAttr>: $l3_hint);
@@ -414,8 +457,21 @@ def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", [
}
}];
- let assemblyFormat = [{$value `,` $TensorDesc prop-dict attr-dict
- `:` type($value) `,` qualified(type($TensorDesc))}];
+ let assemblyFormat = [{
+ $value `,`
+ $TensorDesc ``
+ custom<OptionalDynamicIndexList>($offsets, $const_offsets)
+ prop-dict attr-dict `:` type($value) `,` qualified(type($TensorDesc))
+ }];
+
+ let builders = [
+ OpBuilder<(ins "Value": $value, "Value": $TensorDesc,
+ "xegpu::CachePolicyAttr": $l1_hint,
+ "xegpu::CachePolicyAttr": $l2_hint,
+ "xegpu::CachePolicyAttr": $l3_hint)>
+ ];
+
+
let hasVerifier = 1;
}
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 78cbf884a1911..ca3c92cf4b52c 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -331,16 +331,24 @@ ParseResult parseOptionalDynamicIndexList(
void printOptionalDynamicIndexList(
OpAsmPrinter &printer, Operation *op, OperandRange values,
- ArrayRef<int64_t> integers, TypeRange valueTypes = TypeRange(),
- AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square) {
+ DenseI64ArrayAttr integers) {
- return printDynamicIndexList(printer, op, values, integers,
- /*scalableFlags=*/{}, valueTypes, delimiter);
-}
+ if (!integers)
+ return;
+ return printDynamicIndexList(printer, op, values, integers,
+ /*scalableFlags=*/{}, {}, AsmParser::Delimiter::Square);
+ }
//===----------------------------------------------------------------------===//
// XeGPU_PrefetchNdOp
//===----------------------------------------------------------------------===//
+
+void PrefetchNdOp::build(OpBuilder &builder, OperationState &state, Value tensorDesc, xegpu::CachePolicyAttr l1_hint, xegpu::CachePolicyAttr l2_hint, xegpu::CachePolicyAttr l3_hint) {
+
+ return build(builder, state, tensorDesc, ValueRange(), DenseI64ArrayAttr(), l1_hint, l2_hint, l3_hint);
+
+}
+
LogicalResult PrefetchNdOp::verify() {
auto tdescTy = getTensorDescType();
if (tdescTy.isScattered())
@@ -361,6 +369,13 @@ LogicalResult PrefetchNdOp::verify() {
//===----------------------------------------------------------------------===//
// XeGPU_LoadNdOp
//===----------------------------------------------------------------------===//
+
+void LoadNdOp::build(OpBuilder &builder, OperationState &state, Type retType, Value tensorDesc, UnitAttr packed, DenseI64ArrayAttr transpose, xegpu::CachePolicyAttr l1_hint, xegpu::CachePolicyAttr l2_hint, xegpu::CachePolicyAttr l3_hint) {
+
+ return build(builder, state, retType, tensorDesc, ValueRange(), DenseI64ArrayAttr(), packed, transpose, l1_hint, l2_hint, l3_hint);
+
+}
+
LogicalResult LoadNdOp::verify() {
auto tdescTy = getTensorDescType();
auto valueTy = getType();
@@ -448,6 +463,13 @@ LogicalResult LoadNdOp::verify() {
//===----------------------------------------------------------------------===//
// XeGPU_StoreNdOp
//===----------------------------------------------------------------------===//
+
+void StoreNdOp::build(OpBuilder &builder, OperationState &state, Value value, Value tensorDesc, xegpu::CachePolicyAttr l1_hint, xegpu::CachePolicyAttr l2_hint, xegpu::CachePolicyAttr l3_hint) {
+
+ return build(builder, state, value, tensorDesc, ValueRange(), DenseI64ArrayAttr(), l1_hint, l2_hint, l3_hint);
+
+}
+
LogicalResult StoreNdOp::verify() {
auto dstTy = getTensorDescType(); // Tile
auto valTy = getValueType(); // Vector
diff --git a/mlir/test/Dialect/XeGPU/ops.mlir b/mlir/test/Dialect/XeGPU/ops.mlir
index 695437354cd7c..a1028a8e8a2f3 100644
--- a/mlir/test/Dialect/XeGPU/ops.mlir
+++ b/mlir/test/Dialect/XeGPU/ops.mlir
@@ -121,6 +121,15 @@ gpu.func @prefetch_nd_2(%src: memref<8x24x32x48x64xf16>) {
gpu.return
}
+// CHECK: gpu.func @prefetch_nd_offset_1(%[[arg0:.*]]: memref<8x24x32x48x64xf16>) {
+gpu.func @prefetch_nd_offset_1(%src: memref<8x24x32x48x64xf16>) {
+ // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0, 0, 0, 0] : memref<8x24x32x48x64xf16> -> !xegpu.tensor_desc<1x2x4x8x16xf16>
+ %1 = xegpu.create_nd_tdesc %src[0, 0, 0, 0, 0] : memref<8x24x32x48x64xf16> -> !xegpu.tensor_desc<1x2x4x8x16xf16>
+ // CHECK: xegpu.prefetch_nd %[[R0]][0, 0] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<1x2x4x8x16xf16>
+ xegpu.prefetch_nd %1[0, 0] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>: !xegpu.tensor_desc<1x2x4x8x16xf16>
+ gpu.return
+}
+
// CHECK: func @subgroup_load_nd(%[[arg0:.*]]: memref<8x16xf16>) {
gpu.func @subgroup_load_nd(%src: memref<8x16xf16>) {
// CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
@@ -260,6 +269,15 @@ gpu.func @subgroup_load_nd_8(%src: memref<24x32xf32>) {
gpu.return
}
+// CHECK: func @subgroup_load_nd_offset_1(%[[arg0:.*]]: memref<24x32xf32>) {
+gpu.func @subgroup_load_nd_offset_1(%src: memref<24x32xf32>) {
+ // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<16x8xf32>
+ %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<16x8xf32>
+ // CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]][0, 0] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose = array<i64: 1, 0>}> : !xegpu.tensor_desc<16x8xf32> -> vector<8x16xf32>
+ %2 = xegpu.load_nd %1[0, 0] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose = array<i64: 1, 0>}> : !xegpu.tensor_desc<16x8xf32> -> vector<8x16xf32>
+ gpu.return
+}
+
// CHECK: func @simt_load_nd_8(%[[arg0:.*]]: memref<24x32xf32>) {
gpu.func @simt_load_nd_8(%src: memref<24x32xf32>) {
// CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<16x8xf32>
@@ -269,6 +287,16 @@ gpu.func @simt_load_nd_8(%src: memref<24x32xf32>) {
gpu.return
}
+
+// CHECK: func @simt_load_nd_offset_1(%[[arg0:.*]]: memref<24x32xf32>) {
+gpu.func @simt_load_nd_offset_1(%src: memref<24x32xf32>) {
+ // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<16x8xf32>
+ %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<16x8xf32>
+ // CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]][0, 0] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose = array<i64: 1, 0>}> : !xegpu.tensor_desc<16x8xf32> -> vector<8xf32>
+ %2 = xegpu.load_nd %1[0, 0] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose = array<i64: 1, 0>}> : !xegpu.tensor_desc<16x8xf32> -> vector<8xf32>
+ gpu.return
+}
+
// CHECK: func @subgroup_store_nd(%[[arg0:.*]]: memref<24x32xf16>) {
gpu.func @subgroup_store_nd(%dst: memref<24x32xf16>) {
// CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<24x32xf16>
@@ -293,6 +321,17 @@ gpu.func @simt_store_nd(%src: memref<24x32xf16>) {
// CHECK: func @subgroup_store_nd_2(%[[arg0:.*]]: memref<24x32xf16>) {
gpu.func @subgroup_store_nd_2(%dst: memref<24x32xf16>) {
+ // CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<32xf16>
+ %1 = arith.constant dense<1.0>: vector<32xf16>
+ // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<32xf16>
+ %2 = xegpu.create_nd_tdesc %dst[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<32xf16>
+ // CHECK: xegpu.store_nd %[[C]], %[[R0]][0] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}> : vector<32xf16>, !xegpu.tensor_desc<32xf16>
+ xegpu.store_nd %1, %2[0] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>: vector<32xf16>, !xegpu.tensor_desc<32xf16>
+ gpu.return
+}
+
+// CHECK: func @subgroup_store_nd_offset_1(%[[arg0:.*]]: memref<24x32xf16>) {
+gpu.func @subgroup_store_nd_offset_1(%dst: memref<24x32xf16>) {
// CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<32xf16>
%1 = arith.constant dense<1.0>: vector<32xf16>
// CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<32xf16>
@@ -313,6 +352,17 @@ gpu.func @simt_store_nd_2(%src: memref<24x32xf16>) {
gpu.return
}
+// CHECK: func @simt_store_nd_offset_1(%[[arg0:.*]]: memref<24x32xf16>) {
+gpu.func @simt_store_nd_offset_1(%src: memref<24x32xf16>) {
+ // CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<2xf16>
+ %1 = arith.constant dense<1.0>: vector<2xf16>
+ // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<32xf16>
+ %2 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<32xf16>
+ // CHECK: xegpu.store_nd %[[C]], %[[R0]][0] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}> : vector<2xf16>, !xegpu.tensor_desc<32xf16>
+ xegpu.store_nd %1, %2[0] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>: vector<2xf16>, !xegpu.tensor_desc<32xf16>
+ gpu.return
+}
+
// CHECK: gpu.func @update_nd_tdesc(%[[arg0:.*]]: memref<24x32xf32>) {
gpu.func @update_nd_tdesc(%src: memref<24x32xf32>) {
// CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
>From 30ff640a8d2c59d31effbb9828f1775032564c57 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Thu, 17 Jul 2025 23:22:18 +0000
Subject: [PATCH 02/16] git-clang-format
---
mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 47 ++++++++++++++++----------
1 file changed, 30 insertions(+), 17 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index ca3c92cf4b52c..7cb105bf4292d 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -329,24 +329,28 @@ ParseResult parseOptionalDynamicIndexList(
return success();
}
-void printOptionalDynamicIndexList(
- OpAsmPrinter &printer, Operation *op, OperandRange values,
- DenseI64ArrayAttr integers) {
+void printOptionalDynamicIndexList(OpAsmPrinter &printer, Operation *op,
+ OperandRange values,
+ DenseI64ArrayAttr integers) {
- if (!integers)
- return;
+ if (!integers)
+ return;
- return printDynamicIndexList(printer, op, values, integers,
- /*scalableFlags=*/{}, {}, AsmParser::Delimiter::Square);
- }
+ return printDynamicIndexList(printer, op, values, integers,
+ /*scalableFlags=*/{}, {},
+ AsmParser::Delimiter::Square);
+}
//===----------------------------------------------------------------------===//
// XeGPU_PrefetchNdOp
//===----------------------------------------------------------------------===//
-void PrefetchNdOp::build(OpBuilder &builder, OperationState &state, Value tensorDesc, xegpu::CachePolicyAttr l1_hint, xegpu::CachePolicyAttr l2_hint, xegpu::CachePolicyAttr l3_hint) {
-
- return build(builder, state, tensorDesc, ValueRange(), DenseI64ArrayAttr(), l1_hint, l2_hint, l3_hint);
+void PrefetchNdOp::build(OpBuilder &builder, OperationState &state,
+ Value tensorDesc, xegpu::CachePolicyAttr l1_hint,
+ xegpu::CachePolicyAttr l2_hint,
+ xegpu::CachePolicyAttr l3_hint) {
+ return build(builder, state, tensorDesc, ValueRange(), DenseI64ArrayAttr(),
+ l1_hint, l2_hint, l3_hint);
}
LogicalResult PrefetchNdOp::verify() {
@@ -370,10 +374,16 @@ LogicalResult PrefetchNdOp::verify() {
// XeGPU_LoadNdOp
//===----------------------------------------------------------------------===//
-void LoadNdOp::build(OpBuilder &builder, OperationState &state, Type retType, Value tensorDesc, UnitAttr packed, DenseI64ArrayAttr transpose, xegpu::CachePolicyAttr l1_hint, xegpu::CachePolicyAttr l2_hint, xegpu::CachePolicyAttr l3_hint) {
-
- return build(builder, state, retType, tensorDesc, ValueRange(), DenseI64ArrayAttr(), packed, transpose, l1_hint, l2_hint, l3_hint);
+void LoadNdOp::build(OpBuilder &builder, OperationState &state, Type retType,
+ Value tensorDesc, UnitAttr packed,
+ DenseI64ArrayAttr transpose,
+ xegpu::CachePolicyAttr l1_hint,
+ xegpu::CachePolicyAttr l2_hint,
+ xegpu::CachePolicyAttr l3_hint) {
+ return build(builder, state, retType, tensorDesc, ValueRange(),
+ DenseI64ArrayAttr(), packed, transpose, l1_hint, l2_hint,
+ l3_hint);
}
LogicalResult LoadNdOp::verify() {
@@ -464,10 +474,13 @@ LogicalResult LoadNdOp::verify() {
// XeGPU_StoreNdOp
//===----------------------------------------------------------------------===//
-void StoreNdOp::build(OpBuilder &builder, OperationState &state, Value value, Value tensorDesc, xegpu::CachePolicyAttr l1_hint, xegpu::CachePolicyAttr l2_hint, xegpu::CachePolicyAttr l3_hint) {
-
- return build(builder, state, value, tensorDesc, ValueRange(), DenseI64ArrayAttr(), l1_hint, l2_hint, l3_hint);
+void StoreNdOp::build(OpBuilder &builder, OperationState &state, Value value,
+ Value tensorDesc, xegpu::CachePolicyAttr l1_hint,
+ xegpu::CachePolicyAttr l2_hint,
+ xegpu::CachePolicyAttr l3_hint) {
+ return build(builder, state, value, tensorDesc, ValueRange(),
+ DenseI64ArrayAttr(), l1_hint, l2_hint, l3_hint);
}
LogicalResult StoreNdOp::verify() {
>From efd1661b4ea93a08776213504219b96871449507 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Fri, 18 Jul 2025 01:48:18 +0000
Subject: [PATCH 03/16] add optional offsets to load_gather
---
mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td | 15 ++++++++++-----
mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td | 1 +
mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp | 2 +-
3 files changed, 12 insertions(+), 6 deletions(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index e9f8437d7c102..31c2fb357371a 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -655,7 +655,7 @@ def XeGPU_PrefetchOp : XeGPU_Op<"prefetch", []> {
}
def XeGPU_LoadGatherOp : XeGPU_Op<"load", [
- AllElementTypesMatch<["value", "TensorDesc"]>, MemoryEffects<[MemRead]>
+ AllElementTypesMatch<["value", "source"]>, MemoryEffects<[MemRead]>
]> {
let summary = "load a set of scattered data points from memory.";
@@ -698,7 +698,7 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [
}];
- let arguments = (ins XeGPU_TensorDesc: $TensorDesc,
+ let arguments = (ins XeGPU_TensorDesc_or_MemRef: $source,
XeGPU_MaskType: $mask,
OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint,
@@ -706,8 +706,13 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [
let results = (outs XeGPU_ValueType: $value);
let extraClassDeclaration = extraBaseClassDeclaration # [{
+
+ Type getSourceType() {
+ return getSource().getType();
+ }
+
xegpu::TensorDescType getTensorDescType() {
- return getTensorDesc().getType();
+ return dyn_cast<xegpu::TensorDescType>(getSourceType());
}
mlir::Type getElementType() {
@@ -725,8 +730,8 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [
}];
- let assemblyFormat = [{$TensorDesc `,` $mask prop-dict attr-dict
- `:` qualified(type($TensorDesc)) `,` type($mask) `->` type($value)}];
+ let assemblyFormat = [{$source `,` $mask prop-dict attr-dict
+ `:` qualified(type($source)) `,` type($mask) `->` type($value)}];
let hasVerifier = 1;
}
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
index 277158ac85409..ac41907655122 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
@@ -203,6 +203,7 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
let genVerifyDecl = 1;
}
+def XeGPU_TensorDesc_or_MemRef : AnyTypeOf<[XeGPU_TensorDesc,Non0RankedMemRefOf<[XeGPU_ScalarType]>]>;
def XeGPU_Nbarrier: XeGPUTypeDef<"Nbarrier", "nbarrier", [], "mlir::Type"> {
let summary = "!xegpu.nbarrier a custom XeGPU type representing a barrier.";
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index dc76441b27c02..44f2364d0caec 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -486,7 +486,7 @@ struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
SmallVector<Type> convertedTdescTypes =
getUnrolledTypes(tdescTy, *targetShape);
SmallVector<Value> convertedTdescs = pack(
- op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
+ op.getSource(), convertedTdescTypes, *targetShape, loc, rewriter);
SmallVector<Type> convertedMaskTypes;
SmallVector<Value> convertedMasks;
>From 3578c1b96fd22a1e013bc70f987ac4bfb6849c10 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Thu, 17 Jul 2025 23:21:14 +0000
Subject: [PATCH 04/16] add optional offsets to nd load/store/prefetch
---
mlir/test/Dialect/XeGPU/ops.mlir | 20 ++++++++++++++++++++
1 file changed, 20 insertions(+)
diff --git a/mlir/test/Dialect/XeGPU/ops.mlir b/mlir/test/Dialect/XeGPU/ops.mlir
index 3ebb1b969ac74..3523e3083c168 100644
--- a/mlir/test/Dialect/XeGPU/ops.mlir
+++ b/mlir/test/Dialect/XeGPU/ops.mlir
@@ -130,6 +130,15 @@ gpu.func @prefetch_nd_offset_1(%src: memref<48x64xf16>, %x : index, %y : index)
gpu.return
}
+// CHECK: gpu.func @prefetch_nd_offset_1(%[[arg0:.*]]: memref<8x24x32x48x64xf16>) {
+gpu.func @prefetch_nd_offset_1(%src: memref<8x24x32x48x64xf16>) {
+ // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0, 0, 0, 0] : memref<8x24x32x48x64xf16> -> !xegpu.tensor_desc<1x2x4x8x16xf16>
+ %1 = xegpu.create_nd_tdesc %src[0, 0, 0, 0, 0] : memref<8x24x32x48x64xf16> -> !xegpu.tensor_desc<1x2x4x8x16xf16>
+ // CHECK: xegpu.prefetch_nd %[[R0]][0, 0] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<1x2x4x8x16xf16>
+ xegpu.prefetch_nd %1[0, 0] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>: !xegpu.tensor_desc<1x2x4x8x16xf16>
+ gpu.return
+}
+
// CHECK: func @subgroup_load_nd(%[[arg0:.*]]: memref<8x16xf16>) {
gpu.func @subgroup_load_nd(%src: memref<8x16xf16>) {
// CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
@@ -330,6 +339,17 @@ gpu.func @subgroup_store_nd_2(%dst: memref<24x32xf16>, %x : index) {
gpu.return
}
+// CHECK: func @subgroup_store_nd_offset_1(%[[arg0:.*]]: memref<24x32xf16>) {
+gpu.func @subgroup_store_nd_offset_1(%dst: memref<24x32xf16>) {
+ // CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<32xf16>
+ %1 = arith.constant dense<1.0>: vector<32xf16>
+ // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<32xf16>
+ %2 = xegpu.create_nd_tdesc %dst[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<32xf16>
+ // CHECK: xegpu.store_nd %[[C]], %[[R0]][0] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}> : vector<32xf16>, !xegpu.tensor_desc<32xf16>
+ xegpu.store_nd %1, %2[0] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>: vector<32xf16>, !xegpu.tensor_desc<32xf16>
+ gpu.return
+}
+
// CHECK: func @subgroup_store_nd_offset_1(%[[arg0:.*]]: memref<24x32xf16>) {
gpu.func @subgroup_store_nd_offset_1(%dst: memref<24x32xf16>) {
// CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<32xf16>
>From 59f7ea9bf601205496f5867ddf1445e1f51641fd Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Fri, 18 Jul 2025 01:48:18 +0000
Subject: [PATCH 05/16] add optional offsets to load_gather
---
mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td | 15 ++++++++++-----
mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td | 1 +
mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp | 2 +-
3 files changed, 12 insertions(+), 6 deletions(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index 91d6b2a5ead9b..a5a7dab1bf55a 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -655,7 +655,7 @@ def XeGPU_PrefetchOp : XeGPU_Op<"prefetch", []> {
}
def XeGPU_LoadGatherOp : XeGPU_Op<"load", [
- AllElementTypesMatch<["value", "TensorDesc"]>, MemoryEffects<[MemRead]>
+ AllElementTypesMatch<["value", "source"]>, MemoryEffects<[MemRead]>
]> {
let summary = "load a set of scattered data points from memory.";
@@ -698,7 +698,7 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [
}];
- let arguments = (ins XeGPU_TensorDesc: $TensorDesc,
+ let arguments = (ins XeGPU_TensorDesc_or_MemRef: $source,
XeGPU_MaskType: $mask,
OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint,
@@ -706,8 +706,13 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [
let results = (outs XeGPU_ValueType: $value);
let extraClassDeclaration = extraBaseClassDeclaration # [{
+
+ Type getSourceType() {
+ return getSource().getType();
+ }
+
xegpu::TensorDescType getTensorDescType() {
- return getTensorDesc().getType();
+ return dyn_cast<xegpu::TensorDescType>(getSourceType());
}
mlir::Type getElementType() {
@@ -725,8 +730,8 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [
}];
- let assemblyFormat = [{$TensorDesc `,` $mask prop-dict attr-dict
- `:` qualified(type($TensorDesc)) `,` type($mask) `->` type($value)}];
+ let assemblyFormat = [{$source `,` $mask prop-dict attr-dict
+ `:` qualified(type($source)) `,` type($mask) `->` type($value)}];
let hasVerifier = 1;
}
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
index 20916ae9ef830..334f749ace745 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
@@ -189,6 +189,7 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
let genVerifyDecl = 1;
}
+def XeGPU_TensorDesc_or_MemRef : AnyTypeOf<[XeGPU_TensorDesc,Non0RankedMemRefOf<[XeGPU_ScalarType]>]>;
def XeGPU_Nbarrier: XeGPUTypeDef<"Nbarrier", "nbarrier", [], "mlir::Type"> {
let summary = "!xegpu.nbarrier a custom XeGPU type representing a barrier.";
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index a6208b455aa35..c8f332184bd1b 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -502,7 +502,7 @@ struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
SmallVector<Type> convertedTdescTypes =
getUnrolledTypes(tdescTy, *targetShape);
SmallVector<Value> convertedTdescs = pack(
- op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
+ op.getSource(), convertedTdescTypes, *targetShape, loc, rewriter);
SmallVector<Type> convertedMaskTypes;
SmallVector<Value> convertedMasks;
>From abc84c759f0993d8aa699ba1b87fdad7c5760a69 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Mon, 21 Jul 2025 04:18:01 +0000
Subject: [PATCH 06/16] add offsets to load
---
mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td | 18 ++++++++++++++++--
.../mlir/Dialect/XeGPU/IR/XeGPUTypes.td | 2 +-
2 files changed, 17 insertions(+), 3 deletions(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index a5a7dab1bf55a..51356e963e778 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -699,6 +699,8 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [
}];
let arguments = (ins XeGPU_TensorDesc_or_MemRef: $source,
+ Variadic<Index>: $offsets,
+ OptionalAttr<DenseI64ArrayAttr>: $const_offsets,
XeGPU_MaskType: $mask,
OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint,
@@ -730,8 +732,20 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [
}];
- let assemblyFormat = [{$source `,` $mask prop-dict attr-dict
- `:` qualified(type($source)) `,` type($mask) `->` type($value)}];
+ let assemblyFormat = [{
+ $source `,`
+ custom<OptionalDynamicIndexList>($offsets, $const_offsets)
+ $mask prop-dict
+ attr-dict `:` qualified(type($source)) `,` type($mask) `->` type($value)
+ }];
+
+// let builders = [
+// OpBuilder<(ins "Type": $value, "Value": $TensorDesc,
+// "UnitAttr": $packed, "DenseI64ArrayAttr": $transpose,
+// "xegpu::CachePolicyAttr": $l1_hint,
+// "xegpu::CachePolicyAttr": $l2_hint,
+// "xegpu::CachePolicyAttr": $l3_hint)>
+// ];
let hasVerifier = 1;
}
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
index 334f749ace745..8e575e31255a7 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
@@ -189,7 +189,7 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
let genVerifyDecl = 1;
}
-def XeGPU_TensorDesc_or_MemRef : AnyTypeOf<[XeGPU_TensorDesc,Non0RankedMemRefOf<[XeGPU_ScalarType]>]>;
+def XeGPU_TensorDesc_or_MemRef : AnyTypeOf<[XeGPU_TensorDesc,Non0RankedMemRefOf<[XeGPU_ScalarType]>, UI64]>;
def XeGPU_Nbarrier: XeGPUTypeDef<"Nbarrier", "nbarrier", [], "mlir::Type"> {
let summary = "!xegpu.nbarrier a custom XeGPU type representing a barrier.";
>From 80b4462f48c164309f70b1a6dbeb5805d869c998 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Thu, 24 Jul 2025 02:58:15 +0000
Subject: [PATCH 07/16] add chunk_size and use XeGPU_offsetType
---
.../include/mlir/Dialect/XeGPU/IR/XeGPUOps.td | 84 +++++++++++++++----
mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 27 +++++-
.../Dialect/XeGPU/Transforms/XeGPUUnroll.cpp | 2 +-
mlir/test/Dialect/XeGPU/ops.mlir | 33 +++-----
4 files changed, 105 insertions(+), 41 deletions(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index 51356e963e778..bf036d86d14bb 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -16,6 +16,7 @@ include "mlir/Dialect/XeGPU/IR/XeGPUTypes.td"
include "mlir/Interfaces/ShapedOpInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/ViewLikeInterface.td"
+include "mlir/Dialect/GPU/IR/CompilationAttrInterfaces.td"
// Base class for dialect operations. This operation inherits from the base
// `Op` class in OpBase.td, and provides:
@@ -638,18 +639,39 @@ def XeGPU_PrefetchOp : XeGPU_Op<"prefetch", []> {
}];
- let arguments = (ins XeGPU_TensorDesc: $TensorDesc,
+ let arguments = (ins XeGPU_TensorDesc_or_MemRef: $source,
+ Optional<XeGPU_OffsetType>: $offsets,
OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint,
OptionalAttr<XeGPU_CacheHintAttr>: $l3_hint);
let extraClassDeclaration = extraBaseClassDeclaration # [{
+ Type getSourceType() {
+ return getSource().getType();
+ }
+
+ Value getTensorDesc() {
+ return getSource();
+ }
+
xegpu::TensorDescType getTensorDescType() {
- return getTensorDesc().getType();
+ return dyn_cast<xegpu::TensorDescType>(getSourceType());
}
}];
- let assemblyFormat = "$TensorDesc prop-dict attr-dict `:` qualified(type($TensorDesc))";
+ let assemblyFormat = [{
+ $source
+ (`,` $offsets^)?
+ prop-dict
+ attr-dict `:` type($source) (`,` type($offsets)^)?
+ }];
+
+ let builders = [
+ OpBuilder<(ins "Value": $source,
+ "xegpu::CachePolicyAttr": $l1_hint,
+ "xegpu::CachePolicyAttr": $l2_hint,
+ "xegpu::CachePolicyAttr": $l3_hint)>
+ ];
let hasVerifier = 1;
}
@@ -702,6 +724,7 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [
Variadic<Index>: $offsets,
OptionalAttr<DenseI64ArrayAttr>: $const_offsets,
XeGPU_MaskType: $mask,
+ OptionalAttr<I64Attr>: $chunk_size,
OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint,
OptionalAttr<XeGPU_CacheHintAttr>: $l3_hint);
@@ -713,6 +736,10 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [
return getSource().getType();
}
+ Value getTensorDesc() {
+ return getSource();
+ }
+
xegpu::TensorDescType getTensorDescType() {
return dyn_cast<xegpu::TensorDescType>(getSourceType());
}
@@ -733,25 +760,24 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [
}];
let assemblyFormat = [{
- $source `,`
- custom<OptionalDynamicIndexList>($offsets, $const_offsets)
+ $source ``
+ custom<OptionalDynamicIndexList>($offsets, $const_offsets) `,`
$mask prop-dict
attr-dict `:` qualified(type($source)) `,` type($mask) `->` type($value)
}];
-// let builders = [
-// OpBuilder<(ins "Type": $value, "Value": $TensorDesc,
-// "UnitAttr": $packed, "DenseI64ArrayAttr": $transpose,
-// "xegpu::CachePolicyAttr": $l1_hint,
-// "xegpu::CachePolicyAttr": $l2_hint,
-// "xegpu::CachePolicyAttr": $l3_hint)>
-// ];
+ let builders = [
+ OpBuilder<(ins "Type": $value, "Value": $source, "Value": $mask,
+ "xegpu::CachePolicyAttr": $l1_hint,
+ "xegpu::CachePolicyAttr": $l2_hint,
+ "xegpu::CachePolicyAttr": $l3_hint)>
+ ];
let hasVerifier = 1;
}
def XeGPU_StoreScatterOp : XeGPU_Op<"store", [
- AllElementTypesMatch<["value", "TensorDesc"]>, MemoryEffects<[MemWrite]>
+ AllElementTypesMatch<["value", "dest"]>, MemoryEffects<[MemWrite]>
]> {
let summary = "store data to scattered memory locations.";
let description = [{ It (aka. store) stores data to scattered memory locations. The value is
@@ -791,15 +817,26 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [
let arguments = (ins
XeGPU_ValueType: $value,
- XeGPU_TensorDesc: $TensorDesc,
+ XeGPU_TensorDesc_or_MemRef: $dest,
+ Variadic<Index>: $offsets,
+ OptionalAttr<DenseI64ArrayAttr>: $const_offsets,
XeGPU_MaskType: $mask,
+ OptionalAttr<I64Attr>: $chunk_size,
OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint,
OptionalAttr<XeGPU_CacheHintAttr>: $l3_hint);
let extraClassDeclaration = extraBaseClassDeclaration # [{
+ Type getDestType() {
+ return getDest().getType();
+ }
+
+ Value getTensorDesc() {
+ return getDest();
+ }
+
xegpu::TensorDescType getTensorDescType() {
- return getTensorDesc().getType();
+ return dyn_cast<xegpu::TensorDescType>(getDestType());
}
VectorType getValueType() {
@@ -811,8 +848,21 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [
}
}];
- let assemblyFormat = [{$value `,` $TensorDesc `,` $mask prop-dict attr-dict
- `:` type($value) `,` qualified(type($TensorDesc)) `,` type($mask)}];
+ let assemblyFormat = [{
+ $value `,`
+ $dest ``
+ custom<OptionalDynamicIndexList>($offsets, $const_offsets) `,`
+ $mask
+ prop-dict
+ attr-dict `:` type($value) `,` qualified(type($dest)) `,` type($mask)
+ }];
+
+ let builders = [
+ OpBuilder<(ins "Value": $value, "Value": $dest, "Value": $mask,
+ "xegpu::CachePolicyAttr": $l1_hint,
+ "xegpu::CachePolicyAttr": $l2_hint,
+ "xegpu::CachePolicyAttr": $l3_hint)>
+ ];
let hasVerifier = 1;
}
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 704deeaa1f26b..4f3b3ed475afc 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -644,7 +644,7 @@ LogicalResult CreateDescOp::verify() {
//===----------------------------------------------------------------------===//
LogicalResult PrefetchOp::verify() {
auto tdescTy = getTensorDescType();
- if (!tdescTy.isScattered())
+ if (tdescTy && !tdescTy.isScattered())
return emitOpError("Expects a scattered TensorDesc.\n");
if (!isReadHintOrNone(getL1HintAttr()))
@@ -659,6 +659,13 @@ LogicalResult PrefetchOp::verify() {
return success();
}
+void PrefetchOp::build(OpBuilder &builder, OperationState &state, Value source,
+ xegpu::CachePolicyAttr l1_hint,
+ xegpu::CachePolicyAttr l2_hint,
+ xegpu::CachePolicyAttr l3_hint) {
+ build(builder, state, source, Value(), l1_hint, l2_hint, l3_hint);
+}
+
//===----------------------------------------------------------------------===//
// XeGPU_LoadGatherOp
//===----------------------------------------------------------------------===//
@@ -680,6 +687,15 @@ LogicalResult LoadGatherOp::verify() {
[&]() { return emitOpError(); });
}
+void LoadGatherOp::build(OpBuilder &builder, OperationState &state,
+ Type valueType, Value source, Value mask,
+ xegpu::CachePolicyAttr l1_hint,
+ xegpu::CachePolicyAttr l2_hint,
+ xegpu::CachePolicyAttr l3_hint) {
+ build(builder, state, valueType, source, ValueRange(), DenseI64ArrayAttr(),
+ mask, IntegerAttr(), l1_hint, l2_hint, l3_hint);
+}
+
//===----------------------------------------------------------------------===//
// XeGPU_StoreScatterOp
//===----------------------------------------------------------------------===//
@@ -701,6 +717,15 @@ LogicalResult StoreScatterOp::verify() {
[&]() { return emitOpError(); });
}
+void StoreScatterOp::build(OpBuilder &builder, OperationState &state,
+ Value value, Value dest, Value mask,
+ xegpu::CachePolicyAttr l1_hint,
+ xegpu::CachePolicyAttr l2_hint,
+ xegpu::CachePolicyAttr l3_hint) {
+ build(builder, state, value, dest, ValueRange(), DenseI64ArrayAttr(), mask,
+ IntegerAttr(), l1_hint, l2_hint, l3_hint);
+}
+
//===----------------------------------------------------------------------===//
// XeGPU_UpdateOffsetOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index c8f332184bd1b..a6208b455aa35 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -502,7 +502,7 @@ struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
SmallVector<Type> convertedTdescTypes =
getUnrolledTypes(tdescTy, *targetShape);
SmallVector<Value> convertedTdescs = pack(
- op.getSource(), convertedTdescTypes, *targetShape, loc, rewriter);
+ op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
SmallVector<Type> convertedMaskTypes;
SmallVector<Value> convertedMasks;
diff --git a/mlir/test/Dialect/XeGPU/ops.mlir b/mlir/test/Dialect/XeGPU/ops.mlir
index 3523e3083c168..98836adaa57a3 100644
--- a/mlir/test/Dialect/XeGPU/ops.mlir
+++ b/mlir/test/Dialect/XeGPU/ops.mlir
@@ -130,15 +130,6 @@ gpu.func @prefetch_nd_offset_1(%src: memref<48x64xf16>, %x : index, %y : index)
gpu.return
}
-// CHECK: gpu.func @prefetch_nd_offset_1(%[[arg0:.*]]: memref<8x24x32x48x64xf16>) {
-gpu.func @prefetch_nd_offset_1(%src: memref<8x24x32x48x64xf16>) {
- // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0, 0, 0, 0] : memref<8x24x32x48x64xf16> -> !xegpu.tensor_desc<1x2x4x8x16xf16>
- %1 = xegpu.create_nd_tdesc %src[0, 0, 0, 0, 0] : memref<8x24x32x48x64xf16> -> !xegpu.tensor_desc<1x2x4x8x16xf16>
- // CHECK: xegpu.prefetch_nd %[[R0]][0, 0] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<1x2x4x8x16xf16>
- xegpu.prefetch_nd %1[0, 0] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>: !xegpu.tensor_desc<1x2x4x8x16xf16>
- gpu.return
-}
-
// CHECK: func @subgroup_load_nd(%[[arg0:.*]]: memref<8x16xf16>) {
gpu.func @subgroup_load_nd(%src: memref<8x16xf16>) {
// CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
@@ -339,19 +330,8 @@ gpu.func @subgroup_store_nd_2(%dst: memref<24x32xf16>, %x : index) {
gpu.return
}
-// CHECK: func @subgroup_store_nd_offset_1(%[[arg0:.*]]: memref<24x32xf16>) {
-gpu.func @subgroup_store_nd_offset_1(%dst: memref<24x32xf16>) {
- // CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<32xf16>
- %1 = arith.constant dense<1.0>: vector<32xf16>
- // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<32xf16>
- %2 = xegpu.create_nd_tdesc %dst[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<32xf16>
- // CHECK: xegpu.store_nd %[[C]], %[[R0]][0] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}> : vector<32xf16>, !xegpu.tensor_desc<32xf16>
- xegpu.store_nd %1, %2[0] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>: vector<32xf16>, !xegpu.tensor_desc<32xf16>
- gpu.return
-}
-
-// CHECK: func @subgroup_store_nd_offset_1(%[[arg0:.*]]: memref<24x32xf16>) {
-gpu.func @subgroup_store_nd_offset_1(%dst: memref<24x32xf16>) {
+// CHECK: func @subgroup_store_nd_3(%[[arg0:.*]]: memref<24x32xf16>) {
+gpu.func @subgroup_store_nd_3(%dst: memref<24x32xf16>) {
// CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<32xf16>
%1 = arith.constant dense<1.0>: vector<32xf16>
// CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<32xf16>
@@ -658,6 +638,15 @@ gpu.func @prefetch(%src: ui64) {
}
+// CHECK: gpu.func @prefetch_offset(%[[arg0:.*]]: ui64) {
+gpu.func @prefetch_offset(%src: ui64) {
+ //CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+ %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+ // CHECK: xegpu.prefetch %[[arg0]], %cst <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : ui64, vector<4xindex>
+ xegpu.prefetch %src, %0 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>: ui64, vector<4xindex>
+ gpu.return
+}
+
// CHECK: gpu.func @create_update_tdesc(%[[arg0:.*]]: ui64) {
gpu.func @create_update_tdesc(%src: ui64) {
//CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
>From 769bf19a3775cc8ceacbe9077371c4c712c9f493 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Fri, 25 Jul 2025 02:33:06 +0000
Subject: [PATCH 08/16] add tests
---
.../include/mlir/Dialect/XeGPU/IR/XeGPUOps.td | 31 +++---
mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 97 +++++++++++++++++--
mlir/test/Dialect/XeGPU/invalid.mlir | 22 +++++
mlir/test/Dialect/XeGPU/ops.mlir | 40 ++++----
4 files changed, 147 insertions(+), 43 deletions(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index bf036d86d14bb..c6b192a9dda31 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -676,9 +676,7 @@ def XeGPU_PrefetchOp : XeGPU_Op<"prefetch", []> {
let hasVerifier = 1;
}
-def XeGPU_LoadGatherOp : XeGPU_Op<"load", [
- AllElementTypesMatch<["value", "source"]>, MemoryEffects<[MemRead]>
- ]> {
+def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> {
let summary = "load a set of scattered data points from memory.";
let description = [{ It (aka. load) load data per each work-item. The output
@@ -721,8 +719,7 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [
}];
let arguments = (ins XeGPU_TensorDesc_or_MemRef: $source,
- Variadic<Index>: $offsets,
- OptionalAttr<DenseI64ArrayAttr>: $const_offsets,
+ Optional<XeGPU_OffsetType>: $offsets,
XeGPU_MaskType: $mask,
OptionalAttr<I64Attr>: $chunk_size,
OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
@@ -760,11 +757,15 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [
}];
let assemblyFormat = [{
- $source ``
- custom<OptionalDynamicIndexList>($offsets, $const_offsets) `,`
+ $source
+ (`[` $offsets^ `]`)? `,`
$mask prop-dict
- attr-dict `:` qualified(type($source)) `,` type($mask) `->` type($value)
+ attr-dict `:` type(operands) `->` type($value)
}];
+
+ // functional-type(operands, results)
+ // type($source) (type($offsets)^ )? `,` type($mask) `->` type($value)
+
let builders = [
OpBuilder<(ins "Type": $value, "Value": $source, "Value": $mask,
@@ -776,9 +777,7 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [
let hasVerifier = 1;
}
-def XeGPU_StoreScatterOp : XeGPU_Op<"store", [
- AllElementTypesMatch<["value", "dest"]>, MemoryEffects<[MemWrite]>
- ]> {
+def XeGPU_StoreScatterOp : XeGPU_Op<"store", [MemoryEffects<[MemWrite]>]> {
let summary = "store data to scattered memory locations.";
let description = [{ It (aka. store) stores data to scattered memory locations. The value is
typically a 1D vector. But when the chunk size of the TensorDesc is larger than 1, it will be
@@ -818,8 +817,7 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [
let arguments = (ins
XeGPU_ValueType: $value,
XeGPU_TensorDesc_or_MemRef: $dest,
- Variadic<Index>: $offsets,
- OptionalAttr<DenseI64ArrayAttr>: $const_offsets,
+ Optional<XeGPU_OffsetType>: $offsets,
XeGPU_MaskType: $mask,
OptionalAttr<I64Attr>: $chunk_size,
OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
@@ -850,12 +848,13 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [
let assemblyFormat = [{
$value `,`
- $dest ``
- custom<OptionalDynamicIndexList>($offsets, $const_offsets) `,`
+ $dest
+ (`[` $offsets^ `]`)? `,`
$mask
prop-dict
- attr-dict `:` type($value) `,` qualified(type($dest)) `,` type($mask)
+ attr-dict `:` type(operands)
}];
+// type($value) `,` qualified(type($dest)) (type($offsets)^)? `,` type($mask)
let builders = [
OpBuilder<(ins "Value": $value, "Value": $dest, "Value": $mask,
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 4f3b3ed475afc..7a32f1a45c762 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -110,6 +110,66 @@ isValidGatherScatterParams(Type maskTy, VectorType valueTy,
return success();
}
+static LogicalResult
+isValidGatherScatterMemRefParams(Type maskTy, VectorType valueTy,
+ MemRefType memTy, int64_t chunkSize,
+ function_ref<InFlightDiagnostic()> emitError) {
+
+ if (!valueTy)
+ return emitError() << "Expecting a vector type result.";
+
+ auto maskShape = getShapeOf(maskTy);
+ auto valueShape = getShapeOf(valueTy);
+ auto memShape = getShapeOf(memTy);
+
+ if (valueTy.getElementType() != memTy.getElementType())
+ return emitError() << "Value should have the same element type as MemRef.";
+
+ // a valid shape for SIMT case
+ if (valueTy.getRank() == 1) {
+ if (valueTy.getNumElements() != chunkSize)
+ return emitError() << "value elements must match chunk size " << chunkSize
+ << " for SIMT code.";
+ return success();
+ }
+
+ llvm::SmallVector<int64_t> expectedMaskShape(valueShape);
+ if (chunkSize > 1)
+ expectedMaskShape.pop_back();
+ if (expectedMaskShape != maskShape)
+ return emitError() << "Mask should match value except the chunk size dim.";
+
+ return success();
+}
+
+static LogicalResult
+isValidGatherScatterRawptrParams(Type maskTy, VectorType valueTy,
+ int64_t chunkSize,
+ function_ref<InFlightDiagnostic()> emitError) {
+
+ if (!valueTy)
+ return emitError() << "Expecting a vector type result.";
+
+ auto maskShape = getShapeOf(maskTy);
+ auto valueShape = getShapeOf(valueTy);
+
+ // a valid shape for SIMT case
+ if (valueTy.getRank() == 1) {
+ if (valueTy.getNumElements() != chunkSize)
+ return emitError() << "value elements must match chunk size " << chunkSize
+ << " for SIMT code.";
+ return success();
+ }
+
+ llvm::SmallVector<int64_t> expectedMaskShape(valueShape);
+ if (chunkSize > 1)
+ expectedMaskShape.pop_back();
+ if (expectedMaskShape != maskShape)
+ return emitError() << "Mask should match value except the chunk size dim.";
+
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// XeGPU_CreateNdDescOp
//===----------------------------------------------------------------------===//
@@ -683,8 +743,18 @@ LogicalResult LoadGatherOp::verify() {
if (!isReadHintOrNone(getL3HintAttr()))
return emitOpError("invalid l3_hint: ") << getL3HintAttr();
- return isValidGatherScatterParams(maskTy, valueTy, tdescTy,
- [&]() { return emitOpError(); });
+ if (tdescTy)
+ return isValidGatherScatterParams(maskTy, valueTy, tdescTy,
+ [&]() { return emitOpError(); });
+ auto srcTy = getSourceType();
+ uint64_t chunkSize = static_cast<int64_t>(getChunkSize().value_or(1));
+ auto memTy = dyn_cast<MemRefType>(srcTy);
+
+ if (memTy)
+ return isValidGatherScatterMemRefParams(maskTy, valueTy, memTy, chunkSize,
+ [&]() { return emitOpError(); });
+ return isValidGatherScatterRawptrParams(maskTy, valueTy, chunkSize,
+ [&]() { return emitOpError(); });
}
void LoadGatherOp::build(OpBuilder &builder, OperationState &state,
@@ -692,8 +762,8 @@ void LoadGatherOp::build(OpBuilder &builder, OperationState &state,
xegpu::CachePolicyAttr l1_hint,
xegpu::CachePolicyAttr l2_hint,
xegpu::CachePolicyAttr l3_hint) {
- build(builder, state, valueType, source, ValueRange(), DenseI64ArrayAttr(),
- mask, IntegerAttr(), l1_hint, l2_hint, l3_hint);
+ build(builder, state, valueType, source, Value(), mask, IntegerAttr(),
+ l1_hint, l2_hint, l3_hint);
}
//===----------------------------------------------------------------------===//
@@ -713,8 +783,19 @@ LogicalResult StoreScatterOp::verify() {
if (!isWriteHintOrNone(getL3HintAttr()))
return emitOpError("invalid l3_hint: ") << getL3HintAttr();
- return isValidGatherScatterParams(maskTy, valueTy, tdescTy,
- [&]() { return emitOpError(); });
+ if (tdescTy)
+ return isValidGatherScatterParams(maskTy, valueTy, tdescTy,
+ [&]() { return emitOpError(); });
+
+ auto destTy = getDestType();
+ uint64_t chunkSize = static_cast<int64_t>(getChunkSize().value_or(1));
+ auto memTy = dyn_cast<MemRefType>(destTy);
+
+ if (memTy)
+ return isValidGatherScatterMemRefParams(maskTy, valueTy, memTy, chunkSize,
+ [&]() { return emitOpError(); });
+ return isValidGatherScatterRawptrParams(maskTy, valueTy, chunkSize,
+ [&]() { return emitOpError(); });
}
void StoreScatterOp::build(OpBuilder &builder, OperationState &state,
@@ -722,8 +803,8 @@ void StoreScatterOp::build(OpBuilder &builder, OperationState &state,
xegpu::CachePolicyAttr l1_hint,
xegpu::CachePolicyAttr l2_hint,
xegpu::CachePolicyAttr l3_hint) {
- build(builder, state, value, dest, ValueRange(), DenseI64ArrayAttr(), mask,
- IntegerAttr(), l1_hint, l2_hint, l3_hint);
+ build(builder, state, value, dest, Value(), mask, IntegerAttr(), l1_hint,
+ l2_hint, l3_hint);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index 0160bfee07bf2..af34add37f7ad 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -384,6 +384,28 @@ func.func @load_gather_vc_3(%src: ui64) {
return
}
+// -----
+func.func @load_offset(%src: ui64) {
+ %offsets = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+ %mask = arith.constant dense<1>: vector<8xi1>
+ // expected-error at +1 {{Mask should match value except the chunk size dim}}
+ %2 = xegpu.load %src[%offsets], %mask
+ : ui64, vector<4xindex>, vector<8xi1>
+ -> vector<4x2xf32>
+ return
+}
+
+// -----
+func.func @store_offset(%src: ui64) {
+ %val = arith.constant dense<2.9>: vector<4x2xf16>
+ %offsets = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+ %mask = arith.constant dense<1>: vector<8xi1>
+ // expected-error at +1 {{Mask should match value except the chunk size dim}}
+ xegpu.store %val, %src[%offsets], %mask
+ : vector<4x2xf16>, ui64, vector<4xindex>, vector<8xi1>
+ return
+}
+
// -----
func.func @store_scatter_vc_1(%src: memref<24x32xf32>) {
%0 = arith.constant dense<1>: vector<4xi1>
diff --git a/mlir/test/Dialect/XeGPU/ops.mlir b/mlir/test/Dialect/XeGPU/ops.mlir
index f8c558d614ee6..16f5356a69f24 100644
--- a/mlir/test/Dialect/XeGPU/ops.mlir
+++ b/mlir/test/Dialect/XeGPU/ops.mlir
@@ -130,15 +130,6 @@ gpu.func @prefetch_nd_offset_1(%src: memref<48x64xf16>, %x : index, %y : index)
gpu.return
}
-// CHECK: gpu.func @prefetch_nd_offset_1(%[[arg0:.*]]: memref<8x24x32x48x64xf16>) {
-gpu.func @prefetch_nd_offset_1(%src: memref<8x24x32x48x64xf16>) {
- // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0, 0, 0, 0] : memref<8x24x32x48x64xf16> -> !xegpu.tensor_desc<1x2x4x8x16xf16>
- %1 = xegpu.create_nd_tdesc %src[0, 0, 0, 0, 0] : memref<8x24x32x48x64xf16> -> !xegpu.tensor_desc<1x2x4x8x16xf16>
- // CHECK: xegpu.prefetch_nd %[[R0]][0, 0] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<1x2x4x8x16xf16>
- xegpu.prefetch_nd %1[0, 0] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>: !xegpu.tensor_desc<1x2x4x8x16xf16>
- gpu.return
-}
-
// CHECK: func @subgroup_load_nd(%[[arg0:.*]]: memref<8x16xf16>) {
gpu.func @subgroup_load_nd(%src: memref<8x16xf16>) {
// CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
@@ -339,16 +330,6 @@ gpu.func @subgroup_store_nd_2(%dst: memref<24x32xf16>, %x : index) {
gpu.return
}
-// CHECK: func @subgroup_store_nd_3(%[[arg0:.*]]: memref<24x32xf16>) {
-gpu.func @subgroup_store_nd_3(%dst: memref<24x32xf16>) {
- // CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<32xf16>
- %1 = arith.constant dense<1.0>: vector<32xf16>
- // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<32xf16>
- %2 = xegpu.create_nd_tdesc %dst[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<32xf16>
- // CHECK: xegpu.store_nd %[[C]], %[[R0]][0] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}> : vector<32xf16>, !xegpu.tensor_desc<32xf16>
- xegpu.store_nd %1, %2[0] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>: vector<32xf16>, !xegpu.tensor_desc<32xf16>
- gpu.return
-}
// CHECK: func @subgroup_store_nd_offset_1(%[[arg0:.*]]: memref<24x32xf16>) {
gpu.func @subgroup_store_nd_offset_1(%dst: memref<24x32xf16>) {
@@ -541,6 +522,16 @@ gpu.func @subgroup_load_4(%src: ui64) {
gpu.return
}
+// CHECK: gpu.func @subgroup_load_offset_1(%arg0: memref<?xf16>) {
+gpu.func @subgroup_load_offset_1(%src: memref<?xf16>) {
+ %offset = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+ %mask = arith.constant dense<1>: vector<4xi1>
+ //CHECK: %[[R1:.*]] = xegpu.load %arg0[%cst], %cst_0 <{chunk_size = 2 : i64, l1_hint = #xegpu.cache_hint<cached>}> : memref<?xf16>, vector<4xindex>, vector<4xi1> -> vector<4x2xf16>
+ %val = xegpu.load %src[%offset], %mask <{chunk_size=2, l1_hint = #xegpu.cache_hint<cached>}>
+ : memref<?xf16>, vector<4xindex>, vector<4xi1> -> vector<4x2xf16>
+ gpu.return
+}
+
// CHECK: gpu.func @subgroup_store(%[[arg0:.*]]: ui64) {
gpu.func @subgroup_store(%src: ui64) {
//CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
@@ -646,6 +637,17 @@ gpu.func @subgroup_store_4(%src: ui64) {
gpu.return
}
+// CHECK: gpu.func @subgroup_store_offset_1(%arg0: memref<?xf16>) {
+gpu.func @subgroup_store_offset_1(%dest: memref<?xf16>) {
+ %val = arith.constant dense<2.9>: vector<4x2xf16>
+ %offset = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+ %mask = arith.constant dense<1>: vector<4xi1>
+ //CHECK: xegpu.store %[[R0:.*]], %arg0[%cst_0], %cst_1 <{chunk_size = 2 : i64, l1_hint = #xegpu.cache_hint<cached>}> : vector<4x2xf16>, memref<?xf16>, vector<4xindex>, vector<4xi1>
+ xegpu.store %val, %dest[%offset], %mask <{chunk_size=2, l1_hint = #xegpu.cache_hint<cached>}>
+ : vector<4x2xf16>, memref<?xf16>, vector<4xindex>, vector<4xi1>
+ gpu.return
+}
+
// CHECK: gpu.func @prefetch(%[[arg0:.*]]: ui64) {
gpu.func @prefetch(%src: ui64) {
//CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
>From 45537469eddccf41eabf689b19f550d8616442de Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Fri, 25 Jul 2025 05:15:34 +0000
Subject: [PATCH 09/16] add invalid tests
---
.../include/mlir/Dialect/XeGPU/IR/XeGPUOps.td | 7 ++--
.../Dialect/XeGPU/Transforms/XeGPUUnroll.cpp | 6 +--
mlir/test/Dialect/XeGPU/invalid.mlir | 39 ++++++++++++++-----
mlir/test/Dialect/XeGPU/ops.mlir | 4 +-
4 files changed, 39 insertions(+), 17 deletions(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index c6b192a9dda31..312db1402f58f 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -661,11 +661,12 @@ def XeGPU_PrefetchOp : XeGPU_Op<"prefetch", []> {
let assemblyFormat = [{
$source
- (`,` $offsets^)?
+ (`[` $offsets^ `]`)?
prop-dict
- attr-dict `:` type($source) (`,` type($offsets)^)?
+ attr-dict `:` type(operands)
}];
-
+ // type($source) (type($offsets)^)?
+
let builders = [
OpBuilder<(ins "Value": $source,
"xegpu::CachePolicyAttr": $l1_hint,
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index c8f332184bd1b..cafbf8d5ffc5e 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -484,7 +484,7 @@ struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
xegpu::TensorDescType tdescTy = op.getTensorDescType();
- if (!tdescTy.isScattered())
+ if (!tdescTy || !tdescTy.isScattered())
return failure();
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
@@ -546,7 +546,7 @@ struct UnrollPrefetchOp : public UnrollPattern<xegpu::PrefetchOp> {
Location loc = op.getLoc();
xegpu::TensorDescType tdescTy = op.getTensorDescType();
- if (!tdescTy.isScattered())
+ if (!tdescTy || !tdescTy.isScattered())
return failure();
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
@@ -575,7 +575,7 @@ struct UnrollStoreScatterOp : public UnrollPattern<xegpu::StoreScatterOp> {
VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
xegpu::TensorDescType tdescTy = op.getTensorDescType();
- if (!tdescTy.isScattered())
+ if (!tdescTy || !tdescTy.isScattered())
return failure();
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index af34add37f7ad..b56de88391803 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -385,27 +385,48 @@ func.func @load_gather_vc_3(%src: ui64) {
}
// -----
-func.func @load_offset(%src: ui64) {
+func.func @load_gather_offset_sg(%src: memref<?xf16>) {
%offsets = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
%mask = arith.constant dense<1>: vector<8xi1>
// expected-error at +1 {{Mask should match value except the chunk size dim}}
%2 = xegpu.load %src[%offsets], %mask
- : ui64, vector<4xindex>, vector<8xi1>
- -> vector<4x2xf32>
+ : memref<?xf16>, vector<4xindex>, vector<8xi1>
+ -> vector<4x2xf16>
+ return
+}
+
+// -----
+func.func @load_gather_offset_wi(%src: ui64) {
+ %mask = arith.constant dense<1>: vector<1xi1>
+ %offsets = arith.constant dense<[0]> : vector<1xindex>
+ // expected-error at +1 {{value elements must match chunk size}}
+ %2 = xegpu.load %src[%offsets], %mask <{chunk_size = 2}> : ui64, vector<1xindex>, vector<1xi1> -> vector<4xf32>
+ return
+}
+
+// -----
+func.func @store_scatter_offset(%src: memref<?xf16>) {
+ %val = arith.constant dense<2.9>: vector<4xf16>
+ %offsets = arith.constant dense<[0]> : vector<1xindex>
+ %mask = arith.constant dense<1>: vector<1xi1>
+ // expected-error at +1 {{value elements must match chunk size}}
+ xegpu.store %val, %src[%offsets], %mask
+ : vector<4xf16>, memref<?xf16>, vector<1xindex>, vector<1xi1>
return
}
+
// -----
-func.func @store_offset(%src: ui64) {
+func.func @load_gather_offset_wi(%src: ui64) {
%val = arith.constant dense<2.9>: vector<4x2xf16>
- %offsets = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
- %mask = arith.constant dense<1>: vector<8xi1>
- // expected-error at +1 {{Mask should match value except the chunk size dim}}
- xegpu.store %val, %src[%offsets], %mask
- : vector<4x2xf16>, ui64, vector<4xindex>, vector<8xi1>
+ %mask = arith.constant dense<1>: vector<1xi1>
+ %offsets = arith.constant dense<[0]> : vector<1xindex>
+ // expected-error at +1 {{value elements must match chunk size}}
+ %2 = xegpu.load %src[%offsets], %mask <{chunk_size = 2}> : ui64, vector<1xindex>, vector<1xi1> -> vector<4xf32>
return
}
+
// -----
func.func @store_scatter_vc_1(%src: memref<24x32xf32>) {
%0 = arith.constant dense<1>: vector<4xi1>
diff --git a/mlir/test/Dialect/XeGPU/ops.mlir b/mlir/test/Dialect/XeGPU/ops.mlir
index 16f5356a69f24..ea80601ef5574 100644
--- a/mlir/test/Dialect/XeGPU/ops.mlir
+++ b/mlir/test/Dialect/XeGPU/ops.mlir
@@ -664,8 +664,8 @@ gpu.func @prefetch(%src: ui64) {
gpu.func @prefetch_offset(%src: ui64) {
//CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
%0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
- // CHECK: xegpu.prefetch %[[arg0]], %cst <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : ui64, vector<4xindex>
- xegpu.prefetch %src, %0 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>: ui64, vector<4xindex>
+ // CHECK: xegpu.prefetch %[[arg0]][%cst] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : ui64, vector<4xindex>
+ xegpu.prefetch %src[%0] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>: ui64, vector<4xindex>
gpu.return
}
>From 1249794952a1e67a4f0ff54b7e4fa39d0704b42e Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Fri, 25 Jul 2025 05:28:50 +0000
Subject: [PATCH 10/16] small fixes
---
mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td | 6 ------
mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp | 2 +-
mlir/test/Dialect/XeGPU/invalid.mlir | 2 +-
3 files changed, 2 insertions(+), 8 deletions(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index 312db1402f58f..82edd69f63694 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -665,7 +665,6 @@ def XeGPU_PrefetchOp : XeGPU_Op<"prefetch", []> {
prop-dict
attr-dict `:` type(operands)
}];
- // type($source) (type($offsets)^)?
let builders = [
OpBuilder<(ins "Value": $source,
@@ -763,10 +762,6 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> {
$mask prop-dict
attr-dict `:` type(operands) `->` type($value)
}];
-
- // functional-type(operands, results)
- // type($source) (type($offsets)^ )? `,` type($mask) `->` type($value)
-
let builders = [
OpBuilder<(ins "Type": $value, "Value": $source, "Value": $mask,
@@ -855,7 +850,6 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [MemoryEffects<[MemWrite]>]> {
prop-dict
attr-dict `:` type(operands)
}];
-// type($value) `,` qualified(type($dest)) (type($offsets)^)? `,` type($mask)
let builders = [
OpBuilder<(ins "Value": $value, "Value": $dest, "Value": $mask,
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index cafbf8d5ffc5e..29ee864b8f34f 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -502,7 +502,7 @@ struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
SmallVector<Type> convertedTdescTypes =
getUnrolledTypes(tdescTy, *targetShape);
SmallVector<Value> convertedTdescs = pack(
- op.getSource(), convertedTdescTypes, *targetShape, loc, rewriter);
+ op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
SmallVector<Type> convertedMaskTypes;
SmallVector<Value> convertedMasks;
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index b56de88391803..b8e6a31d8d2f7 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -405,7 +405,7 @@ func.func @load_gather_offset_wi(%src: ui64) {
}
// -----
-func.func @store_scatter_offset(%src: memref<?xf16>) {
+func.func @store_scatter_offset_sg(%src: memref<?xf16>) {
%val = arith.constant dense<2.9>: vector<4xf16>
%offsets = arith.constant dense<[0]> : vector<1xindex>
%mask = arith.constant dense<1>: vector<1xi1>
>From 5cfb24b395ef3e83f4e316c9e35b9d2f8cb72dd1 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Fri, 25 Jul 2025 21:39:07 +0000
Subject: [PATCH 11/16] address comments
---
mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td | 7 +++----
mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td | 2 +-
mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 1 -
3 files changed, 4 insertions(+), 6 deletions(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index 82edd69f63694..9da015b65a6af 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -16,7 +16,6 @@ include "mlir/Dialect/XeGPU/IR/XeGPUTypes.td"
include "mlir/Interfaces/ShapedOpInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/ViewLikeInterface.td"
-include "mlir/Dialect/GPU/IR/CompilationAttrInterfaces.td"
// Base class for dialect operations. This operation inherits from the base
// `Op` class in OpBase.td, and provides:
@@ -639,7 +638,7 @@ def XeGPU_PrefetchOp : XeGPU_Op<"prefetch", []> {
}];
- let arguments = (ins XeGPU_TensorDesc_or_MemRef: $source,
+ let arguments = (ins XeGPU_TensorDescOrMemRef: $source,
Optional<XeGPU_OffsetType>: $offsets,
OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint,
@@ -718,7 +717,7 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> {
}];
- let arguments = (ins XeGPU_TensorDesc_or_MemRef: $source,
+ let arguments = (ins XeGPU_TensorDescOrMemRef: $source,
Optional<XeGPU_OffsetType>: $offsets,
XeGPU_MaskType: $mask,
OptionalAttr<I64Attr>: $chunk_size,
@@ -812,7 +811,7 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [MemoryEffects<[MemWrite]>]> {
let arguments = (ins
XeGPU_ValueType: $value,
- XeGPU_TensorDesc_or_MemRef: $dest,
+ XeGPU_TensorDescOrMemRef: $dest,
Optional<XeGPU_OffsetType>: $offsets,
XeGPU_MaskType: $mask,
OptionalAttr<I64Attr>: $chunk_size,
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
index 8e575e31255a7..fa59bf2e40c4d 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
@@ -189,7 +189,7 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
let genVerifyDecl = 1;
}
-def XeGPU_TensorDesc_or_MemRef : AnyTypeOf<[XeGPU_TensorDesc,Non0RankedMemRefOf<[XeGPU_ScalarType]>, UI64]>;
+def XeGPU_TensorDescOrMemRef : AnyTypeOf<[XeGPU_TensorDesc,Non0RankedMemRefOf<[XeGPU_ScalarType]>, UI64]>;
def XeGPU_Nbarrier: XeGPUTypeDef<"Nbarrier", "nbarrier", [], "mlir::Type"> {
let summary = "!xegpu.nbarrier a custom XeGPU type representing a barrier.";
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 7a32f1a45c762..3a41b298e2aae 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -120,7 +120,6 @@ isValidGatherScatterMemRefParams(Type maskTy, VectorType valueTy,
auto maskShape = getShapeOf(maskTy);
auto valueShape = getShapeOf(valueTy);
- auto memShape = getShapeOf(memTy);
if (valueTy.getElementType() != memTy.getElementType())
return emitError() << "Value should have the same element type as MemRef.";
>From 5940d191e865d06063a8ed97b804b90858bc78ba Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Tue, 29 Jul 2025 17:33:18 +0000
Subject: [PATCH 12/16] address feedback
---
.../include/mlir/Dialect/XeGPU/IR/XeGPUOps.td | 46 ++++++++++-
.../mlir/Dialect/XeGPU/IR/XeGPUTypes.td | 2 +-
mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 76 ++++++++-----------
.../Dialect/XeGPU/Transforms/XeGPUUnroll.cpp | 3 +
mlir/test/Dialect/XeGPU/invalid.mlir | 2 +-
mlir/test/Dialect/XeGPU/ops.mlir | 2 -
6 files changed, 80 insertions(+), 51 deletions(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index 9da015b65a6af..c864ce0c3d9cd 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -628,17 +628,28 @@ def XeGPU_PrefetchOp : XeGPU_Op<"prefetch", []> {
As compared to prefetch_nd, which works on non-scattered TensorDesc,
it works on scattered TensorDesc instead.
- Example:
+ Example 1:
```mlir
xegpu.prefetch %tdesc {l1_hint = #xegpu.cache_hint<cached>,
l2_hint = #xegpu.cache_hint<cached>,
l3_hint = #xegpu.cache_hint<cached>}
: !xegpu.tensor_desc<16xf16>
```
+
+ Example 2:
+ A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc. It combines "create scattered TensorTdesc" and "prefetch with scattered TensorTdesc". The source operand could be a raw pointer (uint64_t). Please refer to create_tdesc for the restriction of memref.
+ ```mlir
+ %a = memref.alloc() : memref<1024xf32>
+ %0 = arith.constant dense<[0, 16, 32, 64]> : vector<4xindex>
+ xegpu.prefetch %a[%0] {l1_hint = #xegpu.cache_hint<cached>,
+ l2_hint = #xegpu.cache_hint<cached>,
+ l3_hint = #xegpu.cache_hint<cached>}
+ : memref<1024xf32>, vector<4xindex>
+ ```
}];
- let arguments = (ins XeGPU_TensorDescOrMemRef: $source,
+ let arguments = (ins XeGPU_GatherScatterSourceType: $source,
Optional<XeGPU_OffsetType>: $offsets,
OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint,
@@ -706,6 +717,7 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> {
: !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<memory_space=global, chunk_size=8>>,
vector<16xi1> -> vector<16x8xf32>
```
+
Example 3 (SIMT mode):
```mlir
%2 = xegpu.load %1, %0 <{l1_hint = #xegpu.cache_hint<cached>,
@@ -714,10 +726,22 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> {
: !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<memory_space=global, chunk_size=8>>
vector<16xi1> -> vector<8xf32>
```
+
+ Example 4:
+ A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc. It combines "create scattered TensorTdesc" and "load with scattered TensorTdesc". The source operand could be a raw pointer (uint64_t). Please refer to create_tdesc for the restriction of memref.
+ ```mlir
+ %a = memref.alloc() : memref<1024xf32>
+ %offsets = vector.step : vector<16xindex>
+ %mask = vector.constant_mask [16]: vector<16xi1>
+ %val = xegpu.load %a[%offsets], %mask {l1_hint = #xegpu.cache_hint<cached>,
+ l2_hint = #xegpu.cache_hint<cached>,
+ l3_hint = #xegpu.cache_hint<cached>}
+ : memref<1024xf32>, vector<16xi1>, vector<16xindex> -> vector<16xf32>
+ ```
}];
- let arguments = (ins XeGPU_TensorDescOrMemRef: $source,
+ let arguments = (ins XeGPU_GatherScatterSourceType: $source,
Optional<XeGPU_OffsetType>: $offsets,
XeGPU_MaskType: $mask,
OptionalAttr<I64Attr>: $chunk_size,
@@ -807,11 +831,25 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [MemoryEffects<[MemWrite]>]> {
l3_hint = #xegpu.cache_hint<write_through>}>
: vector<8xf32>, !xegpu.tensor_desc<16x8xf32, #xegpu.scattered_tdesc_attr<chunk_size=8>> vector<16xi1>
```
+
+ Example 4:
+ A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc. It combines "create scattered TensorTdesc" and "store with scattered TensorTdesc". The dest operand could be a raw pointer (uint64_t). Please refer to create_tdesc for the restriction of memref.
+ ```mlir
+ %a = memref.alloc() : memref<1024xf32>
+ %val = arith.constant dense<0.0> : vector<16xf32>
+ %offsets = vector.step : vector<16xindex>
+ %mask = vector.constant_mask [16]: vector<16xi1>
+ xegpu.store %val, %a[%offsets], %mask {l1_hint = #xegpu.cache_hint<cached>,
+ l2_hint = #xegpu.cache_hint<cached>,
+ l3_hint = #xegpu.cache_hint<cached>}
+ : memref<1024xf32>, vector<16xi1>, vector<16xindex> -> vector<16xf32>
+ ```
+
}];
let arguments = (ins
XeGPU_ValueType: $value,
- XeGPU_TensorDescOrMemRef: $dest,
+ XeGPU_GatherScatterSourceType: $dest,
Optional<XeGPU_OffsetType>: $offsets,
XeGPU_MaskType: $mask,
OptionalAttr<I64Attr>: $chunk_size,
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
index fa59bf2e40c4d..b268cabb5d266 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
@@ -189,7 +189,7 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
let genVerifyDecl = 1;
}
-def XeGPU_TensorDescOrMemRef : AnyTypeOf<[XeGPU_TensorDesc,Non0RankedMemRefOf<[XeGPU_ScalarType]>, UI64]>;
+def XeGPU_GatherScatterSourceType : AnyTypeOf<[XeGPU_TensorDesc,Non0RankedMemRefOf<[XeGPU_ScalarType]>, UI64]>;
def XeGPU_Nbarrier: XeGPUTypeDef<"Nbarrier", "nbarrier", [], "mlir::Type"> {
let summary = "!xegpu.nbarrier a custom XeGPU type representing a barrier.";
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 3a41b298e2aae..7c8ee7408dfd1 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -111,39 +111,7 @@ isValidGatherScatterParams(Type maskTy, VectorType valueTy,
}
static LogicalResult
-isValidGatherScatterMemRefParams(Type maskTy, VectorType valueTy,
- MemRefType memTy, int64_t chunkSize,
- function_ref<InFlightDiagnostic()> emitError) {
-
- if (!valueTy)
- return emitError() << "Expecting a vector type result.";
-
- auto maskShape = getShapeOf(maskTy);
- auto valueShape = getShapeOf(valueTy);
-
- if (valueTy.getElementType() != memTy.getElementType())
- return emitError() << "Value should have the same element type as MemRef.";
-
- // a valid shape for SIMT case
- if (valueTy.getRank() == 1) {
- if (valueTy.getNumElements() != chunkSize)
- return emitError() << "value elements must match chunk size " << chunkSize
- << " for SIMT code.";
- return success();
- }
-
- llvm::SmallVector<int64_t> expectedMaskShape(valueShape);
- if (chunkSize > 1)
- expectedMaskShape.pop_back();
- if (expectedMaskShape != maskShape)
- return emitError() << "Mask should match value except the chunk size dim.";
-
- return success();
-}
-
-static LogicalResult
-isValidGatherScatterRawptrParams(Type maskTy, VectorType valueTy,
- int64_t chunkSize,
+isValidGatherScatterBufferParams(Type maskTy, VectorType valueTy, int64_t chunkSize,
function_ref<InFlightDiagnostic()> emitError) {
if (!valueTy)
@@ -703,8 +671,14 @@ LogicalResult CreateDescOp::verify() {
//===----------------------------------------------------------------------===//
LogicalResult PrefetchOp::verify() {
auto tdescTy = getTensorDescType();
- if (tdescTy && !tdescTy.isScattered())
- return emitOpError("Expects a scattered TensorDesc.\n");
+
+ if (tdescTy) {
+ if (!tdescTy.isScattered())
+ return emitOpError("Expects a scattered TensorDesc.\n");
+ } else {
+ if (getRankOf(getSource()) > 1)
+ return emitOpError("Expecting the source is a 1D memref or pointer (uint64_t).");
+ }
if (!isReadHintOrNone(getL1HintAttr()))
return emitOpError("invalid l1_hint: ") << getL1HintAttr();
@@ -733,6 +707,14 @@ LogicalResult LoadGatherOp::verify() {
auto maskTy = getMaskType();
auto valueTy = getValueType();
+ if (tdescTy) {
+ if (!tdescTy.isScattered())
+ return emitOpError("Expects a scattered TensorDesc.\n");
+ } else {
+ if (getRankOf(getSource()) > 1)
+ return emitOpError("Expecting the source is a 1D memref or pointer (uint64_t).");
+ }
+
if (!isReadHintOrNone(getL1HintAttr()))
return emitOpError("invalid l1_hint: ") << getL1HintAttr();
@@ -749,10 +731,10 @@ LogicalResult LoadGatherOp::verify() {
uint64_t chunkSize = static_cast<int64_t>(getChunkSize().value_or(1));
auto memTy = dyn_cast<MemRefType>(srcTy);
- if (memTy)
- return isValidGatherScatterMemRefParams(maskTy, valueTy, memTy, chunkSize,
- [&]() { return emitOpError(); });
- return isValidGatherScatterRawptrParams(maskTy, valueTy, chunkSize,
+ if (memTy && (valueTy.getElementType() != memTy.getElementType()) )
+ return emitError() << "Value should have the same element type as MemRef.";
+
+ return isValidGatherScatterBufferParams(maskTy, valueTy, chunkSize,
[&]() { return emitOpError(); });
}
@@ -773,6 +755,14 @@ LogicalResult StoreScatterOp::verify() {
auto maskTy = getMaskType();
auto valueTy = getValueType();
+ if (tdescTy) {
+ if (!tdescTy.isScattered())
+ return emitOpError("Expects a scattered TensorDesc.\n");
+ } else {
+ if (getRankOf(getDest()) > 1)
+ return emitOpError("Expecting the dest is a 1D memref or pointer (uint64_t).");
+ }
+
if (!isWriteHintOrNone(getL1HintAttr()))
return emitOpError("invalid l1_hint: ") << getL1HintAttr();
@@ -790,10 +780,10 @@ LogicalResult StoreScatterOp::verify() {
uint64_t chunkSize = static_cast<int64_t>(getChunkSize().value_or(1));
auto memTy = dyn_cast<MemRefType>(destTy);
- if (memTy)
- return isValidGatherScatterMemRefParams(maskTy, valueTy, memTy, chunkSize,
- [&]() { return emitOpError(); });
- return isValidGatherScatterRawptrParams(maskTy, valueTy, chunkSize,
+ if (memTy && (valueTy.getElementType() != memTy.getElementType()) )
+ return emitError() << "Value should have the same element type as MemRef.";
+
+ return isValidGatherScatterBufferParams(maskTy, valueTy, chunkSize,
[&]() { return emitOpError(); });
}
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index 29ee864b8f34f..d52f7f2ac274a 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -484,6 +484,7 @@ struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
xegpu::TensorDescType tdescTy = op.getTensorDescType();
+ // TODO: handle the unstructure source case (!tdesTy)
if (!tdescTy || !tdescTy.isScattered())
return failure();
@@ -546,6 +547,7 @@ struct UnrollPrefetchOp : public UnrollPattern<xegpu::PrefetchOp> {
Location loc = op.getLoc();
xegpu::TensorDescType tdescTy = op.getTensorDescType();
+ // TODO: handle the unstructure source case (!tdesTy)
if (!tdescTy || !tdescTy.isScattered())
return failure();
@@ -575,6 +577,7 @@ struct UnrollStoreScatterOp : public UnrollPattern<xegpu::StoreScatterOp> {
VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
xegpu::TensorDescType tdescTy = op.getTensorDescType();
+ // TODO: handle the unstructure source case (!tdesTy)
if (!tdescTy || !tdescTy.isScattered())
return failure();
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index b8e6a31d8d2f7..4cece4640634e 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -405,7 +405,7 @@ func.func @load_gather_offset_wi(%src: ui64) {
}
// -----
-func.func @store_scatter_offset_sg(%src: memref<?xf16>) {
+func.func @store_scatter_offset_wi(%src: memref<?xf16>) {
%val = arith.constant dense<2.9>: vector<4xf16>
%offsets = arith.constant dense<[0]> : vector<1xindex>
%mask = arith.constant dense<1>: vector<1xi1>
diff --git a/mlir/test/Dialect/XeGPU/ops.mlir b/mlir/test/Dialect/XeGPU/ops.mlir
index ea80601ef5574..6be2371d4d7b2 100644
--- a/mlir/test/Dialect/XeGPU/ops.mlir
+++ b/mlir/test/Dialect/XeGPU/ops.mlir
@@ -330,7 +330,6 @@ gpu.func @subgroup_store_nd_2(%dst: memref<24x32xf16>, %x : index) {
gpu.return
}
-
// CHECK: func @subgroup_store_nd_offset_1(%[[arg0:.*]]: memref<24x32xf16>) {
gpu.func @subgroup_store_nd_offset_1(%dst: memref<24x32xf16>) {
// CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<32xf16>
@@ -659,7 +658,6 @@ gpu.func @prefetch(%src: ui64) {
gpu.return
}
-
// CHECK: gpu.func @prefetch_offset(%[[arg0:.*]]: ui64) {
gpu.func @prefetch_offset(%src: ui64) {
//CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
>From da7142aa4eb83340885e459edd71a4c865ba2aca Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Tue, 29 Jul 2025 17:36:14 +0000
Subject: [PATCH 13/16] git-clang-format
---
mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 18 +++++++++++-------
1 file changed, 11 insertions(+), 7 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 7c8ee7408dfd1..45a4363bd11ba 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -111,7 +111,8 @@ isValidGatherScatterParams(Type maskTy, VectorType valueTy,
}
static LogicalResult
-isValidGatherScatterBufferParams(Type maskTy, VectorType valueTy, int64_t chunkSize,
+isValidGatherScatterBufferParams(Type maskTy, VectorType valueTy,
+ int64_t chunkSize,
function_ref<InFlightDiagnostic()> emitError) {
if (!valueTy)
@@ -677,7 +678,8 @@ LogicalResult PrefetchOp::verify() {
return emitOpError("Expects a scattered TensorDesc.\n");
} else {
if (getRankOf(getSource()) > 1)
- return emitOpError("Expecting the source is a 1D memref or pointer (uint64_t).");
+ return emitOpError(
+ "Expecting the source is a 1D memref or pointer (uint64_t).");
}
if (!isReadHintOrNone(getL1HintAttr()))
@@ -712,7 +714,8 @@ LogicalResult LoadGatherOp::verify() {
return emitOpError("Expects a scattered TensorDesc.\n");
} else {
if (getRankOf(getSource()) > 1)
- return emitOpError("Expecting the source is a 1D memref or pointer (uint64_t).");
+ return emitOpError(
+ "Expecting the source is a 1D memref or pointer (uint64_t).");
}
if (!isReadHintOrNone(getL1HintAttr()))
@@ -731,7 +734,7 @@ LogicalResult LoadGatherOp::verify() {
uint64_t chunkSize = static_cast<int64_t>(getChunkSize().value_or(1));
auto memTy = dyn_cast<MemRefType>(srcTy);
- if (memTy && (valueTy.getElementType() != memTy.getElementType()) )
+ if (memTy && (valueTy.getElementType() != memTy.getElementType()))
return emitError() << "Value should have the same element type as MemRef.";
return isValidGatherScatterBufferParams(maskTy, valueTy, chunkSize,
@@ -760,9 +763,10 @@ LogicalResult StoreScatterOp::verify() {
return emitOpError("Expects a scattered TensorDesc.\n");
} else {
if (getRankOf(getDest()) > 1)
- return emitOpError("Expecting the dest is a 1D memref or pointer (uint64_t).");
+ return emitOpError(
+ "Expecting the dest is a 1D memref or pointer (uint64_t).");
}
-
+
if (!isWriteHintOrNone(getL1HintAttr()))
return emitOpError("invalid l1_hint: ") << getL1HintAttr();
@@ -780,7 +784,7 @@ LogicalResult StoreScatterOp::verify() {
uint64_t chunkSize = static_cast<int64_t>(getChunkSize().value_or(1));
auto memTy = dyn_cast<MemRefType>(destTy);
- if (memTy && (valueTy.getElementType() != memTy.getElementType()) )
+ if (memTy && (valueTy.getElementType() != memTy.getElementType()))
return emitError() << "Value should have the same element type as MemRef.";
return isValidGatherScatterBufferParams(maskTy, valueTy, chunkSize,
>From 8b99ecc629be3c1dfb96c93ead49594dc30a47ef Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Tue, 29 Jul 2025 22:43:25 +0000
Subject: [PATCH 14/16] minor polish
---
.../include/mlir/Dialect/XeGPU/IR/XeGPUOps.td | 3 +-
mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 42 ++++++++-----------
2 files changed, 20 insertions(+), 25 deletions(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index c864ce0c3d9cd..3e075de1651ab 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -863,7 +863,8 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [MemoryEffects<[MemWrite]>]> {
}
Value getTensorDesc() {
- return getDest();
+ assert(getTensorDescType() && "Expected dest to be a TensorDescType");
+ return getDest();
}
xegpu::TensorDescType getTensorDescType() {
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 45a4363bd11ba..1b114d41b6ca5 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -673,14 +673,12 @@ LogicalResult CreateDescOp::verify() {
LogicalResult PrefetchOp::verify() {
auto tdescTy = getTensorDescType();
- if (tdescTy) {
- if (!tdescTy.isScattered())
- return emitOpError("Expects a scattered TensorDesc.\n");
- } else {
- if (getRankOf(getSource()) > 1)
- return emitOpError(
- "Expecting the source is a 1D memref or pointer (uint64_t).");
- }
+ if (tdescTy && !tdescTy.isScattered())
+ return emitOpError("Expects a scattered TensorDesc.\n");
+
+ if (!tdescTy && getRankOf(getSource()) > 1)
+ return emitOpError(
+ "Expecting the source is a 1D memref or pointer (uint64_t).");
if (!isReadHintOrNone(getL1HintAttr()))
return emitOpError("invalid l1_hint: ") << getL1HintAttr();
@@ -709,14 +707,12 @@ LogicalResult LoadGatherOp::verify() {
auto maskTy = getMaskType();
auto valueTy = getValueType();
- if (tdescTy) {
- if (!tdescTy.isScattered())
- return emitOpError("Expects a scattered TensorDesc.\n");
- } else {
- if (getRankOf(getSource()) > 1)
- return emitOpError(
- "Expecting the source is a 1D memref or pointer (uint64_t).");
- }
+ if (tdescTy && !tdescTy.isScattered())
+ return emitOpError("Expects a scattered TensorDesc.\n");
+
+ if (!tdescTy && getRankOf(getSource()) > 1)
+ return emitOpError(
+ "Expecting the source is a 1D memref or pointer (uint64_t).");
if (!isReadHintOrNone(getL1HintAttr()))
return emitOpError("invalid l1_hint: ") << getL1HintAttr();
@@ -758,14 +754,12 @@ LogicalResult StoreScatterOp::verify() {
auto maskTy = getMaskType();
auto valueTy = getValueType();
- if (tdescTy) {
- if (!tdescTy.isScattered())
- return emitOpError("Expects a scattered TensorDesc.\n");
- } else {
- if (getRankOf(getDest()) > 1)
- return emitOpError(
- "Expecting the dest is a 1D memref or pointer (uint64_t).");
- }
+ if (tdescTy && !tdescTy.isScattered())
+ return emitOpError("Expects a scattered TensorDesc.\n");
+
+ if (!tdescTy && getRankOf(getDest()) > 1)
+ return emitOpError(
+ "Expecting the dest is a 1D memref or pointer (uint64_t).");
if (!isWriteHintOrNone(getL1HintAttr()))
return emitOpError("invalid l1_hint: ") << getL1HintAttr();
>From 04306ca04fa6531f2cefe607555d64ba55fe83ef Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Wed, 30 Jul 2025 17:23:15 +0000
Subject: [PATCH 15/16] address comments
---
.../include/mlir/Dialect/XeGPU/IR/XeGPUOps.td | 37 ++++++++++++++-----
mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 2 +-
.../Dialect/XeGPU/Transforms/XeGPUUnroll.cpp | 6 +--
3 files changed, 31 insertions(+), 14 deletions(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index 3e075de1651ab..75b16a87e03c6 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -637,7 +637,10 @@ def XeGPU_PrefetchOp : XeGPU_Op<"prefetch", []> {
```
Example 2:
- A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc. It combines "create scattered TensorTdesc" and "prefetch with scattered TensorTdesc". The source operand could be a raw pointer (uint64_t). Please refer to create_tdesc for the restriction of memref.
+ A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc.
+ It combines "create scattered TensorTdesc" and "prefetch with scattered TensorTdesc".
+ The source operand could be a raw pointer (uint64_t).
+ Please refer to create_tdesc for the restriction of memref.
```mlir
%a = memref.alloc() : memref<1024xf32>
%0 = arith.constant dense<[0, 16, 32, 64]> : vector<4xindex>
@@ -660,8 +663,11 @@ def XeGPU_PrefetchOp : XeGPU_Op<"prefetch", []> {
return getSource().getType();
}
- Value getTensorDesc() {
- return getSource();
+ TypedValue<xegpu::TensorDescType> getTensorDesc() {
+ if (auto tdescType = getTensorDescType()) {
+ return llvm::cast<TypedValue<xegpu::TensorDescType>>(getSource());
+ }
+ return TypedValue<xegpu::TensorDescType>();
}
xegpu::TensorDescType getTensorDescType() {
@@ -728,7 +734,10 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> {
```
Example 4:
- A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc. It combines "create scattered TensorTdesc" and "load with scattered TensorTdesc". The source operand could be a raw pointer (uint64_t). Please refer to create_tdesc for the restriction of memref.
+ A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc.
+ It combines "create scattered TensorTdesc" and "load with scattered TensorTdesc".
+ The source operand could be a raw pointer (uint64_t). Please refer to create_tdesc
+ for the restriction of memref.
```mlir
%a = memref.alloc() : memref<1024xf32>
%offsets = vector.step : vector<16xindex>
@@ -756,8 +765,11 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> {
return getSource().getType();
}
- Value getTensorDesc() {
- return getSource();
+ TypedValue<xegpu::TensorDescType> getTensorDesc() {
+ if (auto tdescType = getTensorDescType()) {
+ return llvm::cast<TypedValue<xegpu::TensorDescType>>(getSource());
+ }
+ return TypedValue<xegpu::TensorDescType>();
}
xegpu::TensorDescType getTensorDescType() {
@@ -833,7 +845,10 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [MemoryEffects<[MemWrite]>]> {
```
Example 4:
- A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc. It combines "create scattered TensorTdesc" and "store with scattered TensorTdesc". The dest operand could be a raw pointer (uint64_t). Please refer to create_tdesc for the restriction of memref.
+ A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc.
+ It combines "create scattered TensorTdesc" and "store with scattered TensorTdesc".
+ The dest operand could be a raw pointer (uint64_t).
+ Please refer to create_tdesc for the restriction of memref.
```mlir
%a = memref.alloc() : memref<1024xf32>
%val = arith.constant dense<0.0> : vector<16xf32>
@@ -862,9 +877,11 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [MemoryEffects<[MemWrite]>]> {
return getDest().getType();
}
- Value getTensorDesc() {
- assert(getTensorDescType() && "Expected dest to be a TensorDescType");
- return getDest();
+ TypedValue<xegpu::TensorDescType> getTensorDesc() {
+ if (auto tdescType = getTensorDescType()) {
+ return llvm::cast<TypedValue<xegpu::TensorDescType>>(getDest());
+ }
+ return TypedValue<xegpu::TensorDescType>();
}
xegpu::TensorDescType getTensorDescType() {
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 1b114d41b6ca5..33450f3fa229e 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -708,7 +708,7 @@ LogicalResult LoadGatherOp::verify() {
auto valueTy = getValueType();
if (tdescTy && !tdescTy.isScattered())
- return emitOpError("Expects a scattered TensorDesc.\n");
+ return emitOpError("Expects a scattered TensorDesc.");
if (!tdescTy && getRankOf(getSource()) > 1)
return emitOpError(
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index d52f7f2ac274a..9f0c074a1489d 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -485,7 +485,7 @@ struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
xegpu::TensorDescType tdescTy = op.getTensorDescType();
// TODO: handle the unstructure source case (!tdesTy)
- if (!tdescTy || !tdescTy.isScattered())
+ if (!tdescTy || op.getOffsets())
return failure();
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
@@ -548,7 +548,7 @@ struct UnrollPrefetchOp : public UnrollPattern<xegpu::PrefetchOp> {
xegpu::TensorDescType tdescTy = op.getTensorDescType();
// TODO: handle the unstructure source case (!tdesTy)
- if (!tdescTy || !tdescTy.isScattered())
+ if (!tdescTy || op.getOffsets())
return failure();
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
@@ -578,7 +578,7 @@ struct UnrollStoreScatterOp : public UnrollPattern<xegpu::StoreScatterOp> {
xegpu::TensorDescType tdescTy = op.getTensorDescType();
// TODO: handle the unstructure source case (!tdesTy)
- if (!tdescTy || !tdescTy.isScattered())
+ if (!tdescTy || op.getOffsets())
return failure();
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
>From bbd6530ee6c0789af571578530dabe3e0cce7915 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Wed, 30 Jul 2025 17:53:36 +0000
Subject: [PATCH 16/16] add more invalid tests
---
mlir/test/Dialect/XeGPU/invalid.mlir | 33 ++++++++++++++++++++++++----
1 file changed, 29 insertions(+), 4 deletions(-)
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index 4cece4640634e..dff3ffab39ecf 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -384,6 +384,14 @@ func.func @load_gather_vc_3(%src: ui64) {
return
}
+// -----
+func.func @prefetch_offset_wi_1(%src: memref<4x4xf32>) {
+ %offsets = arith.constant dense<[0]> : vector<1xindex>
+ // expected-error at +1 {{Expecting the source is a 1D memref or pointer}}
+ xegpu.prefetch %src[%offsets]: memref<4x4xf32>, vector<1xindex>
+ return
+}
+
// -----
func.func @load_gather_offset_sg(%src: memref<?xf16>) {
%offsets = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
@@ -405,7 +413,7 @@ func.func @load_gather_offset_wi(%src: ui64) {
}
// -----
-func.func @store_scatter_offset_wi(%src: memref<?xf16>) {
+func.func @store_scatter_offset_wi_1(%src: memref<?xf16>) {
%val = arith.constant dense<2.9>: vector<4xf16>
%offsets = arith.constant dense<[0]> : vector<1xindex>
%mask = arith.constant dense<1>: vector<1xi1>
@@ -415,17 +423,34 @@ func.func @store_scatter_offset_wi(%src: memref<?xf16>) {
return
}
+// -----
+func.func @store_scatter_offset_wi_2(%src: memref<4x4xf16>) {
+ %val = arith.constant dense<2.9>: vector<4xf16>
+ %offsets = arith.constant dense<[0]> : vector<1xindex>
+ %mask = arith.constant dense<1>: vector<1xi1>
+ // expected-error at +1 {{Expecting the dest is a 1D memref or pointer}}
+ xegpu.store %val, %src[%offsets], %mask
+ : vector<4xf16>, memref<4x4xf16>, vector<1xindex>, vector<1xi1>
+ return
+}
// -----
-func.func @load_gather_offset_wi(%src: ui64) {
- %val = arith.constant dense<2.9>: vector<4x2xf16>
+func.func @load_gather_offset_wi_2(%src: ui64) {
%mask = arith.constant dense<1>: vector<1xi1>
%offsets = arith.constant dense<[0]> : vector<1xindex>
// expected-error at +1 {{value elements must match chunk size}}
- %2 = xegpu.load %src[%offsets], %mask <{chunk_size = 2}> : ui64, vector<1xindex>, vector<1xi1> -> vector<4xf32>
+ %2 = xegpu.load %src[%offsets], %mask <{chunk_size = 2}> : ui64, vector<1xindex>, vector<1xi1> -> vector<4xf16>
return
}
+// -----
+func.func @load_gather_offset_wi_1(%src: memref<4x4xf32>) {
+ %mask = arith.constant dense<1>: vector<1xi1>
+ %offsets = arith.constant dense<[0]> : vector<1xindex>
+ // expected-error at +1 {{Expecting the source is a 1D memref or pointer}}
+ %2 = xegpu.load %src[%offsets], %mask <{chunk_size = 2}> : memref<4x4xf32>, vector<1xindex>, vector<1xi1> -> vector<2xf32>
+ return
+}
// -----
func.func @store_scatter_vc_1(%src: memref<24x32xf32>) {
More information about the Mlir-commits
mailing list