[Mlir-commits] [mlir] [mlir][xegpu] Tensor descriptor type verifier (PR #124548)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Jan 27 06:21:17 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-gpu
Author: Adam Siemieniuk (adam-smnk)
<details>
<summary>Changes</summary>
Adds XeGPU tensor descriptor type verifier.
The checks focus on ensuring that provided subgroup map is valid with respect to the underlying data.
Related operation verifiers are updated to account for the new descriptor validation.
---
Full diff: https://github.com/llvm/llvm-project/pull/124548.diff
5 Files Affected:
- (modified) mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td (+1-1)
- (modified) mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp (+53-3)
- (modified) mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp (+15-11)
- (modified) mlir/test/Dialect/XeGPU/XeGPUOps.mlir (+22)
- (modified) mlir/test/Dialect/XeGPU/invalid.mlir (+127-4)
``````````diff
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
index d09c5c1870d506..494f11f041b71f 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
@@ -179,7 +179,7 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
}];
let hasCustomAssemblyFormat = true;
-
+ let genVerifyDecl = 1;
}
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index eb01b15de75c60..ef0ea38027c450 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -175,9 +175,10 @@ mlir::Type TensorDescType::parse(::mlir::AsmParser &parser) {
if (parser.parseGreater())
return {};
- return TensorDescType::get(parser.getContext(), shape, elementType,
- encoding.value_or(mlir::Attribute()),
- sg_map.value_or(mlir::Attribute()));
+ return TensorDescType::getChecked(
+ [&]() { return parser.emitError(parser.getNameLoc()); },
+ parser.getContext(), shape, elementType,
+ encoding.value_or(mlir::Attribute()), sg_map.value_or(mlir::Attribute()));
}
void TensorDescType::print(::mlir::AsmPrinter &printer) const {
@@ -223,6 +224,55 @@ TensorDescType TensorDescType::get(llvm::ArrayRef<int64_t> shape,
return Base::get(context, shape, elementType, attr, sg_map);
}
+LogicalResult TensorDescType::verify(
+ llvm::function_ref<::mlir::InFlightDiagnostic()> emitError,
+ llvm::ArrayRef<int64_t> shape, mlir::Type elementType,
+ mlir::Attribute encoding, mlir::Attribute sg_map) {
+ size_t rank = shape.size();
+ if (rank > 2)
+ return emitError() << "desc shape rank exceeds 2";
+
+ if (auto sgMapAttr = llvm::dyn_cast_if_present<SGMapAttr>(sg_map)) {
+ ArrayRef<uint32_t> wiLayout = sgMapAttr.getWiLayout();
+ ArrayRef<uint32_t> wiData = sgMapAttr.getWiData();
+
+ if (rank == 1) {
+ if (wiLayout[0] != 1 || wiData[0] != 1)
+ return emitError() << "outer layout and data mapping must be 1 "
+ "for 1D tensor";
+ }
+
+ // For 1D tensor, pad the shape with an outer unit dimension to allow common
+ // validation logic.
+ SmallVector<int64_t> tensorShape(shape.begin(), shape.end());
+ if (rank == 1)
+ tensorShape = {1, tensorShape.back()};
+
+ size_t dims = tensorShape.size();
+ for (size_t i = 0; i < dims; ++i) {
+ uint32_t numElemPerWi = wiLayout[i] * wiData[i];
+ if (tensorShape[i] < numElemPerWi || tensorShape[i] % numElemPerWi != 0)
+ return emitError() << "cannot map " << tensorShape[i]
+ << " elements into " << wiLayout[i] << " by "
+ << wiData[i] << " tiles";
+ }
+
+ if (mlir::isa_and_nonnull<ScatterTensorDescAttr>(encoding)) {
+ auto scatterAttr = llvm::dyn_cast<ScatterTensorDescAttr>(encoding);
+ if (wiData[0] != 1)
+ return emitError()
+ << "cannot map over non-contiguous scattered elements";
+
+ unsigned chunkSize = scatterAttr.getChunkSize().getInt();
+ if (wiData[1] > chunkSize)
+ return emitError()
+ << "too few contiguous elements for work item mapping";
+ }
+ }
+
+ return success();
+}
+
} // namespace xegpu
} // namespace mlir
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 81f46f941785a1..bf9eb8f7e10c3c 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -81,24 +81,28 @@ static bool isWriteHintOrNone(const CachePolicyAttr &attr) {
// each dimension.
static bool isArgShapesValid(ArrayRef<int64_t> descShape,
ArrayRef<int64_t> valShape, SGMapAttr sgMap) {
- if (descShape == valShape) {
- if (!sgMap)
- return true;
-
- // this can be relaxed if necessary by supporting non-2d shapes distribution
- // until the constraints are defined this lives here instead of the tensor
- // descriptor type.
- return valShape.size() == sgMap.getWiLayout().size();
- }
+ // Equal shapes with no distribution - no further verification needed.
+ if (descShape == valShape && !sgMap)
+ return true;
+ // Unknown distribution - cannot perform operation on partial shape.
if (!sgMap)
return false;
- if (valShape.size() != descShape.size())
+ // Invalid rank or mixed rank usage.
+ size_t descRank = descShape.size();
+ if (descRank > 2 || valShape.size() != descRank)
return false;
+ // For 1D, SG map is guaranteed to be unit size in the outer dimension.
+ // Only take the distribution over the innermost dimension for validation.
+ ArrayRef<uint32_t> wiLayout = sgMap.getWiLayout();
+ SmallVector<uint32_t> mapLayout(wiLayout.begin(), wiLayout.end());
+ if (descRank == 1)
+ mapLayout = {wiLayout.back()};
+
for (const auto &[factor, dim, expected] :
- llvm::zip_equal(sgMap.getWiLayout(), valShape, descShape)) {
+ llvm::zip_equal(mapLayout, valShape, descShape)) {
if (factor * dim != expected)
return false;
}
diff --git a/mlir/test/Dialect/XeGPU/XeGPUOps.mlir b/mlir/test/Dialect/XeGPU/XeGPUOps.mlir
index d7174a489888a4..729abc5d69f3d1 100644
--- a/mlir/test/Dialect/XeGPU/XeGPUOps.mlir
+++ b/mlir/test/Dialect/XeGPU/XeGPUOps.mlir
@@ -97,6 +97,16 @@ gpu.func @test_load_nd_vc_3(%src: memref<24x32xf32>) {
gpu.return
}
+// CHECK: func @test_load_nd_vc_4(%[[arg0:.*]]: memref<24x32xf32>) {
+gpu.func @test_load_nd_vc_4(%src: memref<24x32xf32>) {
+ // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<32xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
+ !xegpu.tensor_desc<32xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ // CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<32xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>> -> vector<2xf32>
+ %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<32xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>> -> vector<2xf32>
+ gpu.return
+}
+
// CHECK: func @test_store_nd_vc(%[[arg0:.*]]: memref<24x32xf16>) {
gpu.func @test_store_nd_vc(%dst: memref<24x32xf16>) {
// CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<24x32xf16>
@@ -132,6 +142,18 @@ gpu.func @test_store_nd_vc_3(%src: memref<24x32xf16>) {
gpu.return
}
+// CHECK: func @test_store_nd_vc_4(%[[arg0:.*]]: memref<24x32xf16>) {
+gpu.func @test_store_nd_vc_4(%src: memref<24x32xf16>) {
+ // CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<2xf16>
+ %1 = arith.constant dense<1.0>: vector<2xf16>
+ // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<32xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ %2 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf16> ->
+ !xegpu.tensor_desc<32xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ // CHECK: xegpu.store_nd %[[C]], %[[R0]] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}> : vector<2xf16>, !xegpu.tensor_desc<32xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ xegpu.store_nd %1, %2 <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>: vector<2xf16>, !xegpu.tensor_desc<32xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ gpu.return
+}
+
// CHECK: gpu.func @test_create_update_nd_tdesc_vc(%[[arg0:.*]]: memref<24x32xf32>) {
gpu.func @test_create_update_nd_tdesc_vc(%src: memref<24x32xf32>) {
// CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index 7816bff0582f81..94dc15756fe4ae 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -82,16 +82,33 @@ func.func @test_load_nd_vc_4(%src: memref<24x32xf32>) {
%1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
!xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
// expected-error at +1 {{Result shape doesn't match TensorDesc shape.}}
- %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>> -> vector<8x2xf32>
+ %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>,
+ l2_hint = #xegpu.cache_hint<uncached>}>
+ : !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ -> vector<8x2xf32>
return
}
// -----
func.func @test_load_nd_vc_5(%src: memref<24x32xf32>) {
%1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
- !xegpu.tensor_desc<16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ !xegpu.tensor_desc<16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
// expected-error at +1 {{Result shape doesn't match TensorDesc shape.}}
- %2 = xegpu.load_nd %1: !xegpu.tensor_desc<16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>> -> vector<16xf32>
+ %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>,
+ l2_hint = #xegpu.cache_hint<uncached>}>
+ : !xegpu.tensor_desc<16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ -> vector<8xf32>
+ return
+}
+
+// -----
+func.func @test_load_nd_vc_6(%src: memref<24x32xf32>) {
+ %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
+ !xegpu.tensor_desc<8x16xf32>
+ // expected-error at +1 {{Result shape doesn't match TensorDesc shape.}}
+ %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>,
+ l2_hint = #xegpu.cache_hint<uncached>}>
+ : !xegpu.tensor_desc<8x16xf32> -> vector<8x1xf32>
return
}
@@ -116,6 +133,35 @@ func.func @test_store_nd_vc_2(%dst: memref<16xf16>) {
return
}
+// -----
+func.func @test_store_nd_vc_3(%dst: memref<24x32xf32>, %data: vector<8x2xf32>) {
+ %1 = xegpu.create_nd_tdesc %dst[0, 0] : memref<24x32xf32> ->
+ !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ // expected-error at +1 {{Result shape doesn't match TensorDesc shape.}}
+ xegpu.store_nd %data, %1
+ : vector<8x2xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ return
+}
+
+// -----
+func.func @test_store_nd_vc_4(%dst: memref<24x32xf32>, %data: vector<2xf32>) {
+ %1 = xegpu.create_nd_tdesc %dst[0, 0] : memref<24x32xf32> ->
+ !xegpu.tensor_desc<16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ // expected-error at +1 {{Result shape doesn't match TensorDesc shape.}}
+ xegpu.store_nd %data, %1
+ : vector<2xf32>, !xegpu.tensor_desc<16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ return
+}
+
+// -----
+func.func @test_store_nd_vc_5(%dst: memref<24x32xf32>, %data: vector<8x1xf32>) {
+ %1 = xegpu.create_nd_tdesc %dst[0, 0] : memref<24x32xf32> ->
+ !xegpu.tensor_desc<8x16xf32>
+ // expected-error at +1 {{Result shape doesn't match TensorDesc shape.}}
+ xegpu.store_nd %data, %1 : vector<8x1xf32>, !xegpu.tensor_desc<8x16xf32>
+ return
+}
+
// -----
func.func @test_update_nd_offset_1(%dst: memref<16xf16>) {
%0 = arith.constant dense<[0, 2, 4, 6, 8, 10, 12, 14]> : vector<8xindex>
@@ -238,4 +284,81 @@ func.func @test_atomic_rmw(%src: ui64, %value : vector<16x4xf32>, %mask : vector
// expected-error at +1 {{failed to verify that all of {tensorDesc, value, result} have same shape}}
xegpu.atomic_rmw addf %1, %mask, %value: !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<chunk_size = 8>>, vector<16xi1>, vector<16x4xf32> -> vector<16x8xf32>
return
-}
\ No newline at end of file
+}
+
+// -----
+func.func @tensor_desc_invalid_rank(%src: memref<24x32xf32>) {
+ %0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
+ // expected-error at +1 {{desc shape rank exceeds 2}}
+ !xegpu.tensor_desc<16x2x2xf32>
+ return
+}
+
+// -----
+func.func @tensor_desc_1D_invalid_map_layout(%src: memref<24x32xf32>) {
+ %0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
+ // expected-error at +1 {{outer layout and data mapping must be 1 for 1D tensor}}
+ !xegpu.tensor_desc<16xf32, #xegpu.sg_map<wi_layout = [2, 16], wi_data = [1, 1]>>
+ return
+}
+
+// -----
+func.func @tensor_desc_1D_invalid_map_data(%src: memref<24x32xf32>) {
+ %0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
+ // expected-error at +1 {{outer layout and data mapping must be 1 for 1D tensor}}
+ !xegpu.tensor_desc<16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [2, 1]>>
+ return
+}
+
+// -----
+func.func @tensor_desc_invalid_map_layout(%src: memref<24x32xf32>) {
+ %0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
+ // expected-error at +1 {{cannot map 8 elements into 16 by 1 tiles}}
+ !xegpu.tensor_desc<4x8xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ return
+}
+
+// -----
+func.func @tensor_desc_invalid_map_layout_1(%src: memref<24x32xf32>) {
+ %0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
+ // expected-error at +1 {{cannot map 4 elements into 8 by 1 tiles}}
+ !xegpu.tensor_desc<4x8xf32, #xegpu.sg_map<wi_layout = [8, 2], wi_data = [1, 1]>>
+ return
+}
+
+// -----
+func.func @tensor_desc_invalid_map_data(%src: memref<24x32xf32>) {
+ %0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
+ // expected-error at +1 {{cannot map 4 elements into 2 by 4 tiles}}
+ !xegpu.tensor_desc<4x8xf32, #xegpu.sg_map<wi_layout = [2, 8], wi_data = [4, 1]>>
+ return
+}
+
+// -----
+func.func @tensor_desc_invalid_map_data_1(%src: memref<24x32xf32>) {
+ %0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
+ // expected-error at +1 {{cannot map 4 elements into 8 by 1 tiles}}
+ !xegpu.tensor_desc<4x8xf32, #xegpu.sg_map<wi_layout = [8, 2], wi_data = [1, 2]>>
+ return
+}
+
+// -----
+func.func @tensor_desc_scatter_invalid_map_data(%src: ui64) {
+ %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+ %1 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> ->
+ // expected-error at +1 {{cannot map over non-contiguous scattered elements}}
+ !xegpu.tensor_desc<4x2xf32,
+ #xegpu.scatter_tdesc_attr<chunk_size = 2>,
+ #xegpu.sg_map<wi_layout = [1, 1], wi_data = [2, 1]>>
+ return
+}
+
+// -----
+func.func @tensor_desc_scatter_invalid_map_data_1(%src: ui64, %offsets: vector<16xindex>) {
+ %1 = xegpu.create_tdesc %src, %offsets : ui64, vector<16xindex> ->
+ // expected-error at +1 {{too few contiguous elements for work item mapping}}
+ !xegpu.tensor_desc<16xf32,
+ #xegpu.scatter_tdesc_attr<chunk_size = 1>,
+ #xegpu.sg_map<wi_layout = [1, 8], wi_data = [1, 2]>>
+ return
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/124548
More information about the Mlir-commits
mailing list