[Mlir-commits] [mlir] [MLIR][XeGPU] make offsets optional for create_nd_tdesc (PR #148335)
Jianhui Li
llvmlistbot at llvm.org
Fri Jul 11 22:55:23 PDT 2025
https://github.com/Jianhui-Li created https://github.com/llvm/llvm-project/pull/148335
This PR allows xegpu to take optional offsets when create_nd_tdesc. This is the initial PR to move offsets from create_nd_tdesc to load_nd.
When creating nd_tdesc for dynamic shape tensor, must use @shape and @strides attributes to describe base tensor.
```mlir
%2 = xegpu.create_nd_tdesc %src[%x, %y] shape:[%h, %w] strides:[%w, %c1] : memref<?x?xf16> -> !xegpu.tensor_desc<8x16xf16>
%2 = xegpu.create_nd_tdesc %src[0, 0] shape : [%h, %w] strides : [%w, %c1] : ui64 -> !xegpu.tensor_desc<8x16xf32>
```
>From b9a6d984765445fd17f257f936fe61a1cc94dab1 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Tue, 1 Jul 2025 23:00:42 +0000
Subject: [PATCH 1/7] init code
---
.../include/mlir/Dialect/XeGPU/IR/XeGPUOps.td | 41 ++-
mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 261 +++++++++++++++++-
mlir/test/Dialect/XeGPU/ops.mlir | 12 +-
.../Dialect/XeGPU/subgroup-distribute.mlir | 8 +-
4 files changed, 306 insertions(+), 16 deletions(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index daab65ec893b8..018c187f642d6 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -110,23 +110,36 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface
Variadic<Index>: $offsets,
Variadic<Index>: $shape,
Variadic<Index>: $strides,
- DenseI64ArrayAttr: $const_offsets,
+ OptionalAttr<DenseI64ArrayAttr>: $const_offsets,
OptionalAttr<DenseI64ArrayAttr>: $const_shape,
OptionalAttr<DenseI64ArrayAttr>: $const_strides
);
let results = (outs XeGPU_TensorDesc: $TensorDesc);
- let assemblyFormat = [{
- $source ``
- custom<DynamicIndexList>($offsets, $const_offsets)
- (`,` custom<DynamicIndexList>($shape, $const_shape)^
- `,` custom<DynamicIndexList>($strides, $const_strides))?
- attr-dict `:` type($source) `->` qualified(type($TensorDesc))
- }];
+
+// let assemblyFormat = [{
+// $source
+// (custom<DynamicIndexList>($offsets, $const_offsets)^)?
+// (`base_shape` `:` custom<DynamicIndexList>($shape, $const_shape)^
+// `base_strides` `:` custom<DynamicIndexList>($strides, $const_strides))?
+// attr-dict `:` type($source) `->` qualified(type($TensorDesc))
+// }];
let hasVerifier = 1;
+ let hasCustomAssemblyFormat = 1;
+
let builders = [
+ OpBuilder<(ins "Type": $tdesc, "TypedValue<MemRefType>": $source)>,
+
+ OpBuilder<(ins "Type": $tdesc, "TypedValue<MemRefType> ": $source,
+ "llvm::ArrayRef<OpFoldResult>": $shape,
+ "llvm::ArrayRef<OpFoldResult>": $strides)>,
+
+ OpBuilder<(ins "Type": $tdesc, "TypedValue<IntegerType> ": $source,
+ "llvm::ArrayRef<OpFoldResult>": $shape,
+ "llvm::ArrayRef<OpFoldResult>": $strides)>,
+
OpBuilder<(ins "Type": $tdesc, "TypedValue<MemRefType>": $source,
"llvm::ArrayRef<OpFoldResult>": $offsets)>,
@@ -163,9 +176,19 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface
}
ArrayRef<int64_t> getStaticOffsets(){
- return getConstOffsets();
+ auto attr = getConstOffsetsAttr();
+ if (llvm::isa<IntegerType>(getSourceType()) || attr)
+ return attr;
+
+ // The offsets are allowed to be empty. The Traits verification of OffsetSizeAndStrideOpInterface interface assumes offsets being present. So it is set to be MAX to indicate user not passed any value (kDynamic means offsets passed as variable).
+ setConstOffsets(llvm::SmallVector<int64_t, 4>(getTensorDescShape().size(), std::numeric_limits<int64_t>::max()));
+ // setConstOffsets(llvm::SmallVector<int64_t, 4>(getTensorDescShape().size(), 0));
+ //setConstOffsets(llvm::SmallVector<int64_t, 4>(getTensorDescShape().size(), mlir::ShapedType::kDynamic));
+ attr = getConstOffsetsAttr();
+ return attr;
}
+
/// wrapper for matching with OffsetSizeAndStrideOpInterface
/// If source is IntegerType or `const_shape` is filled,
/// it will return `const_shape`, such that mixes of `shape`
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 2793c7a35bc97..13ef77bb4f970 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -122,7 +122,8 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
build(builder, state, tdesc, source, dynamicOffsets /* dynamic offsets */,
ValueRange({}) /* empty dynamic shape */,
ValueRange({}) /* empty dynamic strides */,
- staticOffsets /* const offsets */, {} /* empty const shape*/,
+ builder.getDenseI64ArrayAttr(staticOffsets) /* const offsets */,
+ {} /* empty const shape*/,
{} /* empty const strides*/);
}
@@ -220,6 +221,263 @@ LogicalResult CreateNdDescOp::verify() {
return success();
}
+
+ParseResult parseOptionalDynamicIndexList(
+ OpAsmParser &parser,
+ SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
+ DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableFlags,
+ SmallVectorImpl<Type> *valueTypes = nullptr,
+ AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square);
+inline ParseResult parseOptionalDynamicIndexList(
+ OpAsmParser &parser,
+ SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
+ DenseI64ArrayAttr &integers, SmallVectorImpl<Type> *valueTypes = nullptr,
+ AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square) {
+ DenseBoolArrayAttr scalableFlags;
+ return parseOptionalDynamicIndexList(parser, values, integers, scalableFlags,
+ valueTypes, delimiter);
+}
+
+ParseResult parseOptionalDynamicIndexList(
+ OpAsmParser &parser,
+ SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
+ DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableFlags,
+ SmallVectorImpl<Type> *valueTypes, AsmParser::Delimiter delimiter) {
+
+ SmallVector<int64_t, 4> integerVals;
+ SmallVector<bool, 4> scalableVals;
+ auto parseIntegerOrValue = [&]() {
+ OpAsmParser::UnresolvedOperand operand;
+ auto res = parser.parseOptionalOperand(operand);
+
+ // When encountering `[`, assume that this is a scalable index.
+ scalableVals.push_back(parser.parseOptionalLSquare().succeeded());
+
+ if (res.has_value() && succeeded(res.value())) {
+ values.push_back(operand);
+ integerVals.push_back(ShapedType::kDynamic);
+ if (valueTypes && parser.parseColonType(valueTypes->emplace_back()))
+ return failure();
+ } else {
+ int64_t integer;
+ if (failed(parser.parseInteger(integer)))
+ return failure();
+ integerVals.push_back(integer);
+ }
+
+ // If this is assumed to be a scalable index, verify that there's a closing
+ // `]`.
+ if (scalableVals.back() && parser.parseOptionalRSquare().failed())
+ return failure();
+ return success();
+ };
+ if (parser.parseOptionalLSquare().succeeded()) {
+ if ( parser.parseCommaSeparatedList(parseIntegerOrValue) || parser.parseRSquare() )
+ return parser.emitError(parser.getNameLoc())
+ << "expected SSA value or integer";
+ integers = parser.getBuilder().getDenseI64ArrayAttr(integerVals);
+ scalableFlags = parser.getBuilder().getDenseBoolArrayAttr(scalableVals);
+ return success();
+ }
+ return success();
+}
+
+
+::mlir::ParseResult CreateNdDescOp::parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result) {
+ ::mlir::OpAsmParser::UnresolvedOperand sourceRawOperand{};
+ ::llvm::ArrayRef<::mlir::OpAsmParser::UnresolvedOperand> sourceOperands(&sourceRawOperand, 1); ::llvm::SMLoc sourceOperandsLoc;
+ (void)sourceOperandsLoc;
+ ::llvm::SmallVector<::mlir::OpAsmParser::UnresolvedOperand, 4> offsetsOperands;
+ ::llvm::SMLoc offsetsOperandsLoc;
+ (void)offsetsOperandsLoc;
+ ::mlir::DenseI64ArrayAttr const_offsetsAttr;
+ ::llvm::SmallVector<::mlir::OpAsmParser::UnresolvedOperand, 4> shapeOperands;
+ ::llvm::SMLoc shapeOperandsLoc;
+ (void)shapeOperandsLoc;
+ ::mlir::DenseI64ArrayAttr const_shapeAttr;
+ ::llvm::SmallVector<::mlir::OpAsmParser::UnresolvedOperand, 4> stridesOperands;
+ ::llvm::SMLoc stridesOperandsLoc;
+ (void)stridesOperandsLoc;
+ ::mlir::DenseI64ArrayAttr const_stridesAttr;
+ ::mlir::Type sourceRawType{};
+ ::llvm::ArrayRef<::mlir::Type> sourceTypes(&sourceRawType, 1);
+ ::mlir::Type TensorDescRawType{};
+ ::llvm::ArrayRef<::mlir::Type> TensorDescTypes(&TensorDescRawType, 1);
+
+ sourceOperandsLoc = parser.getCurrentLocation();
+ if (parser.parseOperand(sourceRawOperand))
+ return ::mlir::failure();
+
+ auto optionalOffsetResult = [&]() -> ::mlir::OptionalParseResult {
+ {
+ // skip the "offsets :" at the begining if it exists
+ if (::mlir::succeeded(parser.parseOptionalKeyword("offsets"))) {
+ if (parser.parseColon())
+ return ::mlir::failure();
+ }
+ offsetsOperandsLoc = parser.getCurrentLocation();
+ auto odsResult = parseOptionalDynamicIndexList(parser, offsetsOperands, const_offsetsAttr);
+ // Debug print for offsets parsing using LLVM_DEBUG
+ LLVM_DEBUG(llvm::dbgs() << "parseOptionalDynamicIndexList returned: " << (odsResult ? "failure" : "success") << "\n");
+ LLVM_DEBUG(llvm::dbgs() << "offsetsOperands size: " << offsetsOperands.size() << "\n");
+ if (const_offsetsAttr)
+ LLVM_DEBUG(llvm::dbgs() << "const_offsetsAttr: " << const_offsetsAttr << "\n");
+ if (const_offsetsAttr)
+ result.getOrAddProperties<CreateNdDescOp::Properties>().const_offsets = const_offsetsAttr;
+ }
+ return ::mlir::success();
+ }();
+
+ if (optionalOffsetResult.has_value() && ::mlir::failed(*optionalOffsetResult)) {
+ LLVM_DEBUG(llvm::dbgs() << "optionalOffsetResult failed\n");
+ return ::mlir::failure();
+ }
+
+ if (::mlir::succeeded(parser.parseOptionalKeyword("shape"))) {
+ LLVM_DEBUG(llvm::dbgs() << "Parsing 'shape' keyword\n");
+ if (parser.parseColon())
+ return ::mlir::failure();
+ {
+ shapeOperandsLoc = parser.getCurrentLocation();
+ auto odsResult = parseDynamicIndexList(parser, shapeOperands, const_shapeAttr);
+ LLVM_DEBUG(llvm::dbgs() << "parseDynamicIndexList for shape returned: " << (odsResult ? "failure" : "success") << "\n");
+ LLVM_DEBUG(llvm::dbgs() << "shapeOperands size: " << shapeOperands.size() << "\n");
+ if (const_shapeAttr)
+ LLVM_DEBUG(llvm::dbgs() << "const_shapeAttr: " << const_shapeAttr << "\n");
+ if (odsResult) return ::mlir::failure();
+ if (const_shapeAttr)
+ result.getOrAddProperties<CreateNdDescOp::Properties>().const_shape = const_shapeAttr;
+ }
+
+ if (parser.parseKeyword("strides"))
+ return ::mlir::failure();
+ if (parser.parseColon())
+ return ::mlir::failure();
+ {
+ stridesOperandsLoc = parser.getCurrentLocation();
+ auto odsResult = parseDynamicIndexList(parser, stridesOperands, const_stridesAttr);
+ LLVM_DEBUG(llvm::dbgs() << "parseDynamicIndexList for strides returned: " << (odsResult ? "failure" : "success") << "\n");
+ LLVM_DEBUG(llvm::dbgs() << "stridesOperands size: " << stridesOperands.size() << "\n");
+ if (const_stridesAttr)
+ LLVM_DEBUG(llvm::dbgs() << "const_stridesAttr: " << const_stridesAttr << "\n");
+ if (odsResult) return ::mlir::failure();
+ if (const_stridesAttr)
+ result.getOrAddProperties<CreateNdDescOp::Properties>().const_strides = const_stridesAttr;
+ }
+ }
+ {
+ auto loc = parser.getCurrentLocation();(void)loc;
+ if (parser.parseOptionalAttrDict(result.attributes))
+ return ::mlir::failure();
+ if (failed(verifyInherentAttrs(result.name, result.attributes, [&]() {
+ return parser.emitError(loc) << "'" << result.name.getStringRef() << "' op ";
+ })))
+ return ::mlir::failure();
+ }
+ if (parser.parseColon())
+ return ::mlir::failure();
+
+ {
+ ::mlir::Type type;
+ if (parser.parseCustomTypeWithFallback(type))
+ return ::mlir::failure();
+ sourceRawType = type;
+ }
+ if (parser.parseArrow())
+ return ::mlir::failure();
+
+ if (parser.parseType(TensorDescRawType))
+ return ::mlir::failure();
+
+ ::llvm::copy(::llvm::ArrayRef<int32_t>({1, static_cast<int32_t>(offsetsOperands.size()), static_cast<int32_t>(shapeOperands.size()), static_cast<int32_t>(stridesOperands.size())}), result.getOrAddProperties<CreateNdDescOp::Properties>().operandSegmentSizes.begin());
+
+ ::mlir::Type odsBuildableType0 = parser.getBuilder().getIndexType();
+ result.addTypes(TensorDescTypes);
+
+ if (parser.resolveOperands(sourceOperands, sourceTypes, sourceOperandsLoc, result.operands))
+ return ::mlir::failure();
+
+ if (parser.resolveOperands(offsetsOperands, odsBuildableType0, offsetsOperandsLoc, result.operands))
+ return ::mlir::failure();
+
+ if (parser.resolveOperands(shapeOperands, odsBuildableType0, shapeOperandsLoc, result.operands))
+ return ::mlir::failure();
+
+ if (parser.resolveOperands(stridesOperands, odsBuildableType0, stridesOperandsLoc, result.operands))
+ return ::mlir::failure();
+ return ::mlir::success();
+}
+
+void CreateNdDescOp::print(::mlir::OpAsmPrinter &_odsPrinter) {
+ _odsPrinter << ' ';
+ _odsPrinter << getSource();
+ // Print offsets if getConstOffsetsAttr() exists, is not empty, and its first value is not int64_t::max.
+ auto constOffsetsAttr = getConstOffsetsAttr();
+ bool printOffsets = false;
+ if (constOffsetsAttr && constOffsetsAttr.size() > 0) {
+ auto firstVal = constOffsetsAttr.asArrayRef()[0];
+ if (firstVal != std::numeric_limits<int64_t>::max()) {
+ printOffsets = true;
+ }
+ }
+ if (printOffsets) {
+
+ printDynamicIndexList(_odsPrinter, *this, getOffsets(), getConstOffsetsAttr());
+ }
+ if (((!getShape().empty()) || (getConstShapeAttr()))) {
+ _odsPrinter << ' ' << "shape";
+ _odsPrinter << ' ' << ":";
+ _odsPrinter << ' ';
+ printDynamicIndexList(_odsPrinter, *this, getShape(), getConstShapeAttr());
+ _odsPrinter << ' ' << "strides";
+ _odsPrinter << ' ' << ":";
+ _odsPrinter << ' ';
+ printDynamicIndexList(_odsPrinter, *this, getStrides(), getConstStridesAttr());
+ }
+ ::llvm::SmallVector<::llvm::StringRef, 2> elidedAttrs;
+ elidedAttrs.push_back("operandSegmentSizes");
+ elidedAttrs.push_back("const_offsets");
+ elidedAttrs.push_back("const_shape");
+ elidedAttrs.push_back("const_strides");
+ _odsPrinter.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
+ _odsPrinter << ' ' << ":";
+ _odsPrinter << ' ';
+ {
+ auto type = getSource().getType();
+ if (auto validType = ::llvm::dyn_cast<::mlir::Type>(type))
+ _odsPrinter.printStrippedAttrOrType(validType);
+ else
+ _odsPrinter << type;
+ }
+ _odsPrinter << ' ' << "->";
+ _odsPrinter << ' ';
+ // _odsPrinter << getTensorDesc().getType();
+
+
+ _odsPrinter << "!xegpu.tensor_desc<";
+
+ auto tDesc = getTensorDesc().getType();
+ auto shape = tDesc.getShape();
+ for (int64_t dim : shape) {
+ if (mlir::ShapedType::isDynamic(dim))
+ _odsPrinter << '?';
+ else
+ _odsPrinter << dim;
+ _odsPrinter << 'x';
+ }
+
+ _odsPrinter << tDesc.getElementType();
+
+ if (auto encoding = tDesc.getEncoding())
+ _odsPrinter << ", " << encoding;
+
+ if (auto layout = tDesc.getLayout())
+ _odsPrinter << ", " << layout;
+
+ _odsPrinter << ">";
+
+}
+
//===----------------------------------------------------------------------===//
// XeGPU_PrefetchNdOp
//===----------------------------------------------------------------------===//
@@ -635,6 +893,7 @@ LogicalResult ConvertLayoutOp::verify() {
return mlir::success();
}
+
} // namespace xegpu
} // namespace mlir
diff --git a/mlir/test/Dialect/XeGPU/ops.mlir b/mlir/test/Dialect/XeGPU/ops.mlir
index aff8f63adc05b..e8836b7cffbc7 100644
--- a/mlir/test/Dialect/XeGPU/ops.mlir
+++ b/mlir/test/Dialect/XeGPU/ops.mlir
@@ -17,8 +17,8 @@ gpu.func @create_nd_tdesc_1(%src: memref<24x32xf32>) {
gpu.func @create_nd_tdesc_2(%src: ui64, %w : index, %h : index, %x : index, %y : index) {
//CHECK: %[[C:.*]] = arith.constant 1 : index
%c1 = arith.constant 1 : index
- // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %[[arg0]][%[[arg3]], %[[arg4]]], [%[[arg2]], %[[arg1]]], [%[[arg1]], %[[C]]] : ui64 -> !xegpu.tensor_desc<8x16xf32>
- %1 = xegpu.create_nd_tdesc %src[%x, %y], [%h, %w], [%w, %c1] : ui64 -> !xegpu.tensor_desc<8x16xf32>
+ // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %[[arg0]][%[[arg3]], %[[arg4]]] shape : [%[[arg2]], %[[arg1]]] strides : [%[[arg1]], %[[C]]] : ui64 -> !xegpu.tensor_desc<8x16xf32>
+ %1 = xegpu.create_nd_tdesc %src[%x, %y] shape:[%h, %w] strides: [%w, %c1] : ui64 -> !xegpu.tensor_desc<8x16xf32>
gpu.return
}
@@ -54,6 +54,14 @@ gpu.func @create_nd_tdesc_6(%src: memref<24x32xf32>) {
gpu.return
}
+// CHECK: gpu.func @test_create_nd_tdesc_7(%[[arg0:.*]]: ui64, %[[arg1:.*]]: index, %[[arg2:.*]]: index, %[[arg3:.*]]: index, %[[arg4:.*]]: index) {
+gpu.func @test_create_nd_tdesc_7(%src: ui64, %w : index, %h : index, %x : index, %y : index) {
+ //CHECK: %[[C:.*]] = arith.constant 1 : index
+ %c1 = arith.constant 1 : index
+ // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %[[arg0]][%[[arg3]], %[[arg4]]] shape : [%[[arg2]], %[[arg1]]] strides : [%[[arg1]], %[[C]]] : ui64 -> !xegpu.tensor_desc<8x16xf32>
+ %1 = xegpu.create_nd_tdesc %src offsets : [%x, %y] shape : [%h, %w] strides : [%w, %c1] : ui64 -> !xegpu.tensor_desc<8x16xf32>
+ gpu.return
+}
// CHECK: gpu.func @prefetch_nd(%[[arg0:.*]]: memref<24x32xf16>) {
gpu.func @prefetch_nd(%src: memref<24x32xf16>) {
diff --git a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
index 3d91b2269bc4b..ba29d1ab13cae 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
@@ -150,16 +150,16 @@ gpu.module @test {
// CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: ui64, %[[ARG1:[0-9a-zA-Z]+]]: ui64, %[[ARG2:[0-9a-zA-Z]+]]: index,
// CHECK-SAME: %[[ARG3:[0-9a-zA-Z]+]]: index, %[[ARG4:[0-9a-zA-Z]+]]: index,
// CHECK-SAME: %[[ARG5:[0-9a-zA-Z]+]]: index, %[[ARG6:[0-9a-zA-Z]+]]: index, %[[ARG7:[0-9a-zA-Z]+]]: index) {
-// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG0]][{{.*}}], [%[[ARG2]], %[[ARG3]]], [%[[ARG4]], %[[ARG5]]] : ui64 -> !xegpu.tensor_desc<16x16xf16>
+// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG0]][{{.*}}] shape : [%[[ARG2]], %[[ARG3]]] strides : [%[[ARG4]], %[[ARG5]]] : ui64 -> !xegpu.tensor_desc<16x16xf16>
// CHECK: %[[T1:.*]] = xegpu.load_nd %[[T0]] : !xegpu.tensor_desc<16x16xf16> -> vector<16xf16>
-// CHECK: %[[T2:.*]] = xegpu.create_nd_tdesc %[[ARG1]][{{.*}}], [%[[ARG2]], %[[ARG3]]], [%[[ARG4]], %[[ARG5]]] : ui64 -> !xegpu.tensor_desc<16x16xf16>
+// CHECK: %[[T2:.*]] = xegpu.create_nd_tdesc %[[ARG1]][{{.*}}] shape : [%[[ARG2]], %[[ARG3]]] strides : [%[[ARG4]], %[[ARG5]]] : ui64 -> !xegpu.tensor_desc<16x16xf16>
// CHECK: xegpu.store_nd %[[T1]], %[[T2]] : vector<16xf16>, !xegpu.tensor_desc<16x16xf16>
gpu.module @test {
gpu.func @create_nd_tdesc_non_memref(%arg0: ui64, %arg1: ui64, %arg2: index, %arg3: index, %arg4: index, %arg5: index, %arg6: index, %arg7: index) {
%c0 = arith.constant 0 : index
- %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0], [%arg2, %arg3], [%arg4, %arg5] : ui64 -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] shape:[%arg2, %arg3] strides:[%arg4, %arg5] : ui64 -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
%1 = xegpu.load_nd %0 {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<16x16xf16>
- %2 = xegpu.create_nd_tdesc %arg1[%c0, %c0], [%arg2, %arg3], [%arg4, %arg5] : ui64 -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ %2 = xegpu.create_nd_tdesc %arg1[%c0, %c0] shape:[%arg2, %arg3] strides:[%arg4, %arg5] : ui64 -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
xegpu.store_nd %1, %2 : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
gpu.return
}
>From 2465050985b4ad0b1073a8cb36b0a462d542d3ae Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Wed, 2 Jul 2025 02:08:16 +0000
Subject: [PATCH 2/7] add tests
---
.../include/mlir/Dialect/XeGPU/IR/XeGPUOps.td | 2 +-
mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 98 ++++++++-----------
mlir/test/Dialect/XeGPU/ops.mlir | 14 ++-
3 files changed, 52 insertions(+), 62 deletions(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index 018c187f642d6..2cbae19ff2c05 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -182,8 +182,8 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface
// The offsets are allowed to be empty. The Traits verification of OffsetSizeAndStrideOpInterface interface assumes offsets being present. So it is set to be MAX to indicate user not passed any value (kDynamic means offsets passed as variable).
setConstOffsets(llvm::SmallVector<int64_t, 4>(getTensorDescShape().size(), std::numeric_limits<int64_t>::max()));
- // setConstOffsets(llvm::SmallVector<int64_t, 4>(getTensorDescShape().size(), 0));
//setConstOffsets(llvm::SmallVector<int64_t, 4>(getTensorDescShape().size(), mlir::ShapedType::kDynamic));
+
attr = getConstOffsetsAttr();
return attr;
}
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 13ef77bb4f970..cab4ca8a73898 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -136,8 +136,8 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
shape.size() == strides.size() && shape.size() == offsets.size());
Type srcTy = source.getType();
- assert(isa<IntegerType>(srcTy) ||
- isa<MemRefType>(srcTy) && "Source has to be either int or memref.");
+ assert((isa<IntegerType>(srcTy) ||
+ isa<MemRefType>(srcTy)) && "Source has to be either int or memref.");
llvm::SmallVector<Value> dynamicOffsets;
llvm::SmallVector<Value> dynamicShape;
@@ -222,27 +222,30 @@ LogicalResult CreateNdDescOp::verify() {
}
+//ParseResult parseOptionalDynamicIndexList(
+// OpAsmParser &parser,
+// SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
+// DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableFlags,
+// SmallVectorImpl<Type> *valueTypes = nullptr,
+// AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square);
+//inline ParseResult parseOptionalDynamicIndexList(
+// OpAsmParser &parser,
+// SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
+// DenseI64ArrayAttr &integers, SmallVectorImpl<Type> *valueTypes = nullptr,
+// AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square) {
+// DenseBoolArrayAttr scalableFlags;
+// return parseOptionalDynamicIndexList(parser, values, integers, scalableFlags,
+// valueTypes, delimiter);
+//}
+
+
+
ParseResult parseOptionalDynamicIndexList(
OpAsmParser &parser,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableFlags,
SmallVectorImpl<Type> *valueTypes = nullptr,
- AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square);
-inline ParseResult parseOptionalDynamicIndexList(
- OpAsmParser &parser,
- SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
- DenseI64ArrayAttr &integers, SmallVectorImpl<Type> *valueTypes = nullptr,
AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square) {
- DenseBoolArrayAttr scalableFlags;
- return parseOptionalDynamicIndexList(parser, values, integers, scalableFlags,
- valueTypes, delimiter);
-}
-
-ParseResult parseOptionalDynamicIndexList(
- OpAsmParser &parser,
- SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
- DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableFlags,
- SmallVectorImpl<Type> *valueTypes, AsmParser::Delimiter delimiter) {
SmallVector<int64_t, 4> integerVals;
SmallVector<bool, 4> scalableVals;
@@ -286,18 +289,15 @@ ParseResult parseOptionalDynamicIndexList(
::mlir::ParseResult CreateNdDescOp::parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result) {
::mlir::OpAsmParser::UnresolvedOperand sourceRawOperand{};
::llvm::ArrayRef<::mlir::OpAsmParser::UnresolvedOperand> sourceOperands(&sourceRawOperand, 1); ::llvm::SMLoc sourceOperandsLoc;
- (void)sourceOperandsLoc;
+
::llvm::SmallVector<::mlir::OpAsmParser::UnresolvedOperand, 4> offsetsOperands;
::llvm::SMLoc offsetsOperandsLoc;
- (void)offsetsOperandsLoc;
::mlir::DenseI64ArrayAttr const_offsetsAttr;
::llvm::SmallVector<::mlir::OpAsmParser::UnresolvedOperand, 4> shapeOperands;
::llvm::SMLoc shapeOperandsLoc;
- (void)shapeOperandsLoc;
::mlir::DenseI64ArrayAttr const_shapeAttr;
::llvm::SmallVector<::mlir::OpAsmParser::UnresolvedOperand, 4> stridesOperands;
::llvm::SMLoc stridesOperandsLoc;
- (void)stridesOperandsLoc;
::mlir::DenseI64ArrayAttr const_stridesAttr;
::mlir::Type sourceRawType{};
::llvm::ArrayRef<::mlir::Type> sourceTypes(&sourceRawType, 1);
@@ -308,45 +308,32 @@ ::mlir::ParseResult CreateNdDescOp::parse(::mlir::OpAsmParser &parser, ::mlir::O
if (parser.parseOperand(sourceRawOperand))
return ::mlir::failure();
- auto optionalOffsetResult = [&]() -> ::mlir::OptionalParseResult {
- {
- // skip the "offsets :" at the begining if it exists
- if (::mlir::succeeded(parser.parseOptionalKeyword("offsets"))) {
- if (parser.parseColon())
- return ::mlir::failure();
- }
- offsetsOperandsLoc = parser.getCurrentLocation();
- auto odsResult = parseOptionalDynamicIndexList(parser, offsetsOperands, const_offsetsAttr);
- // Debug print for offsets parsing using LLVM_DEBUG
- LLVM_DEBUG(llvm::dbgs() << "parseOptionalDynamicIndexList returned: " << (odsResult ? "failure" : "success") << "\n");
- LLVM_DEBUG(llvm::dbgs() << "offsetsOperands size: " << offsetsOperands.size() << "\n");
- if (const_offsetsAttr)
- LLVM_DEBUG(llvm::dbgs() << "const_offsetsAttr: " << const_offsetsAttr << "\n");
- if (const_offsetsAttr)
- result.getOrAddProperties<CreateNdDescOp::Properties>().const_offsets = const_offsetsAttr;
- }
- return ::mlir::success();
- }();
- if (optionalOffsetResult.has_value() && ::mlir::failed(*optionalOffsetResult)) {
- LLVM_DEBUG(llvm::dbgs() << "optionalOffsetResult failed\n");
- return ::mlir::failure();
- }
+ // skip the "offsets :" at the begining if it exists
+ //if (::mlir::succeeded(parser.parseOptionalKeyword("offsets"))) {
+ // if (parser.parseColon())
+ // return ::mlir::failure();
+ //}
+ offsetsOperandsLoc = parser.getCurrentLocation();
+
+ DenseBoolArrayAttr scalableFlags;
+ auto odsResult = parseOptionalDynamicIndexList(parser, offsetsOperands, const_offsetsAttr, scalableFlags);
+
+ if (const_offsetsAttr) {
+ if (odsResult) return ::mlir::failure();
+ result.getOrAddProperties<CreateNdDescOp::Properties>().const_offsets = const_offsetsAttr;
+ }
if (::mlir::succeeded(parser.parseOptionalKeyword("shape"))) {
- LLVM_DEBUG(llvm::dbgs() << "Parsing 'shape' keyword\n");
if (parser.parseColon())
return ::mlir::failure();
{
shapeOperandsLoc = parser.getCurrentLocation();
auto odsResult = parseDynamicIndexList(parser, shapeOperands, const_shapeAttr);
- LLVM_DEBUG(llvm::dbgs() << "parseDynamicIndexList for shape returned: " << (odsResult ? "failure" : "success") << "\n");
- LLVM_DEBUG(llvm::dbgs() << "shapeOperands size: " << shapeOperands.size() << "\n");
- if (const_shapeAttr)
- LLVM_DEBUG(llvm::dbgs() << "const_shapeAttr: " << const_shapeAttr << "\n");
- if (odsResult) return ::mlir::failure();
- if (const_shapeAttr)
+ if (const_shapeAttr) {
+ if (odsResult) return ::mlir::failure();
result.getOrAddProperties<CreateNdDescOp::Properties>().const_shape = const_shapeAttr;
+ }
}
if (parser.parseKeyword("strides"))
@@ -356,13 +343,10 @@ ::mlir::ParseResult CreateNdDescOp::parse(::mlir::OpAsmParser &parser, ::mlir::O
{
stridesOperandsLoc = parser.getCurrentLocation();
auto odsResult = parseDynamicIndexList(parser, stridesOperands, const_stridesAttr);
- LLVM_DEBUG(llvm::dbgs() << "parseDynamicIndexList for strides returned: " << (odsResult ? "failure" : "success") << "\n");
- LLVM_DEBUG(llvm::dbgs() << "stridesOperands size: " << stridesOperands.size() << "\n");
- if (const_stridesAttr)
- LLVM_DEBUG(llvm::dbgs() << "const_stridesAttr: " << const_stridesAttr << "\n");
- if (odsResult) return ::mlir::failure();
- if (const_stridesAttr)
+ if (const_stridesAttr) {
+ if (odsResult) return ::mlir::failure();
result.getOrAddProperties<CreateNdDescOp::Properties>().const_strides = const_stridesAttr;
+ }
}
}
{
diff --git a/mlir/test/Dialect/XeGPU/ops.mlir b/mlir/test/Dialect/XeGPU/ops.mlir
index e8836b7cffbc7..d5a01e4c66b5e 100644
--- a/mlir/test/Dialect/XeGPU/ops.mlir
+++ b/mlir/test/Dialect/XeGPU/ops.mlir
@@ -54,12 +54,18 @@ gpu.func @create_nd_tdesc_6(%src: memref<24x32xf32>) {
gpu.return
}
-// CHECK: gpu.func @test_create_nd_tdesc_7(%[[arg0:.*]]: ui64, %[[arg1:.*]]: index, %[[arg2:.*]]: index, %[[arg3:.*]]: index, %[[arg4:.*]]: index) {
-gpu.func @test_create_nd_tdesc_7(%src: ui64, %w : index, %h : index, %x : index, %y : index) {
+
+// CHECK: gpu.func @test_create_nd_tdesc_7(%[[arg0:.*]]: ui64, %[[arg1:.*]]: index, %[[arg2:.*]]: index, %[[arg3:.*]]: index, %[[arg4:.*]]: index, %[[arg5:.*]]: memref<24x32xf32>)
+gpu.func @test_create_nd_tdesc_7(%src: ui64, %w : index, %h : index, %x : index, %y : index, %src2: memref<24x32xf32>) {
//CHECK: %[[C:.*]] = arith.constant 1 : index
%c1 = arith.constant 1 : index
- // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %[[arg0]][%[[arg3]], %[[arg4]]] shape : [%[[arg2]], %[[arg1]]] strides : [%[[arg1]], %[[C]]] : ui64 -> !xegpu.tensor_desc<8x16xf32>
- %1 = xegpu.create_nd_tdesc %src offsets : [%x, %y] shape : [%h, %w] strides : [%w, %c1] : ui64 -> !xegpu.tensor_desc<8x16xf32>
+
+ // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %[[arg5]] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
+ %3 = xegpu.create_nd_tdesc %src2 : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
+
+ // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0] shape : [%[[arg2]], %[[arg1]]] strides : [%[[arg1]], %[[C]]] : ui64 -> !xegpu.tensor_desc<8x16xf32>
+ %2 = xegpu.create_nd_tdesc %src[0, 0] shape : [%h, %w] strides : [%w, %c1] : ui64 -> !xegpu.tensor_desc<8x16xf32>
+
gpu.return
}
>From 107787193ebe82524ec5231d3c013d08d1532040 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Wed, 2 Jul 2025 02:13:20 +0000
Subject: [PATCH 3/7] git-clang-format
---
mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 136 +++++++++++++------------
1 file changed, 70 insertions(+), 66 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index cab4ca8a73898..e6590c2ed53fa 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -122,9 +122,8 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
build(builder, state, tdesc, source, dynamicOffsets /* dynamic offsets */,
ValueRange({}) /* empty dynamic shape */,
ValueRange({}) /* empty dynamic strides */,
- builder.getDenseI64ArrayAttr(staticOffsets) /* const offsets */,
- {} /* empty const shape*/,
- {} /* empty const strides*/);
+ builder.getDenseI64ArrayAttr(staticOffsets) /* const offsets */,
+ {} /* empty const shape*/, {} /* empty const strides*/);
}
void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
@@ -136,8 +135,8 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
shape.size() == strides.size() && shape.size() == offsets.size());
Type srcTy = source.getType();
- assert((isa<IntegerType>(srcTy) ||
- isa<MemRefType>(srcTy)) && "Source has to be either int or memref.");
+ assert((isa<IntegerType>(srcTy) || isa<MemRefType>(srcTy)) &&
+ "Source has to be either int or memref.");
llvm::SmallVector<Value> dynamicOffsets;
llvm::SmallVector<Value> dynamicShape;
@@ -222,24 +221,6 @@ LogicalResult CreateNdDescOp::verify() {
}
-//ParseResult parseOptionalDynamicIndexList(
-// OpAsmParser &parser,
-// SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
-// DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableFlags,
-// SmallVectorImpl<Type> *valueTypes = nullptr,
-// AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square);
-//inline ParseResult parseOptionalDynamicIndexList(
-// OpAsmParser &parser,
-// SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
-// DenseI64ArrayAttr &integers, SmallVectorImpl<Type> *valueTypes = nullptr,
-// AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square) {
-// DenseBoolArrayAttr scalableFlags;
-// return parseOptionalDynamicIndexList(parser, values, integers, scalableFlags,
-// valueTypes, delimiter);
-//}
-
-
-
ParseResult parseOptionalDynamicIndexList(
OpAsmParser &parser,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
@@ -275,9 +256,10 @@ ParseResult parseOptionalDynamicIndexList(
return success();
};
if (parser.parseOptionalLSquare().succeeded()) {
- if ( parser.parseCommaSeparatedList(parseIntegerOrValue) || parser.parseRSquare() )
+ if (parser.parseCommaSeparatedList(parseIntegerOrValue) ||
+ parser.parseRSquare())
return parser.emitError(parser.getNameLoc())
- << "expected SSA value or integer";
+ << "expected SSA value or integer";
integers = parser.getBuilder().getDenseI64ArrayAttr(integerVals);
scalableFlags = parser.getBuilder().getDenseBoolArrayAttr(scalableVals);
return success();
@@ -285,18 +267,22 @@ ParseResult parseOptionalDynamicIndexList(
return success();
}
-
-::mlir::ParseResult CreateNdDescOp::parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result) {
+::mlir::ParseResult CreateNdDescOp::parse(::mlir::OpAsmParser &parser,
+ ::mlir::OperationState &result) {
::mlir::OpAsmParser::UnresolvedOperand sourceRawOperand{};
- ::llvm::ArrayRef<::mlir::OpAsmParser::UnresolvedOperand> sourceOperands(&sourceRawOperand, 1); ::llvm::SMLoc sourceOperandsLoc;
+ ::llvm::ArrayRef<::mlir::OpAsmParser::UnresolvedOperand> sourceOperands(
+ &sourceRawOperand, 1);
+ ::llvm::SMLoc sourceOperandsLoc;
- ::llvm::SmallVector<::mlir::OpAsmParser::UnresolvedOperand, 4> offsetsOperands;
+ ::llvm::SmallVector<::mlir::OpAsmParser::UnresolvedOperand, 4>
+ offsetsOperands;
::llvm::SMLoc offsetsOperandsLoc;
::mlir::DenseI64ArrayAttr const_offsetsAttr;
::llvm::SmallVector<::mlir::OpAsmParser::UnresolvedOperand, 4> shapeOperands;
::llvm::SMLoc shapeOperandsLoc;
::mlir::DenseI64ArrayAttr const_shapeAttr;
- ::llvm::SmallVector<::mlir::OpAsmParser::UnresolvedOperand, 4> stridesOperands;
+ ::llvm::SmallVector<::mlir::OpAsmParser::UnresolvedOperand, 4>
+ stridesOperands;
::llvm::SMLoc stridesOperandsLoc;
::mlir::DenseI64ArrayAttr const_stridesAttr;
::mlir::Type sourceRawType{};
@@ -308,31 +294,36 @@ ::mlir::ParseResult CreateNdDescOp::parse(::mlir::OpAsmParser &parser, ::mlir::O
if (parser.parseOperand(sourceRawOperand))
return ::mlir::failure();
+ // skip the "offsets :" at the begining if it exists
+ // if (::mlir::succeeded(parser.parseOptionalKeyword("offsets"))) {
+ // if (parser.parseColon())
+ // return ::mlir::failure();
+ //}
+ offsetsOperandsLoc = parser.getCurrentLocation();
- // skip the "offsets :" at the begining if it exists
- //if (::mlir::succeeded(parser.parseOptionalKeyword("offsets"))) {
- // if (parser.parseColon())
- // return ::mlir::failure();
- //}
- offsetsOperandsLoc = parser.getCurrentLocation();
-
- DenseBoolArrayAttr scalableFlags;
- auto odsResult = parseOptionalDynamicIndexList(parser, offsetsOperands, const_offsetsAttr, scalableFlags);
+ DenseBoolArrayAttr scalableFlags;
+ auto odsResult = parseOptionalDynamicIndexList(
+ parser, offsetsOperands, const_offsetsAttr, scalableFlags);
- if (const_offsetsAttr) {
- if (odsResult) return ::mlir::failure();
- result.getOrAddProperties<CreateNdDescOp::Properties>().const_offsets = const_offsetsAttr;
- }
+ if (const_offsetsAttr) {
+ if (odsResult)
+ return ::mlir::failure();
+ result.getOrAddProperties<CreateNdDescOp::Properties>().const_offsets =
+ const_offsetsAttr;
+ }
if (::mlir::succeeded(parser.parseOptionalKeyword("shape"))) {
if (parser.parseColon())
return ::mlir::failure();
{
shapeOperandsLoc = parser.getCurrentLocation();
- auto odsResult = parseDynamicIndexList(parser, shapeOperands, const_shapeAttr);
+ auto odsResult =
+ parseDynamicIndexList(parser, shapeOperands, const_shapeAttr);
if (const_shapeAttr) {
- if (odsResult) return ::mlir::failure();
- result.getOrAddProperties<CreateNdDescOp::Properties>().const_shape = const_shapeAttr;
+ if (odsResult)
+ return ::mlir::failure();
+ result.getOrAddProperties<CreateNdDescOp::Properties>().const_shape =
+ const_shapeAttr;
}
}
@@ -342,20 +333,24 @@ ::mlir::ParseResult CreateNdDescOp::parse(::mlir::OpAsmParser &parser, ::mlir::O
return ::mlir::failure();
{
stridesOperandsLoc = parser.getCurrentLocation();
- auto odsResult = parseDynamicIndexList(parser, stridesOperands, const_stridesAttr);
+ auto odsResult =
+ parseDynamicIndexList(parser, stridesOperands, const_stridesAttr);
if (const_stridesAttr) {
- if (odsResult) return ::mlir::failure();
- result.getOrAddProperties<CreateNdDescOp::Properties>().const_strides = const_stridesAttr;
+ if (odsResult)
+ return ::mlir::failure();
+ result.getOrAddProperties<CreateNdDescOp::Properties>().const_strides =
+ const_stridesAttr;
}
}
}
{
- auto loc = parser.getCurrentLocation();(void)loc;
+ auto loc = parser.getCurrentLocation();
if (parser.parseOptionalAttrDict(result.attributes))
return ::mlir::failure();
if (failed(verifyInherentAttrs(result.name, result.attributes, [&]() {
- return parser.emitError(loc) << "'" << result.name.getStringRef() << "' op ";
- })))
+ return parser.emitError(loc)
+ << "'" << result.name.getStringRef() << "' op ";
+ })))
return ::mlir::failure();
}
if (parser.parseColon())
@@ -373,21 +368,30 @@ ::mlir::ParseResult CreateNdDescOp::parse(::mlir::OpAsmParser &parser, ::mlir::O
if (parser.parseType(TensorDescRawType))
return ::mlir::failure();
- ::llvm::copy(::llvm::ArrayRef<int32_t>({1, static_cast<int32_t>(offsetsOperands.size()), static_cast<int32_t>(shapeOperands.size()), static_cast<int32_t>(stridesOperands.size())}), result.getOrAddProperties<CreateNdDescOp::Properties>().operandSegmentSizes.begin());
+ ::llvm::copy(::llvm::ArrayRef<int32_t>(
+ {1, static_cast<int32_t>(offsetsOperands.size()),
+ static_cast<int32_t>(shapeOperands.size()),
+ static_cast<int32_t>(stridesOperands.size())}),
+ result.getOrAddProperties<CreateNdDescOp::Properties>()
+ .operandSegmentSizes.begin());
::mlir::Type odsBuildableType0 = parser.getBuilder().getIndexType();
result.addTypes(TensorDescTypes);
- if (parser.resolveOperands(sourceOperands, sourceTypes, sourceOperandsLoc, result.operands))
+ if (parser.resolveOperands(sourceOperands, sourceTypes, sourceOperandsLoc,
+ result.operands))
return ::mlir::failure();
- if (parser.resolveOperands(offsetsOperands, odsBuildableType0, offsetsOperandsLoc, result.operands))
+ if (parser.resolveOperands(offsetsOperands, odsBuildableType0,
+ offsetsOperandsLoc, result.operands))
return ::mlir::failure();
- if (parser.resolveOperands(shapeOperands, odsBuildableType0, shapeOperandsLoc, result.operands))
+ if (parser.resolveOperands(shapeOperands, odsBuildableType0, shapeOperandsLoc,
+ result.operands))
return ::mlir::failure();
- if (parser.resolveOperands(stridesOperands, odsBuildableType0, stridesOperandsLoc, result.operands))
+ if (parser.resolveOperands(stridesOperands, odsBuildableType0,
+ stridesOperandsLoc, result.operands))
return ::mlir::failure();
return ::mlir::success();
}
@@ -395,7 +399,8 @@ ::mlir::ParseResult CreateNdDescOp::parse(::mlir::OpAsmParser &parser, ::mlir::O
void CreateNdDescOp::print(::mlir::OpAsmPrinter &_odsPrinter) {
_odsPrinter << ' ';
_odsPrinter << getSource();
- // Print offsets if getConstOffsetsAttr() exists, is not empty, and its first value is not int64_t::max.
+ // Print offsets if getConstOffsetsAttr() exists, is not empty, and its first
+ // value is not int64_t::max.
auto constOffsetsAttr = getConstOffsetsAttr();
bool printOffsets = false;
if (constOffsetsAttr && constOffsetsAttr.size() > 0) {
@@ -406,7 +411,8 @@ void CreateNdDescOp::print(::mlir::OpAsmPrinter &_odsPrinter) {
}
if (printOffsets) {
- printDynamicIndexList(_odsPrinter, *this, getOffsets(), getConstOffsetsAttr());
+ printDynamicIndexList(_odsPrinter, *this, getOffsets(),
+ getConstOffsetsAttr());
}
if (((!getShape().empty()) || (getConstShapeAttr()))) {
_odsPrinter << ' ' << "shape";
@@ -416,7 +422,8 @@ void CreateNdDescOp::print(::mlir::OpAsmPrinter &_odsPrinter) {
_odsPrinter << ' ' << "strides";
_odsPrinter << ' ' << ":";
_odsPrinter << ' ';
- printDynamicIndexList(_odsPrinter, *this, getStrides(), getConstStridesAttr());
+ printDynamicIndexList(_odsPrinter, *this, getStrides(),
+ getConstStridesAttr());
}
::llvm::SmallVector<::llvm::StringRef, 2> elidedAttrs;
elidedAttrs.push_back("operandSegmentSizes");
@@ -430,17 +437,16 @@ void CreateNdDescOp::print(::mlir::OpAsmPrinter &_odsPrinter) {
auto type = getSource().getType();
if (auto validType = ::llvm::dyn_cast<::mlir::Type>(type))
_odsPrinter.printStrippedAttrOrType(validType);
- else
- _odsPrinter << type;
+ else
+ _odsPrinter << type;
}
_odsPrinter << ' ' << "->";
_odsPrinter << ' ';
// _odsPrinter << getTensorDesc().getType();
-
_odsPrinter << "!xegpu.tensor_desc<";
- auto tDesc = getTensorDesc().getType();
+ auto tDesc = getTensorDesc().getType();
auto shape = tDesc.getShape();
for (int64_t dim : shape) {
if (mlir::ShapedType::isDynamic(dim))
@@ -459,7 +465,6 @@ void CreateNdDescOp::print(::mlir::OpAsmPrinter &_odsPrinter) {
_odsPrinter << ", " << layout;
_odsPrinter << ">";
-
}
//===----------------------------------------------------------------------===//
@@ -877,7 +882,6 @@ LogicalResult ConvertLayoutOp::verify() {
return mlir::success();
}
-
} // namespace xegpu
} // namespace mlir
>From 42baa22915a12f680b1aba6b43a6acf10e0009ad Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Wed, 2 Jul 2025 03:03:13 +0000
Subject: [PATCH 4/7] add more tests
---
mlir/test/Dialect/XeGPU/ops.mlir | 29 ++++++++++++++++++++++++++++-
1 file changed, 28 insertions(+), 1 deletion(-)
diff --git a/mlir/test/Dialect/XeGPU/ops.mlir b/mlir/test/Dialect/XeGPU/ops.mlir
index d5a01e4c66b5e..d746de69c4f8f 100644
--- a/mlir/test/Dialect/XeGPU/ops.mlir
+++ b/mlir/test/Dialect/XeGPU/ops.mlir
@@ -63,8 +63,35 @@ gpu.func @test_create_nd_tdesc_7(%src: ui64, %w : index, %h : index, %x : index,
// CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %[[arg5]] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
%3 = xegpu.create_nd_tdesc %src2 : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
- // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0] shape : [%[[arg2]], %[[arg1]]] strides : [%[[arg1]], %[[C]]] : ui64 -> !xegpu.tensor_desc<8x16xf32>
+ gpu.return
+}
+
+// CHECK: gpu.func @test_create_nd_tdesc_8(%[[arg0:.*]]: ui64, %[[arg1:.*]]: index, %[[arg2:.*]]: index, %[[arg3:.*]]: index, %[[arg4:.*]]: index)
+gpu.func @test_create_nd_tdesc_8(%src: ui64, %w : index, %h : index, %x : index, %y : index) {
+
+ %c1 = arith.constant 1 : index
+ // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] shape : [%arg2, %arg1] strides : [%arg1, %c1] : ui64 -> !xegpu.tensor_desc<8x16xf32>
%2 = xegpu.create_nd_tdesc %src[0, 0] shape : [%h, %w] strides : [%w, %c1] : ui64 -> !xegpu.tensor_desc<8x16xf32>
+
+ gpu.return
+}
+
+// CHECK-LABEL: func @test_create_nd_tdesc_9({{.*}})
+
+gpu.func @test_create_nd_tdesc_9(%src: memref<?x?xf16>, %w : index, %h : index, %x : index, %y : index) {
+
+ %c1 = arith.constant 1 : index
+ // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %arg0[%arg3, %arg4] shape : [%arg2, %arg1] strides : [%arg1, %c1] : memref<?x?xf16> -> !xegpu.tensor_desc<8x16xf16>
+ %1 = xegpu.create_nd_tdesc %src[%x, %y] shape:[%h, %w] strides:[%w, %c1] : memref<?x?xf16> -> !xegpu.tensor_desc<8x16xf16>
+
+ gpu.return
+}
+
+// CHECK-LABEL: func @test_create_nd_tdesc_10({{.*}})
+gpu.func @test_create_nd_tdesc_10(%src: memref<?x?xf16>, %w : index, %h : index, %x : index, %y : index) {
+ %c1 = arith.constant 1 : index
+ // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %arg0 shape : [%arg2, %arg1] strides : [%arg1, %c1] : memref<?x?xf16> -> !xegpu.tensor_desc<8x16xf16>
+ %2 = xegpu.create_nd_tdesc %src shape:[%h, %w] strides:[%w, %c1] : memref<?x?xf16> -> !xegpu.tensor_desc<8x16xf16>
gpu.return
}
>From 204d34781cc18dbd19a640afe024245afe0c9684 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Wed, 2 Jul 2025 03:04:09 +0000
Subject: [PATCH 5/7] git-clang-format
---
mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 1 -
1 file changed, 1 deletion(-)
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index e6590c2ed53fa..ba788f3454d25 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -220,7 +220,6 @@ LogicalResult CreateNdDescOp::verify() {
return success();
}
-
ParseResult parseOptionalDynamicIndexList(
OpAsmParser &parser,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
>From 2793c8130b7379987f6ea451c4fc3dcd7e8a34b4 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Sat, 12 Jul 2025 05:07:52 +0000
Subject: [PATCH 6/7] add ui64 case support
---
.../include/mlir/Dialect/XeGPU/IR/XeGPUOps.td | 28 ++++++++++---------
1 file changed, 15 insertions(+), 13 deletions(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index 2cbae19ff2c05..86c9d40575104 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -116,15 +116,6 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface
);
let results = (outs XeGPU_TensorDesc: $TensorDesc);
-
-// let assemblyFormat = [{
-// $source
-// (custom<DynamicIndexList>($offsets, $const_offsets)^)?
-// (`base_shape` `:` custom<DynamicIndexList>($shape, $const_shape)^
-// `base_strides` `:` custom<DynamicIndexList>($strides, $const_strides))?
-// attr-dict `:` type($source) `->` qualified(type($TensorDesc))
-// }];
-
let hasVerifier = 1;
let hasCustomAssemblyFormat = 1;
@@ -177,12 +168,23 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface
ArrayRef<int64_t> getStaticOffsets(){
auto attr = getConstOffsetsAttr();
- if (llvm::isa<IntegerType>(getSourceType()) || attr)
+
+ if (attr)
return attr;
- // The offsets are allowed to be empty. The Traits verification of OffsetSizeAndStrideOpInterface interface assumes offsets being present. So it is set to be MAX to indicate user not passed any value (kDynamic means offsets passed as variable).
- setConstOffsets(llvm::SmallVector<int64_t, 4>(getTensorDescShape().size(), std::numeric_limits<int64_t>::max()));
- //setConstOffsets(llvm::SmallVector<int64_t, 4>(getTensorDescShape().size(), mlir::ShapedType::kDynamic));
+ auto memrefType = llvm::dyn_cast<MemRefType>(getSourceType());
+ int rank = 0;
+ if (memrefType) {
+ //use source memref's rank, as source memref rank may be higher
+ rank = memrefType.getRank();
+ } else {
+ //nd_tdesc created from ui64, use nd_tdesc's rank
+ rank = getTensorDescShape().size();
+ };
+
+ // The offsets are allowed to be empty. The Traits verification of OffsetSizeAndStrideOpInterface interface assumes offsets being present.
+ // It is set to be MAX to indicate user not passed any value, instead of kDynamic which means offsets passed as value.
+ setConstOffsets(llvm::SmallVector<int64_t, 4>(rank, std::numeric_limits<int64_t>::max()));
attr = getConstOffsetsAttr();
return attr;
>From 6793689a36bf58b07fdbee24b92f1fe0fb56cff2 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Sat, 12 Jul 2025 05:49:58 +0000
Subject: [PATCH 7/7] remove unnecessary comments
---
mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 8 +-------
1 file changed, 1 insertion(+), 7 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 62db7bd858d78..9f6090ad279f5 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -294,11 +294,6 @@ ::mlir::ParseResult CreateNdDescOp::parse(::mlir::OpAsmParser &parser,
if (parser.parseOperand(sourceRawOperand))
return ::mlir::failure();
- // skip the "offsets :" at the begining if it exists
- // if (::mlir::succeeded(parser.parseOptionalKeyword("offsets"))) {
- // if (parser.parseColon())
- // return ::mlir::failure();
- //}
offsetsOperandsLoc = parser.getCurrentLocation();
DenseBoolArrayAttr scalableFlags;
@@ -399,8 +394,7 @@ ::mlir::ParseResult CreateNdDescOp::parse(::mlir::OpAsmParser &parser,
void CreateNdDescOp::print(::mlir::OpAsmPrinter &_odsPrinter) {
_odsPrinter << ' ';
_odsPrinter << getSource();
- // Print offsets if getConstOffsetsAttr() exists, is not empty, and its first
- // value is not int64_t::max.
+
auto constOffsetsAttr = getConstOffsetsAttr();
bool printOffsets = false;
if (constOffsetsAttr && constOffsetsAttr.size() > 0) {
More information about the Mlir-commits
mailing list