[Mlir-commits] [mlir] [MLIR][XeGPU] Extend op definitions to support 3D+: dpas, dpas_mx (PR #199809)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue May 26 19:03:50 PDT 2026


llvmorg-github-actions[bot] wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Jianhui Li (Jianhui-Li)

<details>
<summary>Changes</summary>

**Summary**

  Extend xegpu.dpas and xegpu.dpas_mx operations to support 3D and 4D operands with batch dimensions, enabling batched matrix multiplication workloads like [4, 128, 512] x [4, 512, 128] -> [4, 128, 128].

  **Changes**

  - Type definitions: Extended XeGPU_DpasOprType and  XeGPU_DpasResType to support rank 3-4 (previously 1-3 and 1-2) 
   - Op definitions: Extended dpas_mx scale operands to support rank 3-4 vectors
  - Verifiers: Updated verifyDpasDimensions() to validate batch dimensions across A, B, and result operands;  updated DpasMxOp::verify() for batch-aware scale dimension checks
  - Documentation: Added comprehensive documentation explaining batch dimensions, microarchitecture-specific  matrix sizes, and 3D/4D usage examples
  - Tests: Added unit tests for 3D batched dpas and dpas_mx operations
  
  Assisted-by-Claude

---
Full diff: https://github.com/llvm/llvm-project/pull/199809.diff


5 Files Affected:

- (modified) mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td (+52-11) 
- (modified) mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td (+2-2) 
- (modified) mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp (+85-32) 
- (modified) mlir/test/Dialect/XeGPU/invalid.mlir (+4-4) 
- (modified) mlir/test/Dialect/XeGPU/ops.mlir (+21) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index fce14999b4011..c000f7637127c 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -919,6 +919,21 @@ def XeGPU_DpasOp : XeGPU_Op<"dpas", [Pure, AllElementTypesMatch<["lhs", "rhs"]>,
     data type, the matrices are `A: vector<8x16xf16>`, `B: vector<16x16xf16>`,
     and `C/D: vector<8x16xf32>`.
 
+    The operands can be 2D, 3D, or 4D vectors. When the vectors have more than 2
+    dimensions, the leading dimensions are treated as **batch dimensions** that
+    must match across all operands (lhs, rhs, acc, and result). The batch
+    dimensions represent independent matrix multiplications that are executed in
+    parallel. For example:
+    - 2D: `A: vector<8x16xf16>`, `B: vector<16x16xf16>` -> `C: vector<8x16xf32>`
+    - 3D: `A: vector<4x8x16xf16>`, `B: vector<4x16x16xf16>` -> `C: vector<4x8x16xf32>`
+      (4 independent 8x16 matrix multiplications)
+    - 4D: `A: vector<2x4x8x16xf16>`, `B: vector<2x4x16x16xf16>` -> `C: vector<2x4x8x16xf32>`
+      (2x4=8 independent 8x16 matrix multiplications)
+
+    The last 2 dimensions are always the core matrix multiplication dimensions
+    (M, K for lhs; K, N for rhs; M, N for result). Note that rhs (B) can have
+    one extra trailing dimension for VNNI packing (e.g., `vector<4x8x16x2xf16>`).
+
     In lane level code, each lane from a subgroup holds a data fragment for A, B, C and the result,
     which are represented as 1D vectors. Please refer to [OpenCL Intel extentions]
     (https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroup_matrix_multiply_accumulate.html)
@@ -930,19 +945,21 @@ def XeGPU_DpasOp : XeGPU_Op<"dpas", [Pure, AllElementTypesMatch<["lhs", "rhs"]>,
     Arguments:
 
     - `lhs`: A vector value representing the left-hand-side matrix tile (A) participating in the
-      matrix multiply.
+      matrix multiply. Can be 1D, 2D, 3D, or 4D where leading dimensions are batch dimensions.
 
-    - `rhs`: A vector value representing the right-hand-side matrix tile (B).
+    - `rhs`: A vector value representing the right-hand-side matrix tile (B). Can be 1D, 2D, 3D, or 4D
+      where leading dimensions are batch dimensions, plus an optional trailing dimension for VNNI packing.
 
     - `acc`: [optional] A vector value representing the accumulator matrix tile (C). When present, the
       result is computed as `lhs * rhs + acc`; otherwise, the accumulator is implicitly assumed to be zero.
+      Must have the same batch dimensions as lhs and rhs.
 
     - `layout_a`, `layout_b`, `layout_cd`: [optional] Attributes that identify this
       operation as anchor for operands A, B, and the accumulator/result, enabling users to assign layouts
       that govern distribution at the subgroup and/or lane level. Only valid at workgroup and subgroup
       level.
 
-    Example 1 (Workgroup level):
+    Example 1 (Workgroup level, 2D):
 
     ```mlir
       %d = xegpu.dpas %a, %b, %c <{
@@ -952,12 +969,20 @@ def XeGPU_DpasOp : XeGPU_Op<"dpas", [Pure, AllElementTypesMatch<["lhs", "rhs"]>,
           : vector<64x128xf16>, vector<128x128xf16>, vector<64x128xf32> -> vector<64x128xf32>
     ```
 
-    Example 2 (Lane level):
+    Example 2 (Lane level, 1D):
 
     ```mlir
       %d = xegpu.dpas %a, %b, %c
             :  vector<8xf16>, vector<16xf16>, vector<8xf32> -> vector<8xf32>
     ```
+
+    Example 3 (Workgroup level, 3D with batch):
+
+    ```mlir
+      // 4 independent 8x16 matrix multiplications
+      %d = xegpu.dpas %a, %b, %c
+          : vector<4x8x16xf16>, vector<4x16x16xf16>, vector<4x8x16xf32> -> vector<4x8x16xf32>
+    ```
   }];
 
   let arguments = (ins
@@ -1428,6 +1453,21 @@ def XeGPU_DpasMxOp : XeGPU_Op<"dpas_mx", [Pure, AttrSizedOperandSegments,
     size, B is of size `kxn`, and accumulate on matrix C of size `mxn` to the same size
     matrix.
 
+    The operands can be 2D, 3D, or 4D vectors. When the vectors have more than 2
+    dimensions, the leading dimensions are treated as **batch dimensions** that
+    must match across all operands (a, b, acc, result, scale_a, scale_b). The batch
+    dimensions represent independent matrix multiplications that are executed in
+    parallel. For example:
+    - 2D: `A: vector<8x32xf8E5M2>`, `B: vector<32x16xf8E5M2>` -> `C: vector<8x16xbf16>`
+    - 3D: `A: vector<4x8x32xf8E5M2>`, `B: vector<4x32x16xf8E5M2>` -> `C: vector<4x8x16xbf16>`
+      (4 independent 8x16 scaled matrix multiplications)
+    - 4D: `A: vector<2x4x8x32xf8E5M2>`, `B: vector<2x4x32x16xf8E5M2>` -> `C: vector<2x4x8x16xbf16>`
+      (2x4=8 independent 8x16 scaled matrix multiplications)
+
+    The last 2 dimensions are always the core matrix multiplication dimensions
+    (M, K for a; K, N for b; M, N for result). The scale vectors also follow the
+    same batch dimension structure.
+
     In lane level code, each lane from a subgroup holds a data fragment for A, B, Acc and the result,
     which are represented as 1D vectors.
 
@@ -1437,18 +1477,19 @@ def XeGPU_DpasMxOp : XeGPU_Op<"dpas_mx", [Pure, AttrSizedOperandSegments,
     Arguments:
 
     - `a`: A vector value representing the left-hand-side matrix tile (A) participating in the
-      matrix multiply.
+      matrix multiply. Can be 1D, 2D, 3D, or 4D where leading dimensions are batch dimensions.
 
-    - `b`: A vector value representing the right-hand-side matrix tile (B).
+    - `b`: A vector value representing the right-hand-side matrix tile (B). Can be 1D, 2D, 3D, or 4D
+      where leading dimensions are batch dimensions.
 
     - `acc`: A vector value representing the accumulator matrix tile (C). The
-      result is computed as `a * b + acc`.
+      result is computed as `a * b + acc`. Must have the same batch dimensions as a and b.
 
     - `scale_a`: A floating point vector/scalar value used to scale `a` for
-      matrix multiplication.
+      matrix multiplication. When a vector, must have matching batch dimensions.
 
     - `scale_b`: A floating point vector/scalar value used to scale `b` for
-      matrix multiplication.
+      matrix multiplication. When a vector, must have matching batch dimensions.
 
     - `layout_a`, `layout_b`, `layout_cd`: [optional] Attributes that identify this
       operation as anchor for operands A, B, and the accumulator/result, enabling users to assign layouts
@@ -1460,9 +1501,9 @@ def XeGPU_DpasMxOp : XeGPU_Op<"dpas_mx", [Pure, AttrSizedOperandSegments,
   let arguments = (ins XeGPU_DpasOprType:$a, XeGPU_DpasOprType:$b,
       Optional<XeGPU_DpasResType>:$acc,
       Optional<AnyTypeOf<[F8E8M0FNU,
-                          VectorOfRankAndType<[1, 2], [F8E8M0FNU]>]>>:$scale_a,
+                          VectorOfRankAndType<[1, 2, 3, 4], [F8E8M0FNU]>]>>:$scale_a,
       Optional<AnyTypeOf<[F8E8M0FNU,
-                          VectorOfRankAndType<[1, 2], [F8E8M0FNU]>]>>:$scale_b,
+                          VectorOfRankAndType<[1, 2, 3, 4], [F8E8M0FNU]>]>>:$scale_b,
       OptionalAttr<DistributeLayoutAttr>:$layout_a,
       OptionalAttr<DistributeLayoutAttr>:$layout_b,
       OptionalAttr<DistributeLayoutAttr>:$layout_cd,
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
index 95a3e22cf803b..0423303c23493 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
@@ -21,8 +21,8 @@ def XeGPU_ScalarType: AnyTypeOf<[XeGPU_IntType, XeGPU_FloatType]>;
 def XeGPU_PointerType : AnyTypeOf<[UI64, UI32, I64, I32]>;
 def XeGPU_BaseAddrType
     : AnyTypeOf<[Non0RankedMemRefOf<[XeGPU_ScalarType]>, XeGPU_PointerType]>;
-def XeGPU_DpasOprType: FixedVectorOfRankAndType<[1, 2, 3], [XeGPU_ScalarType]>;
-def XeGPU_DpasResType: FixedVectorOfRankAndType<[1, 2], [XeGPU_ScalarType]>;
+def XeGPU_DpasOprType: FixedVectorOfRankAndType<[1, 2, 3, 4], [XeGPU_ScalarType]>;
+def XeGPU_DpasResType: FixedVectorOfRankAndType<[1, 2, 3, 4], [XeGPU_ScalarType]>;
 def XeGPU_OffsetType: FixedVectorOfNonZeroRankOf<[Index]>;
 def XeGPU_MaskType: FixedVectorOfNonZeroRankOf<[I1]>;
 def XeGPU_ValueType: VectorOfRankAndType<[1,2,3,4,5,6,7,8], [XeGPU_ScalarType]>;
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 56d340141ee8f..d55a44f2279d9 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -720,31 +720,78 @@ static LogicalResult verifyDpasDimensions(Operation *op,
   if (aRank == 1 && bRank == 1 && resRank == 1)
     return success();
 
-  // Validate A and B are 2D
-  if (aRank != 2)
-    return op->emitOpError("A operand must be a 2D vector.");
-  if (bRank < 2 || bRank > 3)
-    return op->emitOpError("B operand must be a 2D or 3D vector.");
-  if (resRank != 2)
-    return op->emitOpError("Result must be a 2D vector.");
+  // A must be at least 2D, B must be 2D or 3D (innermost dims), result at
+  // least 2D.
+  if (aRank < 2)
+    return op->emitOpError("A operand must be at least a 2D vector.");
+  if (bRank < 2)
+    return op->emitOpError("B operand must be at least a 2D vector.");
+  if (resRank < 2)
+    return op->emitOpError("Result must be at least a 2D vector.");
+
+  // Determine batch dimensions. For A[batch..., M, K], B[batch..., K, N] (or
+  // B[batch..., K/vnni, N, vnni]), result[batch..., M, N].
+  // B may have one extra trailing dim for VNNI packing (3D innermost).
+  // Determine how many trailing dims are the "core" matmul dims.
+  // A core dims: last 2 (M, K)
+  // B core dims: last 2 (K, N) or last 3 (K/vnni, N, vnni) for packed
+  // Result core dims: last 2 (M, N)
+  int64_t aBatchRank = aRank - 2;
+  int64_t resBatchRank = resRank - 2;
+
+  // B can have an extra trailing dim for VNNI packing. Determine B's batch
+  // rank: if bRank > aRank, the extra dim is the VNNI packing dim.
+  bool bPacked = (bRank == aRank + 1);
+  int64_t bBatchRank = bPacked ? bRank - 3 : bRank - 2;
+
+  // Batch ranks must match across A, B, and result.
+  if (aBatchRank != bBatchRank || aBatchRank != resBatchRank)
+    return op->emitOpError("Batch dimension rank mismatch among A, B, and "
+                           "result.");
+
+  // Verify batch dimensions match.
+  for (int64_t i = 0; i < aBatchRank; ++i) {
+    if (aShape[i] != resShape[i])
+      return op->emitOpError("Batch dimension mismatch at dim ")
+             << i << ": A has " << aShape[i] << " but result has "
+             << resShape[i] << ".";
+    if (aShape[i] != bShape[i])
+      return op->emitOpError("Batch dimension mismatch at dim ")
+             << i << ": A has " << aShape[i] << " but B has " << bShape[i]
+             << ".";
+  }
 
-  // Calculate effective K dimension for B (handle 3D packed case)
-  int64_t bK = bRank == 3 ? bShape[0] * bShape[2] : bShape[0];
+  // Now verify the core matmul dimensions (last 2 of A and result, last 2 or
+  // 3 of B).
+  int64_t aM = aShape[aBatchRank];
+  int64_t aK = aShape[aBatchRank + 1];
+  int64_t resM = resShape[resBatchRank];
+  int64_t resN = resShape[resBatchRank + 1];
+
+  // Calculate effective K dimension for B (handle packed case)
+  int64_t bK, bN;
+  if (bPacked) {
+    bK = bShape[bBatchRank] * bShape[bBatchRank + 2];
+    bN = bShape[bBatchRank + 1];
+  } else {
+    bK = bShape[bBatchRank];
+    bN = bShape[bBatchRank + 1];
+  }
 
   // Verify K dimension match between A and B
-  if (bK != aShape[1])
+  if (bK != aK)
     return op->emitOpError("K-dimension mismatch: A has K=")
-           << aShape[1] << " but B has K=" << bK << ".";
+           << aK << " but B has K=" << bK << ".";
 
   // Verify M dimension match between A and result
-  if (aShape[0] != resShape[0])
+  if (aM != resM)
     return op->emitOpError("M-dimension mismatch: A has M=")
-           << aShape[0] << " but result has M=" << resShape[0] << ".";
+           << aM << " but result has M=" << resM << ".";
 
   // Verify N dimension match between B and result
-  if (bShape[1] != resShape[1])
+  if (bN != resN)
     return op->emitOpError("N-dimension mismatch: B has N=")
-           << bShape[1] << " but result has N=" << resShape[1] << ".";
+           << bN << " but result has N=" << resN << ".";
 
   return success();
 }
@@ -907,6 +954,9 @@ LogicalResult DpasMxOp::verify() {
   if (failed(verifyDpasDimensions(*this, aShape, bShape, resShape)))
     return failure();
 
+  // Determine batch rank from A operand.
+  int64_t aBatchRank = aShape.size() - 2;
+
   // Validate scale_a if present
   if (getScaleA()) {
     auto scaleAVecType = dyn_cast<VectorType>(getScaleAType());
@@ -914,19 +964,20 @@ LogicalResult DpasMxOp::verify() {
     if (scaleAVecType && scaleAVecType.getRank() > 1) {
       auto scaleAShape = scaleAVecType.getShape();
 
-      if (scaleAVecType.getRank() != 2)
-        return emitOpError("Scale A must be a 2D vector when not a scalar.");
+      if (scaleAVecType.getRank() < 2)
+        return emitOpError("Scale A must be at least a 2D vector when not a "
+                           "scalar.");
 
       // Verify layout distributability for scale_a
       if (failed(verifyLayoutDistributable(*this, getLayoutAScale(),
                                            scaleAShape, "ScaleA")))
         return failure();
 
-      // Validate M dimension: scale_a[0] must match a[0]
-      if (scaleAShape[0] != aShape[0])
+      // Validate M dimension: scale_a's M must match A's M (last-1 dim)
+      if (scaleAShape[scaleAShape.size() - 2] != aShape[aBatchRank])
         return emitOpError("Scale A M dimension [")
-               << scaleAShape[0] << "] must match A M dimension [" << aShape[0]
-               << "].";
+               << scaleAShape[scaleAShape.size() - 2]
+               << "] must match A M dimension [" << aShape[aBatchRank] << "].";
     }
   }
 
@@ -937,19 +988,21 @@ LogicalResult DpasMxOp::verify() {
     if (scaleBVecType && scaleBVecType.getRank() > 1) {
       auto scaleBShape = scaleBVecType.getShape();
 
-      if (scaleBVecType.getRank() != 2)
-        return emitOpError("Scale B must be a 2D vector when not a scalar.");
+      if (scaleBVecType.getRank() < 2)
+        return emitOpError("Scale B must be at least a 2D vector when not a "
+                           "scalar.");
 
       // Verify layout distributability for scale_b
       if (failed(verifyLayoutDistributable(*this, getLayoutBScale(),
                                            scaleBShape, "ScaleB")))
         return failure();
 
-      // Validate N dimension: scale_b[1] must match b[1]
-      if (scaleBShape[1] != bShape[1])
+      // Validate N dimension: scale_b's N (last dim) must match B's N (last
+      // dim)
+      if (scaleBShape.back() != bShape.back())
         return emitOpError("Scale B N dimension [")
-               << scaleBShape[1] << "] must match B N dimension [" << bShape[1]
-               << "].";
+               << scaleBShape.back() << "] must match B N dimension ["
+               << bShape.back() << "].";
     }
   }
 
@@ -964,12 +1017,12 @@ LogicalResult DpasMxOp::verify() {
       auto scaleAShape = scaleAVecType.getShape();
       auto scaleBShape = scaleBVecType.getShape();
 
-      // Validate scale K dimension compatibility: scale_a[1] must match
-      // scale_b[0]
-      if (scaleAShape[1] != scaleBShape[0])
+      // Validate scale K dimension compatibility: scale_a's last dim must
+      // match scale_b's second-to-last dim
+      if (scaleAShape.back() != scaleBShape[scaleBShape.size() - 2])
         return emitOpError("Scale K dimension mismatch: scale_a has K=")
-               << scaleAShape[1] << " but scale_b has K=" << scaleBShape[0]
-               << ".";
+               << scaleAShape.back() << " but scale_b has K="
+               << scaleBShape[scaleBShape.size() - 2] << ".";
     }
   }
 
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index 341427e7e1231..438f675ecc3f4 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -394,7 +394,7 @@ func.func @dpas_vc_1(%a : vector<8x8xf16>, %b: vector<8x16x2xf16>) {
 
 // -----
 func.func @dpas_vc_2(%a : vector<8x8x2xf16>, %b: vector<8x16x2xf16>) {
-  // expected-error at +1 {{op A operand must be a 2D vector}}
+  // expected-error at +1 {{'xegpu.dpas' op Batch dimension rank mismatch among A, B, and result}}
   %1 = xegpu.dpas %a, %b : vector<8x8x2xf16>, vector<8x16x2xf16> -> vector<8x16xf32>
   return
 }
@@ -702,21 +702,21 @@ func.func @dpas_mx_acc_result_type_mismatch(%a : vector<8x16xf8E5M2>, %b: vector
 
 // -----
 func.func @dpas_mx_a_not_2d(%a : vector<128xf8E5M2>, %b: vector<16x16xf8E5M2>) {
-  // expected-error at +1 {{A operand must be a 2D vector.}}
+  // expected-error at +1 {{A operand must be at least a 2D vector.}}
   %1 = xegpu.dpas_mx %a, %b : (vector<128xf8E5M2>, vector<16x16xf8E5M2>) -> vector<8x16xf32>
   return
 }
 
 // -----
 func.func @dpas_mx_b_not_2d(%a : vector<8x16xf8E5M2>, %b: vector<256xf8E5M2>) {
-  // expected-error at +1 {{B operand must be a 2D or 3D vector.}}
+  // expected-error at +1 {{B operand must be at least a 2D vector.}}
   %1 = xegpu.dpas_mx %a, %b : (vector<8x16xf8E5M2>, vector<256xf8E5M2>) -> vector<8x16xf32>
   return
 }
 
 // -----
 func.func @dpas_mx_result_not_2d(%a : vector<8x16xf8E5M2>, %b: vector<16x16xf8E5M2>) {
-  // expected-error at +1 {{Result must be a 2D vector.}}
+  // expected-error at +1 {{Result must be at least a 2D vector.}}
   %1 = xegpu.dpas_mx %a, %b : (vector<8x16xf8E5M2>, vector<16x16xf8E5M2>) -> vector<128xf32>
   return
 }
diff --git a/mlir/test/Dialect/XeGPU/ops.mlir b/mlir/test/Dialect/XeGPU/ops.mlir
index 198dc64d2814b..5cc74cdf7c9e4 100644
--- a/mlir/test/Dialect/XeGPU/ops.mlir
+++ b/mlir/test/Dialect/XeGPU/ops.mlir
@@ -649,4 +649,25 @@ gpu.func @dpas_mx(%a : vector<8x32xf8E5M2>, %b: vector<32x16xf8E5M2>, %acc: vect
   gpu.return
 }
 
+// CHECK-LABEL: gpu.func @dpas_3d_batch
+gpu.func @dpas_3d_batch(%a : vector<4x8x16xf16>, %b: vector<4x16x16xf16>) {
+  // CHECK: %{{.+}} = xegpu.dpas %{{.+}}, %{{.+}} : vector<4x8x16xf16>, vector<4x16x16xf16> -> vector<4x8x16xf32>
+  %1 = xegpu.dpas %a, %b: vector<4x8x16xf16>, vector<4x16x16xf16> -> vector<4x8x16xf32>
+  gpu.return
+}
+
+// CHECK-LABEL: gpu.func @dpas_3d_batch_with_acc
+gpu.func @dpas_3d_batch_with_acc(%a : vector<4x8x16xf16>, %b: vector<4x16x16xf16>, %acc: vector<4x8x16xf32>) {
+  // CHECK: %{{.+}} = xegpu.dpas %{{.+}}, %{{.+}}, %{{.+}} : vector<4x8x16xf16>, vector<4x16x16xf16>, vector<4x8x16xf32> -> vector<4x8x16xf32>
+  %1 = xegpu.dpas %a, %b, %acc : vector<4x8x16xf16>, vector<4x16x16xf16>, vector<4x8x16xf32> -> vector<4x8x16xf32>
+  gpu.return
+}
+
+// CHECK-LABEL: gpu.func @dpas_mx_3d_batch
+gpu.func @dpas_mx_3d_batch(%a : vector<4x8x32xf8E5M2>, %b: vector<4x32x16xf8E5M2>, %acc: vector<4x8x16xbf16>, %a_scale: vector<4x8x1xf8E8M0FNU>, %b_scale: vector<4x1x16xf8E8M0FNU>) {
+  // CHECK: %{{.+}} = xegpu.dpas_mx %{{.+}}, %{{.+}}, %{{.+}} scale_a = %{{.+}} scale_b = %{{.+}} : (vector<4x8x32xf8E5M2>, vector<4x32x16xf8E5M2>, vector<4x8x16xbf16>, vector<4x8x1xf8E8M0FNU>, vector<4x1x16xf8E8M0FNU>) -> vector<4x8x16xbf16>
+  %1 = xegpu.dpas_mx %a, %b, %acc scale_a = %a_scale scale_b = %b_scale : (vector<4x8x32xf8E5M2>, vector<4x32x16xf8E5M2>, vector<4x8x16xbf16>, vector<4x8x1xf8E8M0FNU>, vector<4x1x16xf8E8M0FNU>) -> vector<4x8x16xbf16>
+  gpu.return
+}
+
 }

``````````

</details>


https://github.com/llvm/llvm-project/pull/199809


More information about the Mlir-commits mailing list