[Mlir-commits] [mlir] [MLIR][XeGPU] Allow some nd ops to have argument shapes mismatch for … (PR #120566)

Petr Kurapov llvmlistbot at llvm.org
Tue Jan 14 04:09:18 PST 2025


https://github.com/kurapov-peter updated https://github.com/llvm/llvm-project/pull/120566

>From b1cef75c02ee0e62beb962eee9b6a6396c29f913 Mon Sep 17 00:00:00 2001
From: Petr Kurapov <petr.a.kurapov at intel.com>
Date: Thu, 19 Dec 2024 11:35:39 +0000
Subject: [PATCH 1/3] [MLIR][XeGPU] Allow some nd ops to have argument shapes
 mismatch for the distributed IR case.

---
 .../include/mlir/Dialect/XeGPU/IR/XeGPUOps.td |  3 +-
 mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp        | 73 ++++++++++++++-----
 mlir/test/Dialect/XeGPU/XeGPUOps.mlir         | 24 ++++++
 mlir/test/Dialect/XeGPU/invalid.mlir          | 12 +--
 4 files changed, 84 insertions(+), 28 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index 5910aa3f7f2dae..f3ffbd0f5a027d 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -327,8 +327,7 @@ def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [AllElementTypesMatch<["value", "Tensor
   let hasVerifier = 1;
 }
 
-def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", [AllShapesMatch<["value", "TensorDesc"]>,
-                                       AllElementTypesMatch<["value", "TensorDesc"]>]> {
+def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", [AllElementTypesMatch<["value", "TensorDesc"]>]> {
   let summary = "stores a n-D block register region back to memory, currently only supports 2D";
 
   let description = [{
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 9d3c4366a7bd50..721cba70520758 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -73,6 +73,29 @@ static bool isWriteHintOrNone(const CachePolicyAttr &attr) {
          kind == CachePolicy::WRITE_BACK || kind == CachePolicy::WRITE_THROUGH;
 }
 
+// Validations for nd instruction arguments is successful if any of these are
+// true:
+// - tensor descriptor and the output vector shapes exactly match.
+// - tensor descriptor has a sg_map attribute and the distributed vector shape
+//   matches the tensor descriptor shape when scaled using sg_map factors on
+//   each dimension.
+static bool isArgShapesValid(ArrayRef<int64_t> descShape,
+                             ArrayRef<int64_t> valShape, SGMapAttr sgMap) {
+  if (descShape == valShape)
+    return true;
+
+  if (!sgMap)
+    return false;
+
+  for (const auto &[factor, dim, expected] :
+       llvm::zip_equal(sgMap.getWiLayout(), valShape, descShape)) {
+    if (factor * dim != expected)
+      return false;
+  }
+
+  return true;
+}
+
 //===----------------------------------------------------------------------===//
 // XeGPU_CreateNdDescOp
 //===----------------------------------------------------------------------===//
@@ -210,13 +233,13 @@ LogicalResult PrefetchNdOp::verify() {
     return emitOpError("Expects a non-scattered TensorDesc.\n");
 
   if (!isReadHintOrNone(getL1HintAttr()))
-    return emitOpError("invlid l1_hint: ") << getL1HintAttr();
+    return emitOpError("invalid l1_hint: ") << getL1HintAttr();
 
   if (!isReadHintOrNone(getL2HintAttr()))
-    return emitOpError("invlid l2_hint: ") << getL2HintAttr();
+    return emitOpError("invalid l2_hint: ") << getL2HintAttr();
 
   if (!isReadHintOrNone(getL3HintAttr()))
-    return emitOpError("invlid l3_hint: ") << getL3HintAttr();
+    return emitOpError("invalid l3_hint: ") << getL3HintAttr();
 
   return success();
 }
@@ -238,13 +261,13 @@ LogicalResult LoadNdOp::verify() {
     return emitOpError("Invalid result, it should be a VectorType.\n");
 
   if (!isReadHintOrNone(getL1HintAttr()))
-    return emitOpError("invlid l1_hint: ") << getL1HintAttr();
+    return emitOpError("invalid l1_hint: ") << getL1HintAttr();
 
   if (!isReadHintOrNone(getL2HintAttr()))
-    return emitOpError("invlid l2_hint: ") << getL2HintAttr();
+    return emitOpError("invalid l2_hint: ") << getL2HintAttr();
 
   if (!isReadHintOrNone(getL3HintAttr()))
-    return emitOpError("invlid l3_hint: ") << getL3HintAttr();
+    return emitOpError("invalid l3_hint: ") << getL3HintAttr();
 
   auto array_len = tdescTy.getArrayLength();
   auto tdescShape = getShapeOf(tdescTy);
@@ -280,8 +303,9 @@ LogicalResult LoadNdOp::verify() {
     auto it = tdescShape.begin();
     tdescShape.insert(it, array_len);
   }
+  auto sgMap = tdescTy.getSGMapAttr();
 
-  if (tdescShape != valueShape)
+  if (!isArgShapesValid(tdescShape, valueShape, sgMap))
     return emitOpError() << "Result shape doesn't match TensorDesc shape."
                          << "The expected shape is " << makeString(tdescShape)
                          << ". But the given shape is "
@@ -303,17 +327,26 @@ LogicalResult StoreNdOp::verify() {
     return emitOpError("Expects a non-scattered TensorDesc.\n");
 
   if (!valTy)
-    return emitOpError("Exepcting a VectorType result.\n");
+    return emitOpError("Expecting a VectorType result.\n");
 
   if (!isWriteHintOrNone(getL1HintAttr()))
-    return emitOpError("invlid l1_hint: ") << getL1HintAttr();
+    return emitOpError("invalid l1_hint: ") << getL1HintAttr();
 
   if (!isWriteHintOrNone(getL2HintAttr()))
-    return emitOpError("invlid l2_hint: ") << getL2HintAttr();
+    return emitOpError("invalid l2_hint: ") << getL2HintAttr();
 
   if (!isWriteHintOrNone(getL3HintAttr()))
-    return emitOpError("invlid l3_hint: ") << getL3HintAttr();
+    return emitOpError("invalid l3_hint: ") << getL3HintAttr();
+
+  auto tdescShape = getShapeOf(dstTy);
+  auto valueShape = getShapeOf(valTy);
+  auto sgMap = dstTy.getSGMapAttr();
 
+  if (!isArgShapesValid(tdescShape, valueShape, sgMap))
+    return emitOpError() << "Result shape doesn't match TensorDesc shape."
+                         << "The expected shape is " << makeString(tdescShape)
+                         << ". But the given shape is "
+                         << makeString(valueShape) << ".\n";
   return success();
 }
 
@@ -423,13 +456,13 @@ LogicalResult PrefetchOp::verify() {
     return emitOpError("Expects a scattered TensorDesc.\n");
 
   if (!isReadHintOrNone(getL1HintAttr()))
-    return emitOpError("invlid l1_hint: ") << getL1HintAttr();
+    return emitOpError("invalid l1_hint: ") << getL1HintAttr();
 
   if (!isReadHintOrNone(getL2HintAttr()))
-    return emitOpError("invlid l2_hint: ") << getL2HintAttr();
+    return emitOpError("invalid l2_hint: ") << getL2HintAttr();
 
   if (!isReadHintOrNone(getL3HintAttr()))
-    return emitOpError("invlid l3_hint: ") << getL3HintAttr();
+    return emitOpError("invalid l3_hint: ") << getL3HintAttr();
 
   return success();
 }
@@ -446,13 +479,13 @@ LogicalResult LoadGatherOp::verify() {
     return emitOpError("Expects a scattered TensorDesc.\n");
 
   if (!isReadHintOrNone(getL1HintAttr()))
-    return emitOpError("invlid l1_hint: ") << getL1HintAttr();
+    return emitOpError("invalid l1_hint: ") << getL1HintAttr();
 
   if (!isReadHintOrNone(getL2HintAttr()))
-    return emitOpError("invlid l2_hint: ") << getL2HintAttr();
+    return emitOpError("invalid l2_hint: ") << getL2HintAttr();
 
   if (!isReadHintOrNone(getL3HintAttr()))
-    return emitOpError("invlid l3_hint: ") << getL3HintAttr();
+    return emitOpError("invalid l3_hint: ") << getL3HintAttr();
 
   auto tdescElemTy = tdescTy.getElementType();
   auto valueElemTy = getElementType();
@@ -490,13 +523,13 @@ LogicalResult StoreScatterOp::verify() {
     return emitOpError("Expects a scattered TensorDesc.\n");
 
   if (!isWriteHintOrNone(getL1HintAttr()))
-    return emitOpError("invlid l1_hint: ") << getL1HintAttr();
+    return emitOpError("invalid l1_hint: ") << getL1HintAttr();
 
   if (!isWriteHintOrNone(getL2HintAttr()))
-    return emitOpError("invlid l2_hint: ") << getL2HintAttr();
+    return emitOpError("invalid l2_hint: ") << getL2HintAttr();
 
   if (!isWriteHintOrNone(getL3HintAttr()))
-    return emitOpError("invlid l3_hint: ") << getL3HintAttr();
+    return emitOpError("invalid l3_hint: ") << getL3HintAttr();
 
   auto maskTy = getMaskType();
   auto valueTy = getValueType();
diff --git a/mlir/test/Dialect/XeGPU/XeGPUOps.mlir b/mlir/test/Dialect/XeGPU/XeGPUOps.mlir
index a4587faa3345cb..d7174a489888a4 100644
--- a/mlir/test/Dialect/XeGPU/XeGPUOps.mlir
+++ b/mlir/test/Dialect/XeGPU/XeGPUOps.mlir
@@ -86,6 +86,17 @@ gpu.func @test_load_nd_vc_2(%src: memref<8x16xf16>) {
   gpu.return
 }
 
+// load_nd args may have different shapes, validated against sg_map
+// CHECK: func @test_load_nd_vc_3(%[[arg0:.*]]: memref<24x32xf32>) {
+gpu.func @test_load_nd_vc_3(%src: memref<24x32xf32>) {
+  // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+  %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
+    !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+  // CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>> -> vector<8x1xf32>
+  %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>> -> vector<8x1xf32>
+  gpu.return
+}
+
 // CHECK: func @test_store_nd_vc(%[[arg0:.*]]: memref<24x32xf16>) {
 gpu.func @test_store_nd_vc(%dst: memref<24x32xf16>) {
   // CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<24x32xf16>
@@ -108,6 +119,19 @@ gpu.func @test_store_nd_vc_2(%dst: memref<24x32xf16>) {
   gpu.return
 }
 
+// store_nd args may have different shapes, validated against sg_map
+// CHECK: func @test_store_nd_vc_3(%[[arg0:.*]]: memref<24x32xf16>) {
+gpu.func @test_store_nd_vc_3(%src: memref<24x32xf16>) {
+   // CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<24x2xf16>
+  %1 = arith.constant dense<1.0>: vector<24x2xf16>
+  // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<24x32xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+  %2 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf16> ->
+    !xegpu.tensor_desc<24x32xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+  // CHECK: xegpu.store_nd %[[C]], %[[R0]] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}> : vector<24x2xf16>, !xegpu.tensor_desc<24x32xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+  xegpu.store_nd %1, %2 <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>: vector<24x2xf16>, !xegpu.tensor_desc<24x32xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+  gpu.return
+}
+
 // CHECK: gpu.func @test_create_update_nd_tdesc_vc(%[[arg0:.*]]: memref<24x32xf32>) {
 gpu.func @test_create_update_nd_tdesc_vc(%src: memref<24x32xf32>) {
   // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index f8a0d95bd70a27..155131ba9e6d50 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -32,7 +32,7 @@ func.func @test_create_nd_tdesc_vc_4(%src: memref<2x24x32xf32, 3>) {
 // -----
 func.func @test_prefetch_nd_vc_1(%src: memref<24x32xf16>) {
   %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<8x16xf16>
-  // expected-error at +1 {{invlid l1_hint: #xegpu.cache_hint<write_back>}}
+  // expected-error at +1 {{invalid l1_hint: #xegpu.cache_hint<write_back>}}
   xegpu.prefetch_nd %1 <{l1_hint = #xegpu.cache_hint<write_back>}>: !xegpu.tensor_desc<8x16xf16>
   return
 }
@@ -51,7 +51,7 @@ func.func @test_prefetch_nd_vc_2(%src: memref<24xf16>) {
 // -----
 func.func @test_load_nd_vc_1(%src: memref<8x16xf16>) {
   %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
-  // expected-error at +1 {{invlid l1_hint: #xegpu.cache_hint<write_back>}}
+  // expected-error at +1 {{invalid l1_hint: #xegpu.cache_hint<write_back>}}
   %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<write_back>}>
       : !xegpu.tensor_desc<8x16xf16> -> vector<4x16x2xf16>
   return
@@ -81,7 +81,7 @@ func.func @test_load_nd_vc_3(%src: memref<8x16xf16>) {
 func.func @test_store_nd_vc_1(%dst: memref<24x32xf16>) {
   %1 = arith.constant dense<1.0>: vector<24x32xf16>
   %2 = xegpu.create_nd_tdesc %dst[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<24x32xf16>
-  // expected-error at +1 {{invlid l1_hint: #xegpu.cache_hint<streaming>}}
+  // expected-error at +1 {{invalid l1_hint: #xegpu.cache_hint<streaming>}}
   xegpu.store_nd %1, %2 <{l1_hint = #xegpu.cache_hint<streaming>}>: vector<24x32xf16>, !xegpu.tensor_desc<24x32xf16>
   return
 }
@@ -147,7 +147,7 @@ func.func @test_prefetch_vc_2(%src: ui64) {
   %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
   %1 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex>
           -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
-  // expected-error at +1 {{invlid l1_hint: #xegpu.cache_hint<write_back>}}
+  // expected-error at +1 {{invalid l1_hint: #xegpu.cache_hint<write_back>}}
   xegpu.prefetch %1 <{l1_hint = #xegpu.cache_hint<write_back>}>: !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
   return
 }
@@ -168,7 +168,7 @@ func.func @test_load_gather_vc_2(%src: ui64) {
   %0 = arith.constant dense<1>: vector<4xi1>
   %1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex>
         -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
-  // expected-error at +1 {{invlid l1_hint: #xegpu.cache_hint<write_back>}}
+  // expected-error at +1 {{invalid l1_hint: #xegpu.cache_hint<write_back>}}
   %2 = xegpu.load %1, %0 <{l1_hint = #xegpu.cache_hint<write_back>}>
         : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>, vector<4xi1>
           -> vector<4x2xf32>
@@ -193,7 +193,7 @@ func.func @test_store_scatter_vc_2(%src: ui64) {
   %1 = arith.constant dense<2.9>: vector<4x2xf32>
   %2 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex>
               -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
-  // expected-error at +1 {{invlid l1_hint: #xegpu.cache_hint<streaming>}}
+  // expected-error at +1 {{invalid l1_hint: #xegpu.cache_hint<streaming>}}
   xegpu.store %1, %2, %0 <{l1_hint = #xegpu.cache_hint<streaming>}> : vector<4x2xf32>,
           !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>, vector<4xi1>
   return

>From 103db33de46b3fdf2957043cd629751ce397ccc9 Mon Sep 17 00:00:00 2001
From: Petr Kurapov <petr.a.kurapov at intel.com>
Date: Tue, 14 Jan 2025 10:13:55 +0000
Subject: [PATCH 2/3] Validate tensor descriptor rank against a value rank

---
 mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 3 +++
 1 file changed, 3 insertions(+)

diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 721cba70520758..d34f0c8bf9e2c5 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -87,6 +87,9 @@ static bool isArgShapesValid(ArrayRef<int64_t> descShape,
   if (!sgMap)
     return false;
 
+  if (valShape.size() != descShape.size())
+    return false;
+
   for (const auto &[factor, dim, expected] :
        llvm::zip_equal(sgMap.getWiLayout(), valShape, descShape)) {
     if (factor * dim != expected)

>From dfca5d62ba6dcd97d8ad540db141fd4446b19570 Mon Sep 17 00:00:00 2001
From: Petr Kurapov <petr.a.kurapov at intel.com>
Date: Tue, 14 Jan 2025 12:09:05 +0000
Subject: [PATCH 3/3] Add invalid IR test for distribution type mismatch

---
 mlir/test/Dialect/XeGPU/invalid.mlir | 9 +++++++++
 1 file changed, 9 insertions(+)

diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index 155131ba9e6d50..b9255dc49d835d 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -77,6 +77,15 @@ func.func @test_load_nd_vc_3(%src: memref<8x16xf16>) {
   return
 }
 
+// -----
+func.func @test_load_nd_vc_4(%src: memref<24x32xf32>) {
+  %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
+    !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+  // expected-error at +1 {{Result shape doesn't match TensorDesc shape.}}
+  %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>> -> vector<8x2xf32>
+  return
+}
+
 // -----
 func.func @test_store_nd_vc_1(%dst: memref<24x32xf16>) {
   %1 = arith.constant dense<1.0>: vector<24x32xf16>



More information about the Mlir-commits mailing list