[Mlir-commits] [mlir] [MLIR][XeGPU] make offsets optional for create_nd_tdesc (PR #148335)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jul 11 22:55:52 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Jianhui Li (Jianhui-Li)

<details>
<summary>Changes</summary>

This PR allows xegpu to take optional offsets when create_nd_tdesc. This is the initial PR to move offsets from create_nd_tdesc to load_nd. 
When creating nd_tdesc for dynamic shape tensor, must use @<!-- -->shape and @<!-- -->strides  attributes to describe base tensor.  
```mlir
 %2 = xegpu.create_nd_tdesc %src[%x, %y] shape:[%h, %w] strides:[%w, %c1]  : memref<?x?xf16> -> !xegpu.tensor_desc<8x16xf16>
  %2 = xegpu.create_nd_tdesc %src[0, 0] shape : [%h, %w] strides : [%w, %c1]  : ui64 -> !xegpu.tensor_desc<8x16xf32>  
  ```

---
Full diff: https://github.com/llvm/llvm-project/pull/148335.diff


4 Files Affected:

- (modified) mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td (+35-10) 
- (modified) mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp (+242-2) 
- (modified) mlir/test/Dialect/XeGPU/ops.mlir (+43-2) 
- (modified) mlir/test/Dialect/XeGPU/subgroup-distribute.mlir (+4-4) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index bd5ea9fd83781..710fc62b032a9 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -110,23 +110,27 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface
     Variadic<Index>: $offsets,
     Variadic<Index>: $shape,
     Variadic<Index>: $strides,
-    DenseI64ArrayAttr: $const_offsets,
+    OptionalAttr<DenseI64ArrayAttr>: $const_offsets,
     OptionalAttr<DenseI64ArrayAttr>: $const_shape,
     OptionalAttr<DenseI64ArrayAttr>: $const_strides
   );
   let results = (outs XeGPU_TensorDesc: $TensorDesc);
 
-  let assemblyFormat = [{
-    $source ``
-    custom<DynamicIndexList>($offsets, $const_offsets)
-    (`,` custom<DynamicIndexList>($shape, $const_shape)^
-     `,` custom<DynamicIndexList>($strides, $const_strides))?
-    attr-dict `:` type($source) `->` qualified(type($TensorDesc))
-  }];
-
   let hasVerifier = 1;
 
+  let hasCustomAssemblyFormat = 1;
+
   let builders = [
+    OpBuilder<(ins "Type": $tdesc, "TypedValue<MemRefType>": $source)>,
+
+    OpBuilder<(ins "Type": $tdesc, "TypedValue<MemRefType> ": $source,
+                   "llvm::ArrayRef<OpFoldResult>": $shape,
+                   "llvm::ArrayRef<OpFoldResult>": $strides)>,
+
+    OpBuilder<(ins "Type": $tdesc, "TypedValue<IntegerType> ": $source,
+                   "llvm::ArrayRef<OpFoldResult>": $shape,
+                   "llvm::ArrayRef<OpFoldResult>": $strides)>,
+
     OpBuilder<(ins "Type": $tdesc, "TypedValue<MemRefType>": $source,
                    "llvm::ArrayRef<OpFoldResult>": $offsets)>,
 
@@ -163,9 +167,30 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface
     }
 
     ArrayRef<int64_t> getStaticOffsets(){
-      return getConstOffsets();
+      auto attr = getConstOffsetsAttr();
+
+      if (attr) 
+        return attr;
+
+      auto memrefType = llvm::dyn_cast<MemRefType>(getSourceType());
+      int rank = 0;
+      if (memrefType) {
+        //use source memref's rank, as source memref rank may be higher
+        rank = memrefType.getRank();
+      } else {
+        //nd_tdesc created from ui64, use nd_tdesc's rank
+        rank = getTensorDescShape().size();
+      };
+
+      // The offsets are allowed to be empty. The Traits verification of OffsetSizeAndStrideOpInterface interface assumes offsets being present.
+      // It is set to be MAX to indicate user not passed any value, instead of kDynamic which means offsets passed as value.
+      setConstOffsets(llvm::SmallVector<int64_t, 4>(rank, std::numeric_limits<int64_t>::max()));
+
+      attr = getConstOffsetsAttr();
+      return attr;
     }
 
+
     /// wrapper for matching with OffsetSizeAndStrideOpInterface
     /// If source is IntegerType or `const_shape` is filled,
     /// it will return `const_shape`, such that mixes of `shape`
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index ef7cd1424e7a4..9f6090ad279f5 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -125,8 +125,8 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
   build(builder, state, tdesc, source, dynamicOffsets /* dynamic offsets */,
         ValueRange({}) /* empty dynamic shape */,
         ValueRange({}) /* empty dynamic strides */,
-        staticOffsets /* const offsets */, {} /* empty const shape*/,
-        {} /* empty const strides*/);
+        builder.getDenseI64ArrayAttr(staticOffsets) /* const offsets */,
+        {} /* empty const shape*/, {} /* empty const strides*/);
 }
 
 void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
@@ -221,6 +221,246 @@ LogicalResult CreateNdDescOp::verify() {
   return success();
 }
 
+ParseResult parseOptionalDynamicIndexList(
+    OpAsmParser &parser,
+    SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
+    DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableFlags,
+    SmallVectorImpl<Type> *valueTypes = nullptr,
+    AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square) {
+
+  SmallVector<int64_t, 4> integerVals;
+  SmallVector<bool, 4> scalableVals;
+  auto parseIntegerOrValue = [&]() {
+    OpAsmParser::UnresolvedOperand operand;
+    auto res = parser.parseOptionalOperand(operand);
+
+    // When encountering `[`, assume that this is a scalable index.
+    scalableVals.push_back(parser.parseOptionalLSquare().succeeded());
+
+    if (res.has_value() && succeeded(res.value())) {
+      values.push_back(operand);
+      integerVals.push_back(ShapedType::kDynamic);
+      if (valueTypes && parser.parseColonType(valueTypes->emplace_back()))
+        return failure();
+    } else {
+      int64_t integer;
+      if (failed(parser.parseInteger(integer)))
+        return failure();
+      integerVals.push_back(integer);
+    }
+
+    // If this is assumed to be a scalable index, verify that there's a closing
+    // `]`.
+    if (scalableVals.back() && parser.parseOptionalRSquare().failed())
+      return failure();
+    return success();
+  };
+  if (parser.parseOptionalLSquare().succeeded()) {
+    if (parser.parseCommaSeparatedList(parseIntegerOrValue) ||
+        parser.parseRSquare())
+      return parser.emitError(parser.getNameLoc())
+             << "expected SSA value or integer";
+    integers = parser.getBuilder().getDenseI64ArrayAttr(integerVals);
+    scalableFlags = parser.getBuilder().getDenseBoolArrayAttr(scalableVals);
+    return success();
+  }
+  return success();
+}
+
+::mlir::ParseResult CreateNdDescOp::parse(::mlir::OpAsmParser &parser,
+                                          ::mlir::OperationState &result) {
+  ::mlir::OpAsmParser::UnresolvedOperand sourceRawOperand{};
+  ::llvm::ArrayRef<::mlir::OpAsmParser::UnresolvedOperand> sourceOperands(
+      &sourceRawOperand, 1);
+  ::llvm::SMLoc sourceOperandsLoc;
+
+  ::llvm::SmallVector<::mlir::OpAsmParser::UnresolvedOperand, 4>
+      offsetsOperands;
+  ::llvm::SMLoc offsetsOperandsLoc;
+  ::mlir::DenseI64ArrayAttr const_offsetsAttr;
+  ::llvm::SmallVector<::mlir::OpAsmParser::UnresolvedOperand, 4> shapeOperands;
+  ::llvm::SMLoc shapeOperandsLoc;
+  ::mlir::DenseI64ArrayAttr const_shapeAttr;
+  ::llvm::SmallVector<::mlir::OpAsmParser::UnresolvedOperand, 4>
+      stridesOperands;
+  ::llvm::SMLoc stridesOperandsLoc;
+  ::mlir::DenseI64ArrayAttr const_stridesAttr;
+  ::mlir::Type sourceRawType{};
+  ::llvm::ArrayRef<::mlir::Type> sourceTypes(&sourceRawType, 1);
+  ::mlir::Type TensorDescRawType{};
+  ::llvm::ArrayRef<::mlir::Type> TensorDescTypes(&TensorDescRawType, 1);
+
+  sourceOperandsLoc = parser.getCurrentLocation();
+  if (parser.parseOperand(sourceRawOperand))
+    return ::mlir::failure();
+
+  offsetsOperandsLoc = parser.getCurrentLocation();
+
+  DenseBoolArrayAttr scalableFlags;
+  auto odsResult = parseOptionalDynamicIndexList(
+      parser, offsetsOperands, const_offsetsAttr, scalableFlags);
+
+  if (const_offsetsAttr) {
+    if (odsResult)
+      return ::mlir::failure();
+    result.getOrAddProperties<CreateNdDescOp::Properties>().const_offsets =
+        const_offsetsAttr;
+  }
+
+  if (::mlir::succeeded(parser.parseOptionalKeyword("shape"))) {
+    if (parser.parseColon())
+      return ::mlir::failure();
+    {
+      shapeOperandsLoc = parser.getCurrentLocation();
+      auto odsResult =
+          parseDynamicIndexList(parser, shapeOperands, const_shapeAttr);
+      if (const_shapeAttr) {
+        if (odsResult)
+          return ::mlir::failure();
+        result.getOrAddProperties<CreateNdDescOp::Properties>().const_shape =
+            const_shapeAttr;
+      }
+    }
+
+    if (parser.parseKeyword("strides"))
+      return ::mlir::failure();
+    if (parser.parseColon())
+      return ::mlir::failure();
+    {
+      stridesOperandsLoc = parser.getCurrentLocation();
+      auto odsResult =
+          parseDynamicIndexList(parser, stridesOperands, const_stridesAttr);
+      if (const_stridesAttr) {
+        if (odsResult)
+          return ::mlir::failure();
+        result.getOrAddProperties<CreateNdDescOp::Properties>().const_strides =
+            const_stridesAttr;
+      }
+    }
+  }
+  {
+    auto loc = parser.getCurrentLocation();
+    if (parser.parseOptionalAttrDict(result.attributes))
+      return ::mlir::failure();
+    if (failed(verifyInherentAttrs(result.name, result.attributes, [&]() {
+          return parser.emitError(loc)
+                 << "'" << result.name.getStringRef() << "' op ";
+        })))
+      return ::mlir::failure();
+  }
+  if (parser.parseColon())
+    return ::mlir::failure();
+
+  {
+    ::mlir::Type type;
+    if (parser.parseCustomTypeWithFallback(type))
+      return ::mlir::failure();
+    sourceRawType = type;
+  }
+  if (parser.parseArrow())
+    return ::mlir::failure();
+
+  if (parser.parseType(TensorDescRawType))
+    return ::mlir::failure();
+
+  ::llvm::copy(::llvm::ArrayRef<int32_t>(
+                   {1, static_cast<int32_t>(offsetsOperands.size()),
+                    static_cast<int32_t>(shapeOperands.size()),
+                    static_cast<int32_t>(stridesOperands.size())}),
+               result.getOrAddProperties<CreateNdDescOp::Properties>()
+                   .operandSegmentSizes.begin());
+
+  ::mlir::Type odsBuildableType0 = parser.getBuilder().getIndexType();
+  result.addTypes(TensorDescTypes);
+
+  if (parser.resolveOperands(sourceOperands, sourceTypes, sourceOperandsLoc,
+                             result.operands))
+    return ::mlir::failure();
+
+  if (parser.resolveOperands(offsetsOperands, odsBuildableType0,
+                             offsetsOperandsLoc, result.operands))
+    return ::mlir::failure();
+
+  if (parser.resolveOperands(shapeOperands, odsBuildableType0, shapeOperandsLoc,
+                             result.operands))
+    return ::mlir::failure();
+
+  if (parser.resolveOperands(stridesOperands, odsBuildableType0,
+                             stridesOperandsLoc, result.operands))
+    return ::mlir::failure();
+  return ::mlir::success();
+}
+
+void CreateNdDescOp::print(::mlir::OpAsmPrinter &_odsPrinter) {
+  _odsPrinter << ' ';
+  _odsPrinter << getSource();
+
+  auto constOffsetsAttr = getConstOffsetsAttr();
+  bool printOffsets = false;
+  if (constOffsetsAttr && constOffsetsAttr.size() > 0) {
+    auto firstVal = constOffsetsAttr.asArrayRef()[0];
+    if (firstVal != std::numeric_limits<int64_t>::max()) {
+      printOffsets = true;
+    }
+  }
+  if (printOffsets) {
+
+    printDynamicIndexList(_odsPrinter, *this, getOffsets(),
+                          getConstOffsetsAttr());
+  }
+  if (((!getShape().empty()) || (getConstShapeAttr()))) {
+    _odsPrinter << ' ' << "shape";
+    _odsPrinter << ' ' << ":";
+    _odsPrinter << ' ';
+    printDynamicIndexList(_odsPrinter, *this, getShape(), getConstShapeAttr());
+    _odsPrinter << ' ' << "strides";
+    _odsPrinter << ' ' << ":";
+    _odsPrinter << ' ';
+    printDynamicIndexList(_odsPrinter, *this, getStrides(),
+                          getConstStridesAttr());
+  }
+  ::llvm::SmallVector<::llvm::StringRef, 2> elidedAttrs;
+  elidedAttrs.push_back("operandSegmentSizes");
+  elidedAttrs.push_back("const_offsets");
+  elidedAttrs.push_back("const_shape");
+  elidedAttrs.push_back("const_strides");
+  _odsPrinter.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
+  _odsPrinter << ' ' << ":";
+  _odsPrinter << ' ';
+  {
+    auto type = getSource().getType();
+    if (auto validType = ::llvm::dyn_cast<::mlir::Type>(type))
+      _odsPrinter.printStrippedAttrOrType(validType);
+    else
+      _odsPrinter << type;
+  }
+  _odsPrinter << ' ' << "->";
+  _odsPrinter << ' ';
+  // _odsPrinter << getTensorDesc().getType();
+
+  _odsPrinter << "!xegpu.tensor_desc<";
+
+  auto tDesc = getTensorDesc().getType();
+  auto shape = tDesc.getShape();
+  for (int64_t dim : shape) {
+    if (mlir::ShapedType::isDynamic(dim))
+      _odsPrinter << '?';
+    else
+      _odsPrinter << dim;
+    _odsPrinter << 'x';
+  }
+
+  _odsPrinter << tDesc.getElementType();
+
+  if (auto encoding = tDesc.getEncoding())
+    _odsPrinter << ", " << encoding;
+
+  if (auto layout = tDesc.getLayout())
+    _odsPrinter << ", " << layout;
+
+  _odsPrinter << ">";
+}
+
 //===----------------------------------------------------------------------===//
 // XeGPU_PrefetchNdOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/XeGPU/ops.mlir b/mlir/test/Dialect/XeGPU/ops.mlir
index 3bfe1fa81aa6e..0d679e519ed60 100644
--- a/mlir/test/Dialect/XeGPU/ops.mlir
+++ b/mlir/test/Dialect/XeGPU/ops.mlir
@@ -17,8 +17,8 @@ gpu.func @create_nd_tdesc_1(%src: memref<24x32xf32>) {
 gpu.func @create_nd_tdesc_2(%src: ui64, %w : index, %h : index, %x : index, %y : index) {
   //CHECK: %[[C:.*]] = arith.constant 1 : index
   %c1 = arith.constant 1 : index
-  // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %[[arg0]][%[[arg3]], %[[arg4]]], [%[[arg2]], %[[arg1]]], [%[[arg1]], %[[C]]] : ui64 -> !xegpu.tensor_desc<8x16xf32>
-  %1 = xegpu.create_nd_tdesc %src[%x, %y], [%h, %w], [%w, %c1] : ui64 -> !xegpu.tensor_desc<8x16xf32>
+  // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %[[arg0]][%[[arg3]],  %[[arg4]]] shape : [%[[arg2]], %[[arg1]]] strides : [%[[arg1]], %[[C]]] : ui64 -> !xegpu.tensor_desc<8x16xf32>
+  %1 = xegpu.create_nd_tdesc %src[%x, %y] shape:[%h, %w] strides: [%w, %c1] : ui64 -> !xegpu.tensor_desc<8x16xf32>
   gpu.return
 }
 
@@ -62,6 +62,47 @@ gpu.func @create_nd_tdesc_7(%src: memref<8x24x32x48x64xf32>) {
 }
 
 
+// CHECK: gpu.func @test_create_nd_tdesc_7(%[[arg0:.*]]: ui64, %[[arg1:.*]]: index, %[[arg2:.*]]: index, %[[arg3:.*]]: index, %[[arg4:.*]]: index, %[[arg5:.*]]: memref<24x32xf32>) 
+gpu.func @test_create_nd_tdesc_7(%src: ui64, %w : index, %h : index, %x : index, %y : index, %src2: memref<24x32xf32>) {
+  //CHECK: %[[C:.*]] = arith.constant 1 : index
+  %c1 = arith.constant 1 : index
+  
+  // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %[[arg5]] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
+  %3 = xegpu.create_nd_tdesc %src2 : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
+ 
+  gpu.return
+}
+
+// CHECK: gpu.func @test_create_nd_tdesc_8(%[[arg0:.*]]: ui64, %[[arg1:.*]]: index, %[[arg2:.*]]: index, %[[arg3:.*]]: index, %[[arg4:.*]]: index) 
+gpu.func @test_create_nd_tdesc_8(%src: ui64, %w : index, %h : index, %x : index, %y : index) {
+  
+  %c1 = arith.constant 1 : index   
+  // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %arg0 shape : [%arg2, %arg1] strides : [%arg1, %c1] : ui64 -> !xegpu.tensor_desc<8x16xf32>
+  %2 = xegpu.create_nd_tdesc %src shape : [%h, %w] strides : [%w, %c1]  : ui64 -> !xegpu.tensor_desc<8x16xf32>
+ 
+  gpu.return
+}
+
+// CHECK-LABEL: func @test_create_nd_tdesc_9({{.*}}) 
+
+gpu.func @test_create_nd_tdesc_9(%src: memref<?x?xf16>, %w : index, %h : index, %x : index, %y : index) {
+
+  %c1 = arith.constant 1 : index
+  // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %arg0[%arg3, %arg4] shape : [%arg2, %arg1] strides : [%arg1, %c1] : memref<?x?xf16> -> !xegpu.tensor_desc<8x16xf16>
+  %1 = xegpu.create_nd_tdesc %src[%x, %y] shape:[%h, %w] strides:[%w, %c1]  : memref<?x?xf16> -> !xegpu.tensor_desc<8x16xf16>
+
+  gpu.return
+}
+
+// CHECK-LABEL: func @test_create_nd_tdesc_10({{.*}}) 
+gpu.func @test_create_nd_tdesc_10(%src: memref<?x?xf16>, %w : index, %h : index, %x : index, %y : index) {  
+  %c1 = arith.constant 1 : index
+  // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %arg0 shape : [%arg2, %arg1] strides : [%arg1, %c1] : memref<?x?xf16> -> !xegpu.tensor_desc<8x16xf16> 
+  %2 = xegpu.create_nd_tdesc %src shape:[%h, %w] strides:[%w, %c1]  : memref<?x?xf16> -> !xegpu.tensor_desc<8x16xf16>
+
+  gpu.return
+}
+
 // CHECK: gpu.func @prefetch_nd(%[[arg0:.*]]: memref<24x32xf16>) {
 gpu.func @prefetch_nd(%src: memref<24x32xf16>) {
   // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<8x16xf16>
diff --git a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
index 3d91b2269bc4b..ba29d1ab13cae 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
@@ -150,16 +150,16 @@ gpu.module @test {
 // CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: ui64, %[[ARG1:[0-9a-zA-Z]+]]: ui64, %[[ARG2:[0-9a-zA-Z]+]]: index,
 // CHECK-SAME: %[[ARG3:[0-9a-zA-Z]+]]: index, %[[ARG4:[0-9a-zA-Z]+]]: index,
 // CHECK-SAME: %[[ARG5:[0-9a-zA-Z]+]]: index, %[[ARG6:[0-9a-zA-Z]+]]: index, %[[ARG7:[0-9a-zA-Z]+]]: index) {
-// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG0]][{{.*}}], [%[[ARG2]], %[[ARG3]]], [%[[ARG4]], %[[ARG5]]] : ui64 -> !xegpu.tensor_desc<16x16xf16>
+// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG0]][{{.*}}] shape : [%[[ARG2]], %[[ARG3]]] strides : [%[[ARG4]], %[[ARG5]]] : ui64 -> !xegpu.tensor_desc<16x16xf16>
 // CHECK: %[[T1:.*]] = xegpu.load_nd %[[T0]]  : !xegpu.tensor_desc<16x16xf16> -> vector<16xf16>
-// CHECK: %[[T2:.*]] = xegpu.create_nd_tdesc %[[ARG1]][{{.*}}], [%[[ARG2]], %[[ARG3]]], [%[[ARG4]], %[[ARG5]]] : ui64 -> !xegpu.tensor_desc<16x16xf16>
+// CHECK: %[[T2:.*]] = xegpu.create_nd_tdesc %[[ARG1]][{{.*}}] shape : [%[[ARG2]], %[[ARG3]]] strides : [%[[ARG4]], %[[ARG5]]] : ui64 -> !xegpu.tensor_desc<16x16xf16>
 // CHECK: xegpu.store_nd %[[T1]], %[[T2]]  : vector<16xf16>, !xegpu.tensor_desc<16x16xf16>
 gpu.module @test {
   gpu.func @create_nd_tdesc_non_memref(%arg0: ui64, %arg1: ui64, %arg2: index, %arg3: index, %arg4: index, %arg5: index, %arg6: index, %arg7: index) {
     %c0 = arith.constant 0 : index
-    %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0], [%arg2, %arg3], [%arg4, %arg5] : ui64 -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] shape:[%arg2, %arg3] strides:[%arg4, %arg5] : ui64 -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
     %1 = xegpu.load_nd %0  {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<16x16xf16>
-    %2 = xegpu.create_nd_tdesc %arg1[%c0, %c0], [%arg2, %arg3], [%arg4, %arg5] : ui64 -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    %2 = xegpu.create_nd_tdesc %arg1[%c0, %c0] shape:[%arg2, %arg3] strides:[%arg4, %arg5] : ui64 -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
     xegpu.store_nd %1, %2 : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
     gpu.return
   }

``````````

</details>


https://github.com/llvm/llvm-project/pull/148335


More information about the Mlir-commits mailing list