[Mlir-commits] [mlir] [mlir][xegpu] Tensor descriptor type verifier (PR #124548)
Adam Siemieniuk
llvmlistbot at llvm.org
Fri Feb 7 06:36:51 PST 2025
https://github.com/adam-smnk updated https://github.com/llvm/llvm-project/pull/124548
>From d23c97d170efbc1cf9ec5a2e4b10f057d88edb23 Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Thu, 23 Jan 2025 10:24:11 +0100
Subject: [PATCH 1/8] [mlir][xegpu] TensorDesc verifier
Adds XeGPU tensor descriptor type verifier.
The checks focus on ensuring that provided subgroup map is valid
with respect to the underlying data.
---
.../mlir/Dialect/XeGPU/IR/XeGPUTypes.td | 2 +-
mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp | 56 ++++++++++++++++++-
2 files changed, 54 insertions(+), 4 deletions(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
index d09c5c1870d506f..494f11f041b71ff 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
@@ -179,7 +179,7 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
}];
let hasCustomAssemblyFormat = true;
-
+ let genVerifyDecl = 1;
}
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index eb01b15de75c606..42d59da2f7a922a 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -175,9 +175,10 @@ mlir::Type TensorDescType::parse(::mlir::AsmParser &parser) {
if (parser.parseGreater())
return {};
- return TensorDescType::get(parser.getContext(), shape, elementType,
- encoding.value_or(mlir::Attribute()),
- sg_map.value_or(mlir::Attribute()));
+ return TensorDescType::getChecked(
+ [&]() { return parser.emitError(parser.getNameLoc()); },
+ parser.getContext(), shape, elementType,
+ encoding.value_or(mlir::Attribute()), sg_map.value_or(mlir::Attribute()));
}
void TensorDescType::print(::mlir::AsmPrinter &printer) const {
@@ -223,6 +224,55 @@ TensorDescType TensorDescType::get(llvm::ArrayRef<int64_t> shape,
return Base::get(context, shape, elementType, attr, sg_map);
}
+LogicalResult TensorDescType::verify(
+ llvm::function_ref<::mlir::InFlightDiagnostic()> emitError,
+ llvm::ArrayRef<int64_t> shape, mlir::Type elementType,
+ mlir::Attribute encoding, mlir::Attribute sg_map) {
+ size_t rank = shape.size();
+ if (rank > 2)
+ return emitError() << "desc shape rank exceeds 2";
+
+ if (auto sgMapAttr = llvm::dyn_cast_if_present<SGMapAttr>(sg_map)) {
+ ArrayRef<uint32_t> wiLayout = sgMapAttr.getWiLayout();
+ ArrayRef<uint32_t> wiData = sgMapAttr.getWiData();
+
+ if (rank == 1) {
+ if (wiLayout[0] != 1 || wiData[0] != 1)
+ return emitError() << "outer layout and data mapping must be 1 "
+ "for 1D tensor";
+ }
+
+ // For 1D tensor, pad the shape with an outer unit dimension to allow common
+ // validation logic.
+ SmallVector<int64_t> tensorShape(shape.begin(), shape.end());
+ if (rank == 1)
+ tensorShape = {1, tensorShape.back()};
+
+ size_t dims = tensorShape.size();
+ for (size_t i = 0; i < dims; ++i) {
+ uint32_t numElemPerWi = wiLayout[i] * wiData[i];
+ if (tensorShape[i] < numElemPerWi || tensorShape[i] % numElemPerWi != 0)
+ return emitError() << "cannot map " << tensorShape[i]
+ << " elements into " << wiLayout[i] << " by "
+ << wiData[i] << " tiles";
+ }
+
+ if (llvm::isa_and_nonnull<ScatterTensorDescAttr>(encoding)) {
+ auto scatterAttr = llvm::dyn_cast<ScatterTensorDescAttr>(encoding);
+ if (wiData[0] != 1)
+ return emitError()
+ << "cannot map over non-contiguous scattered elements";
+
+ unsigned chunkSize = scatterAttr.getChunkSize().getInt();
+ if (wiData[1] > chunkSize)
+ return emitError()
+ << "too few contiguous elements for work item mapping";
+ }
+ }
+
+ return success();
+}
+
} // namespace xegpu
} // namespace mlir
>From 5e3e7a80e907c0e4160cc00b76cd3f3f6d6c5e56 Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Fri, 24 Jan 2025 17:25:29 +0100
Subject: [PATCH 2/8] Test cases
---
mlir/test/Dialect/XeGPU/invalid.mlir | 79 +++++++++++++++++++++++++++-
1 file changed, 78 insertions(+), 1 deletion(-)
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index 201f72120cf2c5d..975b4aea84fe250 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -315,4 +315,81 @@ func.func @test_atomic_rmw(%src: ui64, %value : vector<16x4xf32>, %mask : vector
// expected-error at +1 {{failed to verify that all of {tensorDesc, value, result} have same shape}}
xegpu.atomic_rmw addf %1, %mask, %value: !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<chunk_size = 8>>, vector<16xi1>, vector<16x4xf32> -> vector<16x8xf32>
return
-}
\ No newline at end of file
+}
+
+// -----
+func.func @tensor_desc_invalid_rank(%src: memref<24x32xf32>) {
+ %0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
+ // expected-error at +1 {{desc shape rank exceeds 2}}
+ !xegpu.tensor_desc<16x2x2xf32>
+ return
+}
+
+// -----
+func.func @tensor_desc_1D_invalid_map_layout(%src: memref<24x32xf32>) {
+ %0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
+ // expected-error at +1 {{outer layout and data mapping must be 1 for 1D tensor}}
+ !xegpu.tensor_desc<16xf32, #xegpu.sg_map<wi_layout = [2, 16], wi_data = [1, 1]>>
+ return
+}
+
+// -----
+func.func @tensor_desc_1D_invalid_map_data(%src: memref<24x32xf32>) {
+ %0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
+ // expected-error at +1 {{outer layout and data mapping must be 1 for 1D tensor}}
+ !xegpu.tensor_desc<16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [2, 1]>>
+ return
+}
+
+// -----
+func.func @tensor_desc_invalid_map_layout(%src: memref<24x32xf32>) {
+ %0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
+ // expected-error at +1 {{cannot map 8 elements into 16 by 1 tiles}}
+ !xegpu.tensor_desc<4x8xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ return
+}
+
+// -----
+func.func @tensor_desc_invalid_map_layout_1(%src: memref<24x32xf32>) {
+ %0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
+ // expected-error at +1 {{cannot map 4 elements into 8 by 1 tiles}}
+ !xegpu.tensor_desc<4x8xf32, #xegpu.sg_map<wi_layout = [8, 2], wi_data = [1, 1]>>
+ return
+}
+
+// -----
+func.func @tensor_desc_invalid_map_data(%src: memref<24x32xf32>) {
+ %0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
+ // expected-error at +1 {{cannot map 4 elements into 2 by 4 tiles}}
+ !xegpu.tensor_desc<4x8xf32, #xegpu.sg_map<wi_layout = [2, 8], wi_data = [4, 1]>>
+ return
+}
+
+// -----
+func.func @tensor_desc_invalid_map_data_1(%src: memref<24x32xf32>) {
+ %0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
+ // expected-error at +1 {{cannot map 4 elements into 8 by 1 tiles}}
+ !xegpu.tensor_desc<4x8xf32, #xegpu.sg_map<wi_layout = [8, 2], wi_data = [1, 2]>>
+ return
+}
+
+// -----
+func.func @tensor_desc_scatter_invalid_map_data(%src: ui64) {
+ %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+ %1 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> ->
+ // expected-error at +1 {{cannot map over non-contiguous scattered elements}}
+ !xegpu.tensor_desc<4x2xf32,
+ #xegpu.scatter_tdesc_attr<chunk_size = 2>,
+ #xegpu.sg_map<wi_layout = [1, 1], wi_data = [2, 1]>>
+ return
+}
+
+// -----
+func.func @tensor_desc_scatter_invalid_map_data_1(%src: ui64, %offsets: vector<16xindex>) {
+ %1 = xegpu.create_tdesc %src, %offsets : ui64, vector<16xindex> ->
+ // expected-error at +1 {{too few contiguous elements for work item mapping}}
+ !xegpu.tensor_desc<16xf32,
+ #xegpu.scatter_tdesc_attr<chunk_size = 1>,
+ #xegpu.sg_map<wi_layout = [1, 8], wi_data = [1, 2]>>
+ return
+}
>From 7aba5a69eca63bd4f0924f48df4bbe0f3c5a8439 Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Mon, 27 Jan 2025 14:40:36 +0100
Subject: [PATCH 3/8] Update load/store verifier + tests
---
mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 26 +++++++------
mlir/test/Dialect/XeGPU/XeGPUOps.mlir | 22 +++++++++++
mlir/test/Dialect/XeGPU/invalid.mlir | 52 ++++++++++++++++++++++++--
3 files changed, 86 insertions(+), 14 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index cd883baa986b85e..996fb3638203311 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -81,24 +81,28 @@ static bool isWriteHintOrNone(const CachePolicyAttr &attr) {
// each dimension.
static bool isArgShapesValid(ArrayRef<int64_t> descShape,
ArrayRef<int64_t> valShape, SGMapAttr sgMap) {
- if (descShape == valShape) {
- if (!sgMap)
- return true;
-
- // this can be relaxed if necessary by supporting non-2d shapes distribution
- // until the constraints are defined this lives here instead of the tensor
- // descriptor type.
- return valShape.size() == sgMap.getWiLayout().size();
- }
+ // Equal shapes with no distribution - no further verification needed.
+ if (descShape == valShape && !sgMap)
+ return true;
+ // Unknown distribution - cannot perform operation on partial shape.
if (!sgMap)
return false;
- if (valShape.size() != descShape.size())
+ // Invalid rank or mixed rank usage.
+ size_t descRank = descShape.size();
+ if (descRank > 2 || valShape.size() != descRank)
return false;
+ // For 1D, SG map is guaranteed to be unit size in the outer dimension.
+ // Only take the distribution over the innermost dimension for validation.
+ ArrayRef<uint32_t> wiLayout = sgMap.getWiLayout();
+ SmallVector<uint32_t> mapLayout(wiLayout.begin(), wiLayout.end());
+ if (descRank == 1)
+ mapLayout = {wiLayout.back()};
+
for (const auto &[factor, dim, expected] :
- llvm::zip_equal(sgMap.getWiLayout(), valShape, descShape)) {
+ llvm::zip_equal(mapLayout, valShape, descShape)) {
if (factor * dim != expected)
return false;
}
diff --git a/mlir/test/Dialect/XeGPU/XeGPUOps.mlir b/mlir/test/Dialect/XeGPU/XeGPUOps.mlir
index dcd6b01974cf306..8af1b600ad0a4e2 100644
--- a/mlir/test/Dialect/XeGPU/XeGPUOps.mlir
+++ b/mlir/test/Dialect/XeGPU/XeGPUOps.mlir
@@ -97,6 +97,16 @@ gpu.func @test_load_nd_vc_3(%src: memref<24x32xf32>) {
gpu.return
}
+// CHECK: func @test_load_nd_vc_4(%[[arg0:.*]]: memref<24x32xf32>) {
+gpu.func @test_load_nd_vc_4(%src: memref<24x32xf32>) {
+ // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<32xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
+ !xegpu.tensor_desc<32xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ // CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<32xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>> -> vector<2xf32>
+ %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<32xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>> -> vector<2xf32>
+ gpu.return
+}
+
// CHECK: func @test_store_nd_vc(%[[arg0:.*]]: memref<24x32xf16>) {
gpu.func @test_store_nd_vc(%dst: memref<24x32xf16>) {
// CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<24x32xf16>
@@ -132,6 +142,18 @@ gpu.func @test_store_nd_vc_3(%src: memref<24x32xf16>) {
gpu.return
}
+// CHECK: func @test_store_nd_vc_4(%[[arg0:.*]]: memref<24x32xf16>) {
+gpu.func @test_store_nd_vc_4(%src: memref<24x32xf16>) {
+ // CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<2xf16>
+ %1 = arith.constant dense<1.0>: vector<2xf16>
+ // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<32xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ %2 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf16> ->
+ !xegpu.tensor_desc<32xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ // CHECK: xegpu.store_nd %[[C]], %[[R0]] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}> : vector<2xf16>, !xegpu.tensor_desc<32xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ xegpu.store_nd %1, %2 <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>: vector<2xf16>, !xegpu.tensor_desc<32xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ gpu.return
+}
+
// CHECK: gpu.func @test_create_update_nd_tdesc_vc(%[[arg0:.*]]: memref<24x32xf32>) {
gpu.func @test_create_update_nd_tdesc_vc(%src: memref<24x32xf32>) {
// CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index 975b4aea84fe250..5ed93db2be502dd 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -82,16 +82,33 @@ func.func @test_load_nd_vc_4(%src: memref<24x32xf32>) {
%1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
!xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
// expected-error at +1 {{Result shape doesn't match TensorDesc shape.}}
- %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>> -> vector<8x2xf32>
+ %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>,
+ l2_hint = #xegpu.cache_hint<uncached>}>
+ : !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ -> vector<8x2xf32>
return
}
// -----
func.func @test_load_nd_vc_5(%src: memref<24x32xf32>) {
%1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
- !xegpu.tensor_desc<16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ !xegpu.tensor_desc<16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
// expected-error at +1 {{Result shape doesn't match TensorDesc shape.}}
- %2 = xegpu.load_nd %1: !xegpu.tensor_desc<16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>> -> vector<16xf32>
+ %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>,
+ l2_hint = #xegpu.cache_hint<uncached>}>
+ : !xegpu.tensor_desc<16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ -> vector<8xf32>
+ return
+}
+
+// -----
+func.func @test_load_nd_vc_6(%src: memref<24x32xf32>) {
+ %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
+ !xegpu.tensor_desc<8x16xf32>
+ // expected-error at +1 {{Result shape doesn't match TensorDesc shape.}}
+ %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>,
+ l2_hint = #xegpu.cache_hint<uncached>}>
+ : !xegpu.tensor_desc<8x16xf32> -> vector<8x1xf32>
return
}
@@ -116,6 +133,35 @@ func.func @test_store_nd_vc_2(%dst: memref<16xf16>) {
return
}
+// -----
+func.func @test_store_nd_vc_3(%dst: memref<24x32xf32>, %data: vector<8x2xf32>) {
+ %1 = xegpu.create_nd_tdesc %dst[0, 0] : memref<24x32xf32> ->
+ !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ // expected-error at +1 {{Result shape doesn't match TensorDesc shape.}}
+ xegpu.store_nd %data, %1
+ : vector<8x2xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ return
+}
+
+// -----
+func.func @test_store_nd_vc_4(%dst: memref<24x32xf32>, %data: vector<2xf32>) {
+ %1 = xegpu.create_nd_tdesc %dst[0, 0] : memref<24x32xf32> ->
+ !xegpu.tensor_desc<16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ // expected-error at +1 {{Result shape doesn't match TensorDesc shape.}}
+ xegpu.store_nd %data, %1
+ : vector<2xf32>, !xegpu.tensor_desc<16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ return
+}
+
+// -----
+func.func @test_store_nd_vc_5(%dst: memref<24x32xf32>, %data: vector<8x1xf32>) {
+ %1 = xegpu.create_nd_tdesc %dst[0, 0] : memref<24x32xf32> ->
+ !xegpu.tensor_desc<8x16xf32>
+ // expected-error at +1 {{Result shape doesn't match TensorDesc shape.}}
+ xegpu.store_nd %data, %1 : vector<8x1xf32>, !xegpu.tensor_desc<8x16xf32>
+ return
+}
+
// -----
func.func @test_update_nd_offset_1(%dst: memref<16xf16>) {
%0 = arith.constant dense<[0, 2, 4, 6, 8, 10, 12, 14]> : vector<8xindex>
>From fa82d893c6b8ee3ab9ebaa5fd97bf00553f10504 Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Mon, 27 Jan 2025 15:16:33 +0100
Subject: [PATCH 4/8] Refactor
---
mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 42d59da2f7a922a..ef0ea38027c4503 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -257,7 +257,7 @@ LogicalResult TensorDescType::verify(
<< wiData[i] << " tiles";
}
- if (llvm::isa_and_nonnull<ScatterTensorDescAttr>(encoding)) {
+ if (mlir::isa_and_nonnull<ScatterTensorDescAttr>(encoding)) {
auto scatterAttr = llvm::dyn_cast<ScatterTensorDescAttr>(encoding);
if (wiData[0] != 1)
return emitError()
>From 5df9cf0089676d2650effacba226ecfdf4fdd523 Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Fri, 7 Feb 2025 13:53:20 +0100
Subject: [PATCH 5/8] Improve scattered verification
---
mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp | 56 ++++++++++++++--------
mlir/test/Dialect/XeGPU/invalid.mlir | 34 ++++++++-----
2 files changed, 58 insertions(+), 32 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index ef0ea38027c4503..077a924dfad266a 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -229,8 +229,24 @@ LogicalResult TensorDescType::verify(
llvm::ArrayRef<int64_t> shape, mlir::Type elementType,
mlir::Attribute encoding, mlir::Attribute sg_map) {
size_t rank = shape.size();
- if (rank > 2)
- return emitError() << "desc shape rank exceeds 2";
+ if (rank != 1 && rank != 2)
+ return emitError() << "expected 1D or 2D tensor";
+
+ // Scattered attribute imposes extra restriction on tensor descriptor.
+ // Block attribute can only be validated further against data transfer
+ // operations.
+ auto scatterAttr = mlir::dyn_cast_if_present<ScatterTensorDescAttr>(encoding);
+ if (scatterAttr) {
+ // Expected tensor ranks for scattered data:
+ // - 1D tensor for fully non-contiguous elements (chunk size == 1)
+ // - 2D tensor for scattered blocks (chunk size > 1)
+ IntegerAttr chunkAttr = scatterAttr.getChunkSize();
+ unsigned chunkSize = chunkAttr ? chunkAttr.getInt() : 1;
+ if (rank == 1 && chunkSize != 1)
+ return emitError() << "expected non-contiguous elements for 1D tensor";
+ if (rank == 2 && chunkSize < 2)
+ return emitError() << "expected chunk blocks for 2D tensor";
+ }
if (auto sgMapAttr = llvm::dyn_cast_if_present<SGMapAttr>(sg_map)) {
ArrayRef<uint32_t> wiLayout = sgMapAttr.getWiLayout();
@@ -238,8 +254,22 @@ LogicalResult TensorDescType::verify(
if (rank == 1) {
if (wiLayout[0] != 1 || wiData[0] != 1)
- return emitError() << "outer layout and data mapping must be 1 "
- "for 1D tensor";
+ return emitError()
+ << "outer layout distribution and data mapping must be 1 "
+ "for 1D tensor";
+ }
+
+ if (scatterAttr) {
+ // Validate subgroup mapping rules for scattered tensors.
+ if (wiData[0] != 1)
+ return emitError()
+ << "cannot map over non-contiguous scattered row elements";
+
+ IntegerAttr chunkAttr = scatterAttr.getChunkSize();
+ unsigned chunkSize = chunkAttr ? chunkAttr.getInt() : 1;
+ if (wiData[1] != chunkSize)
+ return emitError() << "work item data mapping must match the number of "
+ "contiguous elements";
}
// For 1D tensor, pad the shape with an outer unit dimension to allow common
@@ -252,21 +282,9 @@ LogicalResult TensorDescType::verify(
for (size_t i = 0; i < dims; ++i) {
uint32_t numElemPerWi = wiLayout[i] * wiData[i];
if (tensorShape[i] < numElemPerWi || tensorShape[i] % numElemPerWi != 0)
- return emitError() << "cannot map " << tensorShape[i]
- << " elements into " << wiLayout[i] << " by "
- << wiData[i] << " tiles";
- }
-
- if (mlir::isa_and_nonnull<ScatterTensorDescAttr>(encoding)) {
- auto scatterAttr = llvm::dyn_cast<ScatterTensorDescAttr>(encoding);
- if (wiData[0] != 1)
- return emitError()
- << "cannot map over non-contiguous scattered elements";
-
- unsigned chunkSize = scatterAttr.getChunkSize().getInt();
- if (wiData[1] > chunkSize)
- return emitError()
- << "too few contiguous elements for work item mapping";
+ return emitError() << "cannot distribute " << tensorShape[i] << " over "
+ << wiLayout[i] << " work items with " << wiData[i]
+ << " elements each";
}
}
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index 5ed93db2be502dd..733eb1559d6fb14 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -183,8 +183,8 @@ func.func @test_create_tdesc_vc_1(%src: ui64) {
// -----
func.func @test_create_tdesc_vc_2(%src: ui64) {
%0 = arith.constant dense<[0, 2, 4, 6, 8, 10, 12, 14]> : vector<8xindex>
- // expected-error at +1 {{Incorrect TensorDesc shape}}
%1 = xegpu.create_tdesc %src, %0 : ui64, vector<8xindex>
+ // expected-error at +1 {{expected chunk blocks for 2D tensor}}
-> !xegpu.tensor_desc<8x4xf16, #xegpu.scatter_tdesc_attr<>>
return
}
@@ -219,7 +219,7 @@ func.func @test_prefetch_vc_2(%src: ui64) {
// -----
func.func @test_create_tdesc_sg_map_1(%src: ui64) {
%cst = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
- // expected-error at +1 {{Detected a conflict between SG map's work-item layout and TensorDesc shape. Check the index of `subgroup_size` in WI layout map}}
+ // expected-error at +1 {{outer layout distribution and data mapping must be 1 for 1D tensor}}
%1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 1]>>
return
}
@@ -227,7 +227,7 @@ func.func @test_create_tdesc_sg_map_1(%src: ui64) {
// -----
func.func @test_create_tdesc_sg_map_2(%src: ui64) {
%cst = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
- // expected-error at +1 {{TensorDesc's SG map only supports multiple elements contiguous along rows}}
+ // expected-error at +1 {{cannot map over non-contiguous scattered row elements}}
%1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [1, 4], wi_data = [2, 1]>>
return
}
@@ -235,7 +235,7 @@ func.func @test_create_tdesc_sg_map_2(%src: ui64) {
// -----
func.func @test_create_tdesc_sg_map_3(%src: ui64) {
%cst = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
- // expected-error at +1 {{TensorDesc's chunkSize must match WI's data mapping}}
+ // expected-error at +1 {{work item data mapping must match the number of contiguous elements}}
%1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x3xf32, #xegpu.scatter_tdesc_attr<chunk_size = 3>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>
return
}
@@ -366,15 +366,23 @@ func.func @test_atomic_rmw(%src: ui64, %value : vector<16x4xf32>, %mask : vector
// -----
func.func @tensor_desc_invalid_rank(%src: memref<24x32xf32>) {
%0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
- // expected-error at +1 {{desc shape rank exceeds 2}}
+ // expected-error at +1 {{expected 1D or 2D tensor}}
!xegpu.tensor_desc<16x2x2xf32>
return
}
+// -----
+func.func @tensor_desc_invalid_rank_1(%src: memref<24x32xf32>) {
+ %0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
+ // expected-error at +1 {{expected 1D or 2D tensor}}
+ !xegpu.tensor_desc<f32>
+ return
+}
+
// -----
func.func @tensor_desc_1D_invalid_map_layout(%src: memref<24x32xf32>) {
%0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
- // expected-error at +1 {{outer layout and data mapping must be 1 for 1D tensor}}
+ // expected-error at +1 {{outer layout distribution and data mapping must be 1 for 1D tensor}}
!xegpu.tensor_desc<16xf32, #xegpu.sg_map<wi_layout = [2, 16], wi_data = [1, 1]>>
return
}
@@ -382,7 +390,7 @@ func.func @tensor_desc_1D_invalid_map_layout(%src: memref<24x32xf32>) {
// -----
func.func @tensor_desc_1D_invalid_map_data(%src: memref<24x32xf32>) {
%0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
- // expected-error at +1 {{outer layout and data mapping must be 1 for 1D tensor}}
+ // expected-error at +1 {{outer layout distribution and data mapping must be 1 for 1D tensor}}
!xegpu.tensor_desc<16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [2, 1]>>
return
}
@@ -390,7 +398,7 @@ func.func @tensor_desc_1D_invalid_map_data(%src: memref<24x32xf32>) {
// -----
func.func @tensor_desc_invalid_map_layout(%src: memref<24x32xf32>) {
%0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
- // expected-error at +1 {{cannot map 8 elements into 16 by 1 tiles}}
+ // expected-error at +1 {{cannot distribute 8 over 16 work items with 1 elements each}}
!xegpu.tensor_desc<4x8xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
return
}
@@ -398,7 +406,7 @@ func.func @tensor_desc_invalid_map_layout(%src: memref<24x32xf32>) {
// -----
func.func @tensor_desc_invalid_map_layout_1(%src: memref<24x32xf32>) {
%0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
- // expected-error at +1 {{cannot map 4 elements into 8 by 1 tiles}}
+ // expected-error at +1 {{cannot distribute 4 over 8 work items with 1 elements each}}
!xegpu.tensor_desc<4x8xf32, #xegpu.sg_map<wi_layout = [8, 2], wi_data = [1, 1]>>
return
}
@@ -406,7 +414,7 @@ func.func @tensor_desc_invalid_map_layout_1(%src: memref<24x32xf32>) {
// -----
func.func @tensor_desc_invalid_map_data(%src: memref<24x32xf32>) {
%0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
- // expected-error at +1 {{cannot map 4 elements into 2 by 4 tiles}}
+ // expected-error at +1 {{cannot distribute 4 over 2 work items with 4 elements each}}
!xegpu.tensor_desc<4x8xf32, #xegpu.sg_map<wi_layout = [2, 8], wi_data = [4, 1]>>
return
}
@@ -414,7 +422,7 @@ func.func @tensor_desc_invalid_map_data(%src: memref<24x32xf32>) {
// -----
func.func @tensor_desc_invalid_map_data_1(%src: memref<24x32xf32>) {
%0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
- // expected-error at +1 {{cannot map 4 elements into 8 by 1 tiles}}
+ // expected-error at +1 {{cannot distribute 4 over 8 work items with 1 elements each}}
!xegpu.tensor_desc<4x8xf32, #xegpu.sg_map<wi_layout = [8, 2], wi_data = [1, 2]>>
return
}
@@ -423,7 +431,7 @@ func.func @tensor_desc_invalid_map_data_1(%src: memref<24x32xf32>) {
func.func @tensor_desc_scatter_invalid_map_data(%src: ui64) {
%0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
%1 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> ->
- // expected-error at +1 {{cannot map over non-contiguous scattered elements}}
+ // expected-error at +1 {{cannot map over non-contiguous scattered row elements}}
!xegpu.tensor_desc<4x2xf32,
#xegpu.scatter_tdesc_attr<chunk_size = 2>,
#xegpu.sg_map<wi_layout = [1, 1], wi_data = [2, 1]>>
@@ -433,7 +441,7 @@ func.func @tensor_desc_scatter_invalid_map_data(%src: ui64) {
// -----
func.func @tensor_desc_scatter_invalid_map_data_1(%src: ui64, %offsets: vector<16xindex>) {
%1 = xegpu.create_tdesc %src, %offsets : ui64, vector<16xindex> ->
- // expected-error at +1 {{too few contiguous elements for work item mapping}}
+ // expected-error at +1 {{work item data mapping must match the number of contiguous elements}}
!xegpu.tensor_desc<16xf32,
#xegpu.scatter_tdesc_attr<chunk_size = 1>,
#xegpu.sg_map<wi_layout = [1, 8], wi_data = [1, 2]>>
>From f502c66de3e230a364ead78f3d3404f7368ca10d Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Fri, 7 Feb 2025 14:01:52 +0100
Subject: [PATCH 6/8] Remove TensorDesc invariant checks from op verifier
---
mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp | 3 +++
mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 17 +----------------
2 files changed, 4 insertions(+), 16 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 077a924dfad266a..0f17c6a8fa98d63 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -261,6 +261,9 @@ LogicalResult TensorDescType::verify(
if (scatterAttr) {
// Validate subgroup mapping rules for scattered tensors.
+ // A work-item's slice of the tensor with shape [sg_size] or
+ // [sg_size, chunk_size] will be [1] or [1, chunks_size] respectively,
+ // the mapping should reflect that.
if (wiData[0] != 1)
return emitError()
<< "cannot map over non-contiguous scattered row elements";
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 996fb3638203311..476689fae4e254f 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -458,22 +458,7 @@ LogicalResult CreateDescOp::verify() {
if (shape != tdescShape)
return emitOpError("Incorrect TensorDesc shape. ")
<< "Expected is " << makeString(shape) << "\n";
- if (auto sgMap = tdescTy.getSGMapAttr()) {
- // A work-item's slice of the TensorDesc with shape [sg_size] or
- // [sg_size, chunk_size] will be [1] or [1, chunks_size] respectively,
- // the mapping should reflect that.
- if (sgMap.getWiData()[0] > 1)
- return emitOpError("TensorDesc's SG map only supports multiple elements "
- "contiguous along rows.");
- if (chunkSize != static_cast<int>(sgMap.getWiData()[1]))
- return emitOpError(
- "TensorDesc's chunkSize must match WI's data mapping.");
- if (int rank = tdescTy.getRank();
- (sgMap.getWiLayout()[2 - rank] != tdescShape[0]))
- return emitOpError("Detected a conflict between SG map's work-item "
- "layout and TensorDesc shape. Check the index of "
- "`subgroup_size` in WI layout map.");
- }
+
return success();
}
>From 50c628323506d2af53cfa8ea62259485cadc348c Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Fri, 7 Feb 2025 14:20:45 +0100
Subject: [PATCH 7/8] Add more chunk_size test cases
---
mlir/test/Dialect/XeGPU/invalid.mlir | 20 ++++++++++++++++++++
1 file changed, 20 insertions(+)
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index 733eb1559d6fb14..48e8c2808abdae6 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -447,3 +447,23 @@ func.func @tensor_desc_scatter_invalid_map_data_1(%src: ui64, %offsets: vector<1
#xegpu.sg_map<wi_layout = [1, 8], wi_data = [1, 2]>>
return
}
+
+// -----
+func.func @tensor_desc_scatter_invalid_chunk_size_1D(%src: ui64, %offsets: vector<16xindex>) {
+ %1 = xegpu.create_tdesc %src, %offsets : ui64, vector<16xindex> ->
+ // expected-error at +1 {{expected non-contiguous elements for 1D tensor}}
+ !xegpu.tensor_desc<16xf32,
+ #xegpu.scatter_tdesc_attr<chunk_size = 2>,
+ #xegpu.sg_map<wi_layout = [1, 8], wi_data = [1, 2]>>
+ return
+}
+
+// -----
+func.func @tensor_desc_scatter_invalid_chunk_size_2D(%src: ui64, %offsets: vector<16xindex>) {
+ %1 = xegpu.create_tdesc %src, %offsets : ui64, vector<16xindex> ->
+ // expected-error at +1 {{expected chunk blocks for 2D tensor}}
+ !xegpu.tensor_desc<16x2xf32,
+ #xegpu.scatter_tdesc_attr<chunk_size = 1>,
+ #xegpu.sg_map<wi_layout = [8, 1], wi_data = [1, 2]>>
+ return
+}
>From a72c12f8f05f8f7456e442f40882f809ac10cdd1 Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Fri, 7 Feb 2025 14:59:55 +0100
Subject: [PATCH 8/8] Move memory space check to type verifier
---
mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp | 11 ++++++++---
mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 4 ----
mlir/test/Dialect/XeGPU/invalid.mlir | 2 +-
3 files changed, 9 insertions(+), 8 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 0f17c6a8fa98d63..becc32d1226973d 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -232,9 +232,6 @@ LogicalResult TensorDescType::verify(
if (rank != 1 && rank != 2)
return emitError() << "expected 1D or 2D tensor";
- // Scattered attribute imposes extra restriction on tensor descriptor.
- // Block attribute can only be validated further against data transfer
- // operations.
auto scatterAttr = mlir::dyn_cast_if_present<ScatterTensorDescAttr>(encoding);
if (scatterAttr) {
// Expected tensor ranks for scattered data:
@@ -248,6 +245,14 @@ LogicalResult TensorDescType::verify(
return emitError() << "expected chunk blocks for 2D tensor";
}
+ if (auto blockAttr =
+ mlir::dyn_cast_if_present<BlockTensorDescAttr>(encoding)) {
+ MemorySpaceAttr memorySpaceAttr = blockAttr.getMemorySpace();
+ if (rank == 2 && memorySpaceAttr &&
+ memorySpaceAttr.getValue() == MemorySpace::SLM)
+ return emitError() << "SLM is not supported for 2D block tensor";
+ }
+
if (auto sgMapAttr = llvm::dyn_cast_if_present<SGMapAttr>(sg_map)) {
ArrayRef<uint32_t> wiLayout = sgMapAttr.getWiLayout();
ArrayRef<uint32_t> wiData = sgMapAttr.getWiData();
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 476689fae4e254f..e06d99ac20bb736 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -231,10 +231,6 @@ LogicalResult CreateNdDescOp::verify() {
if (getType().isScattered())
return emitOpError("Expects a non-scattered TensorDesc.\n");
- if (getType().getRank() == 2 &&
- tdescMemorySpace == static_cast<unsigned>(MemorySpace::SLM))
- return emitOpError("SLM is not supported for 2D Block TensorDesc.\n");
-
return success();
}
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index 48e8c2808abdae6..9162e0012f6d56d 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -17,7 +17,7 @@ func.func @test_create_nd_tdesc_vc_2(%src: memref<24x32xf32>) {
// -----
func.func @test_create_nd_tdesc_vc_3(%src: memref<2x24x32xf32, 3>) {
- // expected-error at +1 {{SLM is not supported for 2D Block TensorDesc}}
+ // expected-error at +1 {{SLM is not supported for 2D block tensor}}
%1 = xegpu.create_nd_tdesc %src[0, 0, 0] : memref<2x24x32xf32, 3> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<memory_space = slm>>
return
}
More information about the Mlir-commits
mailing list