[Mlir-commits] [mlir] [mlir][xegpu] Relax rank restriction of TensorDescType (PR #145916)
Chao Chen
llvmlistbot at llvm.org
Tue Jul 1 08:21:11 PDT 2025
https://github.com/chencha3 updated https://github.com/llvm/llvm-project/pull/145916
>From d85e4ff4cb1a8687a64945a7e1c506dc3948771c Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Thu, 26 Jun 2025 01:25:06 +0000
Subject: [PATCH 01/14] remove 1D and 2D checks
---
.../mlir/Dialect/XeGPU/IR/XeGPUTypes.td | 13 +++++-----
mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp | 17 +++++--------
mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 25 ++++++-------------
mlir/test/Dialect/XeGPU/invalid.mlir | 14 +++--------
4 files changed, 22 insertions(+), 47 deletions(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
index 84314875c2ae5..d5a61174fbfe3 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
@@ -17,12 +17,12 @@ def XeGPU_IntType: AnyTypeOf<[I1, I8, I16, I32, I64, SI1, SI8, SI16, SI32, SI64,
def XeGPU_FloatType: AnyTypeOf<[F16, F32, F64, BF16, TF32]>;
def XeGPU_ScalarType: AnyTypeOf<[XeGPU_IntType, XeGPU_FloatType]>;
def XeGPU_BaseAddrType: AnyTypeOf<[Non0RankedMemRefOf<[XeGPU_ScalarType]>, UI64, UI32, I64, I32]>;
-def XeGPU_DpasOprType: VectorOfRankAndType<[1, 2, 3], [XeGPU_ScalarType]>;
-def XeGPU_DpasResType: VectorOfRankAndType<[1, 2], [XeGPU_ScalarType]>;
-def XeGPU_OffsetType: VectorOfRankAndType<[1], [Index]>;
-def XeGPU_MaskType: AnyTypeOf<[VectorOfRankAndType<[1], [I1]>, I1]>;
-def XeGPU_ValueType: AnyTypeOf<[VectorOfRankAndType<[1,2,3,4], [XeGPU_ScalarType]>, XeGPU_ScalarType]>;
-def XeGPU_Vector2DType: VectorOfRankAndType<[2], [XeGPU_ScalarType]>;
+def XeGPU_DpasOprType: FixedVectorOfRankAndType<[1, 2, 3], [XeGPU_ScalarType]>;
+def XeGPU_DpasResType: FixedVectorOfRankAndType<[1, 2], [XeGPU_ScalarType]>;
+def XeGPU_OffsetType: FixedVectorOfRankAndType<[1, 2], [Index]>;
+def XeGPU_MaskType: FixedVectorOfRankAndType<[1, 2], [I1]>;
+def XeGPU_ValueType: FixedVectorOfNonZeroRankOf<[XeGPU_ScalarType]>;
+def XeGPU_Vector2DType: FixedVectorOfRankAndType<[2], [XeGPU_ScalarType]>;
// common base class for types in XeGPU dialect
class XeGPUTypeDef<string name, string typeMnemonic, list<Trait> traits = [],
@@ -118,7 +118,6 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
];
let extraClassDeclaration = [{
- using TensorType::clone;
using mlir::ShapedType::Trait<TensorDescType>::getElementTypeBitWidth;
using mlir::ShapedType::Trait<TensorDescType>::getRank;
using mlir::ShapedType::Trait<TensorDescType>::getNumElements;
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 649e0d453015f..5f159dd223aba 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -310,13 +310,14 @@ LogicalResult TensorDescType::verify(
llvm::ArrayRef<int64_t> shape, mlir::Type elementType,
mlir::Attribute encoding, mlir::Attribute layout) {
size_t rank = shape.size();
- if (rank != 1 && rank != 2)
- return emitError() << "expected 1D or 2D tensor";
+
+ if (rank == 0)
+ return emitError() << "expected non-zero rank tensor";
auto blockAttr = mlir::dyn_cast_if_present<BlockTensorDescAttr>(encoding);
if (blockAttr) {
MemorySpaceAttr memorySpaceAttr = blockAttr.getMemorySpace();
- if (rank == 2 && memorySpaceAttr &&
+ if (rank > 1 && memorySpaceAttr &&
memorySpaceAttr.getValue() == MemorySpace::SLM)
return emitError() << "SLM is not supported for 2D block tensor";
}
@@ -329,16 +330,10 @@ LogicalResult TensorDescType::verify(
: 1;
auto scatterAttr = mlir::dyn_cast_if_present<ScatterTensorDescAttr>(encoding);
if (scatterAttr) {
- // Expected tensor ranks for scattered data:
- // - 1D tensor for fully non-contiguous elements (chunk size == 1)
- // - 2D tensor for scattered blocks (chunk size > 1)
unsigned chunkSize = scatterAttr.getChunkSize().getInt();
- if (rank == 1 && chunkSize != 1)
- return emitError() << "expected non-contiguous elements for 1D tensor";
- if (rank == 2 && chunkSize < 2)
- return emitError() << "expected chunk blocks for 2D tensor";
// If chunk size > 1, the second dimension of the tensor shape must be
- // equal to chunk size and it must be a multiple of the packing factor.
+ // equal to chunk size and it must be a multiple of the
+ // chunkAlignmentFactor.
if (chunkSize > 1) {
if (shape.back() != chunkSize)
return emitError() << "expected tensor shape[1] to match chunk size";
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 2793c7a35bc97..d106f1c53c96b 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -203,11 +203,9 @@ LogicalResult CreateNdDescOp::verify() {
"is a memref) should match with each other.");
// check result TensorDesc rank
- invalidRank = (getType().getRank() > 2 || getType().getRank() > rank);
-
- if (invalidRank)
+ if (getType().getRank() > rank)
return emitOpError(
- "Expecting the TensorDesc rank is up to 2 and not greater than the "
+ "Expecting the TensorDesc rank is not greater than the "
"ranks of shape, strides, offsets or the memref source.");
if (invalidElemTy)
@@ -247,9 +245,6 @@ LogicalResult LoadNdOp::verify() {
auto tdescTy = getTensorDescType();
auto valueTy = getType();
- if (tdescTy.getRank() > 2)
- return emitOpError("Expecting a 1D/2D TensorDesc.\n");
-
if (tdescTy.isScattered())
return emitOpError("Expects a non-scattered TensorDesc.\n");
@@ -316,15 +311,13 @@ LogicalResult LoadNdOp::verify() {
}
auto array_len = tdescTy.getArrayLength();
- if (array_len > 1) {
+ if (array_len > 1)
tdescShape.insert(tdescShape.begin(), array_len);
- }
- if (tdescShape != valueShape) {
+ if (tdescShape != valueShape)
return emitOpError() << "Result shape " << makeString(valueShape)
<< " is not consistent with tensor descriptor "
<< tdescTy;
- }
return success();
}
@@ -336,9 +329,6 @@ LogicalResult StoreNdOp::verify() {
auto dstTy = getTensorDescType(); // Tile
auto valTy = getValueType(); // Vector
- if (dstTy.getRank() > 2)
- return emitOpError("Expecting a 1D/2D TensorDesc.\n");
-
if (dstTy.isScattered())
return emitOpError("Expects a non-scattered TensorDesc.\n");
@@ -370,22 +360,21 @@ LogicalResult StoreNdOp::verify() {
return emitOpError()
<< "TensorDesc doesn't need LayoutAttr for SIMT code";
- if (tdescElems % valueElems) {
+ if (tdescElems % valueElems)
return emitOpError()
<< "Value shape " << makeString(getShapeOf(valTy))
<< " is not a valid distribution for tensor descriptor " << dstTy;
- }
+
return success();
}
// SIMD code should have the same shape as the tensor descriptor.
auto tdescShape = getShapeOf(dstTy);
auto valueShape = getShapeOf(valTy);
- if (tdescShape != valueShape) {
+ if (tdescShape != valueShape)
return emitOpError() << "Value shape " << makeString(valueShape)
<< " is not consistent with tensor descriptor "
<< dstTy;
- }
return success();
}
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index a2778cd94d963..3969ae64b6dbf 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -1,8 +1,8 @@
// RUN: mlir-opt %s -split-input-file -verify-diagnostics
// -----
-func.func @create_nd_tdesc_vc_1(%src: memref<24xf32>) {
- // expected-error at +1 {{Expecting the TensorDesc rank is up to 2 and not greater than the ranks of shape, strides, offsets or the memref source}}
+func.func @test_create_nd_tdesc_vc_1(%src: memref<24xf32>) {
+ // expected-error at +1 {{Expecting the TensorDesc rank is not greater than the ranks of shape, strides, offsets or the memref source}}
%1 = xegpu.create_nd_tdesc %src[0] : memref<24xf32> -> !xegpu.tensor_desc<8x16xf32>
return
}
@@ -406,18 +406,10 @@ func.func @atomic_rmw(%src: ui64, %value : vector<16x4xf32>, %mask : vector<16xi
return
}
-// -----
-func.func @tensor_desc_invalid_rank(%src: memref<24x32xf32>) {
- %0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
- // expected-error at +1 {{expected 1D or 2D tensor}}
- !xegpu.tensor_desc<16x2x2xf32>
- return
-}
-
// -----
func.func @tensor_desc_invalid_rank_1(%src: memref<24x32xf32>) {
%0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
- // expected-error at +1 {{expected 1D or 2D tensor}}
+ // expected-error at +1 {{expected non-zero rank tensor}}
!xegpu.tensor_desc<f32>
return
}
>From 265dc278fdf525e826bd01a8393320fcec25d423 Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Thu, 26 Jun 2025 15:26:29 +0000
Subject: [PATCH 02/14] cleanup
---
.../mlir/Dialect/XeGPU/IR/XeGPUAttrs.td | 15 ++++++++---
.../mlir/Dialect/XeGPU/IR/XeGPUTypes.td | 2 +-
mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp | 25 +++++++++++--------
.../XeGPU/Transforms/XeGPUBlocking.cpp | 6 ++---
mlir/test/Dialect/XeGPU/invalid.mlir | 21 +++++-----------
mlir/test/Dialect/XeGPU/ops.mlir | 4 +--
.../lib/Dialect/XeGPU/TestXeGPUTransforms.cpp | 8 ++----
7 files changed, 39 insertions(+), 42 deletions(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index 84c1dc1373ee5..1b02565acfd8c 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -43,8 +43,8 @@ def XeGPU_BlockTensorDescAttr: XeGPU_TensorDescAttr<"BlockTensorDesc", "block_td
let parameters = (ins
OptionalParameter<"MemorySpaceAttr">: $memory_space,
- OptionalParameter<"IntegerAttr", "1">: $array_length,
- OptionalParameter<"BoolAttr", "true">: $boundary_check
+ OptionalParameter<"IntegerAttr">: $array_length,
+ OptionalParameter<"BoolAttr">: $boundary_check
);
let builders = [
@@ -77,9 +77,8 @@ def XeGPU_ScatterTensorDescAttr: XeGPU_TensorDescAttr<"ScatterTensorDesc", "scat
"MemorySpaceAttr::get($_ctxt, xegpu::MemorySpace::Global)",
"Data memory location"
>: $memory_space,
- DefaultValuedParameter<
+ OptionalParameter<
"IntegerAttr",
- "IntegerAttr::get(IntegerType::get($_ctxt, 64), 1)",
"Number of contiguous elements"
>: $chunk_size
);
@@ -91,6 +90,14 @@ def XeGPU_ScatterTensorDescAttr: XeGPU_TensorDescAttr<"ScatterTensorDesc", "scat
)>
];
+ let extraClassDeclaration = [{
+ int64_t getChunkSizeOrDefault() {
+ if (auto attr = getChunkSize())
+ return attr.getInt();
+ return 1;
+ }
+ }];
+
let genVerifyDecl = 1;
}
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
index d5a61174fbfe3..2b99e80677098 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
@@ -185,7 +185,7 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
auto scatter_attr = mlir::dyn_cast_if_present<ScatterTensorDescAttr>(attr);
assert((!attr || scatter_attr) && "invalid on non ScatterTensorDescAttr.");
if (scatter_attr)
- return scatter_attr.getChunkSize().getInt();
+ return scatter_attr.getChunkSizeOrDefault();
return 1;
}
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 5f159dd223aba..57458513805b0 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -128,11 +128,14 @@ ScatterTensorDescAttr::get(mlir::MLIRContext *context,
LogicalResult ScatterTensorDescAttr::verify(
llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
MemorySpaceAttr memory_space, IntegerAttr chunk_size) {
- int64_t chunkSize = chunk_size.getInt();
- SmallVector<int64_t> supportedChunkSizes = {1, 2, 3, 4, 8,
- 16, 32, 64, 128, 256};
- if (!llvm::is_contained(supportedChunkSizes, chunkSize))
- return emitError() << "invalid chunk size";
+
+ if (chunk_size) {
+ int64_t chunkSize = chunk_size.getInt();
+ SmallVector<int64_t> supportedChunkSizes = {2, 3, 4, 8, 16,
+ 32, 64, 128, 256};
+ if (!llvm::is_contained(supportedChunkSizes, chunkSize))
+ return emitError() << "invalid chunk size";
+ }
return success();
}
@@ -319,7 +322,7 @@ LogicalResult TensorDescType::verify(
MemorySpaceAttr memorySpaceAttr = blockAttr.getMemorySpace();
if (rank > 1 && memorySpaceAttr &&
memorySpaceAttr.getValue() == MemorySpace::SLM)
- return emitError() << "SLM is not supported for 2D block tensor";
+ return emitError() << "SLM is only supported for 1D block tensor";
}
// for gather and scatter ops, Low-precision types are packed in 32-bit units.
@@ -330,16 +333,18 @@ LogicalResult TensorDescType::verify(
: 1;
auto scatterAttr = mlir::dyn_cast_if_present<ScatterTensorDescAttr>(encoding);
if (scatterAttr) {
- unsigned chunkSize = scatterAttr.getChunkSize().getInt();
+ int64_t chunkSize = scatterAttr.getChunkSizeOrDefault();
+ if (rank == 1 && chunkSize != 1)
+ return emitError() << "expected non-contiguous elements for 1D tensor";
+
// If chunk size > 1, the second dimension of the tensor shape must be
// equal to chunk size and it must be a multiple of the
// chunkAlignmentFactor.
if (chunkSize > 1) {
if (shape.back() != chunkSize)
- return emitError() << "expected tensor shape[1] to match chunk size";
+ return emitError() << "expected last dim of tensor to match chunk size";
if (shape.back() % chunkAlignmentFactor != 0)
- return emitError() << "expected tensor shape[1] to be a multiple of "
- "chunk alignment factor "
+ return emitError() << "expected last dim of tensor to be a multiple of "
<< chunkAlignmentFactor;
}
}
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
index 3950e8f70d1ca..c6c4e3aaa41ed 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
@@ -303,9 +303,7 @@ void XeGPUBlockingPass::runOnOperation() {
// If the encoding is a ScatterTensorDescAttr, we need to
// potentially adjust the chunk size based on the inst_data.
if (tdescTy.isScattered()) {
- auto scatterAttr =
- llvm::dyn_cast_if_present<xegpu::ScatterTensorDescAttr>(encoding);
- int64_t chunkSize = scatterAttr.getChunkSize().getInt();
+ int64_t chunkSize = tdescTy.getChunkSize();
if (chunkSize > 1) {
int64_t blockedChunkSize = chunkSize;
@@ -315,7 +313,7 @@ void XeGPUBlockingPass::runOnOperation() {
// To create a new attribute with a different chunk_size:
auto newEncoding = xegpu::ScatterTensorDescAttr::get(
- ctx, scatterAttr.getMemorySpace().getValue(), blockedChunkSize);
+ ctx, tdescTy.getMemorySpace(), blockedChunkSize);
encoding = newEncoding;
}
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index 3969ae64b6dbf..8eb629660559c 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -17,7 +17,7 @@ func.func @create_nd_tdesc_vc_2(%src: memref<24x32xf32>) {
// -----
func.func @create_nd_tdesc_vc_3(%src: memref<2x24x32xf32, 3>) {
- // expected-error at +1 {{SLM is not supported for 2D block tensor}}
+ // expected-error at +1 {{SLM is only supported for 1D block tensor}}
%1 = xegpu.create_nd_tdesc %src[0, 0, 0] : memref<2x24x32xf32, 3> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<memory_space = slm>>
return
}
@@ -199,15 +199,6 @@ func.func @create_tdesc_vc_1(%src: ui64) {
return
}
-// -----
-func.func @create_tdesc_vc_2(%src: ui64) {
- %0 = arith.constant dense<[0, 2, 4, 6, 8, 10, 12, 14]> : vector<8xindex>
- %1 = xegpu.create_tdesc %src, %0 : ui64, vector<8xindex>
- // expected-error at +1 {{expected chunk blocks for 2D tensor}}
- -> !xegpu.tensor_desc<8x4xf16, #xegpu.scatter_tdesc_attr<>>
- return
-}
-
// -----
func.func @create_tdesc_vc_3(%src: memref<?xf32>) {
%0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
@@ -230,7 +221,7 @@ func.func @create_tdesc_vc_4(%src: memref<?xf32>) {
func.func @create_tdesc_vc_5(%src: memref<?xf32>) {
%0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
%1 = xegpu.create_tdesc %src, %0 : memref<?xf32>, vector<4xindex>
- // expected-error at +1 {{expected tensor shape[1] to match chunk size}}
+ // expected-error at +1 {{expected last dim of tensor to match chunk size}}
-> !xegpu.tensor_desc<4x5xf32, #xegpu.scatter_tdesc_attr<chunk_size = 4>>
return
}
@@ -239,7 +230,7 @@ func.func @create_tdesc_vc_5(%src: memref<?xf32>) {
func.func @create_tdesc_vc_6(%src: memref<?xf16>) {
%0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
%1 = xegpu.create_tdesc %src, %0 : memref<?xf16>, vector<4xindex>
- // expected-error at +1 {{tensor shape[1] to be a multiple of chunk alignment factor 2}}
+ // expected-error at +1 {{last dim of tensor to be a multiple of 2}}
-> !xegpu.tensor_desc<4x3xf16, #xegpu.scatter_tdesc_attr<chunk_size = 3>>
return
}
@@ -478,7 +469,7 @@ func.func @tensor_desc_scatter_invalid_map_data_1(%src: ui64, %offsets: vector<1
%1 = xegpu.create_tdesc %src, %offsets : ui64, vector<16xindex> ->
// expected-error at +1 {{work item data mapping must match the number of contiguous elements}}
!xegpu.tensor_desc<16xf32,
- #xegpu.scatter_tdesc_attr<chunk_size = 1>,
+ #xegpu.scatter_tdesc_attr<>,
#xegpu.layout<lane_layout = [8], lane_data = [2]>>
return
}
@@ -496,9 +487,9 @@ func.func @tensor_desc_scatter_invalid_chunk_size_1D(%src: ui64, %offsets: vecto
// -----
func.func @tensor_desc_scatter_invalid_chunk_size_2D(%src: ui64, %offsets: vector<16xindex>) {
%1 = xegpu.create_tdesc %src, %offsets : ui64, vector<16xindex> ->
- // expected-error at +1 {{expected chunk blocks for 2D tensor}}
+ // expected-error at +1 {{expected last dim of tensor to match chunk size}}
!xegpu.tensor_desc<16x2xf32,
- #xegpu.scatter_tdesc_attr<chunk_size = 1>,
+ #xegpu.scatter_tdesc_attr<chunk_size = 4>,
#xegpu.layout<lane_layout = [8, 1], lane_data = [1, 2]>>
return
}
diff --git a/mlir/test/Dialect/XeGPU/ops.mlir b/mlir/test/Dialect/XeGPU/ops.mlir
index aff8f63adc05b..5c73810575cb0 100644
--- a/mlir/test/Dialect/XeGPU/ops.mlir
+++ b/mlir/test/Dialect/XeGPU/ops.mlir
@@ -291,8 +291,8 @@ gpu.func @create_tdesc_1(%src: memref<?xf32, 3>) {
gpu.func @create_tdesc_2(%src: memref<?xf32>) {
//CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
%0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
- //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : memref<?xf32>, vector<4xindex> -> !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>
- %1 = xegpu.create_tdesc %src, %0 : memref<?xf32>, vector<4xindex> -> !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<chunk_size = 1>>
+ //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : memref<?xf32>, vector<4xindex> -> !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>>
+ %1 = xegpu.create_tdesc %src, %0 : memref<?xf32>, vector<4xindex> -> !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>>
gpu.return
}
diff --git a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
index c84eb74198544..335f89f1826aa 100644
--- a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
+++ b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
@@ -107,10 +107,7 @@ struct TestXeGPUUnrollingPatterns
// If the encoding is a ScatterTensorDescAttr, we need to
// potentially adjust the chunk size based on the inst_data.
if (tdescTy.isScattered()) {
- auto scatterAttr =
- llvm::dyn_cast_if_present<xegpu::ScatterTensorDescAttr>(
- encoding);
- int64_t chunkSize = scatterAttr.getChunkSize().getInt();
+ int64_t chunkSize = tdescTy.getChunkSize();
if (chunkSize > 1) {
int64_t blockedChunkSize = chunkSize;
@@ -120,8 +117,7 @@ struct TestXeGPUUnrollingPatterns
// To create a new attribute with a different chunk_size:
auto newEncoding = xegpu::ScatterTensorDescAttr::get(
- ctx, scatterAttr.getMemorySpace().getValue(),
- blockedChunkSize);
+ ctx, tdescTy.getMemorySpace(), blockedChunkSize);
encoding = newEncoding;
}
>From 45f7214458a1cd8800c86399235988e7d7729e8b Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Thu, 26 Jun 2025 17:29:12 +0000
Subject: [PATCH 03/14] cleanup
---
.../mlir/Dialect/XeGPU/IR/XeGPUAttrs.td | 18 ++++----
.../mlir/Dialect/XeGPU/IR/XeGPUTypes.td | 6 +--
mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp | 33 +++----------
mlir/test/Dialect/XeGPU/invalid.mlir | 46 ++-----------------
4 files changed, 22 insertions(+), 81 deletions(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index 1b02565acfd8c..bcd5724835783 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -67,8 +67,11 @@ def XeGPU_ScatterTensorDescAttr: XeGPU_TensorDescAttr<"ScatterTensorDesc", "scat
TensorDesc is located, `Global` device memory or `Shared` local memory.
It is default to `Global`.
- 2. `chunk_size`: indicates number of contiguous elements accessed for each
- offset, default is 1. It is used with `scattered` attr only.
+ 2. `chunk_size`: Specifies the number of contiguous elements accessed per offset.
+ The default value is 1. While XeGPU supports a range of chunk sizes, hardware
+ may only allow specific values (e.g., 1, 2, 3, 4, 8, 16, 32, 64, 128, 256).
+ Therefore, XeGPU will legalize the chunk size as needed prior to lowering to
+ hardware instructions.
}];
let parameters = (ins
@@ -77,8 +80,9 @@ def XeGPU_ScatterTensorDescAttr: XeGPU_TensorDescAttr<"ScatterTensorDesc", "scat
"MemorySpaceAttr::get($_ctxt, xegpu::MemorySpace::Global)",
"Data memory location"
>: $memory_space,
- OptionalParameter<
+ DefaultValuedParameter<
"IntegerAttr",
+ "IntegerAttr::get(IntegerType::get($_ctxt, 64), 1)",
"Number of contiguous elements"
>: $chunk_size
);
@@ -91,14 +95,10 @@ def XeGPU_ScatterTensorDescAttr: XeGPU_TensorDescAttr<"ScatterTensorDesc", "scat
];
let extraClassDeclaration = [{
- int64_t getChunkSizeOrDefault() {
- if (auto attr = getChunkSize())
- return attr.getInt();
- return 1;
+ int64_t getChunkSizeAsInt() {
+ return getChunkSize().getInt();
}
}];
-
- let genVerifyDecl = 1;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
index 2b99e80677098..ccbf4ecdf05ec 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
@@ -183,10 +183,8 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
int getChunkSize() {
auto attr = getEncoding();
auto scatter_attr = mlir::dyn_cast_if_present<ScatterTensorDescAttr>(attr);
- assert((!attr || scatter_attr) && "invalid on non ScatterTensorDescAttr.");
- if (scatter_attr)
- return scatter_attr.getChunkSizeOrDefault();
- return 1;
+ assert(scatter_attr && "invalid on non ScatterTensorDescAttr.");
+ return scatter_attr.getChunkSizeAsInt();
}
/// Helper to drop all layout information from the TensorDesc type.
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 57458513805b0..32a4bf883829f 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -125,21 +125,6 @@ ScatterTensorDescAttr::get(mlir::MLIRContext *context,
return Base::get(context, scopeAttr, chunkSizeAttr);
}
-LogicalResult ScatterTensorDescAttr::verify(
- llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
- MemorySpaceAttr memory_space, IntegerAttr chunk_size) {
-
- if (chunk_size) {
- int64_t chunkSize = chunk_size.getInt();
- SmallVector<int64_t> supportedChunkSizes = {2, 3, 4, 8, 16,
- 32, 64, 128, 256};
- if (!llvm::is_contained(supportedChunkSizes, chunkSize))
- return emitError() << "invalid chunk size";
- }
-
- return success();
-}
-
//===----------------------------------------------------------------------===//
// XeGPU_LayoutAttr
//===----------------------------------------------------------------------===//
@@ -333,7 +318,7 @@ LogicalResult TensorDescType::verify(
: 1;
auto scatterAttr = mlir::dyn_cast_if_present<ScatterTensorDescAttr>(encoding);
if (scatterAttr) {
- int64_t chunkSize = scatterAttr.getChunkSizeOrDefault();
+ int64_t chunkSize = scatterAttr.getChunkSizeAsInt();
if (rank == 1 && chunkSize != 1)
return emitError() << "expected non-contiguous elements for 1D tensor";
@@ -357,17 +342,13 @@ LogicalResult TensorDescType::verify(
auto laneData = layoutAttr.getLaneData();
if (scatterAttr && laneData) {
// Validate subgroup mapping rules for scattered tensors.
- // A work-item's slice of the tensor with shape [sg_size] or
- // [sg_size, chunk_size] will be [1] or [1, 32/element_ty_bit_width]
- // respectively, the mapping should reflect that. This is because each
- // work item access data in 32 bit granularity.
-
- if (rank > 1 && laneData[0] != 1)
+ // if chunkSize > 1, the last dimension of the tensor should
+ // be distributed in the units divisible by chunkAlignmentFactor.
+ int64_t chunkSize = scatterAttr.getChunkSizeAsInt();
+ if (chunkSize > 1 && laneData[rank - 1] % chunkAlignmentFactor)
return emitError()
- << "cannot map over non-contiguous scattered row elements";
- if (laneData[rank - 1] != chunkAlignmentFactor)
- return emitError() << "work item data mapping must match the number of "
- "contiguous elements";
+ << "expected last dim of lane_data to be a multiple of: "
+ << chunkAlignmentFactor;
}
if (!XeGPUDialect::isEvenlyDistributable(shape, layoutAttr)) {
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index 8eb629660559c..a6f7d0992d7e7 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -210,15 +210,6 @@ func.func @create_tdesc_vc_3(%src: memref<?xf32>) {
// -----
func.func @create_tdesc_vc_4(%src: memref<?xf32>) {
- %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
- %1 = xegpu.create_tdesc %src, %0 : memref<?xf32>, vector<4xindex>
- // expected-error at +1 {{invalid chunk size}}
- -> !xegpu.tensor_desc<4x5xf32, #xegpu.scatter_tdesc_attr<chunk_size = 5>>
- return
-}
-
-// -----
-func.func @create_tdesc_vc_5(%src: memref<?xf32>) {
%0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
%1 = xegpu.create_tdesc %src, %0 : memref<?xf32>, vector<4xindex>
// expected-error at +1 {{expected last dim of tensor to match chunk size}}
@@ -227,7 +218,7 @@ func.func @create_tdesc_vc_5(%src: memref<?xf32>) {
}
// -----
-func.func @create_tdesc_vc_6(%src: memref<?xf16>) {
+func.func @create_tdesc_vc_5(%src: memref<?xf16>) {
%0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
%1 = xegpu.create_tdesc %src, %0 : memref<?xf16>, vector<4xindex>
// expected-error at +1 {{last dim of tensor to be a multiple of 2}}
@@ -258,23 +249,15 @@ func.func @prefetch_vc_2(%src: ui64) {
func.func @create_tdesc_layout_1(%src: ui64) {
%cst = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
// expected-error at +1 {{expected layout rank to match tensor rank}}
- %1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<lane_layout = [4, 1], lane_data = [1, 1]>>
+ %1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<lane_layout = [4, 1], lane_data = [1, 1]>>
return
}
// -----
func.func @create_tdesc_layout_2(%src: ui64) {
%cst = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
- // expected-error at +1 {{cannot map over non-contiguous scattered row elements}}
- %1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.layout<lane_layout = [1, 4], lane_data = [2, 1]>>
- return
-}
-
-// -----
-func.func @create_tdesc_layout_3(%src: ui64) {
- %cst = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
- // expected-error at +1 {{work item data mapping must match the number of contiguous elements}}
- %1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x3xf32, #xegpu.scatter_tdesc_attr<chunk_size = 3>, #xegpu.layout<lane_layout = [4, 1], lane_data = [1, 2]>>
+ // expected-error at +1 {{expected last dim of lane_data to be a multiple of: 2}}
+ %1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x4xf16, #xegpu.scatter_tdesc_attr<chunk_size = 4>, #xegpu.layout<lane_layout = [4, 1], lane_data = [1, 1]>>
return
}
@@ -453,27 +436,6 @@ func.func @tensor_desc_invalid_map_data_1(%src: memref<24x32xf32>) {
return
}
-// -----
-func.func @tensor_desc_scatter_invalid_map_data(%src: ui64) {
- %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
- %1 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> ->
- // expected-error at +1 {{cannot map over non-contiguous scattered row elements}}
- !xegpu.tensor_desc<4x2xf32,
- #xegpu.scatter_tdesc_attr<chunk_size = 2>,
- #xegpu.layout<lane_layout = [1, 1], lane_data = [2, 1]>>
- return
-}
-
-// -----
-func.func @tensor_desc_scatter_invalid_map_data_1(%src: ui64, %offsets: vector<16xindex>) {
- %1 = xegpu.create_tdesc %src, %offsets : ui64, vector<16xindex> ->
- // expected-error at +1 {{work item data mapping must match the number of contiguous elements}}
- !xegpu.tensor_desc<16xf32,
- #xegpu.scatter_tdesc_attr<>,
- #xegpu.layout<lane_layout = [8], lane_data = [2]>>
- return
-}
-
// -----
func.func @tensor_desc_scatter_invalid_chunk_size_1D(%src: ui64, %offsets: vector<16xindex>) {
%1 = xegpu.create_tdesc %src, %offsets : ui64, vector<16xindex> ->
>From bf24cc82e6e0242f3c9660afc8c5bf68e8271474 Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Thu, 26 Jun 2025 17:41:56 +0000
Subject: [PATCH 04/14] relax rank restriction on MaskType and OffsetType
---
mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
index ccbf4ecdf05ec..bd30335ddc344 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
@@ -19,8 +19,8 @@ def XeGPU_ScalarType: AnyTypeOf<[XeGPU_IntType, XeGPU_FloatType]>;
def XeGPU_BaseAddrType: AnyTypeOf<[Non0RankedMemRefOf<[XeGPU_ScalarType]>, UI64, UI32, I64, I32]>;
def XeGPU_DpasOprType: FixedVectorOfRankAndType<[1, 2, 3], [XeGPU_ScalarType]>;
def XeGPU_DpasResType: FixedVectorOfRankAndType<[1, 2], [XeGPU_ScalarType]>;
-def XeGPU_OffsetType: FixedVectorOfRankAndType<[1, 2], [Index]>;
-def XeGPU_MaskType: FixedVectorOfRankAndType<[1, 2], [I1]>;
+def XeGPU_OffsetType: FixedVectorOfNonZeroRankOf<[Index]>;
+def XeGPU_MaskType: FixedVectorOfNonZeroRankOf<[I1]>;
def XeGPU_ValueType: FixedVectorOfNonZeroRankOf<[XeGPU_ScalarType]>;
def XeGPU_Vector2DType: FixedVectorOfRankAndType<[2], [XeGPU_ScalarType]>;
>From d8ce71ef3a3b0cdc54634e17799c1b1001ce58ba Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Thu, 26 Jun 2025 18:03:36 +0000
Subject: [PATCH 05/14] add verifier for updateOffsetOp
---
.../include/mlir/Dialect/XeGPU/IR/XeGPUOps.td | 2 ++
mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 24 +++++++++++++++++--
2 files changed, 24 insertions(+), 2 deletions(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index daab65ec893b8..b6f047d132c87 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -757,6 +757,8 @@ def XeGPU_UpdateOffsetOp: XeGPU_Op<"update_offset",
let assemblyFormat = [{
$TensorDesc `,` $offsets attr-dict `:` qualified(type($TensorDesc)) `,` type($offsets)
}];
+
+ let hasVerifier = 1;
}
def XeGPU_DpasOp : XeGPU_Op<"dpas", [Pure, AllElementTypesMatch<["lhs", "rhs"]>]> {
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index d106f1c53c96b..3fbc4eff6ad67 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -87,9 +87,12 @@ isValidGatherScatterParams(Type maskTy, VectorType valueTy,
return emitError()
<< "Value should have the same element type as TensorDesc.";
- if (tdescShape[0] != maskShape[0])
+ llvm::SmallVector<int64_t> expectedMaskShape(tdescShape);
+ if (chunkSize > 1)
+ expectedMaskShape.pop_back();
+ if (expectedMaskShape != maskShape)
return emitError()
- << "dim-0 of the Mask and TensorDesc should be the same.";
+ << "Mask should match TensorDesc except the chunk size dim.";
// a valid shape for SIMT case
if (valueTy.getRank() == 1 && valueTy.getNumElements() == chunkSize) {
@@ -552,6 +555,23 @@ void UpdateOffsetOp::build(OpBuilder &builder, OperationState &state,
build(builder, state, tensorDesc, ofrs);
}
+LogicalResult UpdateOffsetOp::verify() {
+ auto tdescTy = getTensorDescType();
+ if (!tdescTy.isScattered())
+ return emitOpError("Expects a scattered TensorDesc.\n");
+
+ auto expectedOffsetShape = getShapeOf(tdescTy);
+ auto offsetShape = getShapeOf(getOffsetsType());
+ if (tdescTy.getChunkSize() > 1)
+ expectedOffsetShape.pop_back();
+
+ if (expectedOffsetShape != offsetShape)
+ return emitOpError(
+ "Offsets should match TensorDesc except the chunk size dim.");
+
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// XeGPU_DpasOp
//===----------------------------------------------------------------------===//
>From 75c86c3a1b54bc50b1ef1aa19e3b5a92ba97d004 Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Thu, 26 Jun 2025 18:50:15 +0000
Subject: [PATCH 06/14] add unit tests
---
mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 19 +-----
mlir/test/Dialect/XeGPU/ops.mlir | 80 ++++++++++++++++++++++++++
2 files changed, 81 insertions(+), 18 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 3fbc4eff6ad67..3f6f596449429 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -442,24 +442,7 @@ LogicalResult CreateDescOp::verify() {
// check total size
auto chunkSize = tdescTy.getChunkSize();
- auto elemBits = tdescTy.getElementType().getIntOrFloatBitWidth();
- auto bitsPerLane = elemBits * chunkSize;
- if (chunkSize > 1 && bitsPerLane % 32) {
- // For 8-bit and 16-bit data, the hardware only supports chunk size of 1.
- // For 32-bit data, the hardware can support larger larger chunk size. So
- // we can bitcast 8-bit/16-bit data to 32-bit data for better performance.
- // But this requires the total size is 32 bit aligned to make the
- // optimization work.
- return emitOpError(
- "access size (chunk_size * sizeof(elemTy)) should be 32-bit aligned.");
- }
-
- auto lscConstraints = 512 * 8; // each access is upto 512 bytes.
- if (elemBits * tdescTy.getNumElements() > lscConstraints)
- return emitOpError("total access size (simd_lanes * chunk_size * "
- "sizeof(elemTy)) is upto 512 bytes.");
-
- SmallVector<int64_t> shape({(int64_t)getNumOffsets()});
+ SmallVector<int64_t> shape(getOffsetsType().getShape());
if (chunkSize != 1)
shape.push_back(chunkSize);
diff --git a/mlir/test/Dialect/XeGPU/ops.mlir b/mlir/test/Dialect/XeGPU/ops.mlir
index 5c73810575cb0..252c6eeaaf6ec 100644
--- a/mlir/test/Dialect/XeGPU/ops.mlir
+++ b/mlir/test/Dialect/XeGPU/ops.mlir
@@ -54,6 +54,13 @@ gpu.func @create_nd_tdesc_6(%src: memref<24x32xf32>) {
gpu.return
}
+// CHECK: gpu.func @create_nd_tdesc_7(%[[arg0:.*]]: memref<8x24x32x48x64xf32>) {
+gpu.func @create_nd_tdesc_7(%src: memref<8x24x32x48x64xf32>) {
+ // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0, 0, 0, 0] : memref<8x24x32x48x64xf32> -> !xegpu.tensor_desc<8x8x8x24x32xf32>
+ %1 = xegpu.create_nd_tdesc %src[0, 0, 0, 0, 0] : memref<8x24x32x48x64xf32> -> !xegpu.tensor_desc<8x8x8x24x32xf32>
+ gpu.return
+}
+
// CHECK: gpu.func @prefetch_nd(%[[arg0:.*]]: memref<24x32xf16>) {
gpu.func @prefetch_nd(%src: memref<24x32xf16>) {
@@ -64,6 +71,14 @@ gpu.func @prefetch_nd(%src: memref<24x32xf16>) {
gpu.return
}
+// CHECK: gpu.func @prefetch_nd_2(%[[arg0:.*]]: memref<8x24x32x48x64xf16>) {
+gpu.func @prefetch_nd_2(%src: memref<8x24x32x48x64xf16>) {
+ // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0, 0, 0, 0] : memref<8x24x32x48x64xf16> -> !xegpu.tensor_desc<1x2x4x8x16xf16>
+ %1 = xegpu.create_nd_tdesc %src[0, 0, 0, 0, 0] : memref<8x24x32x48x64xf16> -> !xegpu.tensor_desc<1x2x4x8x16xf16>
+ // CHECK: xegpu.prefetch_nd %[[R0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<1x2x4x8x16xf16>
+ xegpu.prefetch_nd %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>: !xegpu.tensor_desc<1x2x4x8x16xf16>
+ gpu.return
+}
// CHECK: func @subgroup_load_nd(%[[arg0:.*]]: memref<8x16xf16>) {
gpu.func @subgroup_load_nd(%src: memref<8x16xf16>) {
@@ -213,6 +228,15 @@ gpu.func @simt_load_nd_8(%src: memref<24x32xf32>) {
gpu.return
}
+// CHECK: func @subgroup_load_nd_9(%[[arg0:.*]]: memref<4x8x16xf16>) {
+gpu.func @subgroup_load_nd_9(%src: memref<4x8x16xf16>) {
+ // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0, 0] : memref<4x8x16xf16> -> !xegpu.tensor_desc<4x8x16xf16>
+ %1 = xegpu.create_nd_tdesc %src[0, 0, 0] : memref<4x8x16xf16> -> !xegpu.tensor_desc<4x8x16xf16>
+ // CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<4x8x16xf16> -> vector<4x8x16xf16>
+ %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<4x8x16xf16> -> vector<4x8x16xf16>
+ gpu.return
+}
+
// CHECK: func @subgroup_store_nd(%[[arg0:.*]]: memref<24x32xf16>) {
gpu.func @subgroup_store_nd(%dst: memref<24x32xf16>) {
// CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<24x32xf16>
@@ -257,6 +281,17 @@ gpu.func @simt_store_nd_2(%src: memref<24x32xf16>) {
gpu.return
}
+// CHECK: func @subgroup_store_nd_3(%[[arg0:.*]]: memref<8x24x32xf16>) {
+gpu.func @subgroup_store_nd_3(%dst: memref<8x24x32xf16>) {
+ // CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<8x24x32xf16>
+ %1 = arith.constant dense<1.0>: vector<8x24x32xf16>
+ // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0, 0] : memref<8x24x32xf16> -> !xegpu.tensor_desc<8x24x32xf16>
+ %2 = xegpu.create_nd_tdesc %dst[0, 0, 0] : memref<8x24x32xf16> -> !xegpu.tensor_desc<8x24x32xf16>
+ // CHECK: xegpu.store_nd %[[C]], %[[R0]] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}> : vector<8x24x32xf16>, !xegpu.tensor_desc<8x24x32xf16>
+ xegpu.store_nd %1, %2 <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>: vector<8x24x32xf16>, !xegpu.tensor_desc<8x24x32xf16>
+ gpu.return
+}
+
// CHECK: gpu.func @update_nd_tdesc(%[[arg0:.*]]: memref<24x32xf32>) {
gpu.func @update_nd_tdesc(%src: memref<24x32xf32>) {
// CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
@@ -266,6 +301,14 @@ gpu.func @update_nd_tdesc(%src: memref<24x32xf32>) {
gpu.return
}
+// CHECK: gpu.func @update_nd_tdesc_2(%[[arg0:.*]]: memref<8x24x32xf32>) {
+gpu.func @update_nd_tdesc_2(%src: memref<8x24x32xf32>) {
+ // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %arg0[0, 0, 0] : memref<8x24x32xf32> -> !xegpu.tensor_desc<2x8x16xf32>
+ %1 = xegpu.create_nd_tdesc %src[0, 0, 0] : memref<8x24x32xf32> -> !xegpu.tensor_desc<2x8x16xf32>
+ // CHECK: %[[R1:.*]] = xegpu.update_nd_offset %[[REG]], [0, 0, 16] : !xegpu.tensor_desc<2x8x16xf32>
+ %2 = xegpu.update_nd_offset %1, [0, 0, 16]: !xegpu.tensor_desc<2x8x16xf32>
+ gpu.return
+}
// CHECK: gpu.func @create_tdesc(%[[arg0:.*]]: ui64) {
gpu.func @create_tdesc(%src: ui64) {
@@ -306,6 +349,15 @@ gpu.func @create_tdesc_3(%src: ui64) {
gpu.return
}
+// CHECK: gpu.func @create_tdesc_4(%[[arg0:.*]]: ui64) {
+gpu.func @create_tdesc_4(%src: ui64) {
+ //CHECK: %[[cst:.*]] = arith.constant dense<{{.*}}> : vector<2x4xindex>
+ %0 = arith.constant dense<[[0, 8, 16, 24], [32, 40, 48, 56]]> : vector<2x4xindex>
+ //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<2x4xindex> -> !xegpu.tensor_desc<2x4x2xf16, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
+ %1 = xegpu.create_tdesc %src, %0 : ui64, vector<2x4xindex> -> !xegpu.tensor_desc<2x4x2xf16, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
+ gpu.return
+}
+
// CHECK: gpu.func @subgroup_load(%[[arg0:.*]]: ui64) {
gpu.func @subgroup_load(%src: ui64) {
@@ -385,6 +437,19 @@ gpu.func @simt_load_3(%src: ui64) {
gpu.return
}
+// CHECK: gpu.func @subgroup_load_4(%[[arg0:.*]]: ui64) {
+gpu.func @subgroup_load_4(%src: ui64) {
+ //CHECK: %[[cst:.*]] = arith.constant dense<{{.*}}> : vector<2x4xindex>
+ %0 = arith.constant dense<[[0, 8, 16, 24], [32, 40, 48, 56]]> : vector<2x4xindex>
+ //CHECK: %[[cst1:.*]] = arith.constant dense<true> : vector<2x4xi1>
+ %1 = arith.constant dense<1>: vector<2x4xi1>
+ //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<2x4xindex> -> !xegpu.tensor_desc<2x4x8xf16, #xegpu.scatter_tdesc_attr<chunk_size = 8 : i64>>
+ %2 = xegpu.create_tdesc %src, %0 : ui64, vector<2x4xindex> -> !xegpu.tensor_desc<2x4x8xf16, #xegpu.scatter_tdesc_attr<chunk_size = 8>>
+ //CHECK: %[[R1:.*]] = xegpu.load %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<2x4x8xf16, #xegpu.scatter_tdesc_attr<chunk_size = 8 : i64>>, vector<2x4xi1> -> vector<2x4x8xf16>
+ %3 = xegpu.load %2, %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<2x4x8xf16, #xegpu.scatter_tdesc_attr<chunk_size = 8>>, vector<2x4xi1> -> vector<2x4x8xf16>
+ gpu.return
+}
+
// CHECK: gpu.func @subgroup_store(%[[arg0:.*]]: ui64) {
gpu.func @subgroup_store(%src: ui64) {
//CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
@@ -475,6 +540,21 @@ gpu.func @simt_store_3(%src: ui64) {
gpu.return
}
+// CHECK: gpu.func @subgroup_store_4(%[[arg0:.*]]: ui64) {
+gpu.func @subgroup_store_4(%src: ui64) {
+ //CHECK: %[[cst:.*]] = arith.constant dense<{{.*}}> : vector<2x4xindex>
+ %0 = arith.constant dense<[[0, 8, 16, 24], [32, 40, 48, 56]]> : vector<2x4xindex>
+ //CHECK: %[[cst1:.*]] = arith.constant dense<true> : vector<2x4xi1>
+ %1 = arith.constant dense<1>: vector<2x4xi1>
+ //CHECK: %[[cst2:.*]] = arith.constant {{.*}} : vector<2x4xf32>
+ %2 = arith.constant dense<2.9>: vector<2x4xf32>
+ //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<2x4xindex> -> !xegpu.tensor_desc<2x4xf32, #xegpu.scatter_tdesc_attr<>>
+ %3 = xegpu.create_tdesc %src, %0 : ui64, vector<2x4xindex> -> !xegpu.tensor_desc<2x4xf32, #xegpu.scatter_tdesc_attr<>>
+ //CHECK: xegpu.store %[[cst2]], %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}> : vector<2x4xf32>, !xegpu.tensor_desc<2x4xf32, #xegpu.scatter_tdesc_attr<>>, vector<2x4xi1>
+ xegpu.store %2, %3, %1 <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}> : vector<2x4xf32>, !xegpu.tensor_desc<2x4xf32, #xegpu.scatter_tdesc_attr<>>, vector<2x4xi1>
+ gpu.return
+}
+
// CHECK: gpu.func @prefetch(%[[arg0:.*]]: ui64) {
gpu.func @prefetch(%src: ui64) {
//CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
>From 741d7bb7d4465d69e898d03db1cb175fbde11a48 Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Thu, 26 Jun 2025 20:01:23 +0000
Subject: [PATCH 07/14] add 3D tensor desc blocking tests
---
.../Dialect/XeGPU/Transforms/XeGPUUnroll.cpp | 3 -
mlir/test/Dialect/XeGPU/xegpu-blocking.mlir | 91 ++++++++++++++++++-
2 files changed, 87 insertions(+), 7 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index 2c48a735bf956..66690f9e9a91a 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -625,9 +625,6 @@ struct UnrollUpdateOffsetOp : public UnrollPattern<xegpu::UpdateOffsetOp> {
Location loc = op.getLoc();
xegpu::TensorDescType tdescTy = op.getTensorDescType();
- if (tdescTy.getRank() > 2)
- return failure();
-
if (!tdescTy.isScattered())
return failure();
diff --git a/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir b/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir
index ac5fe89a67f9a..91996bb28ca70 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir
@@ -358,8 +358,8 @@ gpu.module @test_kernel {
// CHECK-SAME: [[arg0:%.+]]: ui64
// CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
// CHECK-COUNT-2: xegpu.prefetch {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
- // CHECK-COUNT-2: xegpu.update_offset {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xindex>
- // CHECK-COUNT-2: xegpu.load {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1> -> vector<16xf32>
+ // CHECK-COUNT-2: xegpu.update_offset {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xindex>
+ // CHECK-COUNT-2: xegpu.load {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1> -> vector<16xf32>
// CHECK-COUNT-2: xegpu.store {{.*}} : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1>
gpu.func @test_prefetch_load_store_update(%src: ui64) {
@@ -406,8 +406,8 @@ gpu.module @test_kernel {
// CHECK-SAME: [[arg0:%.+]]: ui64
// CHECK-COUNT-4: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
// CHECK-COUNT-4: xegpu.prefetch {{.*}} : !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
- // CHECK-COUNT-4: xegpu.update_offset {{.*}} : !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>, vector<16xindex>
- // CHECK-COUNT-4: xegpu.load {{.*}} : !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>, vector<16xi1> -> vector<16x2xf32>
+ // CHECK-COUNT-4: xegpu.update_offset {{.*}} : !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>, vector<16xindex>
+ // CHECK-COUNT-4: xegpu.load {{.*}} : !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>, vector<16xi1> -> vector<16x2xf32>
// CHECK-COUNT-4: xegpu.store {{.*}} : vector<16x2xf32>, !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>, vector<16xi1>
gpu.func @test_prefetch_load_store_update_chunk(%src: ui64) {
@@ -446,4 +446,87 @@ gpu.module @test_kernel {
}
}
+// -----
+#l = #xegpu.layout<inst_data = [8,32,16]>
+gpu.module @test_kernel {
+ gpu.func @test_3d_block_tensor_desc(%A: memref<1024x1024x1024xf16>, %B: memref<1024x1024x1024xf16>, %C: memref<1024x1024x1024xf16>) {
+ %c0 = arith.constant 0 : index
+ %c32 = arith.constant 32 : index
+ %c1024 = arith.constant 1024 : index
+ %block_id_x = gpu.block_id x
+ %m = arith.muli %block_id_x, %c32 : index
+
+ %a_tdesc = xegpu.create_nd_tdesc %A[%m, %m, %c0] : memref<1024x1024x1024xf16> -> !xegpu.tensor_desc<32x32x32xf16, #l>
+ %b_tdesc = xegpu.create_nd_tdesc %B[%m, %m, %c0] : memref<1024x1024x1024xf16> -> !xegpu.tensor_desc<32x32x32xf16, #l>
+ %c_tdesc = xegpu.create_nd_tdesc %C[%m, %m, %c0] : memref<1024x1024x1024xf16> -> !xegpu.tensor_desc<32x32x32xf16, #l>
+
+ %out:3 = scf.for %k = %c0 to %c1024 step %c32
+ iter_args(%arg0 = %a_tdesc, %arg1 = %b_tdesc, %arg2 = %c_tdesc)
+ -> (!xegpu.tensor_desc<32x32x32xf16, #l>, !xegpu.tensor_desc<32x32x32xf16, #l>, !xegpu.tensor_desc<32x32x32xf16, #l>) {
+ //CHECK-COUNT-16: xegpu.load_nd {{.*}} : !xegpu.tensor_desc<8x32x16xf16> -> vector<8x32x16xf16>
+ %a = xegpu.load_nd %arg0 : !xegpu.tensor_desc<32x32x32xf16, #l> -> vector<32x32x32xf16>
+ %b = xegpu.load_nd %arg1 : !xegpu.tensor_desc<32x32x32xf16, #l> -> vector<32x32x32xf16>
+
+ //CHECK-COUNT-8: arith.addf {{.*}} : vector<8x32x16xf16>
+ %c = arith.addf %a, %b {layout_result_0 = #l} : vector<32x32x32xf16>
+
+ //CHECK-COUNT-8: xegpu.store_nd {{.*}} : vector<8x32x16xf16>, !xegpu.tensor_desc<8x32x16xf16>
+ xegpu.store_nd %c, %arg2: vector<32x32x32xf16>, !xegpu.tensor_desc<32x32x32xf16, #l>
+
+ //CHECK-COUNT-24: xegpu.update_nd_offset {{.*}} : !xegpu.tensor_desc<8x32x16xf16>
+ %a_next_tdesc = xegpu.update_nd_offset %arg0, [0, 0, %c32] : !xegpu.tensor_desc<32x32x32xf16, #l>
+ %b_next_tdesc = xegpu.update_nd_offset %arg1, [0, 0, %c32] : !xegpu.tensor_desc<32x32x32xf16, #l>
+ %c_next_tdesc = xegpu.update_nd_offset %arg2, [0, 0, %c32] : !xegpu.tensor_desc<32x32x32xf16, #l>
+ scf.yield %a_next_tdesc, %b_next_tdesc, %c_next_tdesc
+ : !xegpu.tensor_desc<32x32x32xf16, #l>, !xegpu.tensor_desc<32x32x32xf16, #l>, !xegpu.tensor_desc<32x32x32xf16, #l>
+ }
+ gpu.return
+ }
+}
+// -----
+#l = #xegpu.layout<inst_data = [2, 8, 2]>
+gpu.module @test_kernel {
+ // CHECK-LABEL: test_3d_scattered_tensor_desc
+ // CHECK-SAME: [[arg0:%.+]]: ui64
+ // CHECK-COUNT-4: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<2x8xindex> -> !xegpu.tensor_desc<2x8x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
+ // CHECK-COUNT-4: xegpu.prefetch {{.*}} : !xegpu.tensor_desc<2x8x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
+ // CHECK-COUNT-4: xegpu.update_offset {{.*}} : !xegpu.tensor_desc<2x8x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>, vector<2x8xindex>
+ // CHECK-COUNT-4: xegpu.load {{.*}} : !xegpu.tensor_desc<2x8x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>, vector<2x8xi1> -> vector<2x8x2xf32>
+ // CHECK-COUNT-4: xegpu.store {{.*}} : vector<2x8x2xf32>, !xegpu.tensor_desc<2x8x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>, vector<2x8xi1>
+
+
+ gpu.func @test_3d_scattered_tensor_desc(%src: ui64) {
+
+ %cst = arith.constant dense<[
+ [0, 8, 16, 24, 32, 40, 48, 56],
+ [64, 72, 80, 88, 96, 104, 112, 120],
+ [128, 136, 144, 152, 160, 168, 176, 184],
+ [192, 200, 208, 216, 224, 232, 240, 248]
+ ]> : vector<4x8xindex>
+
+ %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<4x8xindex> -> !xegpu.tensor_desc<4x8x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #l>
+ xegpu.prefetch %tdesc: !xegpu.tensor_desc<4x8x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #l>
+
+ %delta = arith.constant dense<[
+ [32, 32, 32, 32, 32, 32, 32, 32],
+ [32, 32, 32, 32, 32, 32, 32, 64],
+ [128, 128, 128, 128, 128, 128, 128, 128],
+ [128, 128, 128, 128, 128, 128, 128, 256]
+ ]> : vector<4x8xindex>
+ %new_tdesc = xegpu.update_offset %tdesc, %delta
+ : !xegpu.tensor_desc<4x8x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #l>, vector<4x8xindex>
+
+ %c4 = arith.constant 4: index
+ %mask = vector.create_mask %c4, %c4: vector<4x8xi1>
+
+ %ld_vec = xegpu.load %new_tdesc, %mask <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>: !xegpu.tensor_desc<4x8x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #l>, vector<4x8xi1> -> vector<4x8x4xf32>
+
+ %st_vec = arith.addf %ld_vec, %ld_vec : vector<4x8x4xf32>
+ xegpu.store %st_vec, %tdesc, %mask <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>:
+ vector<4x8x4xf32>,
+ !xegpu.tensor_desc<4x8x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #l>,
+ vector<4x8xi1>
+ gpu.return
+ }
+}
>From 16abfeab0ea11bbc19a4d08183efb6d674b573cc Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Thu, 26 Jun 2025 20:06:33 +0000
Subject: [PATCH 08/14] refine test
---
mlir/test/Dialect/XeGPU/xegpu-blocking.mlir | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir b/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir
index 91996bb28ca70..edfbf38d0af12 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir
@@ -522,7 +522,7 @@ gpu.module @test_kernel {
%ld_vec = xegpu.load %new_tdesc, %mask <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>: !xegpu.tensor_desc<4x8x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #l>, vector<4x8xi1> -> vector<4x8x4xf32>
- %st_vec = arith.addf %ld_vec, %ld_vec : vector<4x8x4xf32>
+ %st_vec = arith.addf %ld_vec, %ld_vec {layout_result_0 = #l} : vector<4x8x4xf32>
xegpu.store %st_vec, %tdesc, %mask <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>:
vector<4x8x4xf32>,
!xegpu.tensor_desc<4x8x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #l>,
>From 3edb9ef913c9483d58cd52dab7a90474255f193b Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Thu, 26 Jun 2025 20:09:54 +0000
Subject: [PATCH 09/14] fix format
---
mlir/test/Dialect/XeGPU/xegpu-blocking.mlir | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir b/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir
index edfbf38d0af12..70d2db5c41cfa 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir
@@ -491,8 +491,8 @@ gpu.module @test_kernel {
// CHECK-SAME: [[arg0:%.+]]: ui64
// CHECK-COUNT-4: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<2x8xindex> -> !xegpu.tensor_desc<2x8x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
// CHECK-COUNT-4: xegpu.prefetch {{.*}} : !xegpu.tensor_desc<2x8x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
- // CHECK-COUNT-4: xegpu.update_offset {{.*}} : !xegpu.tensor_desc<2x8x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>, vector<2x8xindex>
- // CHECK-COUNT-4: xegpu.load {{.*}} : !xegpu.tensor_desc<2x8x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>, vector<2x8xi1> -> vector<2x8x2xf32>
+ // CHECK-COUNT-4: xegpu.update_offset {{.*}} : !xegpu.tensor_desc<2x8x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>, vector<2x8xindex>
+ // CHECK-COUNT-4: xegpu.load {{.*}} : !xegpu.tensor_desc<2x8x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>, vector<2x8xi1> -> vector<2x8x2xf32>
// CHECK-COUNT-4: xegpu.store {{.*}} : vector<2x8x2xf32>, !xegpu.tensor_desc<2x8x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>, vector<2x8xi1>
>From bdd644ca8c19f4aa54f9e680065dd4b89200eeb3 Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Fri, 27 Jun 2025 02:06:40 +0000
Subject: [PATCH 10/14] address comments
---
mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 4 ++--
mlir/test/Dialect/XeGPU/invalid.mlir | 2 +-
mlir/test/Dialect/XeGPU/xegpu-blocking.mlir | 25 +++++++++++++++++++++
3 files changed, 28 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 3f6f596449429..caef13b59f5c8 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -543,8 +543,8 @@ LogicalResult UpdateOffsetOp::verify() {
if (!tdescTy.isScattered())
return emitOpError("Expects a scattered TensorDesc.\n");
- auto expectedOffsetShape = getShapeOf(tdescTy);
- auto offsetShape = getShapeOf(getOffsetsType());
+ SmallVector<int64_t> expectedOffsetShape = getShapeOf(tdescTy);
+ SmallVector<int64_t> offsetShape = getShapeOf(getOffsetsType());
if (tdescTy.getChunkSize() > 1)
expectedOffsetShape.pop_back();
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index a6f7d0992d7e7..77918c66b82af 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -1,7 +1,7 @@
// RUN: mlir-opt %s -split-input-file -verify-diagnostics
// -----
-func.func @test_create_nd_tdesc_vc_1(%src: memref<24xf32>) {
+func.func @create_nd_tdesc_vc_1(%src: memref<24xf32>) {
// expected-error at +1 {{Expecting the TensorDesc rank is not greater than the ranks of shape, strides, offsets or the memref source}}
%1 = xegpu.create_nd_tdesc %src[0] : memref<24xf32> -> !xegpu.tensor_desc<8x16xf32>
return
diff --git a/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir b/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir
index 70d2db5c41cfa..018b4a0f02858 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir
@@ -449,13 +449,38 @@ gpu.module @test_kernel {
// -----
#l = #xegpu.layout<inst_data = [8,32,16]>
gpu.module @test_kernel {
+ // CHECK-LABEL: test_3d_block_tensor_desc
+ // CHECK-SAME: [[arg0:%.+]]: memref<1024x1024x1024xf16>, [[arg1:%.+]]: memref<1024x1024x1024xf16>, [[arg2:%.+]]: memref<1024x1024x1024xf16>
gpu.func @test_3d_block_tensor_desc(%A: memref<1024x1024x1024xf16>, %B: memref<1024x1024x1024xf16>, %C: memref<1024x1024x1024xf16>) {
+ //CHECK: [[c24:%.*]] = arith.constant 24 : index
+ //CHECK: [[c8:%.*]] = arith.constant 8 : index
+ //CHECK: [[c16:%.*]] = arith.constant 16 : index
+ //CHECK: [[c0:%.*]] = arith.constant 0 : index
+ //CHECK: [[c32:%.*]] = arith.constant 32 : index
%c0 = arith.constant 0 : index
%c32 = arith.constant 32 : index
%c1024 = arith.constant 1024 : index
+
+ //CHECK: [[block_id_x:%.*]] = gpu.block_id x
+ //CHECK: [[m:%.*]] = arith.muli [[block_id_x]], [[c32]] : index
%block_id_x = gpu.block_id x
%m = arith.muli %block_id_x, %c32 : index
+ //CHECK: xegpu.create_nd_tdesc [[arg0]][[[m]], [[m]], [[c0]]] : memref<1024x1024x1024xf16> -> !xegpu.tensor_desc<8x32x16xf16>
+ //CHECK: xegpu.create_nd_tdesc [[arg0]][[[m]], [[m]], [[c16]]] : memref<1024x1024x1024xf16> -> !xegpu.tensor_desc<8x32x16xf16>
+ //CHECK: [[off1:%.*]] = arith.addi [[m]], [[c8]] : index
+ //CHECK: xegpu.create_nd_tdesc [[arg0]][[[off1]], [[m]], [[c0]]] : memref<1024x1024x1024xf16> -> !xegpu.tensor_desc<8x32x16xf16>
+ //CHECK: [[off2:%.*]] = arith.addi [[m]], [[c8]] : index
+ //CHECK: xegpu.create_nd_tdesc [[arg0]][[[off2]], [[m]], [[c16]]] : memref<1024x1024x1024xf16> -> !xegpu.tensor_desc<8x32x16xf16>
+ //CHECK: [[off3:%.*]] = arith.addi [[m]], [[c16]] : index
+ //CHECK: xegpu.create_nd_tdesc [[arg0]][[[off3]], [[m]], [[c0]]] : memref<1024x1024x1024xf16> -> !xegpu.tensor_desc<8x32x16xf16>
+ //CHECK: [[off4:%.*]] = arith.addi [[m]], [[c16]] : index
+ //CHECK: xegpu.create_nd_tdesc [[arg0]][[[off4]], [[m]], [[c16]]] : memref<1024x1024x1024xf16> -> !xegpu.tensor_desc<8x32x16xf16>
+ //CHECK: [[off5:%.*]] = arith.addi [[m]], [[c24]] : index
+ //CHECK: xegpu.create_nd_tdesc [[arg0]][[[off5]], [[m]], [[c0]]] : memref<1024x1024x1024xf16> -> !xegpu.tensor_desc<8x32x16xf16>
+ //CHECK: [[off6:%.*]] = arith.addi [[m]], [[c24]] : index
+ //CHECK: xegpu.create_nd_tdesc [[arg0]][[[off6]], [[m]], [[c16]]] : memref<1024x1024x1024xf16> -> !xegpu.tensor_desc<8x32x16xf16>
+
%a_tdesc = xegpu.create_nd_tdesc %A[%m, %m, %c0] : memref<1024x1024x1024xf16> -> !xegpu.tensor_desc<32x32x32xf16, #l>
%b_tdesc = xegpu.create_nd_tdesc %B[%m, %m, %c0] : memref<1024x1024x1024xf16> -> !xegpu.tensor_desc<32x32x32xf16, #l>
%c_tdesc = xegpu.create_nd_tdesc %C[%m, %m, %c0] : memref<1024x1024x1024xf16> -> !xegpu.tensor_desc<32x32x32xf16, #l>
>From ad0baf65da3b6cab5a84f7a44f142f4102eb06e1 Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Fri, 27 Jun 2025 02:34:17 +0000
Subject: [PATCH 11/14] address comments
---
mlir/test/Dialect/XeGPU/xegpu-blocking.mlir | 5 ++++-
1 file changed, 4 insertions(+), 1 deletion(-)
diff --git a/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir b/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir
index 018b4a0f02858..7da336272555e 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir
@@ -514,6 +514,10 @@ gpu.module @test_kernel {
gpu.module @test_kernel {
// CHECK-LABEL: test_3d_scattered_tensor_desc
// CHECK-SAME: [[arg0:%.+]]: ui64
+ // CHECK: [[cst_1:%.+]] = arith.constant dense<{{.*}}[130, 138, 146, 154, 162, 170, 178, 186], [194, 202, 210, 218, 226, 234, 242, 250]]> : vector<2x8xindex>
+ // CHECK: [[cst_2:%.+]] = arith.constant dense<{{.*}}[2, 10, 18, 26, 34, 42, 50, 58], [66, 74, 82, 90, 98, 106, 114, 122]]> : vector<2x8xindex>
+ // CHECK: [[cst_3:%.+]] = arith.constant dense<{{.*}}[0, 8, 16, 24, 32, 40, 48, 56], [64, 72, 80, 88, 96, 104, 112, 120]]> : vector<2x8xindex>
+ // CHECK: [[cst_4:%.+]] = arith.constant dense<{{.*}}[128, 136, 144, 152, 160, 168, 176, 184], [192, 200, 208, 216, 224, 232, 240, 248]]> : vector<2x8xindex>
// CHECK-COUNT-4: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<2x8xindex> -> !xegpu.tensor_desc<2x8x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
// CHECK-COUNT-4: xegpu.prefetch {{.*}} : !xegpu.tensor_desc<2x8x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
// CHECK-COUNT-4: xegpu.update_offset {{.*}} : !xegpu.tensor_desc<2x8x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>, vector<2x8xindex>
@@ -522,7 +526,6 @@ gpu.module @test_kernel {
gpu.func @test_3d_scattered_tensor_desc(%src: ui64) {
-
%cst = arith.constant dense<[
[0, 8, 16, 24, 32, 40, 48, 56],
[64, 72, 80, 88, 96, 104, 112, 120],
>From db6115c61b36d9d7b0415fc7be56e9bc5760ac06 Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Mon, 30 Jun 2025 19:03:55 +0000
Subject: [PATCH 12/14] address comments
---
.../mlir/Dialect/XeGPU/IR/XeGPUAttrs.td | 22 +++++++++++++------
.../mlir/Dialect/XeGPU/IR/XeGPUTypes.td | 4 +++-
mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp | 10 +++++++++
mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 6 ++---
.../XeGPU/Transforms/XeGPUBlocking.cpp | 2 +-
.../Dialect/XeGPU/Transforms/XeGPUUnroll.cpp | 8 +++----
.../lib/Dialect/XeGPU/TestXeGPUTransforms.cpp | 2 +-
7 files changed, 37 insertions(+), 17 deletions(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index bcd5724835783..ffab875b94cbf 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -42,9 +42,18 @@ def XeGPU_BlockTensorDescAttr: XeGPU_TensorDescAttr<"BlockTensorDesc", "block_td
}];
let parameters = (ins
- OptionalParameter<"MemorySpaceAttr">: $memory_space,
- OptionalParameter<"IntegerAttr">: $array_length,
- OptionalParameter<"BoolAttr">: $boundary_check
+ DefaultValuedParameter<
+ "MemorySpaceAttr",
+ "MemorySpaceAttr::get($_ctxt, xegpu::MemorySpace::Global)",
+ "Data memory location">: $memory_space,
+ DefaultValuedParameter<
+ "IntegerAttr",
+ "IntegerAttr::get(IntegerType::get($_ctxt, 64), 1)",
+ "Number of continuous blocks to load">: $array_length,
+ DefaultValuedParameter<
+ "BoolAttr",
+ "BoolAttr::get($_ctxt, 1)",
+ "Checking the out of boundary access">: $boundary_check
);
let builders = [
@@ -68,10 +77,7 @@ def XeGPU_ScatterTensorDescAttr: XeGPU_TensorDescAttr<"ScatterTensorDesc", "scat
It is default to `Global`.
2. `chunk_size`: Specifies the number of contiguous elements accessed per offset.
- The default value is 1. While XeGPU supports a range of chunk sizes, hardware
- may only allow specific values (e.g., 1, 2, 3, 4, 8, 16, 32, 64, 128, 256).
- Therefore, XeGPU will legalize the chunk size as needed prior to lowering to
- hardware instructions.
+ The default value is 1.
}];
let parameters = (ins
@@ -99,6 +105,8 @@ def XeGPU_ScatterTensorDescAttr: XeGPU_TensorDescAttr<"ScatterTensorDesc", "scat
return getChunkSize().getInt();
}
}];
+
+ let genVerifyDecl = 1;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
index bd30335ddc344..c8e9b3a96ec83 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
@@ -156,6 +156,7 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
return MemorySpace::Global;
}
+ // get the ChunkSize for blocked TensorDesc
int getArrayLength() {
auto attr = getEncoding();
auto block_attr = mlir::dyn_cast_if_present<BlockTensorDescAttr>(attr);
@@ -180,7 +181,8 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
return bool(getEncodingAsScatterTensorDescAttr());
}
- int getChunkSize() {
+ // get the ChunkSize for scattered TensorDesc
+ int getChunkSizeAsInt() {
auto attr = getEncoding();
auto scatter_attr = mlir::dyn_cast_if_present<ScatterTensorDescAttr>(attr);
assert(scatter_attr && "invalid on non ScatterTensorDescAttr.");
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 32a4bf883829f..0f9cd95cf63ca 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -125,6 +125,16 @@ ScatterTensorDescAttr::get(mlir::MLIRContext *context,
return Base::get(context, scopeAttr, chunkSizeAttr);
}
+LogicalResult ScatterTensorDescAttr::verify(
+ llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
+ MemorySpaceAttr memory_space, IntegerAttr chunk_size) {
+ int64_t chunkSize = chunk_size.getInt();
+ if (chunkSize <= 0)
+ return emitError() << "invalid chunk size";
+
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// XeGPU_LayoutAttr
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index caef13b59f5c8..e053a77aea0b2 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -81,7 +81,7 @@ isValidGatherScatterParams(Type maskTy, VectorType valueTy,
auto maskShape = getShapeOf(maskTy);
auto valueShape = getShapeOf(valueTy);
auto tdescShape = getShapeOf(tdescTy);
- auto chunkSize = tdescTy.getChunkSize();
+ auto chunkSize = tdescTy.getChunkSizeAsInt();
if (valueTy.getElementType() != tdescTy.getElementType())
return emitError()
@@ -441,7 +441,7 @@ LogicalResult CreateDescOp::verify() {
<< ", TensorDesc: " << tdescMemorySpace;
// check total size
- auto chunkSize = tdescTy.getChunkSize();
+ auto chunkSize = tdescTy.getChunkSizeAsInt();
SmallVector<int64_t> shape(getOffsetsType().getShape());
if (chunkSize != 1)
shape.push_back(chunkSize);
@@ -545,7 +545,7 @@ LogicalResult UpdateOffsetOp::verify() {
SmallVector<int64_t> expectedOffsetShape = getShapeOf(tdescTy);
SmallVector<int64_t> offsetShape = getShapeOf(getOffsetsType());
- if (tdescTy.getChunkSize() > 1)
+ if (tdescTy.getChunkSizeAsInt() > 1)
expectedOffsetShape.pop_back();
if (expectedOffsetShape != offsetShape)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
index c6c4e3aaa41ed..ddc9e0eb908ac 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
@@ -303,7 +303,7 @@ void XeGPUBlockingPass::runOnOperation() {
// If the encoding is a ScatterTensorDescAttr, we need to
// potentially adjust the chunk size based on the inst_data.
if (tdescTy.isScattered()) {
- int64_t chunkSize = tdescTy.getChunkSize();
+ int64_t chunkSize = tdescTy.getChunkSizeAsInt();
if (chunkSize > 1) {
int64_t blockedChunkSize = chunkSize;
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index 66690f9e9a91a..13d49cb0e9d82 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -413,7 +413,7 @@ struct UnrollCreateDescOp : public UnrollPattern<xegpu::CreateDescOp> {
return failure();
SmallVector<int64_t> targetIndiceShape(*targetShape);
- int64_t originalChunkSize = tdescTy.getChunkSize();
+ int64_t originalChunkSize = tdescTy.getChunkSizeAsInt();
// IndiceVec is 1 dim lower than tdescTy when chunkSize is larger than 1.
if (originalChunkSize > 1)
targetIndiceShape.pop_back();
@@ -480,7 +480,7 @@ struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
return failure();
SmallVector<int64_t> targetMaskShape(*targetShape);
- int64_t originalChunkSize = tdescTy.getChunkSize();
+ int64_t originalChunkSize = tdescTy.getChunkSizeAsInt();
VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());
@@ -571,7 +571,7 @@ struct UnrollStoreScatterOp : public UnrollPattern<xegpu::StoreScatterOp> {
return failure();
SmallVector<int64_t> targetMaskShape(*targetShape);
- int64_t originalChunkSize = tdescTy.getChunkSize();
+ int64_t originalChunkSize = tdescTy.getChunkSizeAsInt();
VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());
@@ -642,7 +642,7 @@ struct UnrollUpdateOffsetOp : public UnrollPattern<xegpu::UpdateOffsetOp> {
SmallVector<Type> convertedOffsetTypes;
SmallVector<Value> convertedOffsetVec;
SmallVector<Value> newOps;
- int64_t originalChunkSize = tdescTy.getChunkSize();
+ int64_t originalChunkSize = tdescTy.getChunkSizeAsInt();
if (originalChunkSize > 1) {
auto targetOffsetShape = ArrayRef<int64_t>(*targetShape).drop_back();
convertedOffsetTypes = getUnrolledTypes(offsetVecTy, targetOffsetShape);
diff --git a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
index 335f89f1826aa..f71fcf7ca297b 100644
--- a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
+++ b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
@@ -107,7 +107,7 @@ struct TestXeGPUUnrollingPatterns
// If the encoding is a ScatterTensorDescAttr, we need to
// potentially adjust the chunk size based on the inst_data.
if (tdescTy.isScattered()) {
- int64_t chunkSize = tdescTy.getChunkSize();
+ int64_t chunkSize = tdescTy.getChunkSizeAsInt();
if (chunkSize > 1) {
int64_t blockedChunkSize = chunkSize;
>From 8a8fa74f3cc1191e8e842480a646326c02491d84 Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Mon, 30 Jun 2025 19:36:53 +0000
Subject: [PATCH 13/14] add unit test
---
mlir/test/Dialect/XeGPU/invalid.mlir | 9 +++++++++
1 file changed, 9 insertions(+)
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index 77918c66b82af..51e5a828377e6 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -199,6 +199,15 @@ func.func @create_tdesc_vc_1(%src: ui64) {
return
}
+// -----
+func.func @create_tdesc_vc_2(%src: memref<?xf32>) {
+ %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+ %1 = xegpu.create_tdesc %src, %0 : memref<?xf32>, vector<4xindex>
+ // expected-error at +1 {{invalid chunk size}}
+ -> !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<chunk_size = 0>>
+ return
+}
+
// -----
func.func @create_tdesc_vc_3(%src: memref<?xf32>) {
%0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
>From ec4158b83e2d6c3da460a06d1f0046a485b00867 Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Tue, 1 Jul 2025 15:20:49 +0000
Subject: [PATCH 14/14] update unit tests
---
mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td | 2 +-
mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir | 3 ---
mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir | 3 ---
mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir | 1 -
.../test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir | 1 -
5 files changed, 1 insertion(+), 9 deletions(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index ffab875b94cbf..42b5b7a0d4e3f 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -52,7 +52,7 @@ def XeGPU_BlockTensorDescAttr: XeGPU_TensorDescAttr<"BlockTensorDesc", "block_td
"Number of continuous blocks to load">: $array_length,
DefaultValuedParameter<
"BoolAttr",
- "BoolAttr::get($_ctxt, 1)",
+ "BoolAttr::get($_ctxt, true)",
"Checking the out of boundary access">: $boundary_check
);
diff --git a/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir
index 7cef17df79dd2..4af7061a4f8a3 100644
--- a/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir
@@ -31,7 +31,6 @@ func.func @load_2D_vector(%source: memref<8x16x32xf32>,
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc
// CHECK-SAME: %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
// CHECK-SAME: memref<8x16x32xf32> -> !xegpu.tensor_desc<8x16xf32,
-// CHECK-SAME: boundary_check = true
// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32>
// CHECK: return %[[VEC]]
@@ -57,7 +56,6 @@ func.func @load_dynamic_source(%source: memref<?x?x?xf32>,
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
// CHECK-SAME: [%[[DIM_0]], %[[DIM_1]], %[[DIM_2]]], [%[[DIM_0_STRIDE]], %[[DIM_2]], 1]
// CHECK-SAME: memref<?x?x?xf32> -> !xegpu.tensor_desc<8x16xf32,
-// CHECK-SAME: boundary_check = true
// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32>
// CHECK: return %[[VEC]]
@@ -76,7 +74,6 @@ func.func @load_out_of_bounds(%source: memref<7x15xf32>,
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc
// CHECK-SAME: %[[SRC]][%[[OFFSET]], %[[OFFSET]]]
// CHECK-SAME: memref<7x15xf32> -> !xegpu.tensor_desc<8x16xf32,
-// CHECK-SAME: boundary_check = true
// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32>
// CHECK: return %[[VEC]]
diff --git a/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir
index 4f069ebc39db3..d68a02b54e967 100644
--- a/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir
@@ -33,7 +33,6 @@ func.func @store_2D_vector(%vec: vector<8x16xf32>,
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc
// CHECK-SAME: %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
// CHECK-SAME: memref<8x16x32xf32> -> !xegpu.tensor_desc<8x16xf32,
-// CHECK-SAME: boundary_check = true
// CHECK: xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8x16xf32>
// -----
@@ -59,7 +58,6 @@ func.func @store_dynamic_source(%vec: vector<8x16xf32>,
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
// CHECK-SAME: [%[[DIM_0]], %[[DIM_1]], %[[DIM_2]]], [%[[DIM_0_STRIDE]], %[[DIM_2]], 1]
// CHECK-SAME: memref<?x?x?xf32> -> !xegpu.tensor_desc<8x16xf32,
-// CHECK-SAME: boundary_check = true
// CHECK: xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8x16xf32>
// -----
@@ -78,7 +76,6 @@ func.func @store_out_of_bounds(%vec: vector<8x16xf32>,
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc
// CHECK-SAME: %[[SRC]][%[[OFFSET]], %[[OFFSET]]]
// CHECK-SAME: memref<7x64xf32> -> !xegpu.tensor_desc<8x16xf32,
-// CHECK-SAME: boundary_check = true
// CHECK: xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8x16xf32>
// -----
diff --git a/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
index 497eb86cea835..c2f760b29afc4 100644
--- a/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
@@ -52,7 +52,6 @@ func.func @load_zero_pad_out_of_bounds(%source: memref<32x64xf32>,
// CHECK-SAME: %[[OFFSET:.+]]: index
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET]], %[[OFFSET]]]
// CHECK-SAME: memref<32x64xf32> -> !xegpu.tensor_desc<8x16xf32,
-// CHECK-SAME: boundary_check = true
// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32>
// CHECK: return %[[VEC]]
diff --git a/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir
index 91e3fb3841f6e..8de6c2283b37c 100644
--- a/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir
@@ -81,7 +81,6 @@ func.func @store_out_of_bounds(%vec: vector<8x16xf32>,
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc
// CHECK-SAME: %[[SRC]][%[[OFFSET]], %[[OFFSET]]]
// CHECK-SAME: memref<7x64xf32> -> !xegpu.tensor_desc<8x16xf32,
-// CHECK-SAME: boundary_check = true
// CHECK: xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8x16xf32>
// -----
More information about the Mlir-commits
mailing list