[Mlir-commits] [mlir] fa6f88a - [MLIR][XeGPU] Allow some nd ops to have argument shapes mismatch for … (#120566)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Jan 22 09:03:40 PST 2025
Author: Petr Kurapov
Date: 2025-01-22T18:03:36+01:00
New Revision: fa6f88af102cb79a0371725b487e929cb0bcfcb2
URL: https://github.com/llvm/llvm-project/commit/fa6f88af102cb79a0371725b487e929cb0bcfcb2
DIFF: https://github.com/llvm/llvm-project/commit/fa6f88af102cb79a0371725b487e929cb0bcfcb2.diff
LOG: [MLIR][XeGPU] Allow some nd ops to have argument shapes mismatch for … (#120566)
…the distributed IR case.
This patch allows `nd_load` and `nd_store` to preserve the tensor
descriptor shape during distribution to SIMT. The validation now expects
the distributed instruction to retain the `sg_map` attribute and uses it
to verify the consistency.
Added:
Modified:
mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
mlir/test/Dialect/XeGPU/XeGPUOps.mlir
mlir/test/Dialect/XeGPU/invalid.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index f5cf3dad75d9c2..a2bfa721f2515b 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -327,8 +327,7 @@ def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [AllElementTypesMatch<["value", "Tensor
let hasVerifier = 1;
}
-def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", [AllShapesMatch<["value", "TensorDesc"]>,
- AllElementTypesMatch<["value", "TensorDesc"]>]> {
+def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", [AllElementTypesMatch<["value", "TensorDesc"]>]> {
let summary = "stores a n-D block register region back to memory, currently only supports 2D";
let description = [{
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 9d3c4366a7bd50..15c435f1fa257b 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -73,6 +73,39 @@ static bool isWriteHintOrNone(const CachePolicyAttr &attr) {
kind == CachePolicy::WRITE_BACK || kind == CachePolicy::WRITE_THROUGH;
}
+// Validations for nd instruction arguments is successful if any of these are
+// true:
+// - tensor descriptor and the output vector shapes exactly match.
+// - tensor descriptor has a sg_map attribute and the distributed vector shape
+// matches the tensor descriptor shape when scaled using sg_map factors on
+// each dimension.
+static bool isArgShapesValid(ArrayRef<int64_t> descShape,
+ ArrayRef<int64_t> valShape, SGMapAttr sgMap) {
+ if (descShape == valShape) {
+ if (!sgMap)
+ return true;
+
+ // this can be relaxed if necessary by supporting non-2d shapes distribution
+ // until the constraints are defined this lives here instead of the tensor
+ // descriptor type.
+ return valShape.size() == sgMap.getWiLayout().size();
+ }
+
+ if (!sgMap)
+ return false;
+
+ if (valShape.size() != descShape.size())
+ return false;
+
+ for (const auto &[factor, dim, expected] :
+ llvm::zip_equal(sgMap.getWiLayout(), valShape, descShape)) {
+ if (factor * dim != expected)
+ return false;
+ }
+
+ return true;
+}
+
//===----------------------------------------------------------------------===//
// XeGPU_CreateNdDescOp
//===----------------------------------------------------------------------===//
@@ -210,13 +243,13 @@ LogicalResult PrefetchNdOp::verify() {
return emitOpError("Expects a non-scattered TensorDesc.\n");
if (!isReadHintOrNone(getL1HintAttr()))
- return emitOpError("invlid l1_hint: ") << getL1HintAttr();
+ return emitOpError("invalid l1_hint: ") << getL1HintAttr();
if (!isReadHintOrNone(getL2HintAttr()))
- return emitOpError("invlid l2_hint: ") << getL2HintAttr();
+ return emitOpError("invalid l2_hint: ") << getL2HintAttr();
if (!isReadHintOrNone(getL3HintAttr()))
- return emitOpError("invlid l3_hint: ") << getL3HintAttr();
+ return emitOpError("invalid l3_hint: ") << getL3HintAttr();
return success();
}
@@ -238,13 +271,13 @@ LogicalResult LoadNdOp::verify() {
return emitOpError("Invalid result, it should be a VectorType.\n");
if (!isReadHintOrNone(getL1HintAttr()))
- return emitOpError("invlid l1_hint: ") << getL1HintAttr();
+ return emitOpError("invalid l1_hint: ") << getL1HintAttr();
if (!isReadHintOrNone(getL2HintAttr()))
- return emitOpError("invlid l2_hint: ") << getL2HintAttr();
+ return emitOpError("invalid l2_hint: ") << getL2HintAttr();
if (!isReadHintOrNone(getL3HintAttr()))
- return emitOpError("invlid l3_hint: ") << getL3HintAttr();
+ return emitOpError("invalid l3_hint: ") << getL3HintAttr();
auto array_len = tdescTy.getArrayLength();
auto tdescShape = getShapeOf(tdescTy);
@@ -280,8 +313,9 @@ LogicalResult LoadNdOp::verify() {
auto it = tdescShape.begin();
tdescShape.insert(it, array_len);
}
+ auto sgMap = tdescTy.getSGMapAttr();
- if (tdescShape != valueShape)
+ if (!isArgShapesValid(tdescShape, valueShape, sgMap))
return emitOpError() << "Result shape doesn't match TensorDesc shape."
<< "The expected shape is " << makeString(tdescShape)
<< ". But the given shape is "
@@ -303,17 +337,26 @@ LogicalResult StoreNdOp::verify() {
return emitOpError("Expects a non-scattered TensorDesc.\n");
if (!valTy)
- return emitOpError("Exepcting a VectorType result.\n");
+ return emitOpError("Expecting a VectorType result.\n");
if (!isWriteHintOrNone(getL1HintAttr()))
- return emitOpError("invlid l1_hint: ") << getL1HintAttr();
+ return emitOpError("invalid l1_hint: ") << getL1HintAttr();
if (!isWriteHintOrNone(getL2HintAttr()))
- return emitOpError("invlid l2_hint: ") << getL2HintAttr();
+ return emitOpError("invalid l2_hint: ") << getL2HintAttr();
if (!isWriteHintOrNone(getL3HintAttr()))
- return emitOpError("invlid l3_hint: ") << getL3HintAttr();
+ return emitOpError("invalid l3_hint: ") << getL3HintAttr();
+
+ auto tdescShape = getShapeOf(dstTy);
+ auto valueShape = getShapeOf(valTy);
+ auto sgMap = dstTy.getSGMapAttr();
+ if (!isArgShapesValid(tdescShape, valueShape, sgMap))
+ return emitOpError() << "Result shape doesn't match TensorDesc shape."
+ << "The expected shape is " << makeString(tdescShape)
+ << ". But the given shape is "
+ << makeString(valueShape) << ".\n";
return success();
}
@@ -423,13 +466,13 @@ LogicalResult PrefetchOp::verify() {
return emitOpError("Expects a scattered TensorDesc.\n");
if (!isReadHintOrNone(getL1HintAttr()))
- return emitOpError("invlid l1_hint: ") << getL1HintAttr();
+ return emitOpError("invalid l1_hint: ") << getL1HintAttr();
if (!isReadHintOrNone(getL2HintAttr()))
- return emitOpError("invlid l2_hint: ") << getL2HintAttr();
+ return emitOpError("invalid l2_hint: ") << getL2HintAttr();
if (!isReadHintOrNone(getL3HintAttr()))
- return emitOpError("invlid l3_hint: ") << getL3HintAttr();
+ return emitOpError("invalid l3_hint: ") << getL3HintAttr();
return success();
}
@@ -446,13 +489,13 @@ LogicalResult LoadGatherOp::verify() {
return emitOpError("Expects a scattered TensorDesc.\n");
if (!isReadHintOrNone(getL1HintAttr()))
- return emitOpError("invlid l1_hint: ") << getL1HintAttr();
+ return emitOpError("invalid l1_hint: ") << getL1HintAttr();
if (!isReadHintOrNone(getL2HintAttr()))
- return emitOpError("invlid l2_hint: ") << getL2HintAttr();
+ return emitOpError("invalid l2_hint: ") << getL2HintAttr();
if (!isReadHintOrNone(getL3HintAttr()))
- return emitOpError("invlid l3_hint: ") << getL3HintAttr();
+ return emitOpError("invalid l3_hint: ") << getL3HintAttr();
auto tdescElemTy = tdescTy.getElementType();
auto valueElemTy = getElementType();
@@ -490,13 +533,13 @@ LogicalResult StoreScatterOp::verify() {
return emitOpError("Expects a scattered TensorDesc.\n");
if (!isWriteHintOrNone(getL1HintAttr()))
- return emitOpError("invlid l1_hint: ") << getL1HintAttr();
+ return emitOpError("invalid l1_hint: ") << getL1HintAttr();
if (!isWriteHintOrNone(getL2HintAttr()))
- return emitOpError("invlid l2_hint: ") << getL2HintAttr();
+ return emitOpError("invalid l2_hint: ") << getL2HintAttr();
if (!isWriteHintOrNone(getL3HintAttr()))
- return emitOpError("invlid l3_hint: ") << getL3HintAttr();
+ return emitOpError("invalid l3_hint: ") << getL3HintAttr();
auto maskTy = getMaskType();
auto valueTy = getValueType();
diff --git a/mlir/test/Dialect/XeGPU/XeGPUOps.mlir b/mlir/test/Dialect/XeGPU/XeGPUOps.mlir
index a4587faa3345cb..d7174a489888a4 100644
--- a/mlir/test/Dialect/XeGPU/XeGPUOps.mlir
+++ b/mlir/test/Dialect/XeGPU/XeGPUOps.mlir
@@ -86,6 +86,17 @@ gpu.func @test_load_nd_vc_2(%src: memref<8x16xf16>) {
gpu.return
}
+// load_nd args may have
diff erent shapes, validated against sg_map
+// CHECK: func @test_load_nd_vc_3(%[[arg0:.*]]: memref<24x32xf32>) {
+gpu.func @test_load_nd_vc_3(%src: memref<24x32xf32>) {
+ // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
+ !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ // CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>> -> vector<8x1xf32>
+ %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>> -> vector<8x1xf32>
+ gpu.return
+}
+
// CHECK: func @test_store_nd_vc(%[[arg0:.*]]: memref<24x32xf16>) {
gpu.func @test_store_nd_vc(%dst: memref<24x32xf16>) {
// CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<24x32xf16>
@@ -108,6 +119,19 @@ gpu.func @test_store_nd_vc_2(%dst: memref<24x32xf16>) {
gpu.return
}
+// store_nd args may have
diff erent shapes, validated against sg_map
+// CHECK: func @test_store_nd_vc_3(%[[arg0:.*]]: memref<24x32xf16>) {
+gpu.func @test_store_nd_vc_3(%src: memref<24x32xf16>) {
+ // CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<24x2xf16>
+ %1 = arith.constant dense<1.0>: vector<24x2xf16>
+ // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<24x32xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ %2 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf16> ->
+ !xegpu.tensor_desc<24x32xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ // CHECK: xegpu.store_nd %[[C]], %[[R0]] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}> : vector<24x2xf16>, !xegpu.tensor_desc<24x32xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ xegpu.store_nd %1, %2 <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>: vector<24x2xf16>, !xegpu.tensor_desc<24x32xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ gpu.return
+}
+
// CHECK: gpu.func @test_create_update_nd_tdesc_vc(%[[arg0:.*]]: memref<24x32xf32>) {
gpu.func @test_create_update_nd_tdesc_vc(%src: memref<24x32xf32>) {
// CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index f8a0d95bd70a27..7816bff0582f81 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -32,7 +32,7 @@ func.func @test_create_nd_tdesc_vc_4(%src: memref<2x24x32xf32, 3>) {
// -----
func.func @test_prefetch_nd_vc_1(%src: memref<24x32xf16>) {
%1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<8x16xf16>
- // expected-error at +1 {{invlid l1_hint: #xegpu.cache_hint<write_back>}}
+ // expected-error at +1 {{invalid l1_hint: #xegpu.cache_hint<write_back>}}
xegpu.prefetch_nd %1 <{l1_hint = #xegpu.cache_hint<write_back>}>: !xegpu.tensor_desc<8x16xf16>
return
}
@@ -51,7 +51,7 @@ func.func @test_prefetch_nd_vc_2(%src: memref<24xf16>) {
// -----
func.func @test_load_nd_vc_1(%src: memref<8x16xf16>) {
%1 = xegpu.create_nd_tdesc %src[0, 0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
- // expected-error at +1 {{invlid l1_hint: #xegpu.cache_hint<write_back>}}
+ // expected-error at +1 {{invalid l1_hint: #xegpu.cache_hint<write_back>}}
%2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<write_back>}>
: !xegpu.tensor_desc<8x16xf16> -> vector<4x16x2xf16>
return
@@ -77,11 +77,29 @@ func.func @test_load_nd_vc_3(%src: memref<8x16xf16>) {
return
}
+// -----
+func.func @test_load_nd_vc_4(%src: memref<24x32xf32>) {
+ %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
+ !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ // expected-error at +1 {{Result shape doesn't match TensorDesc shape.}}
+ %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>> -> vector<8x2xf32>
+ return
+}
+
+// -----
+func.func @test_load_nd_vc_5(%src: memref<24x32xf32>) {
+ %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
+ !xegpu.tensor_desc<16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ // expected-error at +1 {{Result shape doesn't match TensorDesc shape.}}
+ %2 = xegpu.load_nd %1: !xegpu.tensor_desc<16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>> -> vector<16xf32>
+ return
+}
+
// -----
func.func @test_store_nd_vc_1(%dst: memref<24x32xf16>) {
%1 = arith.constant dense<1.0>: vector<24x32xf16>
%2 = xegpu.create_nd_tdesc %dst[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<24x32xf16>
- // expected-error at +1 {{invlid l1_hint: #xegpu.cache_hint<streaming>}}
+ // expected-error at +1 {{invalid l1_hint: #xegpu.cache_hint<streaming>}}
xegpu.store_nd %1, %2 <{l1_hint = #xegpu.cache_hint<streaming>}>: vector<24x32xf16>, !xegpu.tensor_desc<24x32xf16>
return
}
@@ -147,7 +165,7 @@ func.func @test_prefetch_vc_2(%src: ui64) {
%0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
%1 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex>
-> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
- // expected-error at +1 {{invlid l1_hint: #xegpu.cache_hint<write_back>}}
+ // expected-error at +1 {{invalid l1_hint: #xegpu.cache_hint<write_back>}}
xegpu.prefetch %1 <{l1_hint = #xegpu.cache_hint<write_back>}>: !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
return
}
@@ -168,7 +186,7 @@ func.func @test_load_gather_vc_2(%src: ui64) {
%0 = arith.constant dense<1>: vector<4xi1>
%1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex>
-> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
- // expected-error at +1 {{invlid l1_hint: #xegpu.cache_hint<write_back>}}
+ // expected-error at +1 {{invalid l1_hint: #xegpu.cache_hint<write_back>}}
%2 = xegpu.load %1, %0 <{l1_hint = #xegpu.cache_hint<write_back>}>
: !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>, vector<4xi1>
-> vector<4x2xf32>
@@ -193,7 +211,7 @@ func.func @test_store_scatter_vc_2(%src: ui64) {
%1 = arith.constant dense<2.9>: vector<4x2xf32>
%2 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex>
-> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
- // expected-error at +1 {{invlid l1_hint: #xegpu.cache_hint<streaming>}}
+ // expected-error at +1 {{invalid l1_hint: #xegpu.cache_hint<streaming>}}
xegpu.store %1, %2, %0 <{l1_hint = #xegpu.cache_hint<streaming>}> : vector<4x2xf32>,
!xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>, vector<4xi1>
return
More information about the Mlir-commits
mailing list