[Mlir-commits] [mlir] [MLIR][XeGPU] make offsets optional for create_nd_tdesc (PR #148335)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Jul 11 22:55:52 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Jianhui Li (Jianhui-Li)
<details>
<summary>Changes</summary>
This PR allows xegpu to take optional offsets when create_nd_tdesc. This is the initial PR to move offsets from create_nd_tdesc to load_nd.
When creating nd_tdesc for dynamic shape tensor, must use @<!-- -->shape and @<!-- -->strides attributes to describe base tensor.
```mlir
%2 = xegpu.create_nd_tdesc %src[%x, %y] shape:[%h, %w] strides:[%w, %c1] : memref<?x?xf16> -> !xegpu.tensor_desc<8x16xf16>
%2 = xegpu.create_nd_tdesc %src[0, 0] shape : [%h, %w] strides : [%w, %c1] : ui64 -> !xegpu.tensor_desc<8x16xf32>
```
---
Full diff: https://github.com/llvm/llvm-project/pull/148335.diff
4 Files Affected:
- (modified) mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td (+35-10)
- (modified) mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp (+242-2)
- (modified) mlir/test/Dialect/XeGPU/ops.mlir (+43-2)
- (modified) mlir/test/Dialect/XeGPU/subgroup-distribute.mlir (+4-4)
``````````diff
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index bd5ea9fd83781..710fc62b032a9 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -110,23 +110,27 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface
Variadic<Index>: $offsets,
Variadic<Index>: $shape,
Variadic<Index>: $strides,
- DenseI64ArrayAttr: $const_offsets,
+ OptionalAttr<DenseI64ArrayAttr>: $const_offsets,
OptionalAttr<DenseI64ArrayAttr>: $const_shape,
OptionalAttr<DenseI64ArrayAttr>: $const_strides
);
let results = (outs XeGPU_TensorDesc: $TensorDesc);
- let assemblyFormat = [{
- $source ``
- custom<DynamicIndexList>($offsets, $const_offsets)
- (`,` custom<DynamicIndexList>($shape, $const_shape)^
- `,` custom<DynamicIndexList>($strides, $const_strides))?
- attr-dict `:` type($source) `->` qualified(type($TensorDesc))
- }];
-
let hasVerifier = 1;
+ let hasCustomAssemblyFormat = 1;
+
let builders = [
+ OpBuilder<(ins "Type": $tdesc, "TypedValue<MemRefType>": $source)>,
+
+ OpBuilder<(ins "Type": $tdesc, "TypedValue<MemRefType> ": $source,
+ "llvm::ArrayRef<OpFoldResult>": $shape,
+ "llvm::ArrayRef<OpFoldResult>": $strides)>,
+
+ OpBuilder<(ins "Type": $tdesc, "TypedValue<IntegerType> ": $source,
+ "llvm::ArrayRef<OpFoldResult>": $shape,
+ "llvm::ArrayRef<OpFoldResult>": $strides)>,
+
OpBuilder<(ins "Type": $tdesc, "TypedValue<MemRefType>": $source,
"llvm::ArrayRef<OpFoldResult>": $offsets)>,
@@ -163,9 +167,30 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface
}
ArrayRef<int64_t> getStaticOffsets(){
- return getConstOffsets();
+ auto attr = getConstOffsetsAttr();
+
+ if (attr)
+ return attr;
+
+ auto memrefType = llvm::dyn_cast<MemRefType>(getSourceType());
+ int rank = 0;
+ if (memrefType) {
+ //use source memref's rank, as source memref rank may be higher
+ rank = memrefType.getRank();
+ } else {
+ //nd_tdesc created from ui64, use nd_tdesc's rank
+ rank = getTensorDescShape().size();
+ };
+
+ // The offsets are allowed to be empty. The Traits verification of OffsetSizeAndStrideOpInterface interface assumes offsets being present.
+ // It is set to be MAX to indicate user not passed any value, instead of kDynamic which means offsets passed as value.
+ setConstOffsets(llvm::SmallVector<int64_t, 4>(rank, std::numeric_limits<int64_t>::max()));
+
+ attr = getConstOffsetsAttr();
+ return attr;
}
+
/// wrapper for matching with OffsetSizeAndStrideOpInterface
/// If source is IntegerType or `const_shape` is filled,
/// it will return `const_shape`, such that mixes of `shape`
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index ef7cd1424e7a4..9f6090ad279f5 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -125,8 +125,8 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
build(builder, state, tdesc, source, dynamicOffsets /* dynamic offsets */,
ValueRange({}) /* empty dynamic shape */,
ValueRange({}) /* empty dynamic strides */,
- staticOffsets /* const offsets */, {} /* empty const shape*/,
- {} /* empty const strides*/);
+ builder.getDenseI64ArrayAttr(staticOffsets) /* const offsets */,
+ {} /* empty const shape*/, {} /* empty const strides*/);
}
void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
@@ -221,6 +221,246 @@ LogicalResult CreateNdDescOp::verify() {
return success();
}
+ParseResult parseOptionalDynamicIndexList(
+ OpAsmParser &parser,
+ SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
+ DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableFlags,
+ SmallVectorImpl<Type> *valueTypes = nullptr,
+ AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square) {
+
+ SmallVector<int64_t, 4> integerVals;
+ SmallVector<bool, 4> scalableVals;
+ auto parseIntegerOrValue = [&]() {
+ OpAsmParser::UnresolvedOperand operand;
+ auto res = parser.parseOptionalOperand(operand);
+
+ // When encountering `[`, assume that this is a scalable index.
+ scalableVals.push_back(parser.parseOptionalLSquare().succeeded());
+
+ if (res.has_value() && succeeded(res.value())) {
+ values.push_back(operand);
+ integerVals.push_back(ShapedType::kDynamic);
+ if (valueTypes && parser.parseColonType(valueTypes->emplace_back()))
+ return failure();
+ } else {
+ int64_t integer;
+ if (failed(parser.parseInteger(integer)))
+ return failure();
+ integerVals.push_back(integer);
+ }
+
+ // If this is assumed to be a scalable index, verify that there's a closing
+ // `]`.
+ if (scalableVals.back() && parser.parseOptionalRSquare().failed())
+ return failure();
+ return success();
+ };
+ if (parser.parseOptionalLSquare().succeeded()) {
+ if (parser.parseCommaSeparatedList(parseIntegerOrValue) ||
+ parser.parseRSquare())
+ return parser.emitError(parser.getNameLoc())
+ << "expected SSA value or integer";
+ integers = parser.getBuilder().getDenseI64ArrayAttr(integerVals);
+ scalableFlags = parser.getBuilder().getDenseBoolArrayAttr(scalableVals);
+ return success();
+ }
+ return success();
+}
+
+::mlir::ParseResult CreateNdDescOp::parse(::mlir::OpAsmParser &parser,
+ ::mlir::OperationState &result) {
+ ::mlir::OpAsmParser::UnresolvedOperand sourceRawOperand{};
+ ::llvm::ArrayRef<::mlir::OpAsmParser::UnresolvedOperand> sourceOperands(
+ &sourceRawOperand, 1);
+ ::llvm::SMLoc sourceOperandsLoc;
+
+ ::llvm::SmallVector<::mlir::OpAsmParser::UnresolvedOperand, 4>
+ offsetsOperands;
+ ::llvm::SMLoc offsetsOperandsLoc;
+ ::mlir::DenseI64ArrayAttr const_offsetsAttr;
+ ::llvm::SmallVector<::mlir::OpAsmParser::UnresolvedOperand, 4> shapeOperands;
+ ::llvm::SMLoc shapeOperandsLoc;
+ ::mlir::DenseI64ArrayAttr const_shapeAttr;
+ ::llvm::SmallVector<::mlir::OpAsmParser::UnresolvedOperand, 4>
+ stridesOperands;
+ ::llvm::SMLoc stridesOperandsLoc;
+ ::mlir::DenseI64ArrayAttr const_stridesAttr;
+ ::mlir::Type sourceRawType{};
+ ::llvm::ArrayRef<::mlir::Type> sourceTypes(&sourceRawType, 1);
+ ::mlir::Type TensorDescRawType{};
+ ::llvm::ArrayRef<::mlir::Type> TensorDescTypes(&TensorDescRawType, 1);
+
+ sourceOperandsLoc = parser.getCurrentLocation();
+ if (parser.parseOperand(sourceRawOperand))
+ return ::mlir::failure();
+
+ offsetsOperandsLoc = parser.getCurrentLocation();
+
+ DenseBoolArrayAttr scalableFlags;
+ auto odsResult = parseOptionalDynamicIndexList(
+ parser, offsetsOperands, const_offsetsAttr, scalableFlags);
+
+ if (const_offsetsAttr) {
+ if (odsResult)
+ return ::mlir::failure();
+ result.getOrAddProperties<CreateNdDescOp::Properties>().const_offsets =
+ const_offsetsAttr;
+ }
+
+ if (::mlir::succeeded(parser.parseOptionalKeyword("shape"))) {
+ if (parser.parseColon())
+ return ::mlir::failure();
+ {
+ shapeOperandsLoc = parser.getCurrentLocation();
+ auto odsResult =
+ parseDynamicIndexList(parser, shapeOperands, const_shapeAttr);
+ if (const_shapeAttr) {
+ if (odsResult)
+ return ::mlir::failure();
+ result.getOrAddProperties<CreateNdDescOp::Properties>().const_shape =
+ const_shapeAttr;
+ }
+ }
+
+ if (parser.parseKeyword("strides"))
+ return ::mlir::failure();
+ if (parser.parseColon())
+ return ::mlir::failure();
+ {
+ stridesOperandsLoc = parser.getCurrentLocation();
+ auto odsResult =
+ parseDynamicIndexList(parser, stridesOperands, const_stridesAttr);
+ if (const_stridesAttr) {
+ if (odsResult)
+ return ::mlir::failure();
+ result.getOrAddProperties<CreateNdDescOp::Properties>().const_strides =
+ const_stridesAttr;
+ }
+ }
+ }
+ {
+ auto loc = parser.getCurrentLocation();
+ if (parser.parseOptionalAttrDict(result.attributes))
+ return ::mlir::failure();
+ if (failed(verifyInherentAttrs(result.name, result.attributes, [&]() {
+ return parser.emitError(loc)
+ << "'" << result.name.getStringRef() << "' op ";
+ })))
+ return ::mlir::failure();
+ }
+ if (parser.parseColon())
+ return ::mlir::failure();
+
+ {
+ ::mlir::Type type;
+ if (parser.parseCustomTypeWithFallback(type))
+ return ::mlir::failure();
+ sourceRawType = type;
+ }
+ if (parser.parseArrow())
+ return ::mlir::failure();
+
+ if (parser.parseType(TensorDescRawType))
+ return ::mlir::failure();
+
+ ::llvm::copy(::llvm::ArrayRef<int32_t>(
+ {1, static_cast<int32_t>(offsetsOperands.size()),
+ static_cast<int32_t>(shapeOperands.size()),
+ static_cast<int32_t>(stridesOperands.size())}),
+ result.getOrAddProperties<CreateNdDescOp::Properties>()
+ .operandSegmentSizes.begin());
+
+ ::mlir::Type odsBuildableType0 = parser.getBuilder().getIndexType();
+ result.addTypes(TensorDescTypes);
+
+ if (parser.resolveOperands(sourceOperands, sourceTypes, sourceOperandsLoc,
+ result.operands))
+ return ::mlir::failure();
+
+ if (parser.resolveOperands(offsetsOperands, odsBuildableType0,
+ offsetsOperandsLoc, result.operands))
+ return ::mlir::failure();
+
+ if (parser.resolveOperands(shapeOperands, odsBuildableType0, shapeOperandsLoc,
+ result.operands))
+ return ::mlir::failure();
+
+ if (parser.resolveOperands(stridesOperands, odsBuildableType0,
+ stridesOperandsLoc, result.operands))
+ return ::mlir::failure();
+ return ::mlir::success();
+}
+
+void CreateNdDescOp::print(::mlir::OpAsmPrinter &_odsPrinter) {
+ _odsPrinter << ' ';
+ _odsPrinter << getSource();
+
+ auto constOffsetsAttr = getConstOffsetsAttr();
+ bool printOffsets = false;
+ if (constOffsetsAttr && constOffsetsAttr.size() > 0) {
+ auto firstVal = constOffsetsAttr.asArrayRef()[0];
+ if (firstVal != std::numeric_limits<int64_t>::max()) {
+ printOffsets = true;
+ }
+ }
+ if (printOffsets) {
+
+ printDynamicIndexList(_odsPrinter, *this, getOffsets(),
+ getConstOffsetsAttr());
+ }
+ if (((!getShape().empty()) || (getConstShapeAttr()))) {
+ _odsPrinter << ' ' << "shape";
+ _odsPrinter << ' ' << ":";
+ _odsPrinter << ' ';
+ printDynamicIndexList(_odsPrinter, *this, getShape(), getConstShapeAttr());
+ _odsPrinter << ' ' << "strides";
+ _odsPrinter << ' ' << ":";
+ _odsPrinter << ' ';
+ printDynamicIndexList(_odsPrinter, *this, getStrides(),
+ getConstStridesAttr());
+ }
+ ::llvm::SmallVector<::llvm::StringRef, 2> elidedAttrs;
+ elidedAttrs.push_back("operandSegmentSizes");
+ elidedAttrs.push_back("const_offsets");
+ elidedAttrs.push_back("const_shape");
+ elidedAttrs.push_back("const_strides");
+ _odsPrinter.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
+ _odsPrinter << ' ' << ":";
+ _odsPrinter << ' ';
+ {
+ auto type = getSource().getType();
+ if (auto validType = ::llvm::dyn_cast<::mlir::Type>(type))
+ _odsPrinter.printStrippedAttrOrType(validType);
+ else
+ _odsPrinter << type;
+ }
+ _odsPrinter << ' ' << "->";
+ _odsPrinter << ' ';
+ // _odsPrinter << getTensorDesc().getType();
+
+ _odsPrinter << "!xegpu.tensor_desc<";
+
+ auto tDesc = getTensorDesc().getType();
+ auto shape = tDesc.getShape();
+ for (int64_t dim : shape) {
+ if (mlir::ShapedType::isDynamic(dim))
+ _odsPrinter << '?';
+ else
+ _odsPrinter << dim;
+ _odsPrinter << 'x';
+ }
+
+ _odsPrinter << tDesc.getElementType();
+
+ if (auto encoding = tDesc.getEncoding())
+ _odsPrinter << ", " << encoding;
+
+ if (auto layout = tDesc.getLayout())
+ _odsPrinter << ", " << layout;
+
+ _odsPrinter << ">";
+}
+
//===----------------------------------------------------------------------===//
// XeGPU_PrefetchNdOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/XeGPU/ops.mlir b/mlir/test/Dialect/XeGPU/ops.mlir
index 3bfe1fa81aa6e..0d679e519ed60 100644
--- a/mlir/test/Dialect/XeGPU/ops.mlir
+++ b/mlir/test/Dialect/XeGPU/ops.mlir
@@ -17,8 +17,8 @@ gpu.func @create_nd_tdesc_1(%src: memref<24x32xf32>) {
gpu.func @create_nd_tdesc_2(%src: ui64, %w : index, %h : index, %x : index, %y : index) {
//CHECK: %[[C:.*]] = arith.constant 1 : index
%c1 = arith.constant 1 : index
- // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %[[arg0]][%[[arg3]], %[[arg4]]], [%[[arg2]], %[[arg1]]], [%[[arg1]], %[[C]]] : ui64 -> !xegpu.tensor_desc<8x16xf32>
- %1 = xegpu.create_nd_tdesc %src[%x, %y], [%h, %w], [%w, %c1] : ui64 -> !xegpu.tensor_desc<8x16xf32>
+ // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %[[arg0]][%[[arg3]], %[[arg4]]] shape : [%[[arg2]], %[[arg1]]] strides : [%[[arg1]], %[[C]]] : ui64 -> !xegpu.tensor_desc<8x16xf32>
+ %1 = xegpu.create_nd_tdesc %src[%x, %y] shape:[%h, %w] strides: [%w, %c1] : ui64 -> !xegpu.tensor_desc<8x16xf32>
gpu.return
}
@@ -62,6 +62,47 @@ gpu.func @create_nd_tdesc_7(%src: memref<8x24x32x48x64xf32>) {
}
+// CHECK: gpu.func @test_create_nd_tdesc_7(%[[arg0:.*]]: ui64, %[[arg1:.*]]: index, %[[arg2:.*]]: index, %[[arg3:.*]]: index, %[[arg4:.*]]: index, %[[arg5:.*]]: memref<24x32xf32>)
+gpu.func @test_create_nd_tdesc_7(%src: ui64, %w : index, %h : index, %x : index, %y : index, %src2: memref<24x32xf32>) {
+ //CHECK: %[[C:.*]] = arith.constant 1 : index
+ %c1 = arith.constant 1 : index
+
+ // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %[[arg5]] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
+ %3 = xegpu.create_nd_tdesc %src2 : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
+
+ gpu.return
+}
+
+// CHECK: gpu.func @test_create_nd_tdesc_8(%[[arg0:.*]]: ui64, %[[arg1:.*]]: index, %[[arg2:.*]]: index, %[[arg3:.*]]: index, %[[arg4:.*]]: index)
+gpu.func @test_create_nd_tdesc_8(%src: ui64, %w : index, %h : index, %x : index, %y : index) {
+
+ %c1 = arith.constant 1 : index
+ // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %arg0 shape : [%arg2, %arg1] strides : [%arg1, %c1] : ui64 -> !xegpu.tensor_desc<8x16xf32>
+ %2 = xegpu.create_nd_tdesc %src shape : [%h, %w] strides : [%w, %c1] : ui64 -> !xegpu.tensor_desc<8x16xf32>
+
+ gpu.return
+}
+
+// CHECK-LABEL: func @test_create_nd_tdesc_9({{.*}})
+
+gpu.func @test_create_nd_tdesc_9(%src: memref<?x?xf16>, %w : index, %h : index, %x : index, %y : index) {
+
+ %c1 = arith.constant 1 : index
+ // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %arg0[%arg3, %arg4] shape : [%arg2, %arg1] strides : [%arg1, %c1] : memref<?x?xf16> -> !xegpu.tensor_desc<8x16xf16>
+ %1 = xegpu.create_nd_tdesc %src[%x, %y] shape:[%h, %w] strides:[%w, %c1] : memref<?x?xf16> -> !xegpu.tensor_desc<8x16xf16>
+
+ gpu.return
+}
+
+// CHECK-LABEL: func @test_create_nd_tdesc_10({{.*}})
+gpu.func @test_create_nd_tdesc_10(%src: memref<?x?xf16>, %w : index, %h : index, %x : index, %y : index) {
+ %c1 = arith.constant 1 : index
+ // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %arg0 shape : [%arg2, %arg1] strides : [%arg1, %c1] : memref<?x?xf16> -> !xegpu.tensor_desc<8x16xf16>
+ %2 = xegpu.create_nd_tdesc %src shape:[%h, %w] strides:[%w, %c1] : memref<?x?xf16> -> !xegpu.tensor_desc<8x16xf16>
+
+ gpu.return
+}
+
// CHECK: gpu.func @prefetch_nd(%[[arg0:.*]]: memref<24x32xf16>) {
gpu.func @prefetch_nd(%src: memref<24x32xf16>) {
// CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<8x16xf16>
diff --git a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
index 3d91b2269bc4b..ba29d1ab13cae 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
@@ -150,16 +150,16 @@ gpu.module @test {
// CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: ui64, %[[ARG1:[0-9a-zA-Z]+]]: ui64, %[[ARG2:[0-9a-zA-Z]+]]: index,
// CHECK-SAME: %[[ARG3:[0-9a-zA-Z]+]]: index, %[[ARG4:[0-9a-zA-Z]+]]: index,
// CHECK-SAME: %[[ARG5:[0-9a-zA-Z]+]]: index, %[[ARG6:[0-9a-zA-Z]+]]: index, %[[ARG7:[0-9a-zA-Z]+]]: index) {
-// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG0]][{{.*}}], [%[[ARG2]], %[[ARG3]]], [%[[ARG4]], %[[ARG5]]] : ui64 -> !xegpu.tensor_desc<16x16xf16>
+// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG0]][{{.*}}] shape : [%[[ARG2]], %[[ARG3]]] strides : [%[[ARG4]], %[[ARG5]]] : ui64 -> !xegpu.tensor_desc<16x16xf16>
// CHECK: %[[T1:.*]] = xegpu.load_nd %[[T0]] : !xegpu.tensor_desc<16x16xf16> -> vector<16xf16>
-// CHECK: %[[T2:.*]] = xegpu.create_nd_tdesc %[[ARG1]][{{.*}}], [%[[ARG2]], %[[ARG3]]], [%[[ARG4]], %[[ARG5]]] : ui64 -> !xegpu.tensor_desc<16x16xf16>
+// CHECK: %[[T2:.*]] = xegpu.create_nd_tdesc %[[ARG1]][{{.*}}] shape : [%[[ARG2]], %[[ARG3]]] strides : [%[[ARG4]], %[[ARG5]]] : ui64 -> !xegpu.tensor_desc<16x16xf16>
// CHECK: xegpu.store_nd %[[T1]], %[[T2]] : vector<16xf16>, !xegpu.tensor_desc<16x16xf16>
gpu.module @test {
gpu.func @create_nd_tdesc_non_memref(%arg0: ui64, %arg1: ui64, %arg2: index, %arg3: index, %arg4: index, %arg5: index, %arg6: index, %arg7: index) {
%c0 = arith.constant 0 : index
- %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0], [%arg2, %arg3], [%arg4, %arg5] : ui64 -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] shape:[%arg2, %arg3] strides:[%arg4, %arg5] : ui64 -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
%1 = xegpu.load_nd %0 {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<16x16xf16>
- %2 = xegpu.create_nd_tdesc %arg1[%c0, %c0], [%arg2, %arg3], [%arg4, %arg5] : ui64 -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ %2 = xegpu.create_nd_tdesc %arg1[%c0, %c0] shape:[%arg2, %arg3] strides:[%arg4, %arg5] : ui64 -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
xegpu.store_nd %1, %2 : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
gpu.return
}
``````````
</details>
https://github.com/llvm/llvm-project/pull/148335
More information about the Mlir-commits
mailing list