[Mlir-commits] [mlir] [MLIR][XeGPU] Refine XeGPU definitions (PR #100763)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jul 26 08:34:38 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Chao Chen (chencha3)

<details>
<summary>Changes</summary>

This PR has following changes/fixes to XeGPU definition: 
- Fix type print format for atomic_rmw
- removed 2D support for MaskType
- Update LoadNd definition
   - Add 1D TensorDesc support 
   - Replaced vnni_axis attribute with packed attribute 
- Update DPAS op definition, limiting A to 2D vector, and B to either 2D/3D vector. 

---
Full diff: https://github.com/llvm/llvm-project/pull/100763.diff


5 Files Affected:

- (modified) mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td (+9-12) 
- (modified) mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td (+1-1) 
- (modified) mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp (+25-19) 
- (modified) mlir/test/Dialect/XeGPU/XeGPUOps.mlir (+15-7) 
- (modified) mlir/test/Dialect/XeGPU/invalid.mlir (+11-17) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index e477d9a0ca3f1..c50c55060a319 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -245,8 +245,7 @@ def XeGPU_PrefetchNdOp : XeGPU_Op<"prefetch_nd", []> {
 }
 
 
-def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [AllElementTypesMatch<["value", "TensorDesc"]>,
-                                         AllElementCountsMatch<["value", "TensorDesc"]>]> {
+def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [AllElementTypesMatch<["value", "TensorDesc"]>]> {
   let summary = "loads a n-D block from memory (represented by TensorDesc)"
                 "to registers (represented by vector)";
   let description = [{
@@ -275,7 +274,7 @@ def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [AllElementTypesMatch<["value", "Tensor
   }];
 
   let arguments = (ins XeGPU_TensorDesc: $TensorDesc,
-                       OptionalAttr<I64Attr>: $vnni_axis,
+                       OptionalAttr<UnitAttr>: $packed,
                        OptionalAttr<DenseI64ArrayAttr>: $transpose,
                        OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
                        OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint,
@@ -668,14 +667,12 @@ def XeGPU_DpasOp : XeGPU_Op<"dpas", [Pure, AllElementTypesMatch<["lhs", "rhs"]>]
     data type, the matrices are `A: vector<8x16xf16>`, `B: vector<16x16xf16>`,
     and `C/D: vector<8x16xf32>`. Besides the matrix size requirements, DPAS
     also requires A and B to be loaded with the required data layout. Specially,
-    VNNI layout is required for B operand. It is achieved via setting `vnni_axis = 0`
-    of the corresponding `load_nd` operator. To keep both operands as 3D vector,
-    operand A is loaded via setting `vnni_axis = 1` without impacting the
-    physical layouts change in register. Due to the VNNI transformation, A and B operands
-    are represented as 3D vector, with the last dimension representing the VNNI factor,
-    which is computed as `32/bit_width_of_elem_type`. Therefore, `A: vector<8x16xf16>`
-    is represented as `A: vector<8x8x2xf16>`, and `B: vector<16x16xf16>` is
-    represented as `B: vector<8x16x2xf16>`.
+
+    VNNI layout is required for B operand. It is achieved via adding `packed`
+    attribute to the `load_nd` operator.  Due to the VNNI transformation, B operands
+    can be represented as a 3D vector, with the last dimension representing the VNNI
+    factor, which is computed as `32/bit_width_of_elem_type`. Thus, `B: vector<16x16xf16>`
+    can be represented as `B: vector<8x16x2xf16>`.
 
     Note: on PVC, the hardware can perform load with VNNI transformation when data
           element type is 16-bit or lower precision, taking 2 or 4 elements from
@@ -739,7 +736,7 @@ def XeGPU_AtomicRMWOp: XeGPU_Op<"atomic_rmw", [Pure,
 
   let assemblyFormat = [{
     $kind $tensorDesc `,` $mask `,` $value attr-dict `:`
-    type($tensorDesc) `,` type($mask) `,` type($value) `->` type($result)
+    qualified(type($tensorDesc)) `,` type($mask) `,` type($value) `->` type($result)
   }];
 }
 
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
index bab0e4afb1e5e..111a270a28b27 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
@@ -19,7 +19,7 @@ def XeGPU_ScalarType: AnyTypeOf<[XeGPU_IntType, XeGPU_FloatType]>;
 def XeGPU_BaseAddrType: AnyTypeOf<[MemRefRankOf<[XeGPU_ScalarType], [1, 2]>, UI64, UI32, I64, I32]>;
 def XeGPU_DpasOpType: VectorOfRankAndType<[2, 3], [XeGPU_ScalarType]>;
 def XeGPU_OffsetType: VectorOfRankAndType<[1], [Index]>;
-def XeGPU_MaskType: AnyTypeOf<[VectorOfRankAndType<[1,2], [I1]>, I1]>;
+def XeGPU_MaskType: AnyTypeOf<[VectorOfRankAndType<[1], [I1]>, I1]>;
 def XeGPU_ValueType: AnyTypeOf<[VectorOfRankAndType<[1,2,3,4], [XeGPU_ScalarType]>, XeGPU_ScalarType]>;
 def XeGPU_Vector2DType: VectorOfRankAndType<[2], [XeGPU_ScalarType]>;
 
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 5ef47fbbe1ce0..84ccd7f6b4326 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -182,8 +182,8 @@ LogicalResult LoadNdOp::verify() {
   auto tdescTy = getTensorDescType();
   auto valueTy = getType();
 
-  if (tdescTy.getRank() != 2)
-    return emitOpError("Expecting a 2D TensorDesc.\n");
+  if (tdescTy.getRank() > 2)
+    return emitOpError("Expecting a 1D/2D TensorDesc.\n");
 
   if (tdescTy.getScattered())
     return emitOpError("Expects a non-scattered TensorDesc.\n");
@@ -206,17 +206,27 @@ LogicalResult LoadNdOp::verify() {
 
   if (getTranspose()) {
     auto trans = getTranspose().value();
-    if (tdescShape.size() >= trans.size())
+
+    // Make sure the transpose value is valid.
+    bool valid = std::all_of(trans.begin(), trans.end(), [&](int t) {
+      return t >= 0 && t < tdescTy.getRank();
+    });
+
+    if (valid)
       transpose(trans, tdescShape);
     else
       emitWarning("Invalid transpose attr. It is ignored.");
   }
 
-  if (getVnniAxis()) {
-    auto axis = getVnniAxis().value();
-    auto vnni_factor = valueShape.back();
-    tdescShape[axis] /= vnni_factor;
-    tdescShape.push_back(vnni_factor);
+  if (getPacked()) {
+    if (tdescTy.getRank() == 2) {
+      const int axis = 0;
+      auto vnni_factor = valueShape.back();
+      tdescShape[axis] /= vnni_factor;
+      tdescShape.push_back(vnni_factor);
+    } else {
+      return emitWarning("Invalid Packed Attr. It is ignored (available for 2D TensorDesc only).");
+    }
   }
 
   if (array_len > 1) {
@@ -239,8 +249,8 @@ LogicalResult StoreNdOp::verify() {
   auto dstTy = getTensorDescType(); // Tile
   auto valTy = getValueType();      // Vector
 
-  if (dstTy.getRank() != 2)
-    return emitOpError("Expecting a 2D TensorDesc.\n");
+  if (dstTy.getRank() > 2)
+    return emitOpError("Expecting a 1D/2D TensorDesc.\n");
 
   if (dstTy.getScattered())
     return emitOpError("Expects a non-scattered TensorDesc.\n");
@@ -413,18 +423,14 @@ LogicalResult DpasOp::verify() {
   int64_t lhsRank = getLhsType().getRank();
   int64_t rhsRank = getRhsType().getRank();
 
-  if (lhsRank != rhsRank || lhsRank != 3)
-    return emitOpError(
-        "lhs and rhs rank does not match for dpas op, or their rank is not 3.");
-
-  if (getAcc() && getAccType() != getResultType())
-    return emitOpError("Accumulator and Result for dpas op should have the "
-                       "same type (both shape and element type).");
+  if (lhsRank != 2 || (rhsRank != 2 && rhsRank != 3))
+    return emitOpError("expecting lhs to be a 2D vector, and rhs to be either 2D or 3D (packed) vector.");
 
   auto lhsShape = getLhsType().getShape();
   auto rhsShape = getRhsType().getShape();
-  if (lhsShape[1] != rhsShape[0] || lhsShape[2] != rhsShape[2])
-    return emitOpError("K-dimension or vnni-factor mismatch.");
+  auto bK = rhsRank == 3 ? rhsShape[0] * rhsShape[2] : rhsShape[0];
+  if (bK != lhsShape[1])
+    return emitOpError("K-dimension mismatch.");
 
   return success();
 }
diff --git a/mlir/test/Dialect/XeGPU/XeGPUOps.mlir b/mlir/test/Dialect/XeGPU/XeGPUOps.mlir
index 00d32d2a2ee94..4f73e41e55370 100644
--- a/mlir/test/Dialect/XeGPU/XeGPUOps.mlir
+++ b/mlir/test/Dialect/XeGPU/XeGPUOps.mlir
@@ -42,8 +42,8 @@ gpu.func @test_prefetch_nd_vc(%src: memref<24x32xf16>) {
 gpu.func @test_load_nd_vc(%src: memref<8x16xf16>) {
   // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
   %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
-  // CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, vnni_axis = 0 : i64}> : !xegpu.tensor_desc<8x16xf16> -> vector<4x16x2xf16>
-  %2 = xegpu.load_nd %1 <{vnni_axis = 0, l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
+  // CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, packed}> : !xegpu.tensor_desc<8x16xf16> -> vector<4x16x2xf16>
+  %2 = xegpu.load_nd %1 <{packed, l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
        : !xegpu.tensor_desc<8x16xf16> -> vector<4x16x2xf16>
   gpu.return
 }
@@ -121,10 +121,18 @@ gpu.func @test_create_update_tdesc_vc(%src: ui64) {
   gpu.return
 }
 
-// CHECK: gpu.func @test_dpas_vc(%[[arg0:.*]]: vector<8x8x2xf16>, %[[arg1:.*]]: vector<8x16x2xf16>)
-gpu.func @test_dpas_vc(%a : vector<8x8x2xf16>, %b: vector<8x16x2xf16>) {
-  // CHECK: %0 = xegpu.dpas %[[arg0]], %[[arg1]] : vector<8x8x2xf16>, vector<8x16x2xf16> -> vector<8x16xf32>
-  %1 = xegpu.dpas %a, %b: vector<8x8x2xf16>, vector<8x16x2xf16> -> vector<8x16xf32>
+// CHECK: gpu.func @test_dpas_vc(%[[arg0:.*]]: vector<8x16xf16>, %[[arg1:.*]]: vector<16x16xf16>)
+gpu.func @test_dpas_vc(%a : vector<8x16xf16>, %b: vector<16x16xf16>) {
+  // CHECK: %0 = xegpu.dpas %[[arg0]], %[[arg1]] : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+  %1 = xegpu.dpas %a, %b: vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+  gpu.return
+}
+
+
+// CHECK: gpu.func @test_dpas_vc_with_packed_b(%[[arg0:.*]]: vector<8x16xf16>, %[[arg1:.*]]: vector<8x16x2xf16>)
+gpu.func @test_dpas_vc_with_packed_b(%a : vector<8x16xf16>, %b: vector<8x16x2xf16>) {
+  // CHECK: %0 = xegpu.dpas %[[arg0]], %[[arg1]] : vector<8x16xf16>, vector<8x16x2xf16> -> vector<8x16xf32>
+  %1 = xegpu.dpas %a, %b: vector<8x16xf16>, vector<8x16x2xf16> -> vector<8x16xf32>
   gpu.return
 }
 
@@ -132,7 +140,7 @@ gpu.func @test_dpas_vc(%a : vector<8x8x2xf16>, %b: vector<8x16x2xf16>) {
 gpu.func @test_atomic_rmw(%src: ui64, %value : vector<16xf32>, %mask : vector<16xi1>) {
   //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : ui64 -> !xegpu.tensor_desc<16xf32, #xegpu.tdesc_attr<scattered = true>>
   %1 = xegpu.create_tdesc %src[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]: ui64 -> !xegpu.tensor_desc<16xf32, #xegpu.tdesc_attr<scattered = true>>
-  //CHECK: %[[R1:.*]] = xegpu.atomic_rmw addf %[[R0]], %[[arg2]], %[[arg1]] : <16xf32, #xegpu.tdesc_attr<scattered = true>>, vector<16xi1>, vector<16xf32> -> vector<16xf32>
+  //CHECK: %[[R1:.*]] = xegpu.atomic_rmw addf %[[R0]], %[[arg2]], %[[arg1]] : !xegpu.tensor_desc<16xf32, #xegpu.tdesc_attr<scattered = true>>, vector<16xi1>, vector<16xf32> -> vector<16xf32>
   xegpu.atomic_rmw addf %1, %mask, %value: !xegpu.tensor_desc<16xf32, #xegpu.tdesc_attr<scattered = true>>, vector<16xi1>, vector<16xf32> -> vector<16xf32>
   gpu.return
 }
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index 7819ad60b97d9..ff37f5e1cca17 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -159,29 +159,23 @@ func.func @test_store_scatter_vc_2(%src: ui64) {
 }
 
 // -----
-func.func @test_dpas_vc_1(%a : vector<8x4x2xf16>, %b: vector<8x16x2xf16>) {
-  // expected-error at +1 {{K-dimension or vnni-factor mismatch}}
-  %1 = xegpu.dpas %a, %b : vector<8x4x2xf16>, vector<8x16x2xf16> -> vector<8x16xf32>
+func.func @test_dpas_vc_1(%a : vector<8x8xf16>, %b: vector<8x16x2xf16>) {
+  // expected-error at +1 {{K-dimension mismatch}}
+  %1 = xegpu.dpas %a, %b : vector<8x8xf16>, vector<8x16x2xf16> -> vector<8x16xf32>
   return
 }
 
 // -----
-func.func @test_dpas_vc_2(%a : vector<8x16xf16>, %b: vector<8x16x2xf16>) {
-  // expected-error at +1 {{lhs and rhs rank does not match for dpas op, or their rank is not 3}}
-  %1 = xegpu.dpas %a, %b : vector<8x16xf16>, vector<8x16x2xf16> -> vector<8x16xf32>
+func.func @test_dpas_vc_2(%a : vector<8x8x2xf16>, %b: vector<8x16x2xf16>) {
+  // expected-error at +1 {{expecting lhs to be a 2D vector, and rhs to be either 2D or 3D (packed) vector}}
+  %1 = xegpu.dpas %a, %b : vector<8x8x2xf16>, vector<8x16x2xf16> -> vector<8x16xf32>
   return
 }
 
 // -----
-func.func @test_dpas_vc_3(%a : vector<8x16xf16>, %b: vector<16x16xf16>) {
-  // expected-error at +1 {{lhs and rhs rank does not match for dpas op, or their rank is not 3}}
-  %1 = xegpu.dpas %a, %b : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
-  return
-}
-
-// -----
-func.func @test_dpas_vc_4(%a : vector<8x8x2xf16>, %b: vector<8x16x2xf16>, %c : vector<8x16xf16>) {
-  // expected-error at +1 {{Accumulator and Result for dpas op should have the same type}}
-  %1 = xegpu.dpas %a, %b, %c : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf16> -> vector<8x16xf32>
-  return
+func.func @test_atomic_rmw(%src: ui64, %value : vector<16x8xf32>, %mask : vector<16xi1>) {
+  %1 = xegpu.create_tdesc %src[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] {chunk_size = 8}: ui64 -> !xegpu.tensor_desc<16x8xf32, #xegpu.tdesc_attr<scattered = true>>
+  // expected-error at +1 {{failed to verify that all of {tensorDesc, mask, value, result} have same shape}}
+  xegpu.atomic_rmw addf %1, %mask, %value: !xegpu.tensor_desc<16x8xf32, #xegpu.tdesc_attr<scattered = true>>, vector<16xi1>, vector<16x8xf32> -> vector<16x8xf32>
+  gpu.return
 }
\ No newline at end of file

``````````

</details>


https://github.com/llvm/llvm-project/pull/100763


More information about the Mlir-commits mailing list