[Mlir-commits] [mlir] [mlir][xegpu] Remove OffsetSizeAndStrideOpInterface from CreateNdDescOp (PR #152773)

Chao Chen llvmlistbot at llvm.org
Fri Aug 8 12:18:54 PDT 2025


https://github.com/chencha3 updated https://github.com/llvm/llvm-project/pull/152773

>From 2bc35d8b56877479d4b7ae202a800cedb2da33f9 Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Fri, 8 Aug 2025 18:12:30 +0000
Subject: [PATCH 1/3] remove OffsetSizeAndStrideOpInterface from CreateNdDescOp

---
 .../include/mlir/Dialect/XeGPU/IR/XeGPUOps.td | 173 +++++++-----------
 mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp        |   7 +-
 mlir/test/Dialect/XeGPU/invalid.mlir          |  10 +-
 mlir/test/Dialect/XeGPU/ops.mlir              |  36 ++--
 4 files changed, 93 insertions(+), 133 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index 75b16a87e03c6..4c18b3f47ba18 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -29,7 +29,7 @@ class XeGPU_Op<string mnemonic, list<Trait> traits = []>:
     void printProperties(::mlir::MLIRContext *ctx,
             ::mlir::OpAsmPrinter &p, const Properties &prop,
             ::mlir::ArrayRef<::llvm::StringRef> elidedProps) {
-      
+
       DictionaryAttr propAttr = dyn_cast_if_present<mlir::DictionaryAttr>(getPropertiesAsAttr(ctx, prop));
 
       // filter out the elidedProps from propAttr, and get the resultAttr
@@ -43,7 +43,7 @@ class XeGPU_Op<string mnemonic, list<Trait> traits = []>:
       }
 
       if (!filteredAttrs.empty()) {
-        p << "<" << DictionaryAttr::get(ctx, filteredAttrs) << ">"; 
+        p << "<" << DictionaryAttr::get(ctx, filteredAttrs) << ">";
       }
     }
 
@@ -60,8 +60,7 @@ class XeGPU_Op<string mnemonic, list<Trait> traits = []>:
 }
 
 
-def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface,
-                        AttrSizedOperandSegments, OffsetSizeAndStrideOpInterface]> {
+def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface, AttrSizedOperandSegments]> {
 
   let summary = "Create nd-tensor descriptor operation";
   let description = [{
@@ -181,82 +180,40 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface
       return getType().getShape();
     }
 
-    /// wrapper for matching with OffsetSizeAndStrideOpInterface
-    OperandRange getSizes() {
-      return getShape();
+    SmallVector<OpFoldResult> getMixedOffsets() {
+      auto statics = getConstOffsets().value_or(SmallVector<int64_t>());
+      auto dynamics = getOffsets();
+      if (statics.size() == 0 && dynamics.size() == 0)
+        return {};
+      return getMixedValues(statics, dynamics, getContext());
     }
 
-    ArrayRef<int64_t> getStaticOffsets(){
-      auto attr = getConstOffsetsAttr();
-
-      if (attr) 
-        return attr;
+    SmallVector<OpFoldResult> getMixedSizes() {
+      Builder b(getContext());
+      SmallVector<int64_t> statics;
 
-      int64_t rank = getMixedSizes().size();
-      
-      setConstOffsets(llvm::SmallVector<int64_t, 4>(rank, 0));
+      /// Get the static sizes/shape, the value passed to const_shape
+      /// will overide the value in memref shape.
+      if (auto memrefTy = llvm::dyn_cast<MemRefType>(getSourceType()))
+        statics = llvm::to_vector(memrefTy.getShape());
+      if (auto attr = getConstShapeAttr())
+        statics = llvm::to_vector(attr.asArrayRef());
 
-      attr = getConstOffsetsAttr();
-      return attr;
+      return getMixedValues(statics, getShape(), b);
     }
 
-    /// wrapper for matching with OffsetSizeAndStrideOpInterface
-    /// If source is IntegerType or `const_shape` is filled,
-    /// it will return `const_shape`, such that mixes of `shape`
-    /// and `const_shape` will be used to represent the shape of
-    /// source operand. They overide static shape from source memref type.
-    ArrayRef<int64_t> getStaticSizes() {
-      /// To be compatible with OffsetSizeAndStrideOpInterface, which expects valid return value and perform checks
-      static  llvm::SmallVector<int64_t, 4> emptyShape;
-
-      auto attr = getConstShapeAttr();
-      if (attr)
-        return attr;
-
-      if (llvm::isa<IntegerType>(getSourceType()))
-        return emptyShape;
-
-      auto memrefType = llvm::dyn_cast<MemRefType>(getSourceType());
-      assert(memrefType && "Incorrect use of getStaticSizes");
-      return memrefType.getShape();
-    }
+    SmallVector<OpFoldResult> getMixedStrides() {
+      Builder b(getContext());
+      SmallVector<int64_t> statics;
 
-    /// wrapper for matching with OffsetSizeAndStrideOpInterface
-    /// If source is IntegerType or `const_strides` is filled, it
-    /// will return `const_strides`, such that mixes of `strides`
-    /// and `const_strides` will be used to represent the strides of
-    /// source operand. They overide static strides from source memref type.
-    ArrayRef<int64_t> getStaticStrides() {
-      /// To be compatible with OffsetSizeAndStrideOpInterface, which expects valid return value and perform checks
-      static llvm::SmallVector<int64_t, 4> emptyStrides;
-
-      auto attr = getConstStridesAttr();
-      if (attr)
-        return attr;
-      
-      if (llvm::isa<IntegerType>(getSourceType()))
-        return emptyStrides;
-
-      auto memrefType = llvm::dyn_cast<MemRefType>(getSourceType());
-      assert(memrefType && "Incorrect use of getStaticStrides");
-      auto [strides, _] = memrefType.getStridesAndOffset();
-      // reuse the storage of ConstStridesAttr since strides from
-      // memref is not persistant
-      setConstStrides(strides);
-      attr = getConstStridesAttr();
-      return attr;
-    }
+      /// Get the static strides, the value passed to const_strides
+      /// will overide the value in memref.
+      if (auto memrefTy = llvm::dyn_cast<MemRefType>(getSourceType()))
+        statics = memrefTy.getStridesAndOffset().first;
+      if (auto attr = getConstStridesAttr())
+        statics = llvm::to_vector(attr.asArrayRef());
 
-    /// Return the expected rank of each of the`static_offsets`,
-    /// `static_shape` and `static_strides` attributes.
-    std::array<unsigned, 3> getArrayAttrMaxRanks() {
-      unsigned rank;
-      if (auto ty = llvm::dyn_cast<MemRefType>(getSourceType())) {
-        rank = ty.getRank();
-      } else {
-        rank = (unsigned)getMixedOffsets().size();
-      }
-      return {rank, rank, rank};
+      return getMixedValues(statics, getStrides(), b);
     }
 
     /// Return the number of leading operands before the `offsets`,
@@ -314,15 +271,15 @@ def XeGPU_PrefetchNdOp : XeGPU_Op<"prefetch_nd", []> {
   }];
 
   let assemblyFormat = [{
-    $TensorDesc `` 
-    custom<OptionalDynamicIndexList>($offsets, $const_offsets) 
+    $TensorDesc ``
+    custom<OptionalDynamicIndexList>($offsets, $const_offsets)
     prop-dict attr-dict `:` qualified(type($TensorDesc))
   }];
 
   let builders = [
-    OpBuilder<(ins "Value": $TensorDesc, 
-                   "xegpu::CachePolicyAttr": $l1_hint, 
-                   "xegpu::CachePolicyAttr": $l2_hint, 
+    OpBuilder<(ins "Value": $TensorDesc,
+                   "xegpu::CachePolicyAttr": $l1_hint,
+                   "xegpu::CachePolicyAttr": $l2_hint,
                    "xegpu::CachePolicyAttr": $l3_hint)>
   ];
 
@@ -370,7 +327,7 @@ def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [
 
   let arguments = (ins XeGPU_TensorDesc: $TensorDesc,
                        Variadic<Index>: $offsets,
-                       OptionalAttr<DenseI64ArrayAttr>: $const_offsets,  
+                       OptionalAttr<DenseI64ArrayAttr>: $const_offsets,
                        OptionalAttr<UnitAttr>: $packed,
                        OptionalAttr<DenseI64ArrayAttr>: $transpose,
                        OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
@@ -390,16 +347,16 @@ def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [
   }];
 
   let assemblyFormat = [{
-    $TensorDesc `` 
-    custom<OptionalDynamicIndexList>($offsets, $const_offsets) 
+    $TensorDesc ``
+    custom<OptionalDynamicIndexList>($offsets, $const_offsets)
     prop-dict attr-dict `:` qualified(type($TensorDesc)) `->` type($value)
   }];
 
   let builders = [
-    OpBuilder<(ins "Type": $value, "Value": $TensorDesc, 
+    OpBuilder<(ins "Type": $value, "Value": $TensorDesc,
                     "UnitAttr": $packed, "DenseI64ArrayAttr": $transpose,
-                    "xegpu::CachePolicyAttr": $l1_hint, 
-                    "xegpu::CachePolicyAttr": $l2_hint, 
+                    "xegpu::CachePolicyAttr": $l1_hint,
+                    "xegpu::CachePolicyAttr": $l2_hint,
                     "xegpu::CachePolicyAttr": $l3_hint)>
   ];
 
@@ -442,7 +399,7 @@ def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", [
   let arguments = (ins XeGPU_ValueType: $value,
                        XeGPU_TensorDesc: $TensorDesc,
                        Variadic<Index>: $offsets,
-                       OptionalAttr<DenseI64ArrayAttr>: $const_offsets,  
+                       OptionalAttr<DenseI64ArrayAttr>: $const_offsets,
                        OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
                        OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint,
                        OptionalAttr<XeGPU_CacheHintAttr>: $l3_hint);
@@ -458,16 +415,16 @@ def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", [
   }];
 
    let assemblyFormat = [{
-    $value `,` 
-    $TensorDesc `` 
-    custom<OptionalDynamicIndexList>($offsets, $const_offsets) 
+    $value `,`
+    $TensorDesc ``
+    custom<OptionalDynamicIndexList>($offsets, $const_offsets)
     prop-dict attr-dict `:`  type($value) `,` qualified(type($TensorDesc))
   }];
 
   let builders = [
-    OpBuilder<(ins "Value": $value, "Value": $TensorDesc, 
-                   "xegpu::CachePolicyAttr": $l1_hint, 
-                   "xegpu::CachePolicyAttr": $l2_hint, 
+    OpBuilder<(ins "Value": $value, "Value": $TensorDesc,
+                   "xegpu::CachePolicyAttr": $l1_hint,
+                   "xegpu::CachePolicyAttr": $l2_hint,
                    "xegpu::CachePolicyAttr": $l3_hint)>
   ];
 
@@ -635,12 +592,12 @@ def XeGPU_PrefetchOp : XeGPU_Op<"prefetch", []> {
                              l3_hint = #xegpu.cache_hint<cached>}
         : !xegpu.tensor_desc<16xf16>
     ```
-    
+
     Example 2:
     A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc.
     It combines "create scattered TensorTdesc" and "prefetch with scattered TensorTdesc".
     The source operand could be a raw pointer (uint64_t).
-    Please refer to create_tdesc for the restriction of memref. 
+    Please refer to create_tdesc for the restriction of memref.
     ```mlir
       %a = memref.alloc() : memref<1024xf32>
       %0 = arith.constant dense<[0, 16, 32, 64]> : vector<4xindex>
@@ -676,16 +633,16 @@ def XeGPU_PrefetchOp : XeGPU_Op<"prefetch", []> {
   }];
 
   let assemblyFormat = [{
-    $source 
+    $source
     (`[` $offsets^ `]`)?
     prop-dict
-    attr-dict `:` type(operands) 
+    attr-dict `:` type(operands)
   }];
-    
+
   let builders = [
     OpBuilder<(ins "Value": $source,
-                    "xegpu::CachePolicyAttr": $l1_hint, 
-                    "xegpu::CachePolicyAttr": $l2_hint, 
+                    "xegpu::CachePolicyAttr": $l1_hint,
+                    "xegpu::CachePolicyAttr": $l2_hint,
                     "xegpu::CachePolicyAttr": $l3_hint)>
   ];
 
@@ -723,7 +680,7 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> {
           : !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<memory_space=global, chunk_size=8>>,
             vector<16xi1> -> vector<16x8xf32>
   ```
-  
+
   Example 3 (SIMT mode):
   ```mlir
     %2 = xegpu.load %1, %0 <{l1_hint = #xegpu.cache_hint<cached>,
@@ -732,12 +689,12 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> {
           : !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<memory_space=global, chunk_size=8>>
             vector<16xi1> -> vector<8xf32>
   ```
-  
+
   Example 4:
   A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc.
   It combines "create scattered TensorTdesc" and "load with scattered TensorTdesc".
   The source operand could be a raw pointer (uint64_t). Please refer to create_tdesc
-  for the restriction of memref. 
+  for the restriction of memref.
   ```mlir
     %a = memref.alloc() : memref<1024xf32>
     %offsets = vector.step : vector<16xindex>
@@ -794,14 +751,14 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> {
   let assemblyFormat = [{
     $source
     (`[` $offsets^ `]`)? `,`
-    $mask prop-dict 
+    $mask prop-dict
     attr-dict `:` type(operands) `->` type($value)
   }];
 
   let builders = [
     OpBuilder<(ins "Type": $value, "Value": $source, "Value": $mask,
-                    "xegpu::CachePolicyAttr": $l1_hint, 
-                    "xegpu::CachePolicyAttr": $l2_hint, 
+                    "xegpu::CachePolicyAttr": $l1_hint,
+                    "xegpu::CachePolicyAttr": $l2_hint,
                     "xegpu::CachePolicyAttr": $l3_hint)>
    ];
 
@@ -848,7 +805,7 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [MemoryEffects<[MemWrite]>]> {
   A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc.
   It combines "create scattered TensorTdesc" and "store with scattered TensorTdesc".
   The dest operand could be a raw pointer (uint64_t).
-  Please refer to create_tdesc for the restriction of memref. 
+  Please refer to create_tdesc for the restriction of memref.
   ```mlir
     %a = memref.alloc() : memref<1024xf32>
     %val = arith.constant dense<0.0> : vector<16xf32>
@@ -901,15 +858,15 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [MemoryEffects<[MemWrite]>]> {
     $value `,`
     $dest
     (`[` $offsets^ `]`)? `,`
-    $mask 
-    prop-dict 
+    $mask
+    prop-dict
     attr-dict `:`  type(operands)
   }];
 
   let builders = [
     OpBuilder<(ins "Value": $value, "Value": $dest, "Value": $mask,
-                    "xegpu::CachePolicyAttr": $l1_hint, 
-                    "xegpu::CachePolicyAttr": $l2_hint, 
+                    "xegpu::CachePolicyAttr": $l1_hint,
+                    "xegpu::CachePolicyAttr": $l2_hint,
                     "xegpu::CachePolicyAttr": $l3_hint)>
    ];
 
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 33450f3fa229e..7ac885d2ed40f 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -265,7 +265,7 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
 }
 
 LogicalResult CreateNdDescOp::verify() {
-  auto rank = (int64_t)getMixedOffsets().size();
+  int64_t rank = getMixedSizes().size();
   bool invalidRank = false;
   bool invalidElemTy = false;
 
@@ -280,6 +280,9 @@ LogicalResult CreateNdDescOp::verify() {
            << " Source: " << srcMemorySpace
            << ", TensorDesc: " << tdescMemorySpace;
 
+  if (int64_t offsetRank = getMixedOffsets().size())
+    invalidRank |= (offsetRank != rank);
+
   // check source type matches the rank if it is a memref.
   // It also should have the same ElementType as TensorDesc.
   auto memrefTy = dyn_cast<MemRefType>(getSourceType());
@@ -291,7 +294,7 @@ LogicalResult CreateNdDescOp::verify() {
   if (llvm::isa<IntegerType>(getSourceType())) {
     // strides and shape must present for integer source.
     if (getMixedStrides().empty() || getMixedSizes().empty())
-      return emitOpError("Expecting strides and shape to be present for "
+      return emitOpError("expecting strides and shape to be present for "
                          "integer source.");
   }
 
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index dff3ffab39ecf..cdf147a9fdd0e 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -52,14 +52,14 @@ func.func @create_nd_tdesc_7(%src: memref<128x128xf32>) {
 
 // -----
 func.func @create_nd_tdesc_8(%src: ui64) {
-  // expected-error at +1 {{'xegpu.create_nd_tdesc' op Expecting strides and shape to be present for integer source}}
+  // expected-error at +1 {{'xegpu.create_nd_tdesc' op expecting strides and shape to be present for integer source}}
   %1 = xegpu.create_nd_tdesc %src : ui64-> !xegpu.tensor_desc<128x128xf32>
   return
 }
 
 // -----
 func.func @create_nd_tdesc_9(%src: ui64) {
-  // expected-error at +1 {{expected mixed offsets rank to match mixed sizes rank}}
+  // expected-error at +1 {{expecting strides and shape to be present for integer source}}
   %1 = xegpu.create_nd_tdesc %src[0, 0] : ui64-> !xegpu.tensor_desc<128x128xf32>
   return
 }
@@ -149,7 +149,7 @@ func.func @subgroup_load_nd_offset_2(%src: memref<4x8x16xf16>, %x : index) {
 }
 
 // -----
-func.func @subgroup_load_nd_offset_3(%src: memref<4x8x16xf16>, %x : index) {  
+func.func @subgroup_load_nd_offset_3(%src: memref<4x8x16xf16>, %x : index) {
   %3 = xegpu.create_nd_tdesc %src: memref<4x8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
   %5 = xegpu.load_nd %3[0, 0] : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
     // expected-error at +1 {{Mismatched ranks between offsets and tensor descriptor}}
@@ -418,7 +418,7 @@ func.func @store_scatter_offset_wi_1(%src: memref<?xf16>) {
   %offsets = arith.constant dense<[0]> : vector<1xindex>
   %mask = arith.constant dense<1>: vector<1xi1>
   // expected-error at +1 {{value elements must match chunk size}}
-  xegpu.store %val, %src[%offsets], %mask 
+  xegpu.store %val, %src[%offsets], %mask
         : vector<4xf16>, memref<?xf16>, vector<1xindex>, vector<1xi1>
   return
 }
@@ -429,7 +429,7 @@ func.func @store_scatter_offset_wi_2(%src: memref<4x4xf16>) {
   %offsets = arith.constant dense<[0]> : vector<1xindex>
   %mask = arith.constant dense<1>: vector<1xi1>
   // expected-error at +1 {{Expecting the dest is a 1D memref or pointer}}
-  xegpu.store %val, %src[%offsets], %mask 
+  xegpu.store %val, %src[%offsets], %mask
         : vector<4xf16>, memref<4x4xf16>, vector<1xindex>, vector<1xi1>
   return
 }
diff --git a/mlir/test/Dialect/XeGPU/ops.mlir b/mlir/test/Dialect/XeGPU/ops.mlir
index 6be2371d4d7b2..67c00f5a9cc2f 100644
--- a/mlir/test/Dialect/XeGPU/ops.mlir
+++ b/mlir/test/Dialect/XeGPU/ops.mlir
@@ -62,28 +62,28 @@ gpu.func @create_nd_tdesc_7(%src: memref<8x24x32x48x64xf32>) {
 }
 
 
-// CHECK: gpu.func @test_create_nd_tdesc_7(%[[arg0:.*]]: ui64, %[[arg1:.*]]: index, %[[arg2:.*]]: index, %[[arg3:.*]]: index, %[[arg4:.*]]: index, %[[arg5:.*]]: memref<24x32xf32>) 
+// CHECK: gpu.func @test_create_nd_tdesc_7(%[[arg0:.*]]: ui64, %[[arg1:.*]]: index, %[[arg2:.*]]: index, %[[arg3:.*]]: index, %[[arg4:.*]]: index, %[[arg5:.*]]: memref<24x32xf32>)
 gpu.func @test_create_nd_tdesc_7(%src: ui64, %w : index, %h : index, %x : index, %y : index, %src2: memref<24x32xf32>) {
   //CHECK: %[[C:.*]] = arith.constant 1 : index
   %c1 = arith.constant 1 : index
-  
-  // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %[[arg5]][0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
+
+  // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %[[arg5]] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
   %3 = xegpu.create_nd_tdesc %src2 : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
- 
+
   gpu.return
 }
 
-// CHECK: gpu.func @test_create_nd_tdesc_8(%[[arg0:.*]]: ui64, %[[arg1:.*]]: index, %[[arg2:.*]]: index, %[[arg3:.*]]: index, %[[arg4:.*]]: index) 
+// CHECK: gpu.func @test_create_nd_tdesc_8(%[[arg0:.*]]: ui64, %[[arg1:.*]]: index, %[[arg2:.*]]: index, %[[arg3:.*]]: index, %[[arg4:.*]]: index)
 gpu.func @test_create_nd_tdesc_8(%src: ui64, %w : index, %h : index, %x : index, %y : index) {
-  
-  %c1 = arith.constant 1 : index   
-  // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %arg0[0, 0], shape : [%arg2, %arg1], strides : [%arg1, %c1] : ui64 -> !xegpu.tensor_desc<8x16xf32>
+
+  %c1 = arith.constant 1 : index
+  // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %arg0, shape : [%arg2, %arg1], strides : [%arg1, %c1] : ui64 -> !xegpu.tensor_desc<8x16xf32>
   %2 = xegpu.create_nd_tdesc %src, shape : [%h, %w], strides : [%w, %c1]  : ui64 -> !xegpu.tensor_desc<8x16xf32>
- 
+
   gpu.return
 }
 
-// CHECK-LABEL: func @test_create_nd_tdesc_9({{.*}}) 
+// CHECK-LABEL: func @test_create_nd_tdesc_9({{.*}})
 
 gpu.func @test_create_nd_tdesc_9(%src: memref<?x?xf16>, %w : index, %h : index, %x : index, %y : index) {
 
@@ -94,10 +94,10 @@ gpu.func @test_create_nd_tdesc_9(%src: memref<?x?xf16>, %w : index, %h : index,
   gpu.return
 }
 
-// CHECK-LABEL: func @test_create_nd_tdesc_10({{.*}}) 
-gpu.func @test_create_nd_tdesc_10(%src: memref<?x?xf16>, %w : index, %h : index, %x : index, %y : index) {  
+// CHECK-LABEL: func @test_create_nd_tdesc_10({{.*}})
+gpu.func @test_create_nd_tdesc_10(%src: memref<?x?xf16>, %w : index, %h : index, %x : index, %y : index) {
   %c1 = arith.constant 1 : index
-  // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %arg0[0, 0], shape : [%arg2, %arg1], strides : [%arg1, %c1] : memref<?x?xf16> -> !xegpu.tensor_desc<8x16xf16> 
+  // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %arg0, shape : [%arg2, %arg1], strides : [%arg1, %c1] : memref<?x?xf16> -> !xegpu.tensor_desc<8x16xf16>
   %2 = xegpu.create_nd_tdesc %src, shape:[%h, %w], strides:[%w, %c1]  : memref<?x?xf16> -> !xegpu.tensor_desc<8x16xf16>
 
   gpu.return
@@ -123,7 +123,7 @@ gpu.func @prefetch_nd_2(%src: memref<48x64xf16>) {
 
 // CHECK: gpu.func @prefetch_nd_offset_1(%[[arg0:.*]]: memref<48x64xf16>,  %arg1: index, %arg2: index) {
 gpu.func @prefetch_nd_offset_1(%src: memref<48x64xf16>, %x : index, %y : index) {
-  // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0] : memref<48x64xf16> -> !xegpu.tensor_desc<8x16xf16>
+  // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %[[arg0]] : memref<48x64xf16> -> !xegpu.tensor_desc<8x16xf16>
   %1 = xegpu.create_nd_tdesc %src : memref<48x64xf16> -> !xegpu.tensor_desc<8x16xf16>
   // CHECK: xegpu.prefetch_nd %[[R0]][%arg1, %arg2] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<8x16xf16>
   xegpu.prefetch_nd %1[%x, %y] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>: !xegpu.tensor_desc<8x16xf16>
@@ -271,7 +271,7 @@ gpu.func @subgroup_load_nd_8(%src: memref<24x32xf32>) {
 
 // CHECK: func @subgroup_load_nd_offset_1(%[[arg0:.*]]: memref<24x32xf32>, %arg1: index, %arg2: index) {
 gpu.func @subgroup_load_nd_offset_1(%src: memref<24x32xf32>, %x : index, %y : index) {
-  // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<16x8xf32>
+  // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0 : memref<24x32xf32> -> !xegpu.tensor_desc<16x8xf32>
   %1 = xegpu.create_nd_tdesc %src : memref<24x32xf32> -> !xegpu.tensor_desc<16x8xf32>
   // CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]][%arg1, %arg2] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose = array<i64: 1, 0>}> : !xegpu.tensor_desc<16x8xf32> -> vector<8x16xf32>
   %2 = xegpu.load_nd %1[%x, %y] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose = array<i64: 1, 0>}> : !xegpu.tensor_desc<16x8xf32> -> vector<8x16xf32>
@@ -290,7 +290,7 @@ gpu.func @simt_load_nd_8(%src: memref<24x32xf32>) {
 
 // CHECK: func @simt_load_nd_offset_1(%[[arg0:.*]]: memref<24x32xf32>) {
 gpu.func @simt_load_nd_offset_1(%src: memref<24x32xf32>) {
-  // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<16x8xf32>
+  // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0 : memref<24x32xf32> -> !xegpu.tensor_desc<16x8xf32>
   %1 = xegpu.create_nd_tdesc %src : memref<24x32xf32> -> !xegpu.tensor_desc<16x8xf32>
   // CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]][0, 0] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose = array<i64: 1, 0>}> : !xegpu.tensor_desc<16x8xf32> -> vector<8xf32>
   %2 = xegpu.load_nd %1[0, 0] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose = array<i64: 1, 0>}> : !xegpu.tensor_desc<16x8xf32> -> vector<8xf32>
@@ -323,7 +323,7 @@ gpu.func @simt_store_nd(%src: memref<24x32xf16>) {
 gpu.func @subgroup_store_nd_2(%dst: memref<24x32xf16>, %x : index) {
   // CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<32xf16>
   %1 = arith.constant dense<1.0>: vector<32xf16>
-  // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<32xf16>
+  // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %[[arg0]] : memref<24x32xf16> -> !xegpu.tensor_desc<32xf16>
   %2 = xegpu.create_nd_tdesc %dst : memref<24x32xf16> -> !xegpu.tensor_desc<32xf16>
   // CHECK: xegpu.store_nd %[[C]], %[[R0]][%arg1] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}> : vector<32xf16>, !xegpu.tensor_desc<32xf16>
   xegpu.store_nd %1, %2[%x] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>: vector<32xf16>, !xegpu.tensor_desc<32xf16>
@@ -356,7 +356,7 @@ gpu.func @simt_store_nd_2(%src: memref<24x32xf16>) {
 gpu.func @simt_store_nd_offset_1(%src: memref<24x32xf16>) {
   // CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<2xf16>
   %1 = arith.constant dense<1.0>: vector<2xf16>
-  // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<32xf16>
+  // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0 : memref<24x32xf16> -> !xegpu.tensor_desc<32xf16>
   %2 = xegpu.create_nd_tdesc %src : memref<24x32xf16> -> !xegpu.tensor_desc<32xf16>
   // CHECK: xegpu.store_nd %[[C]], %[[R0]][0] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}> : vector<2xf16>, !xegpu.tensor_desc<32xf16>
   xegpu.store_nd %1, %2[0] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>: vector<2xf16>, !xegpu.tensor_desc<32xf16>

>From a2bc905459bb9b853b3fe0ce68673f562e6fb766 Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Fri, 8 Aug 2025 18:20:02 +0000
Subject: [PATCH 2/3] clean up

---
 mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td | 6 ++----
 1 file changed, 2 insertions(+), 4 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index 4c18b3f47ba18..1a6a34c8d775a 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -189,7 +189,6 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface
     }
 
     SmallVector<OpFoldResult> getMixedSizes() {
-      Builder b(getContext());
       SmallVector<int64_t> statics;
 
       /// Get the static sizes/shape, the value passed to const_shape
@@ -199,11 +198,10 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface
       if (auto attr = getConstShapeAttr())
         statics = llvm::to_vector(attr.asArrayRef());
 
-      return getMixedValues(statics, getShape(), b);
+      return getMixedValues(statics, getShape(), getContext());
     }
 
     SmallVector<OpFoldResult> getMixedStrides() {
-      Builder b(getContext());
       SmallVector<int64_t> statics;
 
       /// Get the static strides, the value passed to const_strides
@@ -213,7 +211,7 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface
       if (auto attr = getConstStridesAttr())
         statics = llvm::to_vector(attr.asArrayRef());
 
-      return getMixedValues(statics, getStrides(), b);
+      return getMixedValues(statics, getStrides(), getContext());
     }
 
     /// Return the number of leading operands before the `offsets`,

>From b0f626b994a05af574b72c6757720d919a15e3be Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Fri, 8 Aug 2025 19:18:38 +0000
Subject: [PATCH 3/3] cleanup

---
 mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 16 +++++-----------
 1 file changed, 5 insertions(+), 11 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 7ac885d2ed40f..b519d6ad72660 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -265,8 +265,8 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
 }
 
 LogicalResult CreateNdDescOp::verify() {
-  int64_t rank = getMixedSizes().size();
-  bool invalidRank = false;
+  size_t rank = getMixedSizes().size();
+  bool invalidRank = rank != getMixedStrides().size();
   bool invalidElemTy = false;
 
   // Memory space of created TensorDesc should match with the source.
@@ -280,16 +280,13 @@ LogicalResult CreateNdDescOp::verify() {
            << " Source: " << srcMemorySpace
            << ", TensorDesc: " << tdescMemorySpace;
 
-  if (int64_t offsetRank = getMixedOffsets().size())
+  if (size_t offsetRank = getMixedOffsets().size())
     invalidRank |= (offsetRank != rank);
 
   // check source type matches the rank if it is a memref.
   // It also should have the same ElementType as TensorDesc.
-  auto memrefTy = dyn_cast<MemRefType>(getSourceType());
-  if (memrefTy) {
-    invalidRank |= (memrefTy.getRank() != rank);
+  if (auto memrefTy = dyn_cast<MemRefType>(getSourceType()))
     invalidElemTy |= memrefTy.getElementType() != getElementType();
-  }
 
   if (llvm::isa<IntegerType>(getSourceType())) {
     // strides and shape must present for integer source.
@@ -298,16 +295,13 @@ LogicalResult CreateNdDescOp::verify() {
                          "integer source.");
   }
 
-  // mismatches among shape, strides, and offsets are
-  // already handeled by OffsetSizeAndStrideOpInterface.
-  // So they are not check here.
   if (invalidRank)
     return emitOpError(
         "Expecting the rank of shape, strides, offsets, and source (if source "
         "is a memref) should match with each other.");
 
   // check result TensorDesc rank
-  if (getType().getRank() > rank)
+  if (getType().getRank() > (int64_t)rank)
     return emitOpError(
         "Expecting the TensorDesc rank is not greater than the "
         "ranks of shape, strides, offsets or the memref source.");



More information about the Mlir-commits mailing list