[Mlir-commits] [mlir] [mlir][xegpu] Relax rank restriction of TensorDescType (PR #145916)

Chao Chen llvmlistbot at llvm.org
Thu Jun 26 11:03:56 PDT 2025


https://github.com/chencha3 updated https://github.com/llvm/llvm-project/pull/145916

>From d85e4ff4cb1a8687a64945a7e1c506dc3948771c Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Thu, 26 Jun 2025 01:25:06 +0000
Subject: [PATCH 1/5] remove 1D and 2D checks

---
 .../mlir/Dialect/XeGPU/IR/XeGPUTypes.td       | 13 +++++-----
 mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp    | 17 +++++--------
 mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp        | 25 ++++++-------------
 mlir/test/Dialect/XeGPU/invalid.mlir          | 14 +++--------
 4 files changed, 22 insertions(+), 47 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
index 84314875c2ae5..d5a61174fbfe3 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
@@ -17,12 +17,12 @@ def XeGPU_IntType: AnyTypeOf<[I1, I8, I16, I32, I64, SI1, SI8, SI16, SI32, SI64,
 def XeGPU_FloatType: AnyTypeOf<[F16, F32, F64, BF16, TF32]>;
 def XeGPU_ScalarType: AnyTypeOf<[XeGPU_IntType, XeGPU_FloatType]>;
 def XeGPU_BaseAddrType: AnyTypeOf<[Non0RankedMemRefOf<[XeGPU_ScalarType]>, UI64, UI32, I64, I32]>;
-def XeGPU_DpasOprType: VectorOfRankAndType<[1, 2, 3], [XeGPU_ScalarType]>;
-def XeGPU_DpasResType: VectorOfRankAndType<[1, 2], [XeGPU_ScalarType]>;
-def XeGPU_OffsetType: VectorOfRankAndType<[1], [Index]>;
-def XeGPU_MaskType: AnyTypeOf<[VectorOfRankAndType<[1], [I1]>, I1]>;
-def XeGPU_ValueType: AnyTypeOf<[VectorOfRankAndType<[1,2,3,4], [XeGPU_ScalarType]>, XeGPU_ScalarType]>;
-def XeGPU_Vector2DType: VectorOfRankAndType<[2], [XeGPU_ScalarType]>;
+def XeGPU_DpasOprType: FixedVectorOfRankAndType<[1, 2, 3], [XeGPU_ScalarType]>;
+def XeGPU_DpasResType: FixedVectorOfRankAndType<[1, 2], [XeGPU_ScalarType]>;
+def XeGPU_OffsetType: FixedVectorOfRankAndType<[1, 2], [Index]>;
+def XeGPU_MaskType: FixedVectorOfRankAndType<[1, 2], [I1]>;
+def XeGPU_ValueType: FixedVectorOfNonZeroRankOf<[XeGPU_ScalarType]>;
+def XeGPU_Vector2DType: FixedVectorOfRankAndType<[2], [XeGPU_ScalarType]>;
 
 // common base class for types in XeGPU dialect
 class XeGPUTypeDef<string name, string typeMnemonic, list<Trait> traits = [],
@@ -118,7 +118,6 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
   ];
 
   let extraClassDeclaration = [{
-    using TensorType::clone;
     using mlir::ShapedType::Trait<TensorDescType>::getElementTypeBitWidth;
     using mlir::ShapedType::Trait<TensorDescType>::getRank;
     using mlir::ShapedType::Trait<TensorDescType>::getNumElements;
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 649e0d453015f..5f159dd223aba 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -310,13 +310,14 @@ LogicalResult TensorDescType::verify(
     llvm::ArrayRef<int64_t> shape, mlir::Type elementType,
     mlir::Attribute encoding, mlir::Attribute layout) {
   size_t rank = shape.size();
-  if (rank != 1 && rank != 2)
-    return emitError() << "expected 1D or 2D tensor";
+
+  if (rank == 0)
+    return emitError() << "expected non-zero rank tensor";
 
   auto blockAttr = mlir::dyn_cast_if_present<BlockTensorDescAttr>(encoding);
   if (blockAttr) {
     MemorySpaceAttr memorySpaceAttr = blockAttr.getMemorySpace();
-    if (rank == 2 && memorySpaceAttr &&
+    if (rank > 1 && memorySpaceAttr &&
         memorySpaceAttr.getValue() == MemorySpace::SLM)
       return emitError() << "SLM is not supported for 2D block tensor";
   }
@@ -329,16 +330,10 @@ LogicalResult TensorDescType::verify(
           : 1;
   auto scatterAttr = mlir::dyn_cast_if_present<ScatterTensorDescAttr>(encoding);
   if (scatterAttr) {
-    // Expected tensor ranks for scattered data:
-    //   - 1D tensor for fully non-contiguous elements (chunk size == 1)
-    //   - 2D tensor for scattered blocks (chunk size > 1)
     unsigned chunkSize = scatterAttr.getChunkSize().getInt();
-    if (rank == 1 && chunkSize != 1)
-      return emitError() << "expected non-contiguous elements for 1D tensor";
-    if (rank == 2 && chunkSize < 2)
-      return emitError() << "expected chunk blocks for 2D tensor";
     // If chunk size > 1, the second dimension of the tensor shape must be
-    // equal to chunk size and it must be a multiple of the packing factor.
+    // equal to chunk size and it must be a multiple of the
+    // chunkAlignmentFactor.
     if (chunkSize > 1) {
       if (shape.back() != chunkSize)
         return emitError() << "expected tensor shape[1] to match chunk size";
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 2793c7a35bc97..d106f1c53c96b 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -203,11 +203,9 @@ LogicalResult CreateNdDescOp::verify() {
         "is a memref) should match with each other.");
 
   // check result TensorDesc rank
-  invalidRank = (getType().getRank() > 2 || getType().getRank() > rank);
-
-  if (invalidRank)
+  if (getType().getRank() > rank)
     return emitOpError(
-        "Expecting the TensorDesc rank is up to 2 and not greater than the "
+        "Expecting the TensorDesc rank is not greater than the "
         "ranks of shape, strides, offsets or the memref source.");
 
   if (invalidElemTy)
@@ -247,9 +245,6 @@ LogicalResult LoadNdOp::verify() {
   auto tdescTy = getTensorDescType();
   auto valueTy = getType();
 
-  if (tdescTy.getRank() > 2)
-    return emitOpError("Expecting a 1D/2D TensorDesc.\n");
-
   if (tdescTy.isScattered())
     return emitOpError("Expects a non-scattered TensorDesc.\n");
 
@@ -316,15 +311,13 @@ LogicalResult LoadNdOp::verify() {
   }
 
   auto array_len = tdescTy.getArrayLength();
-  if (array_len > 1) {
+  if (array_len > 1)
     tdescShape.insert(tdescShape.begin(), array_len);
-  }
 
-  if (tdescShape != valueShape) {
+  if (tdescShape != valueShape)
     return emitOpError() << "Result shape " << makeString(valueShape)
                          << " is not consistent with tensor descriptor "
                          << tdescTy;
-  }
 
   return success();
 }
@@ -336,9 +329,6 @@ LogicalResult StoreNdOp::verify() {
   auto dstTy = getTensorDescType(); // Tile
   auto valTy = getValueType();      // Vector
 
-  if (dstTy.getRank() > 2)
-    return emitOpError("Expecting a 1D/2D TensorDesc.\n");
-
   if (dstTy.isScattered())
     return emitOpError("Expects a non-scattered TensorDesc.\n");
 
@@ -370,22 +360,21 @@ LogicalResult StoreNdOp::verify() {
       return emitOpError()
              << "TensorDesc doesn't need LayoutAttr for SIMT code";
 
-    if (tdescElems % valueElems) {
+    if (tdescElems % valueElems)
       return emitOpError()
              << "Value shape " << makeString(getShapeOf(valTy))
              << " is not a valid distribution for tensor descriptor " << dstTy;
-    }
+
     return success();
   }
 
   // SIMD code should have the same shape as the tensor descriptor.
   auto tdescShape = getShapeOf(dstTy);
   auto valueShape = getShapeOf(valTy);
-  if (tdescShape != valueShape) {
+  if (tdescShape != valueShape)
     return emitOpError() << "Value shape " << makeString(valueShape)
                          << " is not consistent with tensor descriptor "
                          << dstTy;
-  }
 
   return success();
 }
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index a2778cd94d963..3969ae64b6dbf 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -1,8 +1,8 @@
 // RUN: mlir-opt %s -split-input-file -verify-diagnostics
 
 // -----
-func.func @create_nd_tdesc_vc_1(%src: memref<24xf32>) {
-  // expected-error at +1 {{Expecting the TensorDesc rank is up to 2 and not greater than the ranks of shape, strides, offsets or the memref source}}
+func.func @test_create_nd_tdesc_vc_1(%src: memref<24xf32>) {
+  // expected-error at +1 {{Expecting the TensorDesc rank is not greater than the ranks of shape, strides, offsets or the memref source}}
   %1 = xegpu.create_nd_tdesc %src[0] : memref<24xf32> -> !xegpu.tensor_desc<8x16xf32>
   return
 }
@@ -406,18 +406,10 @@ func.func @atomic_rmw(%src: ui64, %value : vector<16x4xf32>, %mask : vector<16xi
   return
 }
 
-// -----
-func.func @tensor_desc_invalid_rank(%src: memref<24x32xf32>) {
-  %0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
-      // expected-error at +1 {{expected 1D or 2D tensor}}
-      !xegpu.tensor_desc<16x2x2xf32>
-  return
-}
-
 // -----
 func.func @tensor_desc_invalid_rank_1(%src: memref<24x32xf32>) {
   %0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
-      // expected-error at +1 {{expected 1D or 2D tensor}}
+      // expected-error at +1 {{expected non-zero rank tensor}}
       !xegpu.tensor_desc<f32>
   return
 }

>From 265dc278fdf525e826bd01a8393320fcec25d423 Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Thu, 26 Jun 2025 15:26:29 +0000
Subject: [PATCH 2/5] cleanup

---
 .../mlir/Dialect/XeGPU/IR/XeGPUAttrs.td       | 15 ++++++++---
 .../mlir/Dialect/XeGPU/IR/XeGPUTypes.td       |  2 +-
 mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp    | 25 +++++++++++--------
 .../XeGPU/Transforms/XeGPUBlocking.cpp        |  6 ++---
 mlir/test/Dialect/XeGPU/invalid.mlir          | 21 +++++-----------
 mlir/test/Dialect/XeGPU/ops.mlir              |  4 +--
 .../lib/Dialect/XeGPU/TestXeGPUTransforms.cpp |  8 ++----
 7 files changed, 39 insertions(+), 42 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index 84c1dc1373ee5..1b02565acfd8c 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -43,8 +43,8 @@ def XeGPU_BlockTensorDescAttr: XeGPU_TensorDescAttr<"BlockTensorDesc", "block_td
 
   let parameters = (ins
     OptionalParameter<"MemorySpaceAttr">: $memory_space,
-    OptionalParameter<"IntegerAttr", "1">: $array_length,
-    OptionalParameter<"BoolAttr", "true">: $boundary_check
+    OptionalParameter<"IntegerAttr">: $array_length,
+    OptionalParameter<"BoolAttr">: $boundary_check
   );
 
   let builders = [
@@ -77,9 +77,8 @@ def XeGPU_ScatterTensorDescAttr: XeGPU_TensorDescAttr<"ScatterTensorDesc", "scat
       "MemorySpaceAttr::get($_ctxt, xegpu::MemorySpace::Global)",
       "Data memory location"
     >: $memory_space,
-    DefaultValuedParameter<
+    OptionalParameter<
       "IntegerAttr",
-      "IntegerAttr::get(IntegerType::get($_ctxt, 64), 1)",
       "Number of contiguous elements"
     >: $chunk_size
   );
@@ -91,6 +90,14 @@ def XeGPU_ScatterTensorDescAttr: XeGPU_TensorDescAttr<"ScatterTensorDesc", "scat
     )>
   ];
 
+  let extraClassDeclaration = [{
+    int64_t getChunkSizeOrDefault() {
+      if (auto attr = getChunkSize())
+        return attr.getInt();
+      return 1;
+    }
+  }];
+
   let genVerifyDecl = 1;
  }
 
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
index d5a61174fbfe3..2b99e80677098 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
@@ -185,7 +185,7 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
       auto scatter_attr = mlir::dyn_cast_if_present<ScatterTensorDescAttr>(attr);
       assert((!attr || scatter_attr) && "invalid on non ScatterTensorDescAttr.");
       if (scatter_attr)
-        return scatter_attr.getChunkSize().getInt();
+        return scatter_attr.getChunkSizeOrDefault();
       return 1;
     }
 
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 5f159dd223aba..57458513805b0 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -128,11 +128,14 @@ ScatterTensorDescAttr::get(mlir::MLIRContext *context,
 LogicalResult ScatterTensorDescAttr::verify(
     llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
     MemorySpaceAttr memory_space, IntegerAttr chunk_size) {
-  int64_t chunkSize = chunk_size.getInt();
-  SmallVector<int64_t> supportedChunkSizes = {1,  2,  3,  4,   8,
-                                              16, 32, 64, 128, 256};
-  if (!llvm::is_contained(supportedChunkSizes, chunkSize))
-    return emitError() << "invalid chunk size";
+
+  if (chunk_size) {
+    int64_t chunkSize = chunk_size.getInt();
+    SmallVector<int64_t> supportedChunkSizes = {2,  3,  4,   8,  16,
+                                                32, 64, 128, 256};
+    if (!llvm::is_contained(supportedChunkSizes, chunkSize))
+      return emitError() << "invalid chunk size";
+  }
 
   return success();
 }
@@ -319,7 +322,7 @@ LogicalResult TensorDescType::verify(
     MemorySpaceAttr memorySpaceAttr = blockAttr.getMemorySpace();
     if (rank > 1 && memorySpaceAttr &&
         memorySpaceAttr.getValue() == MemorySpace::SLM)
-      return emitError() << "SLM is not supported for 2D block tensor";
+      return emitError() << "SLM is only supported for 1D block tensor";
   }
 
   // for gather and scatter ops, Low-precision types are packed in 32-bit units.
@@ -330,16 +333,18 @@ LogicalResult TensorDescType::verify(
           : 1;
   auto scatterAttr = mlir::dyn_cast_if_present<ScatterTensorDescAttr>(encoding);
   if (scatterAttr) {
-    unsigned chunkSize = scatterAttr.getChunkSize().getInt();
+    int64_t chunkSize = scatterAttr.getChunkSizeOrDefault();
+    if (rank == 1 && chunkSize != 1)
+      return emitError() << "expected non-contiguous elements for 1D tensor";
+
     // If chunk size > 1, the second dimension of the tensor shape must be
     // equal to chunk size and it must be a multiple of the
     // chunkAlignmentFactor.
     if (chunkSize > 1) {
       if (shape.back() != chunkSize)
-        return emitError() << "expected tensor shape[1] to match chunk size";
+        return emitError() << "expected last dim of tensor to match chunk size";
       if (shape.back() % chunkAlignmentFactor != 0)
-        return emitError() << "expected tensor shape[1] to be a multiple of "
-                              "chunk alignment factor "
+        return emitError() << "expected last dim of tensor to be a multiple of "
                            << chunkAlignmentFactor;
     }
   }
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
index 3950e8f70d1ca..c6c4e3aaa41ed 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
@@ -303,9 +303,7 @@ void XeGPUBlockingPass::runOnOperation() {
       // If the encoding is a ScatterTensorDescAttr, we need to
       // potentially adjust the chunk size based on the inst_data.
       if (tdescTy.isScattered()) {
-        auto scatterAttr =
-            llvm::dyn_cast_if_present<xegpu::ScatterTensorDescAttr>(encoding);
-        int64_t chunkSize = scatterAttr.getChunkSize().getInt();
+        int64_t chunkSize = tdescTy.getChunkSize();
 
         if (chunkSize > 1) {
           int64_t blockedChunkSize = chunkSize;
@@ -315,7 +313,7 @@ void XeGPUBlockingPass::runOnOperation() {
 
           // To create a new attribute with a different chunk_size:
           auto newEncoding = xegpu::ScatterTensorDescAttr::get(
-              ctx, scatterAttr.getMemorySpace().getValue(), blockedChunkSize);
+              ctx, tdescTy.getMemorySpace(), blockedChunkSize);
 
           encoding = newEncoding;
         }
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index 3969ae64b6dbf..8eb629660559c 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -17,7 +17,7 @@ func.func @create_nd_tdesc_vc_2(%src: memref<24x32xf32>) {
 
 // -----
 func.func @create_nd_tdesc_vc_3(%src: memref<2x24x32xf32, 3>) {
-  // expected-error at +1 {{SLM is not supported for 2D block tensor}}
+  // expected-error at +1 {{SLM is only supported for 1D block tensor}}
   %1 = xegpu.create_nd_tdesc %src[0, 0, 0] : memref<2x24x32xf32, 3> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<memory_space = slm>>
   return
 }
@@ -199,15 +199,6 @@ func.func @create_tdesc_vc_1(%src: ui64) {
   return
 }
 
-// -----
-func.func @create_tdesc_vc_2(%src: ui64) {
-  %0 = arith.constant dense<[0, 2, 4, 6, 8, 10, 12, 14]> : vector<8xindex>
-  %1 = xegpu.create_tdesc %src, %0 : ui64, vector<8xindex>
-  // expected-error at +1 {{expected chunk blocks for 2D tensor}}
-          -> !xegpu.tensor_desc<8x4xf16, #xegpu.scatter_tdesc_attr<>>
-  return
-}
-
 // -----
 func.func @create_tdesc_vc_3(%src: memref<?xf32>) {
   %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
@@ -230,7 +221,7 @@ func.func @create_tdesc_vc_4(%src: memref<?xf32>) {
 func.func @create_tdesc_vc_5(%src: memref<?xf32>) {
   %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
   %1 = xegpu.create_tdesc %src, %0 : memref<?xf32>, vector<4xindex>
-  // expected-error at +1 {{expected tensor shape[1] to match chunk size}}
+  // expected-error at +1 {{expected last dim of tensor to match chunk size}}
           -> !xegpu.tensor_desc<4x5xf32, #xegpu.scatter_tdesc_attr<chunk_size = 4>>
   return
 }
@@ -239,7 +230,7 @@ func.func @create_tdesc_vc_5(%src: memref<?xf32>) {
 func.func @create_tdesc_vc_6(%src: memref<?xf16>) {
   %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
   %1 = xegpu.create_tdesc %src, %0 : memref<?xf16>, vector<4xindex>
-  // expected-error at +1 {{tensor shape[1] to be a multiple of chunk alignment factor 2}}
+  // expected-error at +1 {{last dim of tensor to be a multiple of 2}}
           -> !xegpu.tensor_desc<4x3xf16, #xegpu.scatter_tdesc_attr<chunk_size = 3>>
   return
 }
@@ -478,7 +469,7 @@ func.func @tensor_desc_scatter_invalid_map_data_1(%src: ui64, %offsets: vector<1
   %1 = xegpu.create_tdesc %src, %offsets : ui64, vector<16xindex> ->
       // expected-error at +1 {{work item data mapping must match the number of contiguous elements}}
       !xegpu.tensor_desc<16xf32,
-        #xegpu.scatter_tdesc_attr<chunk_size = 1>,
+        #xegpu.scatter_tdesc_attr<>,
          #xegpu.layout<lane_layout = [8], lane_data = [2]>>
   return
 }
@@ -496,9 +487,9 @@ func.func @tensor_desc_scatter_invalid_chunk_size_1D(%src: ui64, %offsets: vecto
 // -----
 func.func @tensor_desc_scatter_invalid_chunk_size_2D(%src: ui64, %offsets: vector<16xindex>) {
   %1 = xegpu.create_tdesc %src, %offsets : ui64, vector<16xindex> ->
-      // expected-error at +1 {{expected chunk blocks for 2D tensor}}
+      // expected-error at +1 {{expected last dim of tensor to match chunk size}}
       !xegpu.tensor_desc<16x2xf32,
-        #xegpu.scatter_tdesc_attr<chunk_size = 1>,
+        #xegpu.scatter_tdesc_attr<chunk_size = 4>,
          #xegpu.layout<lane_layout = [8, 1], lane_data = [1, 2]>>
   return
 }
diff --git a/mlir/test/Dialect/XeGPU/ops.mlir b/mlir/test/Dialect/XeGPU/ops.mlir
index aff8f63adc05b..5c73810575cb0 100644
--- a/mlir/test/Dialect/XeGPU/ops.mlir
+++ b/mlir/test/Dialect/XeGPU/ops.mlir
@@ -291,8 +291,8 @@ gpu.func @create_tdesc_1(%src: memref<?xf32, 3>) {
 gpu.func @create_tdesc_2(%src: memref<?xf32>) {
   //CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
   %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
-  //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : memref<?xf32>, vector<4xindex> -> !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>
-  %1 = xegpu.create_tdesc %src, %0 : memref<?xf32>, vector<4xindex>  -> !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<chunk_size = 1>>
+  //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : memref<?xf32>, vector<4xindex> -> !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>>
+  %1 = xegpu.create_tdesc %src, %0 : memref<?xf32>, vector<4xindex>  -> !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>>
   gpu.return
 }
 
diff --git a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
index c84eb74198544..335f89f1826aa 100644
--- a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
+++ b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
@@ -107,10 +107,7 @@ struct TestXeGPUUnrollingPatterns
             // If the encoding is a ScatterTensorDescAttr, we need to
             // potentially adjust the chunk size based on the inst_data.
             if (tdescTy.isScattered()) {
-              auto scatterAttr =
-                  llvm::dyn_cast_if_present<xegpu::ScatterTensorDescAttr>(
-                      encoding);
-              int64_t chunkSize = scatterAttr.getChunkSize().getInt();
+              int64_t chunkSize = tdescTy.getChunkSize();
 
               if (chunkSize > 1) {
                 int64_t blockedChunkSize = chunkSize;
@@ -120,8 +117,7 @@ struct TestXeGPUUnrollingPatterns
 
                 // To create a new attribute with a different chunk_size:
                 auto newEncoding = xegpu::ScatterTensorDescAttr::get(
-                    ctx, scatterAttr.getMemorySpace().getValue(),
-                    blockedChunkSize);
+                    ctx, tdescTy.getMemorySpace(), blockedChunkSize);
 
                 encoding = newEncoding;
               }

>From 45f7214458a1cd8800c86399235988e7d7729e8b Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Thu, 26 Jun 2025 17:29:12 +0000
Subject: [PATCH 3/5] cleanup

---
 .../mlir/Dialect/XeGPU/IR/XeGPUAttrs.td       | 18 ++++----
 .../mlir/Dialect/XeGPU/IR/XeGPUTypes.td       |  6 +--
 mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp    | 33 +++----------
 mlir/test/Dialect/XeGPU/invalid.mlir          | 46 ++-----------------
 4 files changed, 22 insertions(+), 81 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index 1b02565acfd8c..bcd5724835783 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -67,8 +67,11 @@ def XeGPU_ScatterTensorDescAttr: XeGPU_TensorDescAttr<"ScatterTensorDesc", "scat
         TensorDesc is located, `Global` device memory or `Shared` local memory.
         It is default to `Global`.
 
-    2.  `chunk_size`: indicates number of contiguous elements accessed for each
-        offset, default is 1. It is used with `scattered` attr only.
+    2. `chunk_size`: Specifies the number of contiguous elements accessed per offset.
+      The default value is 1. While XeGPU supports a range of chunk sizes, hardware
+      may only allow specific values (e.g., 1, 2, 3, 4, 8, 16, 32, 64, 128, 256).
+      Therefore, XeGPU will legalize the chunk size as needed prior to lowering to
+      hardware instructions.
   }];
 
   let parameters = (ins
@@ -77,8 +80,9 @@ def XeGPU_ScatterTensorDescAttr: XeGPU_TensorDescAttr<"ScatterTensorDesc", "scat
       "MemorySpaceAttr::get($_ctxt, xegpu::MemorySpace::Global)",
       "Data memory location"
     >: $memory_space,
-    OptionalParameter<
+    DefaultValuedParameter<
       "IntegerAttr",
+      "IntegerAttr::get(IntegerType::get($_ctxt, 64), 1)",
       "Number of contiguous elements"
     >: $chunk_size
   );
@@ -91,14 +95,10 @@ def XeGPU_ScatterTensorDescAttr: XeGPU_TensorDescAttr<"ScatterTensorDesc", "scat
   ];
 
   let extraClassDeclaration = [{
-    int64_t getChunkSizeOrDefault() {
-      if (auto attr = getChunkSize())
-        return attr.getInt();
-      return 1;
+    int64_t getChunkSizeAsInt() {
+      return getChunkSize().getInt();
     }
   }];
-
-  let genVerifyDecl = 1;
  }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
index 2b99e80677098..ccbf4ecdf05ec 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
@@ -183,10 +183,8 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
     int getChunkSize() {
       auto attr = getEncoding();
       auto scatter_attr = mlir::dyn_cast_if_present<ScatterTensorDescAttr>(attr);
-      assert((!attr || scatter_attr) && "invalid on non ScatterTensorDescAttr.");
-      if (scatter_attr)
-        return scatter_attr.getChunkSizeOrDefault();
-      return 1;
+      assert(scatter_attr && "invalid on non ScatterTensorDescAttr.");
+      return scatter_attr.getChunkSizeAsInt();
     }
 
     /// Helper to drop all layout information from the TensorDesc type.
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 57458513805b0..32a4bf883829f 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -125,21 +125,6 @@ ScatterTensorDescAttr::get(mlir::MLIRContext *context,
   return Base::get(context, scopeAttr, chunkSizeAttr);
 }
 
-LogicalResult ScatterTensorDescAttr::verify(
-    llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
-    MemorySpaceAttr memory_space, IntegerAttr chunk_size) {
-
-  if (chunk_size) {
-    int64_t chunkSize = chunk_size.getInt();
-    SmallVector<int64_t> supportedChunkSizes = {2,  3,  4,   8,  16,
-                                                32, 64, 128, 256};
-    if (!llvm::is_contained(supportedChunkSizes, chunkSize))
-      return emitError() << "invalid chunk size";
-  }
-
-  return success();
-}
-
 //===----------------------------------------------------------------------===//
 // XeGPU_LayoutAttr
 //===----------------------------------------------------------------------===//
@@ -333,7 +318,7 @@ LogicalResult TensorDescType::verify(
           : 1;
   auto scatterAttr = mlir::dyn_cast_if_present<ScatterTensorDescAttr>(encoding);
   if (scatterAttr) {
-    int64_t chunkSize = scatterAttr.getChunkSizeOrDefault();
+    int64_t chunkSize = scatterAttr.getChunkSizeAsInt();
     if (rank == 1 && chunkSize != 1)
       return emitError() << "expected non-contiguous elements for 1D tensor";
 
@@ -357,17 +342,13 @@ LogicalResult TensorDescType::verify(
     auto laneData = layoutAttr.getLaneData();
     if (scatterAttr && laneData) {
       // Validate subgroup mapping rules for scattered tensors.
-      // A work-item's slice of the tensor with shape [sg_size] or
-      // [sg_size, chunk_size] will be [1] or [1, 32/element_ty_bit_width]
-      // respectively, the mapping should reflect that. This is because each
-      // work item access data in 32 bit granularity.
-
-      if (rank > 1 && laneData[0] != 1)
+      // if chunkSize > 1, the last dimension of the tensor should
+      // be distributed in the units divisible by chunkAlignmentFactor.
+      int64_t chunkSize = scatterAttr.getChunkSizeAsInt();
+      if (chunkSize > 1 && laneData[rank - 1] % chunkAlignmentFactor)
         return emitError()
-               << "cannot map over non-contiguous scattered row elements";
-      if (laneData[rank - 1] != chunkAlignmentFactor)
-        return emitError() << "work item data mapping must match the number of "
-                              "contiguous elements";
+               << "expected last dim of lane_data to be a multiple of: "
+               << chunkAlignmentFactor;
     }
 
     if (!XeGPUDialect::isEvenlyDistributable(shape, layoutAttr)) {
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index 8eb629660559c..a6f7d0992d7e7 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -210,15 +210,6 @@ func.func @create_tdesc_vc_3(%src: memref<?xf32>) {
 
 // -----
 func.func @create_tdesc_vc_4(%src: memref<?xf32>) {
-  %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
-  %1 = xegpu.create_tdesc %src, %0 : memref<?xf32>, vector<4xindex>
-  // expected-error at +1 {{invalid chunk size}}
-          -> !xegpu.tensor_desc<4x5xf32, #xegpu.scatter_tdesc_attr<chunk_size = 5>>
-  return
-}
-
-// -----
-func.func @create_tdesc_vc_5(%src: memref<?xf32>) {
   %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
   %1 = xegpu.create_tdesc %src, %0 : memref<?xf32>, vector<4xindex>
   // expected-error at +1 {{expected last dim of tensor to match chunk size}}
@@ -227,7 +218,7 @@ func.func @create_tdesc_vc_5(%src: memref<?xf32>) {
 }
 
 // -----
-func.func @create_tdesc_vc_6(%src: memref<?xf16>) {
+func.func @create_tdesc_vc_5(%src: memref<?xf16>) {
   %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
   %1 = xegpu.create_tdesc %src, %0 : memref<?xf16>, vector<4xindex>
   // expected-error at +1 {{last dim of tensor to be a multiple of 2}}
@@ -258,23 +249,15 @@ func.func @prefetch_vc_2(%src: ui64) {
 func.func @create_tdesc_layout_1(%src: ui64) {
   %cst = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
   // expected-error at +1 {{expected layout rank to match tensor rank}}
-  %1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>,   #xegpu.layout<lane_layout = [4, 1], lane_data = [1, 1]>>
+  %1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<lane_layout = [4, 1], lane_data = [1, 1]>>
   return
 }
 
 // -----
 func.func @create_tdesc_layout_2(%src: ui64) {
   %cst = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
-  // expected-error at +1 {{cannot map over non-contiguous scattered row elements}}
-  %1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>,   #xegpu.layout<lane_layout = [1, 4], lane_data = [2, 1]>>
-  return
-}
-
-// -----
-func.func @create_tdesc_layout_3(%src: ui64) {
-  %cst = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
-  // expected-error at +1 {{work item data mapping must match the number of contiguous elements}}
-  %1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x3xf32, #xegpu.scatter_tdesc_attr<chunk_size = 3>,   #xegpu.layout<lane_layout = [4, 1], lane_data = [1, 2]>>
+  // expected-error at +1 {{expected last dim of lane_data to be a multiple of: 2}}
+  %1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x4xf16, #xegpu.scatter_tdesc_attr<chunk_size = 4>, #xegpu.layout<lane_layout = [4, 1], lane_data = [1, 1]>>
   return
 }
 
@@ -453,27 +436,6 @@ func.func @tensor_desc_invalid_map_data_1(%src: memref<24x32xf32>) {
   return
 }
 
-// -----
-func.func @tensor_desc_scatter_invalid_map_data(%src: ui64) {
-  %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
-  %1 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> ->
-      // expected-error at +1 {{cannot map over non-contiguous scattered row elements}}
-      !xegpu.tensor_desc<4x2xf32,
-        #xegpu.scatter_tdesc_attr<chunk_size = 2>,
-         #xegpu.layout<lane_layout = [1, 1], lane_data = [2, 1]>>
-  return
-}
-
-// -----
-func.func @tensor_desc_scatter_invalid_map_data_1(%src: ui64, %offsets: vector<16xindex>) {
-  %1 = xegpu.create_tdesc %src, %offsets : ui64, vector<16xindex> ->
-      // expected-error at +1 {{work item data mapping must match the number of contiguous elements}}
-      !xegpu.tensor_desc<16xf32,
-        #xegpu.scatter_tdesc_attr<>,
-         #xegpu.layout<lane_layout = [8], lane_data = [2]>>
-  return
-}
-
 // -----
 func.func @tensor_desc_scatter_invalid_chunk_size_1D(%src: ui64, %offsets: vector<16xindex>) {
   %1 = xegpu.create_tdesc %src, %offsets : ui64, vector<16xindex> ->

>From bf24cc82e6e0242f3c9660afc8c5bf68e8271474 Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Thu, 26 Jun 2025 17:41:56 +0000
Subject: [PATCH 4/5] relax rank restriction on MaskType and OffsetType

---
 mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
index ccbf4ecdf05ec..bd30335ddc344 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
@@ -19,8 +19,8 @@ def XeGPU_ScalarType: AnyTypeOf<[XeGPU_IntType, XeGPU_FloatType]>;
 def XeGPU_BaseAddrType: AnyTypeOf<[Non0RankedMemRefOf<[XeGPU_ScalarType]>, UI64, UI32, I64, I32]>;
 def XeGPU_DpasOprType: FixedVectorOfRankAndType<[1, 2, 3], [XeGPU_ScalarType]>;
 def XeGPU_DpasResType: FixedVectorOfRankAndType<[1, 2], [XeGPU_ScalarType]>;
-def XeGPU_OffsetType: FixedVectorOfRankAndType<[1, 2], [Index]>;
-def XeGPU_MaskType: FixedVectorOfRankAndType<[1, 2], [I1]>;
+def XeGPU_OffsetType: FixedVectorOfNonZeroRankOf<[Index]>;
+def XeGPU_MaskType: FixedVectorOfNonZeroRankOf<[I1]>;
 def XeGPU_ValueType: FixedVectorOfNonZeroRankOf<[XeGPU_ScalarType]>;
 def XeGPU_Vector2DType: FixedVectorOfRankAndType<[2], [XeGPU_ScalarType]>;
 

>From d8ce71ef3a3b0cdc54634e17799c1b1001ce58ba Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Thu, 26 Jun 2025 18:03:36 +0000
Subject: [PATCH 5/5] add verifier for updateOffsetOp

---
 .../include/mlir/Dialect/XeGPU/IR/XeGPUOps.td |  2 ++
 mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp        | 24 +++++++++++++++++--
 2 files changed, 24 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index daab65ec893b8..b6f047d132c87 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -757,6 +757,8 @@ def XeGPU_UpdateOffsetOp: XeGPU_Op<"update_offset",
   let assemblyFormat = [{
     $TensorDesc `,` $offsets attr-dict `:` qualified(type($TensorDesc)) `,` type($offsets)
   }];
+
+  let hasVerifier = 1;
 }
 
 def XeGPU_DpasOp : XeGPU_Op<"dpas", [Pure, AllElementTypesMatch<["lhs", "rhs"]>]> {
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index d106f1c53c96b..3fbc4eff6ad67 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -87,9 +87,12 @@ isValidGatherScatterParams(Type maskTy, VectorType valueTy,
     return emitError()
            << "Value should have the same element type as TensorDesc.";
 
-  if (tdescShape[0] != maskShape[0])
+  llvm::SmallVector<int64_t> expectedMaskShape(tdescShape);
+  if (chunkSize > 1)
+    expectedMaskShape.pop_back();
+  if (expectedMaskShape != maskShape)
     return emitError()
-           << "dim-0 of the Mask and TensorDesc should be the same.";
+           << "Mask should match TensorDesc except the chunk size dim.";
 
   // a valid shape for SIMT case
   if (valueTy.getRank() == 1 && valueTy.getNumElements() == chunkSize) {
@@ -552,6 +555,23 @@ void UpdateOffsetOp::build(OpBuilder &builder, OperationState &state,
   build(builder, state, tensorDesc, ofrs);
 }
 
+LogicalResult UpdateOffsetOp::verify() {
+  auto tdescTy = getTensorDescType();
+  if (!tdescTy.isScattered())
+    return emitOpError("Expects a scattered TensorDesc.\n");
+
+  auto expectedOffsetShape = getShapeOf(tdescTy);
+  auto offsetShape = getShapeOf(getOffsetsType());
+  if (tdescTy.getChunkSize() > 1)
+    expectedOffsetShape.pop_back();
+
+  if (expectedOffsetShape != offsetShape)
+    return emitOpError(
+        "Offsets should match TensorDesc except the chunk size dim.");
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // XeGPU_DpasOp
 //===----------------------------------------------------------------------===//



More information about the Mlir-commits mailing list