[Mlir-commits] [mlir] [MLIR][XeGPU] Allow some nd ops to have argument shapes mismatch for … (PR #120566)
Petr Kurapov
llvmlistbot at llvm.org
Tue Jan 14 04:09:18 PST 2025
https://github.com/kurapov-peter updated https://github.com/llvm/llvm-project/pull/120566
>From b1cef75c02ee0e62beb962eee9b6a6396c29f913 Mon Sep 17 00:00:00 2001
From: Petr Kurapov <petr.a.kurapov at intel.com>
Date: Thu, 19 Dec 2024 11:35:39 +0000
Subject: [PATCH 1/3] [MLIR][XeGPU] Allow some nd ops to have argument shapes
mismatch for the distributed IR case.
---
.../include/mlir/Dialect/XeGPU/IR/XeGPUOps.td | 3 +-
mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 73 ++++++++++++++-----
mlir/test/Dialect/XeGPU/XeGPUOps.mlir | 24 ++++++
mlir/test/Dialect/XeGPU/invalid.mlir | 12 +--
4 files changed, 84 insertions(+), 28 deletions(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index 5910aa3f7f2dae..f3ffbd0f5a027d 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -327,8 +327,7 @@ def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [AllElementTypesMatch<["value", "Tensor
let hasVerifier = 1;
}
-def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", [AllShapesMatch<["value", "TensorDesc"]>,
- AllElementTypesMatch<["value", "TensorDesc"]>]> {
+def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", [AllElementTypesMatch<["value", "TensorDesc"]>]> {
let summary = "stores a n-D block register region back to memory, currently only supports 2D";
let description = [{
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 9d3c4366a7bd50..721cba70520758 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -73,6 +73,29 @@ static bool isWriteHintOrNone(const CachePolicyAttr &attr) {
kind == CachePolicy::WRITE_BACK || kind == CachePolicy::WRITE_THROUGH;
}
+// Validations for nd instruction arguments is successful if any of these are
+// true:
+// - tensor descriptor and the output vector shapes exactly match.
+// - tensor descriptor has a sg_map attribute and the distributed vector shape
+// matches the tensor descriptor shape when scaled using sg_map factors on
+// each dimension.
+static bool isArgShapesValid(ArrayRef<int64_t> descShape,
+ ArrayRef<int64_t> valShape, SGMapAttr sgMap) {
+ if (descShape == valShape)
+ return true;
+
+ if (!sgMap)
+ return false;
+
+ for (const auto &[factor, dim, expected] :
+ llvm::zip_equal(sgMap.getWiLayout(), valShape, descShape)) {
+ if (factor * dim != expected)
+ return false;
+ }
+
+ return true;
+}
+
//===----------------------------------------------------------------------===//
// XeGPU_CreateNdDescOp
//===----------------------------------------------------------------------===//
@@ -210,13 +233,13 @@ LogicalResult PrefetchNdOp::verify() {
return emitOpError("Expects a non-scattered TensorDesc.\n");
if (!isReadHintOrNone(getL1HintAttr()))
- return emitOpError("invlid l1_hint: ") << getL1HintAttr();
+ return emitOpError("invalid l1_hint: ") << getL1HintAttr();
if (!isReadHintOrNone(getL2HintAttr()))
- return emitOpError("invlid l2_hint: ") << getL2HintAttr();
+ return emitOpError("invalid l2_hint: ") << getL2HintAttr();
if (!isReadHintOrNone(getL3HintAttr()))
- return emitOpError("invlid l3_hint: ") << getL3HintAttr();
+ return emitOpError("invalid l3_hint: ") << getL3HintAttr();
return success();
}
@@ -238,13 +261,13 @@ LogicalResult LoadNdOp::verify() {
return emitOpError("Invalid result, it should be a VectorType.\n");
if (!isReadHintOrNone(getL1HintAttr()))
- return emitOpError("invlid l1_hint: ") << getL1HintAttr();
+ return emitOpError("invalid l1_hint: ") << getL1HintAttr();
if (!isReadHintOrNone(getL2HintAttr()))
- return emitOpError("invlid l2_hint: ") << getL2HintAttr();
+ return emitOpError("invalid l2_hint: ") << getL2HintAttr();
if (!isReadHintOrNone(getL3HintAttr()))
- return emitOpError("invlid l3_hint: ") << getL3HintAttr();
+ return emitOpError("invalid l3_hint: ") << getL3HintAttr();
auto array_len = tdescTy.getArrayLength();
auto tdescShape = getShapeOf(tdescTy);
@@ -280,8 +303,9 @@ LogicalResult LoadNdOp::verify() {
auto it = tdescShape.begin();
tdescShape.insert(it, array_len);
}
+ auto sgMap = tdescTy.getSGMapAttr();
- if (tdescShape != valueShape)
+ if (!isArgShapesValid(tdescShape, valueShape, sgMap))
return emitOpError() << "Result shape doesn't match TensorDesc shape."
<< "The expected shape is " << makeString(tdescShape)
<< ". But the given shape is "
@@ -303,17 +327,26 @@ LogicalResult StoreNdOp::verify() {
return emitOpError("Expects a non-scattered TensorDesc.\n");
if (!valTy)
- return emitOpError("Exepcting a VectorType result.\n");
+ return emitOpError("Expecting a VectorType result.\n");
if (!isWriteHintOrNone(getL1HintAttr()))
- return emitOpError("invlid l1_hint: ") << getL1HintAttr();
+ return emitOpError("invalid l1_hint: ") << getL1HintAttr();
if (!isWriteHintOrNone(getL2HintAttr()))
- return emitOpError("invlid l2_hint: ") << getL2HintAttr();
+ return emitOpError("invalid l2_hint: ") << getL2HintAttr();
if (!isWriteHintOrNone(getL3HintAttr()))
- return emitOpError("invlid l3_hint: ") << getL3HintAttr();
+ return emitOpError("invalid l3_hint: ") << getL3HintAttr();
+
+ auto tdescShape = getShapeOf(dstTy);
+ auto valueShape = getShapeOf(valTy);
+ auto sgMap = dstTy.getSGMapAttr();
+ if (!isArgShapesValid(tdescShape, valueShape, sgMap))
+ return emitOpError() << "Result shape doesn't match TensorDesc shape."
+ << "The expected shape is " << makeString(tdescShape)
+ << ". But the given shape is "
+ << makeString(valueShape) << ".\n";
return success();
}
@@ -423,13 +456,13 @@ LogicalResult PrefetchOp::verify() {
return emitOpError("Expects a scattered TensorDesc.\n");
if (!isReadHintOrNone(getL1HintAttr()))
- return emitOpError("invlid l1_hint: ") << getL1HintAttr();
+ return emitOpError("invalid l1_hint: ") << getL1HintAttr();
if (!isReadHintOrNone(getL2HintAttr()))
- return emitOpError("invlid l2_hint: ") << getL2HintAttr();
+ return emitOpError("invalid l2_hint: ") << getL2HintAttr();
if (!isReadHintOrNone(getL3HintAttr()))
- return emitOpError("invlid l3_hint: ") << getL3HintAttr();
+ return emitOpError("invalid l3_hint: ") << getL3HintAttr();
return success();
}
@@ -446,13 +479,13 @@ LogicalResult LoadGatherOp::verify() {
return emitOpError("Expects a scattered TensorDesc.\n");
if (!isReadHintOrNone(getL1HintAttr()))
- return emitOpError("invlid l1_hint: ") << getL1HintAttr();
+ return emitOpError("invalid l1_hint: ") << getL1HintAttr();
if (!isReadHintOrNone(getL2HintAttr()))
- return emitOpError("invlid l2_hint: ") << getL2HintAttr();
+ return emitOpError("invalid l2_hint: ") << getL2HintAttr();
if (!isReadHintOrNone(getL3HintAttr()))
- return emitOpError("invlid l3_hint: ") << getL3HintAttr();
+ return emitOpError("invalid l3_hint: ") << getL3HintAttr();
auto tdescElemTy = tdescTy.getElementType();
auto valueElemTy = getElementType();
@@ -490,13 +523,13 @@ LogicalResult StoreScatterOp::verify() {
return emitOpError("Expects a scattered TensorDesc.\n");
if (!isWriteHintOrNone(getL1HintAttr()))
- return emitOpError("invlid l1_hint: ") << getL1HintAttr();
+ return emitOpError("invalid l1_hint: ") << getL1HintAttr();
if (!isWriteHintOrNone(getL2HintAttr()))
- return emitOpError("invlid l2_hint: ") << getL2HintAttr();
+ return emitOpError("invalid l2_hint: ") << getL2HintAttr();
if (!isWriteHintOrNone(getL3HintAttr()))
- return emitOpError("invlid l3_hint: ") << getL3HintAttr();
+ return emitOpError("invalid l3_hint: ") << getL3HintAttr();
auto maskTy = getMaskType();
auto valueTy = getValueType();
diff --git a/mlir/test/Dialect/XeGPU/XeGPUOps.mlir b/mlir/test/Dialect/XeGPU/XeGPUOps.mlir
index a4587faa3345cb..d7174a489888a4 100644
--- a/mlir/test/Dialect/XeGPU/XeGPUOps.mlir
+++ b/mlir/test/Dialect/XeGPU/XeGPUOps.mlir
@@ -86,6 +86,17 @@ gpu.func @test_load_nd_vc_2(%src: memref<8x16xf16>) {
gpu.return
}
+// load_nd args may have different shapes, validated against sg_map
+// CHECK: func @test_load_nd_vc_3(%[[arg0:.*]]: memref<24x32xf32>) {
+gpu.func @test_load_nd_vc_3(%src: memref<24x32xf32>) {
+ // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
+ !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ // CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>> -> vector<8x1xf32>
+ %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>> -> vector<8x1xf32>
+ gpu.return
+}
+
// CHECK: func @test_store_nd_vc(%[[arg0:.*]]: memref<24x32xf16>) {
gpu.func @test_store_nd_vc(%dst: memref<24x32xf16>) {
// CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<24x32xf16>
@@ -108,6 +119,19 @@ gpu.func @test_store_nd_vc_2(%dst: memref<24x32xf16>) {
gpu.return
}
+// store_nd args may have different shapes, validated against sg_map
+// CHECK: func @test_store_nd_vc_3(%[[arg0:.*]]: memref<24x32xf16>) {
+gpu.func @test_store_nd_vc_3(%src: memref<24x32xf16>) {
+ // CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<24x2xf16>
+ %1 = arith.constant dense<1.0>: vector<24x2xf16>
+ // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<24x32xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ %2 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf16> ->
+ !xegpu.tensor_desc<24x32xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ // CHECK: xegpu.store_nd %[[C]], %[[R0]] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}> : vector<24x2xf16>, !xegpu.tensor_desc<24x32xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ xegpu.store_nd %1, %2 <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>: vector<24x2xf16>, !xegpu.tensor_desc<24x32xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ gpu.return
+}
+
// CHECK: gpu.func @test_create_update_nd_tdesc_vc(%[[arg0:.*]]: memref<24x32xf32>) {
gpu.func @test_create_update_nd_tdesc_vc(%src: memref<24x32xf32>) {
// CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index f8a0d95bd70a27..155131ba9e6d50 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -32,7 +32,7 @@ func.func @test_create_nd_tdesc_vc_4(%src: memref<2x24x32xf32, 3>) {
// -----
func.func @test_prefetch_nd_vc_1(%src: memref<24x32xf16>) {
%1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<8x16xf16>
- // expected-error at +1 {{invlid l1_hint: #xegpu.cache_hint<write_back>}}
+ // expected-error at +1 {{invalid l1_hint: #xegpu.cache_hint<write_back>}}
xegpu.prefetch_nd %1 <{l1_hint = #xegpu.cache_hint<write_back>}>: !xegpu.tensor_desc<8x16xf16>
return
}
@@ -51,7 +51,7 @@ func.func @test_prefetch_nd_vc_2(%src: memref<24xf16>) {
// -----
func.func @test_load_nd_vc_1(%src: memref<8x16xf16>) {
%1 = xegpu.create_nd_tdesc %src[0, 0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
- // expected-error at +1 {{invlid l1_hint: #xegpu.cache_hint<write_back>}}
+ // expected-error at +1 {{invalid l1_hint: #xegpu.cache_hint<write_back>}}
%2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<write_back>}>
: !xegpu.tensor_desc<8x16xf16> -> vector<4x16x2xf16>
return
@@ -81,7 +81,7 @@ func.func @test_load_nd_vc_3(%src: memref<8x16xf16>) {
func.func @test_store_nd_vc_1(%dst: memref<24x32xf16>) {
%1 = arith.constant dense<1.0>: vector<24x32xf16>
%2 = xegpu.create_nd_tdesc %dst[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<24x32xf16>
- // expected-error at +1 {{invlid l1_hint: #xegpu.cache_hint<streaming>}}
+ // expected-error at +1 {{invalid l1_hint: #xegpu.cache_hint<streaming>}}
xegpu.store_nd %1, %2 <{l1_hint = #xegpu.cache_hint<streaming>}>: vector<24x32xf16>, !xegpu.tensor_desc<24x32xf16>
return
}
@@ -147,7 +147,7 @@ func.func @test_prefetch_vc_2(%src: ui64) {
%0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
%1 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex>
-> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
- // expected-error at +1 {{invlid l1_hint: #xegpu.cache_hint<write_back>}}
+ // expected-error at +1 {{invalid l1_hint: #xegpu.cache_hint<write_back>}}
xegpu.prefetch %1 <{l1_hint = #xegpu.cache_hint<write_back>}>: !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
return
}
@@ -168,7 +168,7 @@ func.func @test_load_gather_vc_2(%src: ui64) {
%0 = arith.constant dense<1>: vector<4xi1>
%1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex>
-> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
- // expected-error at +1 {{invlid l1_hint: #xegpu.cache_hint<write_back>}}
+ // expected-error at +1 {{invalid l1_hint: #xegpu.cache_hint<write_back>}}
%2 = xegpu.load %1, %0 <{l1_hint = #xegpu.cache_hint<write_back>}>
: !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>, vector<4xi1>
-> vector<4x2xf32>
@@ -193,7 +193,7 @@ func.func @test_store_scatter_vc_2(%src: ui64) {
%1 = arith.constant dense<2.9>: vector<4x2xf32>
%2 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex>
-> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
- // expected-error at +1 {{invlid l1_hint: #xegpu.cache_hint<streaming>}}
+ // expected-error at +1 {{invalid l1_hint: #xegpu.cache_hint<streaming>}}
xegpu.store %1, %2, %0 <{l1_hint = #xegpu.cache_hint<streaming>}> : vector<4x2xf32>,
!xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>, vector<4xi1>
return
>From 103db33de46b3fdf2957043cd629751ce397ccc9 Mon Sep 17 00:00:00 2001
From: Petr Kurapov <petr.a.kurapov at intel.com>
Date: Tue, 14 Jan 2025 10:13:55 +0000
Subject: [PATCH 2/3] Validate tensor descriptor rank against a value rank
---
mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 3 +++
1 file changed, 3 insertions(+)
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 721cba70520758..d34f0c8bf9e2c5 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -87,6 +87,9 @@ static bool isArgShapesValid(ArrayRef<int64_t> descShape,
if (!sgMap)
return false;
+ if (valShape.size() != descShape.size())
+ return false;
+
for (const auto &[factor, dim, expected] :
llvm::zip_equal(sgMap.getWiLayout(), valShape, descShape)) {
if (factor * dim != expected)
>From dfca5d62ba6dcd97d8ad540db141fd4446b19570 Mon Sep 17 00:00:00 2001
From: Petr Kurapov <petr.a.kurapov at intel.com>
Date: Tue, 14 Jan 2025 12:09:05 +0000
Subject: [PATCH 3/3] Add invalid IR test for distribution type mismatch
---
mlir/test/Dialect/XeGPU/invalid.mlir | 9 +++++++++
1 file changed, 9 insertions(+)
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index 155131ba9e6d50..b9255dc49d835d 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -77,6 +77,15 @@ func.func @test_load_nd_vc_3(%src: memref<8x16xf16>) {
return
}
+// -----
+func.func @test_load_nd_vc_4(%src: memref<24x32xf32>) {
+ %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
+ !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ // expected-error at +1 {{Result shape doesn't match TensorDesc shape.}}
+ %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>> -> vector<8x2xf32>
+ return
+}
+
// -----
func.func @test_store_nd_vc_1(%dst: memref<24x32xf16>) {
%1 = arith.constant dense<1.0>: vector<24x32xf16>
More information about the Mlir-commits
mailing list