[Mlir-commits] [mlir] [mlir][xegpu] Tensor descriptor type verifier (PR #124548)
Adam Siemieniuk
llvmlistbot at llvm.org
Mon Jan 27 06:20:41 PST 2025
https://github.com/adam-smnk created https://github.com/llvm/llvm-project/pull/124548
Adds XeGPU tensor descriptor type verifier.
The checks focus on ensuring that provided subgroup map is valid with respect to the underlying data.
Related operation verifiers are updated to account for the new descriptor validation.
>From 8c2b7097920d479ae8b68d370b13393afcd23037 Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Thu, 23 Jan 2025 10:24:11 +0100
Subject: [PATCH 1/4] [mlir][xegpu] TensorDesc verifier
Adds XeGPU tensor descriptor type verifier.
The checks focus on ensuring that provided subgroup map is valid
with respect to the underlying data.
---
.../mlir/Dialect/XeGPU/IR/XeGPUTypes.td | 2 +-
mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp | 56 ++++++++++++++++++-
2 files changed, 54 insertions(+), 4 deletions(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
index d09c5c1870d506..494f11f041b71f 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
@@ -179,7 +179,7 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
}];
let hasCustomAssemblyFormat = true;
-
+ let genVerifyDecl = 1;
}
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index eb01b15de75c60..42d59da2f7a922 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -175,9 +175,10 @@ mlir::Type TensorDescType::parse(::mlir::AsmParser &parser) {
if (parser.parseGreater())
return {};
- return TensorDescType::get(parser.getContext(), shape, elementType,
- encoding.value_or(mlir::Attribute()),
- sg_map.value_or(mlir::Attribute()));
+ return TensorDescType::getChecked(
+ [&]() { return parser.emitError(parser.getNameLoc()); },
+ parser.getContext(), shape, elementType,
+ encoding.value_or(mlir::Attribute()), sg_map.value_or(mlir::Attribute()));
}
void TensorDescType::print(::mlir::AsmPrinter &printer) const {
@@ -223,6 +224,55 @@ TensorDescType TensorDescType::get(llvm::ArrayRef<int64_t> shape,
return Base::get(context, shape, elementType, attr, sg_map);
}
+LogicalResult TensorDescType::verify(
+ llvm::function_ref<::mlir::InFlightDiagnostic()> emitError,
+ llvm::ArrayRef<int64_t> shape, mlir::Type elementType,
+ mlir::Attribute encoding, mlir::Attribute sg_map) {
+ size_t rank = shape.size();
+ if (rank > 2)
+ return emitError() << "desc shape rank exceeds 2";
+
+ if (auto sgMapAttr = llvm::dyn_cast_if_present<SGMapAttr>(sg_map)) {
+ ArrayRef<uint32_t> wiLayout = sgMapAttr.getWiLayout();
+ ArrayRef<uint32_t> wiData = sgMapAttr.getWiData();
+
+ if (rank == 1) {
+ if (wiLayout[0] != 1 || wiData[0] != 1)
+ return emitError() << "outer layout and data mapping must be 1 "
+ "for 1D tensor";
+ }
+
+ // For 1D tensor, pad the shape with an outer unit dimension to allow common
+ // validation logic.
+ SmallVector<int64_t> tensorShape(shape.begin(), shape.end());
+ if (rank == 1)
+ tensorShape = {1, tensorShape.back()};
+
+ size_t dims = tensorShape.size();
+ for (size_t i = 0; i < dims; ++i) {
+ uint32_t numElemPerWi = wiLayout[i] * wiData[i];
+ if (tensorShape[i] < numElemPerWi || tensorShape[i] % numElemPerWi != 0)
+ return emitError() << "cannot map " << tensorShape[i]
+ << " elements into " << wiLayout[i] << " by "
+ << wiData[i] << " tiles";
+ }
+
+ if (llvm::isa_and_nonnull<ScatterTensorDescAttr>(encoding)) {
+ auto scatterAttr = llvm::dyn_cast<ScatterTensorDescAttr>(encoding);
+ if (wiData[0] != 1)
+ return emitError()
+ << "cannot map over non-contiguous scattered elements";
+
+ unsigned chunkSize = scatterAttr.getChunkSize().getInt();
+ if (wiData[1] > chunkSize)
+ return emitError()
+ << "too few contiguous elements for work item mapping";
+ }
+ }
+
+ return success();
+}
+
} // namespace xegpu
} // namespace mlir
>From 72553d6f995ec42a31641414ce37ca853495e171 Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Fri, 24 Jan 2025 17:25:29 +0100
Subject: [PATCH 2/4] Test cases
---
mlir/test/Dialect/XeGPU/invalid.mlir | 79 +++++++++++++++++++++++++++-
1 file changed, 78 insertions(+), 1 deletion(-)
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index 7816bff0582f81..a5e615ae917a86 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -238,4 +238,81 @@ func.func @test_atomic_rmw(%src: ui64, %value : vector<16x4xf32>, %mask : vector
// expected-error at +1 {{failed to verify that all of {tensorDesc, value, result} have same shape}}
xegpu.atomic_rmw addf %1, %mask, %value: !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<chunk_size = 8>>, vector<16xi1>, vector<16x4xf32> -> vector<16x8xf32>
return
-}
\ No newline at end of file
+}
+
+// -----
+func.func @tensor_desc_invalid_rank(%src: memref<24x32xf32>) {
+ %0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
+ // expected-error at +1 {{desc shape rank exceeds 2}}
+ !xegpu.tensor_desc<16x2x2xf32>
+ return
+}
+
+// -----
+func.func @tensor_desc_1D_invalid_map_layout(%src: memref<24x32xf32>) {
+ %0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
+ // expected-error at +1 {{outer layout and data mapping must be 1 for 1D tensor}}
+ !xegpu.tensor_desc<16xf32, #xegpu.sg_map<wi_layout = [2, 16], wi_data = [1, 1]>>
+ return
+}
+
+// -----
+func.func @tensor_desc_1D_invalid_map_data(%src: memref<24x32xf32>) {
+ %0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
+ // expected-error at +1 {{outer layout and data mapping must be 1 for 1D tensor}}
+ !xegpu.tensor_desc<16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [2, 1]>>
+ return
+}
+
+// -----
+func.func @tensor_desc_invalid_map_layout(%src: memref<24x32xf32>) {
+ %0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
+ // expected-error at +1 {{cannot map 8 elements into 16 by 1 tiles}}
+ !xegpu.tensor_desc<4x8xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ return
+}
+
+// -----
+func.func @tensor_desc_invalid_map_layout_1(%src: memref<24x32xf32>) {
+ %0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
+ // expected-error at +1 {{cannot map 4 elements into 8 by 1 tiles}}
+ !xegpu.tensor_desc<4x8xf32, #xegpu.sg_map<wi_layout = [8, 2], wi_data = [1, 1]>>
+ return
+}
+
+// -----
+func.func @tensor_desc_invalid_map_data(%src: memref<24x32xf32>) {
+ %0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
+ // expected-error at +1 {{cannot map 4 elements into 2 by 4 tiles}}
+ !xegpu.tensor_desc<4x8xf32, #xegpu.sg_map<wi_layout = [2, 8], wi_data = [4, 1]>>
+ return
+}
+
+// -----
+func.func @tensor_desc_invalid_map_data_1(%src: memref<24x32xf32>) {
+ %0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
+ // expected-error at +1 {{cannot map 4 elements into 8 by 1 tiles}}
+ !xegpu.tensor_desc<4x8xf32, #xegpu.sg_map<wi_layout = [8, 2], wi_data = [1, 2]>>
+ return
+}
+
+// -----
+func.func @tensor_desc_scatter_invalid_map_data(%src: ui64) {
+ %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+ %1 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> ->
+ // expected-error at +1 {{cannot map over non-contiguous scattered elements}}
+ !xegpu.tensor_desc<4x2xf32,
+ #xegpu.scatter_tdesc_attr<chunk_size = 2>,
+ #xegpu.sg_map<wi_layout = [1, 1], wi_data = [2, 1]>>
+ return
+}
+
+// -----
+func.func @tensor_desc_scatter_invalid_map_data_1(%src: ui64, %offsets: vector<16xindex>) {
+ %1 = xegpu.create_tdesc %src, %offsets : ui64, vector<16xindex> ->
+ // expected-error at +1 {{too few contiguous elements for work item mapping}}
+ !xegpu.tensor_desc<16xf32,
+ #xegpu.scatter_tdesc_attr<chunk_size = 1>,
+ #xegpu.sg_map<wi_layout = [1, 8], wi_data = [1, 2]>>
+ return
+}
>From c427acb631f54672f25c7a90ebf7e55e18cadbb9 Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Mon, 27 Jan 2025 14:40:36 +0100
Subject: [PATCH 3/4] Update load/store verifier + tests
---
mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 26 +++++++------
mlir/test/Dialect/XeGPU/XeGPUOps.mlir | 22 +++++++++++
mlir/test/Dialect/XeGPU/invalid.mlir | 52 ++++++++++++++++++++++++--
3 files changed, 86 insertions(+), 14 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 81f46f941785a1..bf9eb8f7e10c3c 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -81,24 +81,28 @@ static bool isWriteHintOrNone(const CachePolicyAttr &attr) {
// each dimension.
static bool isArgShapesValid(ArrayRef<int64_t> descShape,
ArrayRef<int64_t> valShape, SGMapAttr sgMap) {
- if (descShape == valShape) {
- if (!sgMap)
- return true;
-
- // this can be relaxed if necessary by supporting non-2d shapes distribution
- // until the constraints are defined this lives here instead of the tensor
- // descriptor type.
- return valShape.size() == sgMap.getWiLayout().size();
- }
+ // Equal shapes with no distribution - no further verification needed.
+ if (descShape == valShape && !sgMap)
+ return true;
+ // Unknown distribution - cannot perform operation on partial shape.
if (!sgMap)
return false;
- if (valShape.size() != descShape.size())
+ // Invalid rank or mixed rank usage.
+ size_t descRank = descShape.size();
+ if (descRank > 2 || valShape.size() != descRank)
return false;
+ // For 1D, SG map is guaranteed to be unit size in the outer dimension.
+ // Only take the distribution over the innermost dimension for validation.
+ ArrayRef<uint32_t> wiLayout = sgMap.getWiLayout();
+ SmallVector<uint32_t> mapLayout(wiLayout.begin(), wiLayout.end());
+ if (descRank == 1)
+ mapLayout = {wiLayout.back()};
+
for (const auto &[factor, dim, expected] :
- llvm::zip_equal(sgMap.getWiLayout(), valShape, descShape)) {
+ llvm::zip_equal(mapLayout, valShape, descShape)) {
if (factor * dim != expected)
return false;
}
diff --git a/mlir/test/Dialect/XeGPU/XeGPUOps.mlir b/mlir/test/Dialect/XeGPU/XeGPUOps.mlir
index d7174a489888a4..729abc5d69f3d1 100644
--- a/mlir/test/Dialect/XeGPU/XeGPUOps.mlir
+++ b/mlir/test/Dialect/XeGPU/XeGPUOps.mlir
@@ -97,6 +97,16 @@ gpu.func @test_load_nd_vc_3(%src: memref<24x32xf32>) {
gpu.return
}
+// CHECK: func @test_load_nd_vc_4(%[[arg0:.*]]: memref<24x32xf32>) {
+gpu.func @test_load_nd_vc_4(%src: memref<24x32xf32>) {
+ // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<32xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
+ !xegpu.tensor_desc<32xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ // CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<32xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>> -> vector<2xf32>
+ %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<32xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>> -> vector<2xf32>
+ gpu.return
+}
+
// CHECK: func @test_store_nd_vc(%[[arg0:.*]]: memref<24x32xf16>) {
gpu.func @test_store_nd_vc(%dst: memref<24x32xf16>) {
// CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<24x32xf16>
@@ -132,6 +142,18 @@ gpu.func @test_store_nd_vc_3(%src: memref<24x32xf16>) {
gpu.return
}
+// CHECK: func @test_store_nd_vc_4(%[[arg0:.*]]: memref<24x32xf16>) {
+gpu.func @test_store_nd_vc_4(%src: memref<24x32xf16>) {
+ // CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<2xf16>
+ %1 = arith.constant dense<1.0>: vector<2xf16>
+ // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<32xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ %2 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf16> ->
+ !xegpu.tensor_desc<32xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ // CHECK: xegpu.store_nd %[[C]], %[[R0]] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}> : vector<2xf16>, !xegpu.tensor_desc<32xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ xegpu.store_nd %1, %2 <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>: vector<2xf16>, !xegpu.tensor_desc<32xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ gpu.return
+}
+
// CHECK: gpu.func @test_create_update_nd_tdesc_vc(%[[arg0:.*]]: memref<24x32xf32>) {
gpu.func @test_create_update_nd_tdesc_vc(%src: memref<24x32xf32>) {
// CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index a5e615ae917a86..94dc15756fe4ae 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -82,16 +82,33 @@ func.func @test_load_nd_vc_4(%src: memref<24x32xf32>) {
%1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
!xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
// expected-error at +1 {{Result shape doesn't match TensorDesc shape.}}
- %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>> -> vector<8x2xf32>
+ %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>,
+ l2_hint = #xegpu.cache_hint<uncached>}>
+ : !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ -> vector<8x2xf32>
return
}
// -----
func.func @test_load_nd_vc_5(%src: memref<24x32xf32>) {
%1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
- !xegpu.tensor_desc<16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ !xegpu.tensor_desc<16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
// expected-error at +1 {{Result shape doesn't match TensorDesc shape.}}
- %2 = xegpu.load_nd %1: !xegpu.tensor_desc<16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>> -> vector<16xf32>
+ %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>,
+ l2_hint = #xegpu.cache_hint<uncached>}>
+ : !xegpu.tensor_desc<16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ -> vector<8xf32>
+ return
+}
+
+// -----
+func.func @test_load_nd_vc_6(%src: memref<24x32xf32>) {
+ %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
+ !xegpu.tensor_desc<8x16xf32>
+ // expected-error at +1 {{Result shape doesn't match TensorDesc shape.}}
+ %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>,
+ l2_hint = #xegpu.cache_hint<uncached>}>
+ : !xegpu.tensor_desc<8x16xf32> -> vector<8x1xf32>
return
}
@@ -116,6 +133,35 @@ func.func @test_store_nd_vc_2(%dst: memref<16xf16>) {
return
}
+// -----
+func.func @test_store_nd_vc_3(%dst: memref<24x32xf32>, %data: vector<8x2xf32>) {
+ %1 = xegpu.create_nd_tdesc %dst[0, 0] : memref<24x32xf32> ->
+ !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ // expected-error at +1 {{Result shape doesn't match TensorDesc shape.}}
+ xegpu.store_nd %data, %1
+ : vector<8x2xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ return
+}
+
+// -----
+func.func @test_store_nd_vc_4(%dst: memref<24x32xf32>, %data: vector<2xf32>) {
+ %1 = xegpu.create_nd_tdesc %dst[0, 0] : memref<24x32xf32> ->
+ !xegpu.tensor_desc<16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ // expected-error at +1 {{Result shape doesn't match TensorDesc shape.}}
+ xegpu.store_nd %data, %1
+ : vector<2xf32>, !xegpu.tensor_desc<16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ return
+}
+
+// -----
+func.func @test_store_nd_vc_5(%dst: memref<24x32xf32>, %data: vector<8x1xf32>) {
+ %1 = xegpu.create_nd_tdesc %dst[0, 0] : memref<24x32xf32> ->
+ !xegpu.tensor_desc<8x16xf32>
+ // expected-error at +1 {{Result shape doesn't match TensorDesc shape.}}
+ xegpu.store_nd %data, %1 : vector<8x1xf32>, !xegpu.tensor_desc<8x16xf32>
+ return
+}
+
// -----
func.func @test_update_nd_offset_1(%dst: memref<16xf16>) {
%0 = arith.constant dense<[0, 2, 4, 6, 8, 10, 12, 14]> : vector<8xindex>
>From 2b8dd10aae1b07c3cf9464ace9c5d9830fae873f Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Mon, 27 Jan 2025 15:16:33 +0100
Subject: [PATCH 4/4] Refactor
---
mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 42d59da2f7a922..ef0ea38027c450 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -257,7 +257,7 @@ LogicalResult TensorDescType::verify(
<< wiData[i] << " tiles";
}
- if (llvm::isa_and_nonnull<ScatterTensorDescAttr>(encoding)) {
+ if (mlir::isa_and_nonnull<ScatterTensorDescAttr>(encoding)) {
auto scatterAttr = llvm::dyn_cast<ScatterTensorDescAttr>(encoding);
if (wiData[0] != 1)
return emitError()
More information about the Mlir-commits
mailing list