[Mlir-commits] [mlir] [MLIR][XeGPU] Allow some nd ops to have argument shapes mismatch for … (PR #120566)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Dec 19 04:19:32 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-gpu

Author: Petr Kurapov (kurapov-peter)

<details>
<summary>Changes</summary>

…the distributed IR case.

This patch allows `nd_load` and `nd_store` to preserve the tensor descriptor shape during distribution to SIMT. The validation now expects the distributed instruction to retain the `sg_map` attribute and uses it to verify the consistency.

Some background:
We've been discussing the appropriate way of distributing xegpu operations and there are multiple approaches possible. I'm listing the three main ones commenting on their properties.
1. **Distribute both** the tensor descriptor type and the vector argument during the distribution transformation. This seems like a natural way of applying the transformation. The resulting IR describes what part of the computation and memory a lane "owns", and plays by the rules of the infrastructure. This is also similar to what happens to non-nd xegpu ops, so we get the benefit of consistent representation. The downside of the approach is that when we lower the resulting operation to an llvm intrinsic, the information about the original non-distributed tensor type must be restored (the instrinsic is cooperative and expects a complete size of a loaded/stored tile). This can technically be done (e.g., using `sg_map`), but somewhat breaks the promise of XeGPU to represent HW abstractions 1:1 (arguably, imo). Allowing distributed tensor descriptors also means users would be able to create and use them inappropriately, so some UB cases are inevitably introduced.
2. **Distribute only the vector type**, and preserve the original tensor descriptor type. This approach addresses the concerns above by adding some flexibility to nd ops validation (this patch). By preserving the descriptor any valid IR can produce some reasonable lowering. This approach comes at the price of a vague IR-to-HW-constructs mapping: now the IR no longer represents something a single logical thread in SIMT owns and acts upon, so the abstraction layering is still broken. This can potentially have some unexpected implications that we don't see today on other transformations and analyses (e.g., `nd_load` and `nd_store` are ops that should have memory side effects that are not implemented at the moment. It is unclear to me what implications this violation of "ownership" may have.).
3. **Perform the distribution during lowering to a lower-level dialect**. Instead of choosing between the two unappealing options we may resolve the layering properly by converting the tensor descriptor from XeGPU dialect to an appropriate lower-level construct that deals with scalars instead of vectors. This way we can keep all the promises of XeGPU intact and avoid the somewhat dubious type issues. The downside of course is that we are a putting what essentially is a transformation into a conversion which is also not ideal.

My personal opinion on this is that we should take the 1 approach and modify it in the way so that the creation of a tensor descriptor survives the distribution but instructions are using views into the descriptor so the IR describes "what a logical thread does" and there are no type mismatches (introduce some xegpu view into tensor descriptor OP, so that it would work for both ND and scattered case in the distribution logic). That said, I'm OK to put this patch in for experiments with the second approach as it doesn't break anything.

@<!-- -->Jianhui-Li, @<!-- -->charithaintc, @<!-- -->chencha3, @<!-- -->adam-smnk, @<!-- -->rengolin 

---
Full diff: https://github.com/llvm/llvm-project/pull/120566.diff


4 Files Affected:

- (modified) mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td (+1-2) 
- (modified) mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp (+53-20) 
- (modified) mlir/test/Dialect/XeGPU/XeGPUOps.mlir (+24) 
- (modified) mlir/test/Dialect/XeGPU/invalid.mlir (+6-6) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index 5910aa3f7f2dae..f3ffbd0f5a027d 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -327,8 +327,7 @@ def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [AllElementTypesMatch<["value", "Tensor
   let hasVerifier = 1;
 }
 
-def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", [AllShapesMatch<["value", "TensorDesc"]>,
-                                       AllElementTypesMatch<["value", "TensorDesc"]>]> {
+def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", [AllElementTypesMatch<["value", "TensorDesc"]>]> {
   let summary = "stores a n-D block register region back to memory, currently only supports 2D";
 
   let description = [{
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 9d3c4366a7bd50..721cba70520758 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -73,6 +73,29 @@ static bool isWriteHintOrNone(const CachePolicyAttr &attr) {
          kind == CachePolicy::WRITE_BACK || kind == CachePolicy::WRITE_THROUGH;
 }
 
+// Validations for nd instruction arguments is successful if any of these are
+// true:
+// - tensor descriptor and the output vector shapes exactly match.
+// - tensor descriptor has a sg_map attribute and the distributed vector shape
+//   matches the tensor descriptor shape when scaled using sg_map factors on
+//   each dimension.
+static bool isArgShapesValid(ArrayRef<int64_t> descShape,
+                             ArrayRef<int64_t> valShape, SGMapAttr sgMap) {
+  if (descShape == valShape)
+    return true;
+
+  if (!sgMap)
+    return false;
+
+  for (const auto &[factor, dim, expected] :
+       llvm::zip_equal(sgMap.getWiLayout(), valShape, descShape)) {
+    if (factor * dim != expected)
+      return false;
+  }
+
+  return true;
+}
+
 //===----------------------------------------------------------------------===//
 // XeGPU_CreateNdDescOp
 //===----------------------------------------------------------------------===//
@@ -210,13 +233,13 @@ LogicalResult PrefetchNdOp::verify() {
     return emitOpError("Expects a non-scattered TensorDesc.\n");
 
   if (!isReadHintOrNone(getL1HintAttr()))
-    return emitOpError("invlid l1_hint: ") << getL1HintAttr();
+    return emitOpError("invalid l1_hint: ") << getL1HintAttr();
 
   if (!isReadHintOrNone(getL2HintAttr()))
-    return emitOpError("invlid l2_hint: ") << getL2HintAttr();
+    return emitOpError("invalid l2_hint: ") << getL2HintAttr();
 
   if (!isReadHintOrNone(getL3HintAttr()))
-    return emitOpError("invlid l3_hint: ") << getL3HintAttr();
+    return emitOpError("invalid l3_hint: ") << getL3HintAttr();
 
   return success();
 }
@@ -238,13 +261,13 @@ LogicalResult LoadNdOp::verify() {
     return emitOpError("Invalid result, it should be a VectorType.\n");
 
   if (!isReadHintOrNone(getL1HintAttr()))
-    return emitOpError("invlid l1_hint: ") << getL1HintAttr();
+    return emitOpError("invalid l1_hint: ") << getL1HintAttr();
 
   if (!isReadHintOrNone(getL2HintAttr()))
-    return emitOpError("invlid l2_hint: ") << getL2HintAttr();
+    return emitOpError("invalid l2_hint: ") << getL2HintAttr();
 
   if (!isReadHintOrNone(getL3HintAttr()))
-    return emitOpError("invlid l3_hint: ") << getL3HintAttr();
+    return emitOpError("invalid l3_hint: ") << getL3HintAttr();
 
   auto array_len = tdescTy.getArrayLength();
   auto tdescShape = getShapeOf(tdescTy);
@@ -280,8 +303,9 @@ LogicalResult LoadNdOp::verify() {
     auto it = tdescShape.begin();
     tdescShape.insert(it, array_len);
   }
+  auto sgMap = tdescTy.getSGMapAttr();
 
-  if (tdescShape != valueShape)
+  if (!isArgShapesValid(tdescShape, valueShape, sgMap))
     return emitOpError() << "Result shape doesn't match TensorDesc shape."
                          << "The expected shape is " << makeString(tdescShape)
                          << ". But the given shape is "
@@ -303,17 +327,26 @@ LogicalResult StoreNdOp::verify() {
     return emitOpError("Expects a non-scattered TensorDesc.\n");
 
   if (!valTy)
-    return emitOpError("Exepcting a VectorType result.\n");
+    return emitOpError("Expecting a VectorType result.\n");
 
   if (!isWriteHintOrNone(getL1HintAttr()))
-    return emitOpError("invlid l1_hint: ") << getL1HintAttr();
+    return emitOpError("invalid l1_hint: ") << getL1HintAttr();
 
   if (!isWriteHintOrNone(getL2HintAttr()))
-    return emitOpError("invlid l2_hint: ") << getL2HintAttr();
+    return emitOpError("invalid l2_hint: ") << getL2HintAttr();
 
   if (!isWriteHintOrNone(getL3HintAttr()))
-    return emitOpError("invlid l3_hint: ") << getL3HintAttr();
+    return emitOpError("invalid l3_hint: ") << getL3HintAttr();
+
+  auto tdescShape = getShapeOf(dstTy);
+  auto valueShape = getShapeOf(valTy);
+  auto sgMap = dstTy.getSGMapAttr();
 
+  if (!isArgShapesValid(tdescShape, valueShape, sgMap))
+    return emitOpError() << "Result shape doesn't match TensorDesc shape."
+                         << "The expected shape is " << makeString(tdescShape)
+                         << ". But the given shape is "
+                         << makeString(valueShape) << ".\n";
   return success();
 }
 
@@ -423,13 +456,13 @@ LogicalResult PrefetchOp::verify() {
     return emitOpError("Expects a scattered TensorDesc.\n");
 
   if (!isReadHintOrNone(getL1HintAttr()))
-    return emitOpError("invlid l1_hint: ") << getL1HintAttr();
+    return emitOpError("invalid l1_hint: ") << getL1HintAttr();
 
   if (!isReadHintOrNone(getL2HintAttr()))
-    return emitOpError("invlid l2_hint: ") << getL2HintAttr();
+    return emitOpError("invalid l2_hint: ") << getL2HintAttr();
 
   if (!isReadHintOrNone(getL3HintAttr()))
-    return emitOpError("invlid l3_hint: ") << getL3HintAttr();
+    return emitOpError("invalid l3_hint: ") << getL3HintAttr();
 
   return success();
 }
@@ -446,13 +479,13 @@ LogicalResult LoadGatherOp::verify() {
     return emitOpError("Expects a scattered TensorDesc.\n");
 
   if (!isReadHintOrNone(getL1HintAttr()))
-    return emitOpError("invlid l1_hint: ") << getL1HintAttr();
+    return emitOpError("invalid l1_hint: ") << getL1HintAttr();
 
   if (!isReadHintOrNone(getL2HintAttr()))
-    return emitOpError("invlid l2_hint: ") << getL2HintAttr();
+    return emitOpError("invalid l2_hint: ") << getL2HintAttr();
 
   if (!isReadHintOrNone(getL3HintAttr()))
-    return emitOpError("invlid l3_hint: ") << getL3HintAttr();
+    return emitOpError("invalid l3_hint: ") << getL3HintAttr();
 
   auto tdescElemTy = tdescTy.getElementType();
   auto valueElemTy = getElementType();
@@ -490,13 +523,13 @@ LogicalResult StoreScatterOp::verify() {
     return emitOpError("Expects a scattered TensorDesc.\n");
 
   if (!isWriteHintOrNone(getL1HintAttr()))
-    return emitOpError("invlid l1_hint: ") << getL1HintAttr();
+    return emitOpError("invalid l1_hint: ") << getL1HintAttr();
 
   if (!isWriteHintOrNone(getL2HintAttr()))
-    return emitOpError("invlid l2_hint: ") << getL2HintAttr();
+    return emitOpError("invalid l2_hint: ") << getL2HintAttr();
 
   if (!isWriteHintOrNone(getL3HintAttr()))
-    return emitOpError("invlid l3_hint: ") << getL3HintAttr();
+    return emitOpError("invalid l3_hint: ") << getL3HintAttr();
 
   auto maskTy = getMaskType();
   auto valueTy = getValueType();
diff --git a/mlir/test/Dialect/XeGPU/XeGPUOps.mlir b/mlir/test/Dialect/XeGPU/XeGPUOps.mlir
index a4587faa3345cb..d7174a489888a4 100644
--- a/mlir/test/Dialect/XeGPU/XeGPUOps.mlir
+++ b/mlir/test/Dialect/XeGPU/XeGPUOps.mlir
@@ -86,6 +86,17 @@ gpu.func @test_load_nd_vc_2(%src: memref<8x16xf16>) {
   gpu.return
 }
 
+// load_nd args may have different shapes, validated against sg_map
+// CHECK: func @test_load_nd_vc_3(%[[arg0:.*]]: memref<24x32xf32>) {
+gpu.func @test_load_nd_vc_3(%src: memref<24x32xf32>) {
+  // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+  %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
+    !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+  // CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>> -> vector<8x1xf32>
+  %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>> -> vector<8x1xf32>
+  gpu.return
+}
+
 // CHECK: func @test_store_nd_vc(%[[arg0:.*]]: memref<24x32xf16>) {
 gpu.func @test_store_nd_vc(%dst: memref<24x32xf16>) {
   // CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<24x32xf16>
@@ -108,6 +119,19 @@ gpu.func @test_store_nd_vc_2(%dst: memref<24x32xf16>) {
   gpu.return
 }
 
+// store_nd args may have different shapes, validated against sg_map
+// CHECK: func @test_store_nd_vc_3(%[[arg0:.*]]: memref<24x32xf16>) {
+gpu.func @test_store_nd_vc_3(%src: memref<24x32xf16>) {
+   // CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<24x2xf16>
+  %1 = arith.constant dense<1.0>: vector<24x2xf16>
+  // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<24x32xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+  %2 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf16> ->
+    !xegpu.tensor_desc<24x32xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+  // CHECK: xegpu.store_nd %[[C]], %[[R0]] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}> : vector<24x2xf16>, !xegpu.tensor_desc<24x32xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+  xegpu.store_nd %1, %2 <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>: vector<24x2xf16>, !xegpu.tensor_desc<24x32xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+  gpu.return
+}
+
 // CHECK: gpu.func @test_create_update_nd_tdesc_vc(%[[arg0:.*]]: memref<24x32xf32>) {
 gpu.func @test_create_update_nd_tdesc_vc(%src: memref<24x32xf32>) {
   // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index f8a0d95bd70a27..155131ba9e6d50 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -32,7 +32,7 @@ func.func @test_create_nd_tdesc_vc_4(%src: memref<2x24x32xf32, 3>) {
 // -----
 func.func @test_prefetch_nd_vc_1(%src: memref<24x32xf16>) {
   %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<8x16xf16>
-  // expected-error at +1 {{invlid l1_hint: #xegpu.cache_hint<write_back>}}
+  // expected-error at +1 {{invalid l1_hint: #xegpu.cache_hint<write_back>}}
   xegpu.prefetch_nd %1 <{l1_hint = #xegpu.cache_hint<write_back>}>: !xegpu.tensor_desc<8x16xf16>
   return
 }
@@ -51,7 +51,7 @@ func.func @test_prefetch_nd_vc_2(%src: memref<24xf16>) {
 // -----
 func.func @test_load_nd_vc_1(%src: memref<8x16xf16>) {
   %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
-  // expected-error at +1 {{invlid l1_hint: #xegpu.cache_hint<write_back>}}
+  // expected-error at +1 {{invalid l1_hint: #xegpu.cache_hint<write_back>}}
   %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<write_back>}>
       : !xegpu.tensor_desc<8x16xf16> -> vector<4x16x2xf16>
   return
@@ -81,7 +81,7 @@ func.func @test_load_nd_vc_3(%src: memref<8x16xf16>) {
 func.func @test_store_nd_vc_1(%dst: memref<24x32xf16>) {
   %1 = arith.constant dense<1.0>: vector<24x32xf16>
   %2 = xegpu.create_nd_tdesc %dst[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<24x32xf16>
-  // expected-error at +1 {{invlid l1_hint: #xegpu.cache_hint<streaming>}}
+  // expected-error at +1 {{invalid l1_hint: #xegpu.cache_hint<streaming>}}
   xegpu.store_nd %1, %2 <{l1_hint = #xegpu.cache_hint<streaming>}>: vector<24x32xf16>, !xegpu.tensor_desc<24x32xf16>
   return
 }
@@ -147,7 +147,7 @@ func.func @test_prefetch_vc_2(%src: ui64) {
   %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
   %1 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex>
           -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
-  // expected-error at +1 {{invlid l1_hint: #xegpu.cache_hint<write_back>}}
+  // expected-error at +1 {{invalid l1_hint: #xegpu.cache_hint<write_back>}}
   xegpu.prefetch %1 <{l1_hint = #xegpu.cache_hint<write_back>}>: !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
   return
 }
@@ -168,7 +168,7 @@ func.func @test_load_gather_vc_2(%src: ui64) {
   %0 = arith.constant dense<1>: vector<4xi1>
   %1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex>
         -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
-  // expected-error at +1 {{invlid l1_hint: #xegpu.cache_hint<write_back>}}
+  // expected-error at +1 {{invalid l1_hint: #xegpu.cache_hint<write_back>}}
   %2 = xegpu.load %1, %0 <{l1_hint = #xegpu.cache_hint<write_back>}>
         : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>, vector<4xi1>
           -> vector<4x2xf32>
@@ -193,7 +193,7 @@ func.func @test_store_scatter_vc_2(%src: ui64) {
   %1 = arith.constant dense<2.9>: vector<4x2xf32>
   %2 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex>
               -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
-  // expected-error at +1 {{invlid l1_hint: #xegpu.cache_hint<streaming>}}
+  // expected-error at +1 {{invalid l1_hint: #xegpu.cache_hint<streaming>}}
   xegpu.store %1, %2, %0 <{l1_hint = #xegpu.cache_hint<streaming>}> : vector<4x2xf32>,
           !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>, vector<4xi1>
   return

``````````

</details>


https://github.com/llvm/llvm-project/pull/120566


More information about the Mlir-commits mailing list