[Mlir-commits] [mlir] [mlir][xegpu] Relax rank restriction of TensorDescType (PR #145916)
Chao Chen
llvmlistbot at llvm.org
Thu Jun 26 08:28:02 PDT 2025
https://github.com/chencha3 created https://github.com/llvm/llvm-project/pull/145916
None
>From d85e4ff4cb1a8687a64945a7e1c506dc3948771c Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Thu, 26 Jun 2025 01:25:06 +0000
Subject: [PATCH 1/2] remove 1D and 2D checks
---
.../mlir/Dialect/XeGPU/IR/XeGPUTypes.td | 13 +++++-----
mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp | 17 +++++--------
mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 25 ++++++-------------
mlir/test/Dialect/XeGPU/invalid.mlir | 14 +++--------
4 files changed, 22 insertions(+), 47 deletions(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
index 84314875c2ae5..d5a61174fbfe3 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
@@ -17,12 +17,12 @@ def XeGPU_IntType: AnyTypeOf<[I1, I8, I16, I32, I64, SI1, SI8, SI16, SI32, SI64,
def XeGPU_FloatType: AnyTypeOf<[F16, F32, F64, BF16, TF32]>;
def XeGPU_ScalarType: AnyTypeOf<[XeGPU_IntType, XeGPU_FloatType]>;
def XeGPU_BaseAddrType: AnyTypeOf<[Non0RankedMemRefOf<[XeGPU_ScalarType]>, UI64, UI32, I64, I32]>;
-def XeGPU_DpasOprType: VectorOfRankAndType<[1, 2, 3], [XeGPU_ScalarType]>;
-def XeGPU_DpasResType: VectorOfRankAndType<[1, 2], [XeGPU_ScalarType]>;
-def XeGPU_OffsetType: VectorOfRankAndType<[1], [Index]>;
-def XeGPU_MaskType: AnyTypeOf<[VectorOfRankAndType<[1], [I1]>, I1]>;
-def XeGPU_ValueType: AnyTypeOf<[VectorOfRankAndType<[1,2,3,4], [XeGPU_ScalarType]>, XeGPU_ScalarType]>;
-def XeGPU_Vector2DType: VectorOfRankAndType<[2], [XeGPU_ScalarType]>;
+def XeGPU_DpasOprType: FixedVectorOfRankAndType<[1, 2, 3], [XeGPU_ScalarType]>;
+def XeGPU_DpasResType: FixedVectorOfRankAndType<[1, 2], [XeGPU_ScalarType]>;
+def XeGPU_OffsetType: FixedVectorOfRankAndType<[1, 2], [Index]>;
+def XeGPU_MaskType: FixedVectorOfRankAndType<[1, 2], [I1]>;
+def XeGPU_ValueType: FixedVectorOfNonZeroRankOf<[XeGPU_ScalarType]>;
+def XeGPU_Vector2DType: FixedVectorOfRankAndType<[2], [XeGPU_ScalarType]>;
// common base class for types in XeGPU dialect
class XeGPUTypeDef<string name, string typeMnemonic, list<Trait> traits = [],
@@ -118,7 +118,6 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
];
let extraClassDeclaration = [{
- using TensorType::clone;
using mlir::ShapedType::Trait<TensorDescType>::getElementTypeBitWidth;
using mlir::ShapedType::Trait<TensorDescType>::getRank;
using mlir::ShapedType::Trait<TensorDescType>::getNumElements;
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 649e0d453015f..5f159dd223aba 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -310,13 +310,14 @@ LogicalResult TensorDescType::verify(
llvm::ArrayRef<int64_t> shape, mlir::Type elementType,
mlir::Attribute encoding, mlir::Attribute layout) {
size_t rank = shape.size();
- if (rank != 1 && rank != 2)
- return emitError() << "expected 1D or 2D tensor";
+
+ if (rank == 0)
+ return emitError() << "expected non-zero rank tensor";
auto blockAttr = mlir::dyn_cast_if_present<BlockTensorDescAttr>(encoding);
if (blockAttr) {
MemorySpaceAttr memorySpaceAttr = blockAttr.getMemorySpace();
- if (rank == 2 && memorySpaceAttr &&
+ if (rank > 1 && memorySpaceAttr &&
memorySpaceAttr.getValue() == MemorySpace::SLM)
return emitError() << "SLM is not supported for 2D block tensor";
}
@@ -329,16 +330,10 @@ LogicalResult TensorDescType::verify(
: 1;
auto scatterAttr = mlir::dyn_cast_if_present<ScatterTensorDescAttr>(encoding);
if (scatterAttr) {
- // Expected tensor ranks for scattered data:
- // - 1D tensor for fully non-contiguous elements (chunk size == 1)
- // - 2D tensor for scattered blocks (chunk size > 1)
unsigned chunkSize = scatterAttr.getChunkSize().getInt();
- if (rank == 1 && chunkSize != 1)
- return emitError() << "expected non-contiguous elements for 1D tensor";
- if (rank == 2 && chunkSize < 2)
- return emitError() << "expected chunk blocks for 2D tensor";
// If chunk size > 1, the second dimension of the tensor shape must be
- // equal to chunk size and it must be a multiple of the packing factor.
+ // equal to chunk size and it must be a multiple of the
+ // chunkAlignmentFactor.
if (chunkSize > 1) {
if (shape.back() != chunkSize)
return emitError() << "expected tensor shape[1] to match chunk size";
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 2793c7a35bc97..d106f1c53c96b 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -203,11 +203,9 @@ LogicalResult CreateNdDescOp::verify() {
"is a memref) should match with each other.");
// check result TensorDesc rank
- invalidRank = (getType().getRank() > 2 || getType().getRank() > rank);
-
- if (invalidRank)
+ if (getType().getRank() > rank)
return emitOpError(
- "Expecting the TensorDesc rank is up to 2 and not greater than the "
+ "Expecting the TensorDesc rank is not greater than the "
"ranks of shape, strides, offsets or the memref source.");
if (invalidElemTy)
@@ -247,9 +245,6 @@ LogicalResult LoadNdOp::verify() {
auto tdescTy = getTensorDescType();
auto valueTy = getType();
- if (tdescTy.getRank() > 2)
- return emitOpError("Expecting a 1D/2D TensorDesc.\n");
-
if (tdescTy.isScattered())
return emitOpError("Expects a non-scattered TensorDesc.\n");
@@ -316,15 +311,13 @@ LogicalResult LoadNdOp::verify() {
}
auto array_len = tdescTy.getArrayLength();
- if (array_len > 1) {
+ if (array_len > 1)
tdescShape.insert(tdescShape.begin(), array_len);
- }
- if (tdescShape != valueShape) {
+ if (tdescShape != valueShape)
return emitOpError() << "Result shape " << makeString(valueShape)
<< " is not consistent with tensor descriptor "
<< tdescTy;
- }
return success();
}
@@ -336,9 +329,6 @@ LogicalResult StoreNdOp::verify() {
auto dstTy = getTensorDescType(); // Tile
auto valTy = getValueType(); // Vector
- if (dstTy.getRank() > 2)
- return emitOpError("Expecting a 1D/2D TensorDesc.\n");
-
if (dstTy.isScattered())
return emitOpError("Expects a non-scattered TensorDesc.\n");
@@ -370,22 +360,21 @@ LogicalResult StoreNdOp::verify() {
return emitOpError()
<< "TensorDesc doesn't need LayoutAttr for SIMT code";
- if (tdescElems % valueElems) {
+ if (tdescElems % valueElems)
return emitOpError()
<< "Value shape " << makeString(getShapeOf(valTy))
<< " is not a valid distribution for tensor descriptor " << dstTy;
- }
+
return success();
}
// SIMD code should have the same shape as the tensor descriptor.
auto tdescShape = getShapeOf(dstTy);
auto valueShape = getShapeOf(valTy);
- if (tdescShape != valueShape) {
+ if (tdescShape != valueShape)
return emitOpError() << "Value shape " << makeString(valueShape)
<< " is not consistent with tensor descriptor "
<< dstTy;
- }
return success();
}
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index a2778cd94d963..3969ae64b6dbf 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -1,8 +1,8 @@
// RUN: mlir-opt %s -split-input-file -verify-diagnostics
// -----
-func.func @create_nd_tdesc_vc_1(%src: memref<24xf32>) {
- // expected-error at +1 {{Expecting the TensorDesc rank is up to 2 and not greater than the ranks of shape, strides, offsets or the memref source}}
+func.func @test_create_nd_tdesc_vc_1(%src: memref<24xf32>) {
+ // expected-error at +1 {{Expecting the TensorDesc rank is not greater than the ranks of shape, strides, offsets or the memref source}}
%1 = xegpu.create_nd_tdesc %src[0] : memref<24xf32> -> !xegpu.tensor_desc<8x16xf32>
return
}
@@ -406,18 +406,10 @@ func.func @atomic_rmw(%src: ui64, %value : vector<16x4xf32>, %mask : vector<16xi
return
}
-// -----
-func.func @tensor_desc_invalid_rank(%src: memref<24x32xf32>) {
- %0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
- // expected-error at +1 {{expected 1D or 2D tensor}}
- !xegpu.tensor_desc<16x2x2xf32>
- return
-}
-
// -----
func.func @tensor_desc_invalid_rank_1(%src: memref<24x32xf32>) {
%0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
- // expected-error at +1 {{expected 1D or 2D tensor}}
+ // expected-error at +1 {{expected non-zero rank tensor}}
!xegpu.tensor_desc<f32>
return
}
>From 265dc278fdf525e826bd01a8393320fcec25d423 Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Thu, 26 Jun 2025 15:26:29 +0000
Subject: [PATCH 2/2] cleanup
---
.../mlir/Dialect/XeGPU/IR/XeGPUAttrs.td | 15 ++++++++---
.../mlir/Dialect/XeGPU/IR/XeGPUTypes.td | 2 +-
mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp | 25 +++++++++++--------
.../XeGPU/Transforms/XeGPUBlocking.cpp | 6 ++---
mlir/test/Dialect/XeGPU/invalid.mlir | 21 +++++-----------
mlir/test/Dialect/XeGPU/ops.mlir | 4 +--
.../lib/Dialect/XeGPU/TestXeGPUTransforms.cpp | 8 ++----
7 files changed, 39 insertions(+), 42 deletions(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index 84c1dc1373ee5..1b02565acfd8c 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -43,8 +43,8 @@ def XeGPU_BlockTensorDescAttr: XeGPU_TensorDescAttr<"BlockTensorDesc", "block_td
let parameters = (ins
OptionalParameter<"MemorySpaceAttr">: $memory_space,
- OptionalParameter<"IntegerAttr", "1">: $array_length,
- OptionalParameter<"BoolAttr", "true">: $boundary_check
+ OptionalParameter<"IntegerAttr">: $array_length,
+ OptionalParameter<"BoolAttr">: $boundary_check
);
let builders = [
@@ -77,9 +77,8 @@ def XeGPU_ScatterTensorDescAttr: XeGPU_TensorDescAttr<"ScatterTensorDesc", "scat
"MemorySpaceAttr::get($_ctxt, xegpu::MemorySpace::Global)",
"Data memory location"
>: $memory_space,
- DefaultValuedParameter<
+ OptionalParameter<
"IntegerAttr",
- "IntegerAttr::get(IntegerType::get($_ctxt, 64), 1)",
"Number of contiguous elements"
>: $chunk_size
);
@@ -91,6 +90,14 @@ def XeGPU_ScatterTensorDescAttr: XeGPU_TensorDescAttr<"ScatterTensorDesc", "scat
)>
];
+ let extraClassDeclaration = [{
+ int64_t getChunkSizeOrDefault() {
+ if (auto attr = getChunkSize())
+ return attr.getInt();
+ return 1;
+ }
+ }];
+
let genVerifyDecl = 1;
}
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
index d5a61174fbfe3..2b99e80677098 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
@@ -185,7 +185,7 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
auto scatter_attr = mlir::dyn_cast_if_present<ScatterTensorDescAttr>(attr);
assert((!attr || scatter_attr) && "invalid on non ScatterTensorDescAttr.");
if (scatter_attr)
- return scatter_attr.getChunkSize().getInt();
+ return scatter_attr.getChunkSizeOrDefault();
return 1;
}
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 5f159dd223aba..57458513805b0 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -128,11 +128,14 @@ ScatterTensorDescAttr::get(mlir::MLIRContext *context,
LogicalResult ScatterTensorDescAttr::verify(
llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
MemorySpaceAttr memory_space, IntegerAttr chunk_size) {
- int64_t chunkSize = chunk_size.getInt();
- SmallVector<int64_t> supportedChunkSizes = {1, 2, 3, 4, 8,
- 16, 32, 64, 128, 256};
- if (!llvm::is_contained(supportedChunkSizes, chunkSize))
- return emitError() << "invalid chunk size";
+
+ if (chunk_size) {
+ int64_t chunkSize = chunk_size.getInt();
+ SmallVector<int64_t> supportedChunkSizes = {2, 3, 4, 8, 16,
+ 32, 64, 128, 256};
+ if (!llvm::is_contained(supportedChunkSizes, chunkSize))
+ return emitError() << "invalid chunk size";
+ }
return success();
}
@@ -319,7 +322,7 @@ LogicalResult TensorDescType::verify(
MemorySpaceAttr memorySpaceAttr = blockAttr.getMemorySpace();
if (rank > 1 && memorySpaceAttr &&
memorySpaceAttr.getValue() == MemorySpace::SLM)
- return emitError() << "SLM is not supported for 2D block tensor";
+ return emitError() << "SLM is only supported for 1D block tensor";
}
// for gather and scatter ops, Low-precision types are packed in 32-bit units.
@@ -330,16 +333,18 @@ LogicalResult TensorDescType::verify(
: 1;
auto scatterAttr = mlir::dyn_cast_if_present<ScatterTensorDescAttr>(encoding);
if (scatterAttr) {
- unsigned chunkSize = scatterAttr.getChunkSize().getInt();
+ int64_t chunkSize = scatterAttr.getChunkSizeOrDefault();
+ if (rank == 1 && chunkSize != 1)
+ return emitError() << "expected non-contiguous elements for 1D tensor";
+
// If chunk size > 1, the second dimension of the tensor shape must be
// equal to chunk size and it must be a multiple of the
// chunkAlignmentFactor.
if (chunkSize > 1) {
if (shape.back() != chunkSize)
- return emitError() << "expected tensor shape[1] to match chunk size";
+ return emitError() << "expected last dim of tensor to match chunk size";
if (shape.back() % chunkAlignmentFactor != 0)
- return emitError() << "expected tensor shape[1] to be a multiple of "
- "chunk alignment factor "
+ return emitError() << "expected last dim of tensor to be a multiple of "
<< chunkAlignmentFactor;
}
}
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
index 3950e8f70d1ca..c6c4e3aaa41ed 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
@@ -303,9 +303,7 @@ void XeGPUBlockingPass::runOnOperation() {
// If the encoding is a ScatterTensorDescAttr, we need to
// potentially adjust the chunk size based on the inst_data.
if (tdescTy.isScattered()) {
- auto scatterAttr =
- llvm::dyn_cast_if_present<xegpu::ScatterTensorDescAttr>(encoding);
- int64_t chunkSize = scatterAttr.getChunkSize().getInt();
+ int64_t chunkSize = tdescTy.getChunkSize();
if (chunkSize > 1) {
int64_t blockedChunkSize = chunkSize;
@@ -315,7 +313,7 @@ void XeGPUBlockingPass::runOnOperation() {
// To create a new attribute with a different chunk_size:
auto newEncoding = xegpu::ScatterTensorDescAttr::get(
- ctx, scatterAttr.getMemorySpace().getValue(), blockedChunkSize);
+ ctx, tdescTy.getMemorySpace(), blockedChunkSize);
encoding = newEncoding;
}
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index 3969ae64b6dbf..8eb629660559c 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -17,7 +17,7 @@ func.func @create_nd_tdesc_vc_2(%src: memref<24x32xf32>) {
// -----
func.func @create_nd_tdesc_vc_3(%src: memref<2x24x32xf32, 3>) {
- // expected-error at +1 {{SLM is not supported for 2D block tensor}}
+ // expected-error at +1 {{SLM is only supported for 1D block tensor}}
%1 = xegpu.create_nd_tdesc %src[0, 0, 0] : memref<2x24x32xf32, 3> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<memory_space = slm>>
return
}
@@ -199,15 +199,6 @@ func.func @create_tdesc_vc_1(%src: ui64) {
return
}
-// -----
-func.func @create_tdesc_vc_2(%src: ui64) {
- %0 = arith.constant dense<[0, 2, 4, 6, 8, 10, 12, 14]> : vector<8xindex>
- %1 = xegpu.create_tdesc %src, %0 : ui64, vector<8xindex>
- // expected-error at +1 {{expected chunk blocks for 2D tensor}}
- -> !xegpu.tensor_desc<8x4xf16, #xegpu.scatter_tdesc_attr<>>
- return
-}
-
// -----
func.func @create_tdesc_vc_3(%src: memref<?xf32>) {
%0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
@@ -230,7 +221,7 @@ func.func @create_tdesc_vc_4(%src: memref<?xf32>) {
func.func @create_tdesc_vc_5(%src: memref<?xf32>) {
%0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
%1 = xegpu.create_tdesc %src, %0 : memref<?xf32>, vector<4xindex>
- // expected-error at +1 {{expected tensor shape[1] to match chunk size}}
+ // expected-error at +1 {{expected last dim of tensor to match chunk size}}
-> !xegpu.tensor_desc<4x5xf32, #xegpu.scatter_tdesc_attr<chunk_size = 4>>
return
}
@@ -239,7 +230,7 @@ func.func @create_tdesc_vc_5(%src: memref<?xf32>) {
func.func @create_tdesc_vc_6(%src: memref<?xf16>) {
%0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
%1 = xegpu.create_tdesc %src, %0 : memref<?xf16>, vector<4xindex>
- // expected-error at +1 {{tensor shape[1] to be a multiple of chunk alignment factor 2}}
+ // expected-error at +1 {{last dim of tensor to be a multiple of 2}}
-> !xegpu.tensor_desc<4x3xf16, #xegpu.scatter_tdesc_attr<chunk_size = 3>>
return
}
@@ -478,7 +469,7 @@ func.func @tensor_desc_scatter_invalid_map_data_1(%src: ui64, %offsets: vector<1
%1 = xegpu.create_tdesc %src, %offsets : ui64, vector<16xindex> ->
// expected-error at +1 {{work item data mapping must match the number of contiguous elements}}
!xegpu.tensor_desc<16xf32,
- #xegpu.scatter_tdesc_attr<chunk_size = 1>,
+ #xegpu.scatter_tdesc_attr<>,
#xegpu.layout<lane_layout = [8], lane_data = [2]>>
return
}
@@ -496,9 +487,9 @@ func.func @tensor_desc_scatter_invalid_chunk_size_1D(%src: ui64, %offsets: vecto
// -----
func.func @tensor_desc_scatter_invalid_chunk_size_2D(%src: ui64, %offsets: vector<16xindex>) {
%1 = xegpu.create_tdesc %src, %offsets : ui64, vector<16xindex> ->
- // expected-error at +1 {{expected chunk blocks for 2D tensor}}
+ // expected-error at +1 {{expected last dim of tensor to match chunk size}}
!xegpu.tensor_desc<16x2xf32,
- #xegpu.scatter_tdesc_attr<chunk_size = 1>,
+ #xegpu.scatter_tdesc_attr<chunk_size = 4>,
#xegpu.layout<lane_layout = [8, 1], lane_data = [1, 2]>>
return
}
diff --git a/mlir/test/Dialect/XeGPU/ops.mlir b/mlir/test/Dialect/XeGPU/ops.mlir
index aff8f63adc05b..5c73810575cb0 100644
--- a/mlir/test/Dialect/XeGPU/ops.mlir
+++ b/mlir/test/Dialect/XeGPU/ops.mlir
@@ -291,8 +291,8 @@ gpu.func @create_tdesc_1(%src: memref<?xf32, 3>) {
gpu.func @create_tdesc_2(%src: memref<?xf32>) {
//CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
%0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
- //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : memref<?xf32>, vector<4xindex> -> !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>
- %1 = xegpu.create_tdesc %src, %0 : memref<?xf32>, vector<4xindex> -> !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<chunk_size = 1>>
+ //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : memref<?xf32>, vector<4xindex> -> !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>>
+ %1 = xegpu.create_tdesc %src, %0 : memref<?xf32>, vector<4xindex> -> !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>>
gpu.return
}
diff --git a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
index c84eb74198544..335f89f1826aa 100644
--- a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
+++ b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
@@ -107,10 +107,7 @@ struct TestXeGPUUnrollingPatterns
// If the encoding is a ScatterTensorDescAttr, we need to
// potentially adjust the chunk size based on the inst_data.
if (tdescTy.isScattered()) {
- auto scatterAttr =
- llvm::dyn_cast_if_present<xegpu::ScatterTensorDescAttr>(
- encoding);
- int64_t chunkSize = scatterAttr.getChunkSize().getInt();
+ int64_t chunkSize = tdescTy.getChunkSize();
if (chunkSize > 1) {
int64_t blockedChunkSize = chunkSize;
@@ -120,8 +117,7 @@ struct TestXeGPUUnrollingPatterns
// To create a new attribute with a different chunk_size:
auto newEncoding = xegpu::ScatterTensorDescAttr::get(
- ctx, scatterAttr.getMemorySpace().getValue(),
- blockedChunkSize);
+ ctx, tdescTy.getMemorySpace(), blockedChunkSize);
encoding = newEncoding;
}
More information about the Mlir-commits
mailing list