[Mlir-commits] [mlir] [MLIR][XeGPU] Account for sg_map in LoadNdOp verification (PR #123928)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Jan 22 03:14:14 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Artem Kroviakov (akroviakov)

<details>
<summary>Changes</summary>

The current XeGPU has the `sg_map` attribute definition, but does not allow using it for loads due to verification failure (shape mismatch).

To allow both valid load operations with `sg_map` tdesc and to introduce certain rules for `sg_map` usage, this PR modifies the verifiers of `CreateNdOp` and `LoadNdOp` by including checks related to the `sg_map` attribute.


---
Full diff: https://github.com/llvm/llvm-project/pull/123928.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp (+23) 
- (modified) mlir/test/Dialect/XeGPU/XeGPUOps.mlir (+27) 


``````````diff
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 9d3c4366a7bd50..0c5a1ce0e96a38 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -198,6 +198,22 @@ LogicalResult CreateNdDescOp::verify() {
       tdescMemorySpace == static_cast<unsigned>(MemorySpace::SLM))
     return emitOpError("SLM is not supported for 2D Block TensorDesc.\n");
 
+  if (auto attr = getType().getSGMapAttr()) {
+    auto wiLayout = attr.getWiLayout();
+    auto wiData = attr.getWiData();
+    if (wiData[0] < 1 || wiData[1] < 1 || (wiData[0] > 1 && wiData[1] > 1))
+      return emitOpError() << "`wi_data` values must be >=1 and can only be >1 "
+                              "along one dimension."
+                           << "\n";
+    auto tdescShape = getType().getShape();
+    for (size_t i = 0; i < tdescShape.size(); i++) {
+      if (tdescShape[i] % wiLayout[i])
+        return emitOpError() << "Work-items must uniformly divide a tile "
+                                "(tdescShape[i] % wiLayout[i] == 0)"
+                             << "\n";
+    }
+  }
+
   return success();
 }
 
@@ -250,6 +266,13 @@ LogicalResult LoadNdOp::verify() {
   auto tdescShape = getShapeOf(tdescTy);
   auto valueShape = getShapeOf(valueTy);
 
+  if (auto attr = getTensorDescType().getSGMapAttr()) {
+    auto wiLayout = attr.getWiLayout();
+    for (size_t i = 0; i < tdescShape.size(); i++) {
+      tdescShape[i] /= wiLayout[i];
+    }
+  }
+
   if (getTranspose()) {
     auto trans = getTranspose().value();
 
diff --git a/mlir/test/Dialect/XeGPU/XeGPUOps.mlir b/mlir/test/Dialect/XeGPU/XeGPUOps.mlir
index a4587faa3345cb..0f92e9cb68db68 100644
--- a/mlir/test/Dialect/XeGPU/XeGPUOps.mlir
+++ b/mlir/test/Dialect/XeGPU/XeGPUOps.mlir
@@ -21,6 +21,33 @@ gpu.func @test_create_nd_tdesc_with_sg_map(%src: memref<24x32xf32>) {
   gpu.return
 }
 
+// CHECK: gpu.func @test_load_nd_tdesc_with_sg_map(%[[arg0:.*]]: memref<32x32xi8>) {
+gpu.func @test_load_nd_tdesc_with_sg_map(%src: memref<32x32xi8>) {
+  // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0] : memref<32x32xi8> -> !xegpu.tensor_desc<32x16xi8, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [4, 1]>>
+  %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<32x32xi8> -> !xegpu.tensor_desc<32x16xi8, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [4, 1]>>
+  // CHECK: %[[LOAD:.*]] = xegpu.load_nd %[[REG]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, packed}> : !xegpu.tensor_desc<32x16xi8, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [4, 1]>> -> vector<8x1x4xi8>
+  %2 = xegpu.load_nd %1 <{packed, l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>: !xegpu.tensor_desc<32x16xi8, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [4, 1]>> -> vector<8x1x4xi8>
+  gpu.return
+}
+
+// CHECK: gpu.func @test_load_nd_tdesc_with_sg_map_2(%[[arg0:.*]]: memref<24x32xf32>) {
+gpu.func @test_load_nd_tdesc_with_sg_map_2(%src: memref<24x32xf32>) {
+  // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+  %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+  // CHECK: %[[LOAD:.*]] = xegpu.load_nd %[[REG]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>> -> vector<8x1xf32>
+  %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>> -> vector<8x1xf32>
+  gpu.return
+}
+
+// CHECK: gpu.func @test_load_nd_tdesc_with_sg_map_3(%[[arg0:.*]]: memref<32x32xf32>) {
+gpu.func @test_load_nd_tdesc_with_sg_map_3(%src: memref<32x32xf32>) {
+  // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0] : memref<32x32xf32> -> !xegpu.tensor_desc<16x8xf32, #xegpu.sg_map<wi_layout = [16, 1], wi_data = [1, 1]>>
+  %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<32x32xf32> -> !xegpu.tensor_desc<16x8xf32, #xegpu.sg_map<wi_layout = [16, 1], wi_data = [1, 1]>>
+  // CHECK: %[[LOAD:.*]] = xegpu.load_nd %[[REG]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose = array<i64: 1, 0>}> : !xegpu.tensor_desc<16x8xf32, #xegpu.sg_map<wi_layout = [16, 1], wi_data = [1, 1]>> -> vector<8x1xf32>
+  %2 = xegpu.load_nd %1 <{transpose = array<i64: 1, 0>, l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<16x8xf32, #xegpu.sg_map<wi_layout = [16, 1], wi_data = [1, 1]>> -> vector<8x1xf32>
+  gpu.return
+}
+
 // CHECK: gpu.func @test_create_nd_tdesc_vc_2(%[[arg0:.*]]: ui64, %[[arg1:.*]]: index, %[[arg2:.*]]: index, %[[arg3:.*]]: index, %[[arg4:.*]]: index) {
 gpu.func @test_create_nd_tdesc_vc_2(%src: ui64, %w : index, %h : index, %x : index, %y : index) {
   //CHECK: %[[C:.*]] = arith.constant 1 : index

``````````

</details>


https://github.com/llvm/llvm-project/pull/123928


More information about the Mlir-commits mailing list