[Mlir-commits] [mlir] [mlir][xegpu] Remove OffsetSizeAndStrideOpInterface from CreateNdDescOp (PR #152773)
Chao Chen
llvmlistbot at llvm.org
Fri Aug 8 12:18:54 PDT 2025
https://github.com/chencha3 updated https://github.com/llvm/llvm-project/pull/152773
>From 2bc35d8b56877479d4b7ae202a800cedb2da33f9 Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Fri, 8 Aug 2025 18:12:30 +0000
Subject: [PATCH 1/3] remove OffsetSizeAndStrideOpInterface from CreateNdDescOp
---
.../include/mlir/Dialect/XeGPU/IR/XeGPUOps.td | 173 +++++++-----------
mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 7 +-
mlir/test/Dialect/XeGPU/invalid.mlir | 10 +-
mlir/test/Dialect/XeGPU/ops.mlir | 36 ++--
4 files changed, 93 insertions(+), 133 deletions(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index 75b16a87e03c6..4c18b3f47ba18 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -29,7 +29,7 @@ class XeGPU_Op<string mnemonic, list<Trait> traits = []>:
void printProperties(::mlir::MLIRContext *ctx,
::mlir::OpAsmPrinter &p, const Properties &prop,
::mlir::ArrayRef<::llvm::StringRef> elidedProps) {
-
+
DictionaryAttr propAttr = dyn_cast_if_present<mlir::DictionaryAttr>(getPropertiesAsAttr(ctx, prop));
// filter out the elidedProps from propAttr, and get the resultAttr
@@ -43,7 +43,7 @@ class XeGPU_Op<string mnemonic, list<Trait> traits = []>:
}
if (!filteredAttrs.empty()) {
- p << "<" << DictionaryAttr::get(ctx, filteredAttrs) << ">";
+ p << "<" << DictionaryAttr::get(ctx, filteredAttrs) << ">";
}
}
@@ -60,8 +60,7 @@ class XeGPU_Op<string mnemonic, list<Trait> traits = []>:
}
-def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface,
- AttrSizedOperandSegments, OffsetSizeAndStrideOpInterface]> {
+def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface, AttrSizedOperandSegments]> {
let summary = "Create nd-tensor descriptor operation";
let description = [{
@@ -181,82 +180,40 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface
return getType().getShape();
}
- /// wrapper for matching with OffsetSizeAndStrideOpInterface
- OperandRange getSizes() {
- return getShape();
+ SmallVector<OpFoldResult> getMixedOffsets() {
+ auto statics = getConstOffsets().value_or(SmallVector<int64_t>());
+ auto dynamics = getOffsets();
+ if (statics.size() == 0 && dynamics.size() == 0)
+ return {};
+ return getMixedValues(statics, dynamics, getContext());
}
- ArrayRef<int64_t> getStaticOffsets(){
- auto attr = getConstOffsetsAttr();
-
- if (attr)
- return attr;
+ SmallVector<OpFoldResult> getMixedSizes() {
+ Builder b(getContext());
+ SmallVector<int64_t> statics;
- int64_t rank = getMixedSizes().size();
-
- setConstOffsets(llvm::SmallVector<int64_t, 4>(rank, 0));
+ /// Get the static sizes/shape, the value passed to const_shape
+ /// will overide the value in memref shape.
+ if (auto memrefTy = llvm::dyn_cast<MemRefType>(getSourceType()))
+ statics = llvm::to_vector(memrefTy.getShape());
+ if (auto attr = getConstShapeAttr())
+ statics = llvm::to_vector(attr.asArrayRef());
- attr = getConstOffsetsAttr();
- return attr;
+ return getMixedValues(statics, getShape(), b);
}
- /// wrapper for matching with OffsetSizeAndStrideOpInterface
- /// If source is IntegerType or `const_shape` is filled,
- /// it will return `const_shape`, such that mixes of `shape`
- /// and `const_shape` will be used to represent the shape of
- /// source operand. They overide static shape from source memref type.
- ArrayRef<int64_t> getStaticSizes() {
- /// To be compatible with OffsetSizeAndStrideOpInterface, which expects valid return value and perform checks
- static llvm::SmallVector<int64_t, 4> emptyShape;
-
- auto attr = getConstShapeAttr();
- if (attr)
- return attr;
-
- if (llvm::isa<IntegerType>(getSourceType()))
- return emptyShape;
-
- auto memrefType = llvm::dyn_cast<MemRefType>(getSourceType());
- assert(memrefType && "Incorrect use of getStaticSizes");
- return memrefType.getShape();
- }
+ SmallVector<OpFoldResult> getMixedStrides() {
+ Builder b(getContext());
+ SmallVector<int64_t> statics;
- /// wrapper for matching with OffsetSizeAndStrideOpInterface
- /// If source is IntegerType or `const_strides` is filled, it
- /// will return `const_strides`, such that mixes of `strides`
- /// and `const_strides` will be used to represent the strides of
- /// source operand. They overide static strides from source memref type.
- ArrayRef<int64_t> getStaticStrides() {
- /// To be compatible with OffsetSizeAndStrideOpInterface, which expects valid return value and perform checks
- static llvm::SmallVector<int64_t, 4> emptyStrides;
-
- auto attr = getConstStridesAttr();
- if (attr)
- return attr;
-
- if (llvm::isa<IntegerType>(getSourceType()))
- return emptyStrides;
-
- auto memrefType = llvm::dyn_cast<MemRefType>(getSourceType());
- assert(memrefType && "Incorrect use of getStaticStrides");
- auto [strides, _] = memrefType.getStridesAndOffset();
- // reuse the storage of ConstStridesAttr since strides from
- // memref is not persistant
- setConstStrides(strides);
- attr = getConstStridesAttr();
- return attr;
- }
+ /// Get the static strides, the value passed to const_strides
+ /// will overide the value in memref.
+ if (auto memrefTy = llvm::dyn_cast<MemRefType>(getSourceType()))
+ statics = memrefTy.getStridesAndOffset().first;
+ if (auto attr = getConstStridesAttr())
+ statics = llvm::to_vector(attr.asArrayRef());
- /// Return the expected rank of each of the`static_offsets`,
- /// `static_shape` and `static_strides` attributes.
- std::array<unsigned, 3> getArrayAttrMaxRanks() {
- unsigned rank;
- if (auto ty = llvm::dyn_cast<MemRefType>(getSourceType())) {
- rank = ty.getRank();
- } else {
- rank = (unsigned)getMixedOffsets().size();
- }
- return {rank, rank, rank};
+ return getMixedValues(statics, getStrides(), b);
}
/// Return the number of leading operands before the `offsets`,
@@ -314,15 +271,15 @@ def XeGPU_PrefetchNdOp : XeGPU_Op<"prefetch_nd", []> {
}];
let assemblyFormat = [{
- $TensorDesc ``
- custom<OptionalDynamicIndexList>($offsets, $const_offsets)
+ $TensorDesc ``
+ custom<OptionalDynamicIndexList>($offsets, $const_offsets)
prop-dict attr-dict `:` qualified(type($TensorDesc))
}];
let builders = [
- OpBuilder<(ins "Value": $TensorDesc,
- "xegpu::CachePolicyAttr": $l1_hint,
- "xegpu::CachePolicyAttr": $l2_hint,
+ OpBuilder<(ins "Value": $TensorDesc,
+ "xegpu::CachePolicyAttr": $l1_hint,
+ "xegpu::CachePolicyAttr": $l2_hint,
"xegpu::CachePolicyAttr": $l3_hint)>
];
@@ -370,7 +327,7 @@ def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [
let arguments = (ins XeGPU_TensorDesc: $TensorDesc,
Variadic<Index>: $offsets,
- OptionalAttr<DenseI64ArrayAttr>: $const_offsets,
+ OptionalAttr<DenseI64ArrayAttr>: $const_offsets,
OptionalAttr<UnitAttr>: $packed,
OptionalAttr<DenseI64ArrayAttr>: $transpose,
OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
@@ -390,16 +347,16 @@ def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [
}];
let assemblyFormat = [{
- $TensorDesc ``
- custom<OptionalDynamicIndexList>($offsets, $const_offsets)
+ $TensorDesc ``
+ custom<OptionalDynamicIndexList>($offsets, $const_offsets)
prop-dict attr-dict `:` qualified(type($TensorDesc)) `->` type($value)
}];
let builders = [
- OpBuilder<(ins "Type": $value, "Value": $TensorDesc,
+ OpBuilder<(ins "Type": $value, "Value": $TensorDesc,
"UnitAttr": $packed, "DenseI64ArrayAttr": $transpose,
- "xegpu::CachePolicyAttr": $l1_hint,
- "xegpu::CachePolicyAttr": $l2_hint,
+ "xegpu::CachePolicyAttr": $l1_hint,
+ "xegpu::CachePolicyAttr": $l2_hint,
"xegpu::CachePolicyAttr": $l3_hint)>
];
@@ -442,7 +399,7 @@ def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", [
let arguments = (ins XeGPU_ValueType: $value,
XeGPU_TensorDesc: $TensorDesc,
Variadic<Index>: $offsets,
- OptionalAttr<DenseI64ArrayAttr>: $const_offsets,
+ OptionalAttr<DenseI64ArrayAttr>: $const_offsets,
OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint,
OptionalAttr<XeGPU_CacheHintAttr>: $l3_hint);
@@ -458,16 +415,16 @@ def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", [
}];
let assemblyFormat = [{
- $value `,`
- $TensorDesc ``
- custom<OptionalDynamicIndexList>($offsets, $const_offsets)
+ $value `,`
+ $TensorDesc ``
+ custom<OptionalDynamicIndexList>($offsets, $const_offsets)
prop-dict attr-dict `:` type($value) `,` qualified(type($TensorDesc))
}];
let builders = [
- OpBuilder<(ins "Value": $value, "Value": $TensorDesc,
- "xegpu::CachePolicyAttr": $l1_hint,
- "xegpu::CachePolicyAttr": $l2_hint,
+ OpBuilder<(ins "Value": $value, "Value": $TensorDesc,
+ "xegpu::CachePolicyAttr": $l1_hint,
+ "xegpu::CachePolicyAttr": $l2_hint,
"xegpu::CachePolicyAttr": $l3_hint)>
];
@@ -635,12 +592,12 @@ def XeGPU_PrefetchOp : XeGPU_Op<"prefetch", []> {
l3_hint = #xegpu.cache_hint<cached>}
: !xegpu.tensor_desc<16xf16>
```
-
+
Example 2:
A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc.
It combines "create scattered TensorTdesc" and "prefetch with scattered TensorTdesc".
The source operand could be a raw pointer (uint64_t).
- Please refer to create_tdesc for the restriction of memref.
+ Please refer to create_tdesc for the restriction of memref.
```mlir
%a = memref.alloc() : memref<1024xf32>
%0 = arith.constant dense<[0, 16, 32, 64]> : vector<4xindex>
@@ -676,16 +633,16 @@ def XeGPU_PrefetchOp : XeGPU_Op<"prefetch", []> {
}];
let assemblyFormat = [{
- $source
+ $source
(`[` $offsets^ `]`)?
prop-dict
- attr-dict `:` type(operands)
+ attr-dict `:` type(operands)
}];
-
+
let builders = [
OpBuilder<(ins "Value": $source,
- "xegpu::CachePolicyAttr": $l1_hint,
- "xegpu::CachePolicyAttr": $l2_hint,
+ "xegpu::CachePolicyAttr": $l1_hint,
+ "xegpu::CachePolicyAttr": $l2_hint,
"xegpu::CachePolicyAttr": $l3_hint)>
];
@@ -723,7 +680,7 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> {
: !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<memory_space=global, chunk_size=8>>,
vector<16xi1> -> vector<16x8xf32>
```
-
+
Example 3 (SIMT mode):
```mlir
%2 = xegpu.load %1, %0 <{l1_hint = #xegpu.cache_hint<cached>,
@@ -732,12 +689,12 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> {
: !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<memory_space=global, chunk_size=8>>
vector<16xi1> -> vector<8xf32>
```
-
+
Example 4:
A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc.
It combines "create scattered TensorTdesc" and "load with scattered TensorTdesc".
The source operand could be a raw pointer (uint64_t). Please refer to create_tdesc
- for the restriction of memref.
+ for the restriction of memref.
```mlir
%a = memref.alloc() : memref<1024xf32>
%offsets = vector.step : vector<16xindex>
@@ -794,14 +751,14 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> {
let assemblyFormat = [{
$source
(`[` $offsets^ `]`)? `,`
- $mask prop-dict
+ $mask prop-dict
attr-dict `:` type(operands) `->` type($value)
}];
let builders = [
OpBuilder<(ins "Type": $value, "Value": $source, "Value": $mask,
- "xegpu::CachePolicyAttr": $l1_hint,
- "xegpu::CachePolicyAttr": $l2_hint,
+ "xegpu::CachePolicyAttr": $l1_hint,
+ "xegpu::CachePolicyAttr": $l2_hint,
"xegpu::CachePolicyAttr": $l3_hint)>
];
@@ -848,7 +805,7 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [MemoryEffects<[MemWrite]>]> {
A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc.
It combines "create scattered TensorTdesc" and "store with scattered TensorTdesc".
The dest operand could be a raw pointer (uint64_t).
- Please refer to create_tdesc for the restriction of memref.
+ Please refer to create_tdesc for the restriction of memref.
```mlir
%a = memref.alloc() : memref<1024xf32>
%val = arith.constant dense<0.0> : vector<16xf32>
@@ -901,15 +858,15 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [MemoryEffects<[MemWrite]>]> {
$value `,`
$dest
(`[` $offsets^ `]`)? `,`
- $mask
- prop-dict
+ $mask
+ prop-dict
attr-dict `:` type(operands)
}];
let builders = [
OpBuilder<(ins "Value": $value, "Value": $dest, "Value": $mask,
- "xegpu::CachePolicyAttr": $l1_hint,
- "xegpu::CachePolicyAttr": $l2_hint,
+ "xegpu::CachePolicyAttr": $l1_hint,
+ "xegpu::CachePolicyAttr": $l2_hint,
"xegpu::CachePolicyAttr": $l3_hint)>
];
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 33450f3fa229e..7ac885d2ed40f 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -265,7 +265,7 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
}
LogicalResult CreateNdDescOp::verify() {
- auto rank = (int64_t)getMixedOffsets().size();
+ int64_t rank = getMixedSizes().size();
bool invalidRank = false;
bool invalidElemTy = false;
@@ -280,6 +280,9 @@ LogicalResult CreateNdDescOp::verify() {
<< " Source: " << srcMemorySpace
<< ", TensorDesc: " << tdescMemorySpace;
+ if (int64_t offsetRank = getMixedOffsets().size())
+ invalidRank |= (offsetRank != rank);
+
// check source type matches the rank if it is a memref.
// It also should have the same ElementType as TensorDesc.
auto memrefTy = dyn_cast<MemRefType>(getSourceType());
@@ -291,7 +294,7 @@ LogicalResult CreateNdDescOp::verify() {
if (llvm::isa<IntegerType>(getSourceType())) {
// strides and shape must present for integer source.
if (getMixedStrides().empty() || getMixedSizes().empty())
- return emitOpError("Expecting strides and shape to be present for "
+ return emitOpError("expecting strides and shape to be present for "
"integer source.");
}
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index dff3ffab39ecf..cdf147a9fdd0e 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -52,14 +52,14 @@ func.func @create_nd_tdesc_7(%src: memref<128x128xf32>) {
// -----
func.func @create_nd_tdesc_8(%src: ui64) {
- // expected-error at +1 {{'xegpu.create_nd_tdesc' op Expecting strides and shape to be present for integer source}}
+ // expected-error at +1 {{'xegpu.create_nd_tdesc' op expecting strides and shape to be present for integer source}}
%1 = xegpu.create_nd_tdesc %src : ui64-> !xegpu.tensor_desc<128x128xf32>
return
}
// -----
func.func @create_nd_tdesc_9(%src: ui64) {
- // expected-error at +1 {{expected mixed offsets rank to match mixed sizes rank}}
+ // expected-error at +1 {{expecting strides and shape to be present for integer source}}
%1 = xegpu.create_nd_tdesc %src[0, 0] : ui64-> !xegpu.tensor_desc<128x128xf32>
return
}
@@ -149,7 +149,7 @@ func.func @subgroup_load_nd_offset_2(%src: memref<4x8x16xf16>, %x : index) {
}
// -----
-func.func @subgroup_load_nd_offset_3(%src: memref<4x8x16xf16>, %x : index) {
+func.func @subgroup_load_nd_offset_3(%src: memref<4x8x16xf16>, %x : index) {
%3 = xegpu.create_nd_tdesc %src: memref<4x8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
%5 = xegpu.load_nd %3[0, 0] : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
// expected-error at +1 {{Mismatched ranks between offsets and tensor descriptor}}
@@ -418,7 +418,7 @@ func.func @store_scatter_offset_wi_1(%src: memref<?xf16>) {
%offsets = arith.constant dense<[0]> : vector<1xindex>
%mask = arith.constant dense<1>: vector<1xi1>
// expected-error at +1 {{value elements must match chunk size}}
- xegpu.store %val, %src[%offsets], %mask
+ xegpu.store %val, %src[%offsets], %mask
: vector<4xf16>, memref<?xf16>, vector<1xindex>, vector<1xi1>
return
}
@@ -429,7 +429,7 @@ func.func @store_scatter_offset_wi_2(%src: memref<4x4xf16>) {
%offsets = arith.constant dense<[0]> : vector<1xindex>
%mask = arith.constant dense<1>: vector<1xi1>
// expected-error at +1 {{Expecting the dest is a 1D memref or pointer}}
- xegpu.store %val, %src[%offsets], %mask
+ xegpu.store %val, %src[%offsets], %mask
: vector<4xf16>, memref<4x4xf16>, vector<1xindex>, vector<1xi1>
return
}
diff --git a/mlir/test/Dialect/XeGPU/ops.mlir b/mlir/test/Dialect/XeGPU/ops.mlir
index 6be2371d4d7b2..67c00f5a9cc2f 100644
--- a/mlir/test/Dialect/XeGPU/ops.mlir
+++ b/mlir/test/Dialect/XeGPU/ops.mlir
@@ -62,28 +62,28 @@ gpu.func @create_nd_tdesc_7(%src: memref<8x24x32x48x64xf32>) {
}
-// CHECK: gpu.func @test_create_nd_tdesc_7(%[[arg0:.*]]: ui64, %[[arg1:.*]]: index, %[[arg2:.*]]: index, %[[arg3:.*]]: index, %[[arg4:.*]]: index, %[[arg5:.*]]: memref<24x32xf32>)
+// CHECK: gpu.func @test_create_nd_tdesc_7(%[[arg0:.*]]: ui64, %[[arg1:.*]]: index, %[[arg2:.*]]: index, %[[arg3:.*]]: index, %[[arg4:.*]]: index, %[[arg5:.*]]: memref<24x32xf32>)
gpu.func @test_create_nd_tdesc_7(%src: ui64, %w : index, %h : index, %x : index, %y : index, %src2: memref<24x32xf32>) {
//CHECK: %[[C:.*]] = arith.constant 1 : index
%c1 = arith.constant 1 : index
-
- // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %[[arg5]][0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
+
+ // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %[[arg5]] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
%3 = xegpu.create_nd_tdesc %src2 : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
-
+
gpu.return
}
-// CHECK: gpu.func @test_create_nd_tdesc_8(%[[arg0:.*]]: ui64, %[[arg1:.*]]: index, %[[arg2:.*]]: index, %[[arg3:.*]]: index, %[[arg4:.*]]: index)
+// CHECK: gpu.func @test_create_nd_tdesc_8(%[[arg0:.*]]: ui64, %[[arg1:.*]]: index, %[[arg2:.*]]: index, %[[arg3:.*]]: index, %[[arg4:.*]]: index)
gpu.func @test_create_nd_tdesc_8(%src: ui64, %w : index, %h : index, %x : index, %y : index) {
-
- %c1 = arith.constant 1 : index
- // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %arg0[0, 0], shape : [%arg2, %arg1], strides : [%arg1, %c1] : ui64 -> !xegpu.tensor_desc<8x16xf32>
+
+ %c1 = arith.constant 1 : index
+ // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %arg0, shape : [%arg2, %arg1], strides : [%arg1, %c1] : ui64 -> !xegpu.tensor_desc<8x16xf32>
%2 = xegpu.create_nd_tdesc %src, shape : [%h, %w], strides : [%w, %c1] : ui64 -> !xegpu.tensor_desc<8x16xf32>
-
+
gpu.return
}
-// CHECK-LABEL: func @test_create_nd_tdesc_9({{.*}})
+// CHECK-LABEL: func @test_create_nd_tdesc_9({{.*}})
gpu.func @test_create_nd_tdesc_9(%src: memref<?x?xf16>, %w : index, %h : index, %x : index, %y : index) {
@@ -94,10 +94,10 @@ gpu.func @test_create_nd_tdesc_9(%src: memref<?x?xf16>, %w : index, %h : index,
gpu.return
}
-// CHECK-LABEL: func @test_create_nd_tdesc_10({{.*}})
-gpu.func @test_create_nd_tdesc_10(%src: memref<?x?xf16>, %w : index, %h : index, %x : index, %y : index) {
+// CHECK-LABEL: func @test_create_nd_tdesc_10({{.*}})
+gpu.func @test_create_nd_tdesc_10(%src: memref<?x?xf16>, %w : index, %h : index, %x : index, %y : index) {
%c1 = arith.constant 1 : index
- // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %arg0[0, 0], shape : [%arg2, %arg1], strides : [%arg1, %c1] : memref<?x?xf16> -> !xegpu.tensor_desc<8x16xf16>
+ // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %arg0, shape : [%arg2, %arg1], strides : [%arg1, %c1] : memref<?x?xf16> -> !xegpu.tensor_desc<8x16xf16>
%2 = xegpu.create_nd_tdesc %src, shape:[%h, %w], strides:[%w, %c1] : memref<?x?xf16> -> !xegpu.tensor_desc<8x16xf16>
gpu.return
@@ -123,7 +123,7 @@ gpu.func @prefetch_nd_2(%src: memref<48x64xf16>) {
// CHECK: gpu.func @prefetch_nd_offset_1(%[[arg0:.*]]: memref<48x64xf16>, %arg1: index, %arg2: index) {
gpu.func @prefetch_nd_offset_1(%src: memref<48x64xf16>, %x : index, %y : index) {
- // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0] : memref<48x64xf16> -> !xegpu.tensor_desc<8x16xf16>
+ // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %[[arg0]] : memref<48x64xf16> -> !xegpu.tensor_desc<8x16xf16>
%1 = xegpu.create_nd_tdesc %src : memref<48x64xf16> -> !xegpu.tensor_desc<8x16xf16>
// CHECK: xegpu.prefetch_nd %[[R0]][%arg1, %arg2] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<8x16xf16>
xegpu.prefetch_nd %1[%x, %y] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>: !xegpu.tensor_desc<8x16xf16>
@@ -271,7 +271,7 @@ gpu.func @subgroup_load_nd_8(%src: memref<24x32xf32>) {
// CHECK: func @subgroup_load_nd_offset_1(%[[arg0:.*]]: memref<24x32xf32>, %arg1: index, %arg2: index) {
gpu.func @subgroup_load_nd_offset_1(%src: memref<24x32xf32>, %x : index, %y : index) {
- // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<16x8xf32>
+ // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0 : memref<24x32xf32> -> !xegpu.tensor_desc<16x8xf32>
%1 = xegpu.create_nd_tdesc %src : memref<24x32xf32> -> !xegpu.tensor_desc<16x8xf32>
// CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]][%arg1, %arg2] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose = array<i64: 1, 0>}> : !xegpu.tensor_desc<16x8xf32> -> vector<8x16xf32>
%2 = xegpu.load_nd %1[%x, %y] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose = array<i64: 1, 0>}> : !xegpu.tensor_desc<16x8xf32> -> vector<8x16xf32>
@@ -290,7 +290,7 @@ gpu.func @simt_load_nd_8(%src: memref<24x32xf32>) {
// CHECK: func @simt_load_nd_offset_1(%[[arg0:.*]]: memref<24x32xf32>) {
gpu.func @simt_load_nd_offset_1(%src: memref<24x32xf32>) {
- // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<16x8xf32>
+ // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0 : memref<24x32xf32> -> !xegpu.tensor_desc<16x8xf32>
%1 = xegpu.create_nd_tdesc %src : memref<24x32xf32> -> !xegpu.tensor_desc<16x8xf32>
// CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]][0, 0] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose = array<i64: 1, 0>}> : !xegpu.tensor_desc<16x8xf32> -> vector<8xf32>
%2 = xegpu.load_nd %1[0, 0] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose = array<i64: 1, 0>}> : !xegpu.tensor_desc<16x8xf32> -> vector<8xf32>
@@ -323,7 +323,7 @@ gpu.func @simt_store_nd(%src: memref<24x32xf16>) {
gpu.func @subgroup_store_nd_2(%dst: memref<24x32xf16>, %x : index) {
// CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<32xf16>
%1 = arith.constant dense<1.0>: vector<32xf16>
- // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<32xf16>
+ // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %[[arg0]] : memref<24x32xf16> -> !xegpu.tensor_desc<32xf16>
%2 = xegpu.create_nd_tdesc %dst : memref<24x32xf16> -> !xegpu.tensor_desc<32xf16>
// CHECK: xegpu.store_nd %[[C]], %[[R0]][%arg1] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}> : vector<32xf16>, !xegpu.tensor_desc<32xf16>
xegpu.store_nd %1, %2[%x] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>: vector<32xf16>, !xegpu.tensor_desc<32xf16>
@@ -356,7 +356,7 @@ gpu.func @simt_store_nd_2(%src: memref<24x32xf16>) {
gpu.func @simt_store_nd_offset_1(%src: memref<24x32xf16>) {
// CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<2xf16>
%1 = arith.constant dense<1.0>: vector<2xf16>
- // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<32xf16>
+ // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0 : memref<24x32xf16> -> !xegpu.tensor_desc<32xf16>
%2 = xegpu.create_nd_tdesc %src : memref<24x32xf16> -> !xegpu.tensor_desc<32xf16>
// CHECK: xegpu.store_nd %[[C]], %[[R0]][0] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}> : vector<2xf16>, !xegpu.tensor_desc<32xf16>
xegpu.store_nd %1, %2[0] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>: vector<2xf16>, !xegpu.tensor_desc<32xf16>
>From a2bc905459bb9b853b3fe0ce68673f562e6fb766 Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Fri, 8 Aug 2025 18:20:02 +0000
Subject: [PATCH 2/3] clean up
---
mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td | 6 ++----
1 file changed, 2 insertions(+), 4 deletions(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index 4c18b3f47ba18..1a6a34c8d775a 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -189,7 +189,6 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface
}
SmallVector<OpFoldResult> getMixedSizes() {
- Builder b(getContext());
SmallVector<int64_t> statics;
/// Get the static sizes/shape, the value passed to const_shape
@@ -199,11 +198,10 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface
if (auto attr = getConstShapeAttr())
statics = llvm::to_vector(attr.asArrayRef());
- return getMixedValues(statics, getShape(), b);
+ return getMixedValues(statics, getShape(), getContext());
}
SmallVector<OpFoldResult> getMixedStrides() {
- Builder b(getContext());
SmallVector<int64_t> statics;
/// Get the static strides, the value passed to const_strides
@@ -213,7 +211,7 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface
if (auto attr = getConstStridesAttr())
statics = llvm::to_vector(attr.asArrayRef());
- return getMixedValues(statics, getStrides(), b);
+ return getMixedValues(statics, getStrides(), getContext());
}
/// Return the number of leading operands before the `offsets`,
>From b0f626b994a05af574b72c6757720d919a15e3be Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Fri, 8 Aug 2025 19:18:38 +0000
Subject: [PATCH 3/3] cleanup
---
mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 16 +++++-----------
1 file changed, 5 insertions(+), 11 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 7ac885d2ed40f..b519d6ad72660 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -265,8 +265,8 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
}
LogicalResult CreateNdDescOp::verify() {
- int64_t rank = getMixedSizes().size();
- bool invalidRank = false;
+ size_t rank = getMixedSizes().size();
+ bool invalidRank = rank != getMixedStrides().size();
bool invalidElemTy = false;
// Memory space of created TensorDesc should match with the source.
@@ -280,16 +280,13 @@ LogicalResult CreateNdDescOp::verify() {
<< " Source: " << srcMemorySpace
<< ", TensorDesc: " << tdescMemorySpace;
- if (int64_t offsetRank = getMixedOffsets().size())
+ if (size_t offsetRank = getMixedOffsets().size())
invalidRank |= (offsetRank != rank);
// check source type matches the rank if it is a memref.
// It also should have the same ElementType as TensorDesc.
- auto memrefTy = dyn_cast<MemRefType>(getSourceType());
- if (memrefTy) {
- invalidRank |= (memrefTy.getRank() != rank);
+ if (auto memrefTy = dyn_cast<MemRefType>(getSourceType()))
invalidElemTy |= memrefTy.getElementType() != getElementType();
- }
if (llvm::isa<IntegerType>(getSourceType())) {
// strides and shape must present for integer source.
@@ -298,16 +295,13 @@ LogicalResult CreateNdDescOp::verify() {
"integer source.");
}
- // mismatches among shape, strides, and offsets are
- // already handeled by OffsetSizeAndStrideOpInterface.
- // So they are not check here.
if (invalidRank)
return emitOpError(
"Expecting the rank of shape, strides, offsets, and source (if source "
"is a memref) should match with each other.");
// check result TensorDesc rank
- if (getType().getRank() > rank)
+ if (getType().getRank() > (int64_t)rank)
return emitOpError(
"Expecting the TensorDesc rank is not greater than the "
"ranks of shape, strides, offsets or the memref source.");
More information about the Mlir-commits
mailing list