[Mlir-commits] [mlir] [MLIR][XeGPU] Extend op definitions to support 3D+: dpas, dpas_mx (PR #199809)

Jianhui Li llvmlistbot at llvm.org
Tue May 26 19:08:32 PDT 2026


https://github.com/Jianhui-Li updated https://github.com/llvm/llvm-project/pull/199809

>From cea5d934c04441d3ebc140e77964023f716ead73 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Tue, 26 May 2026 21:26:41 +0000
Subject: [PATCH 1/3] [MLIR][XeGPU] Extend op definitions to support 3D+: dpas,
 dpas_mx

Extend DpasOp and DpasMxOp to support higher-dimensional (3D+) operands
where leading dimensions represent batch dimensions. This enables batched
matrix multiplication operations with shapes like [4, 128, 512] x [4, 512, 128]
-> [4, 128, 128].

Key changes:
- Extend XeGPU_DpasOprType to support rank 3-4 (was 1-3)
- Extend XeGPU_DpasResType to support rank 3-4 (was 1-2)
- Extend dpas_mx scale_a/scale_b to support rank 3-4 vectors (was 1-2)
- Update verifyDpasDimensions() to validate batch dimensions across A, B, and
  result, treating the last 2 dims as the core matmul dims (M, K for A; K, N
  for B; M, N for result) and validating that batch dims match
- Update DpasMxOp::verify() to use batch-aware indexing when validating
  scale dimensions (use last-1 dim for M, last dim for N, etc.)
- Update test expectations to reflect new "at least 2D" error messages

This change enables expressing batched GEMMs and MXFP scaled GEMMs directly
at the operation definition level. Transform passes and lowering support will
be added in a follow-up PR.

Co-Authored-By: Claude Sonnet 4.5 <noreply at anthropic.com>
---
 .../include/mlir/Dialect/XeGPU/IR/XeGPUOps.td |   4 +-
 .../mlir/Dialect/XeGPU/IR/XeGPUTypes.td       |   4 +-
 mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp        | 117 +++++++++++++-----
 mlir/test/Dialect/XeGPU/invalid.mlir          |   8 +-
 mlir/test/Dialect/XeGPU/ops.mlir              |  21 ++++
 5 files changed, 114 insertions(+), 40 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index fce14999b4011..dcb51352edf13 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -1460,9 +1460,9 @@ def XeGPU_DpasMxOp : XeGPU_Op<"dpas_mx", [Pure, AttrSizedOperandSegments,
   let arguments = (ins XeGPU_DpasOprType:$a, XeGPU_DpasOprType:$b,
       Optional<XeGPU_DpasResType>:$acc,
       Optional<AnyTypeOf<[F8E8M0FNU,
-                          VectorOfRankAndType<[1, 2], [F8E8M0FNU]>]>>:$scale_a,
+                          VectorOfRankAndType<[1, 2, 3, 4], [F8E8M0FNU]>]>>:$scale_a,
       Optional<AnyTypeOf<[F8E8M0FNU,
-                          VectorOfRankAndType<[1, 2], [F8E8M0FNU]>]>>:$scale_b,
+                          VectorOfRankAndType<[1, 2, 3, 4], [F8E8M0FNU]>]>>:$scale_b,
       OptionalAttr<DistributeLayoutAttr>:$layout_a,
       OptionalAttr<DistributeLayoutAttr>:$layout_b,
       OptionalAttr<DistributeLayoutAttr>:$layout_cd,
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
index 95a3e22cf803b..0423303c23493 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
@@ -21,8 +21,8 @@ def XeGPU_ScalarType: AnyTypeOf<[XeGPU_IntType, XeGPU_FloatType]>;
 def XeGPU_PointerType : AnyTypeOf<[UI64, UI32, I64, I32]>;
 def XeGPU_BaseAddrType
     : AnyTypeOf<[Non0RankedMemRefOf<[XeGPU_ScalarType]>, XeGPU_PointerType]>;
-def XeGPU_DpasOprType: FixedVectorOfRankAndType<[1, 2, 3], [XeGPU_ScalarType]>;
-def XeGPU_DpasResType: FixedVectorOfRankAndType<[1, 2], [XeGPU_ScalarType]>;
+def XeGPU_DpasOprType: FixedVectorOfRankAndType<[1, 2, 3, 4], [XeGPU_ScalarType]>;
+def XeGPU_DpasResType: FixedVectorOfRankAndType<[1, 2, 3, 4], [XeGPU_ScalarType]>;
 def XeGPU_OffsetType: FixedVectorOfNonZeroRankOf<[Index]>;
 def XeGPU_MaskType: FixedVectorOfNonZeroRankOf<[I1]>;
 def XeGPU_ValueType: VectorOfRankAndType<[1,2,3,4,5,6,7,8], [XeGPU_ScalarType]>;
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 56d340141ee8f..d55a44f2279d9 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -720,31 +720,78 @@ static LogicalResult verifyDpasDimensions(Operation *op,
   if (aRank == 1 && bRank == 1 && resRank == 1)
     return success();
 
-  // Validate A and B are 2D
-  if (aRank != 2)
-    return op->emitOpError("A operand must be a 2D vector.");
-  if (bRank < 2 || bRank > 3)
-    return op->emitOpError("B operand must be a 2D or 3D vector.");
-  if (resRank != 2)
-    return op->emitOpError("Result must be a 2D vector.");
+  // A must be at least 2D, B must be 2D or 3D (innermost dims), result at
+  // least 2D.
+  if (aRank < 2)
+    return op->emitOpError("A operand must be at least a 2D vector.");
+  if (bRank < 2)
+    return op->emitOpError("B operand must be at least a 2D vector.");
+  if (resRank < 2)
+    return op->emitOpError("Result must be at least a 2D vector.");
+
+  // Determine batch dimensions. For A[batch..., M, K], B[batch..., K, N] (or
+  // B[batch..., K/vnni, N, vnni]), result[batch..., M, N].
+  // B may have one extra trailing dim for VNNI packing (3D innermost).
+  // Determine how many trailing dims are the "core" matmul dims.
+  // A core dims: last 2 (M, K)
+  // B core dims: last 2 (K, N) or last 3 (K/vnni, N, vnni) for packed
+  // Result core dims: last 2 (M, N)
+  int64_t aBatchRank = aRank - 2;
+  int64_t resBatchRank = resRank - 2;
+
+  // B can have an extra trailing dim for VNNI packing. Determine B's batch
+  // rank: if bRank > aRank, the extra dim is the VNNI packing dim.
+  bool bPacked = (bRank == aRank + 1);
+  int64_t bBatchRank = bPacked ? bRank - 3 : bRank - 2;
+
+  // Batch ranks must match across A, B, and result.
+  if (aBatchRank != bBatchRank || aBatchRank != resBatchRank)
+    return op->emitOpError("Batch dimension rank mismatch among A, B, and "
+                           "result.");
+
+  // Verify batch dimensions match.
+  for (int64_t i = 0; i < aBatchRank; ++i) {
+    if (aShape[i] != resShape[i])
+      return op->emitOpError("Batch dimension mismatch at dim ")
+             << i << ": A has " << aShape[i] << " but result has "
+             << resShape[i] << ".";
+    if (aShape[i] != bShape[i])
+      return op->emitOpError("Batch dimension mismatch at dim ")
+             << i << ": A has " << aShape[i] << " but B has " << bShape[i]
+             << ".";
+  }
 
-  // Calculate effective K dimension for B (handle 3D packed case)
-  int64_t bK = bRank == 3 ? bShape[0] * bShape[2] : bShape[0];
+  // Now verify the core matmul dimensions (last 2 of A and result, last 2 or
+  // 3 of B).
+  int64_t aM = aShape[aBatchRank];
+  int64_t aK = aShape[aBatchRank + 1];
+  int64_t resM = resShape[resBatchRank];
+  int64_t resN = resShape[resBatchRank + 1];
+
+  // Calculate effective K dimension for B (handle packed case)
+  int64_t bK, bN;
+  if (bPacked) {
+    bK = bShape[bBatchRank] * bShape[bBatchRank + 2];
+    bN = bShape[bBatchRank + 1];
+  } else {
+    bK = bShape[bBatchRank];
+    bN = bShape[bBatchRank + 1];
+  }
 
   // Verify K dimension match between A and B
-  if (bK != aShape[1])
+  if (bK != aK)
     return op->emitOpError("K-dimension mismatch: A has K=")
-           << aShape[1] << " but B has K=" << bK << ".";
+           << aK << " but B has K=" << bK << ".";
 
   // Verify M dimension match between A and result
-  if (aShape[0] != resShape[0])
+  if (aM != resM)
     return op->emitOpError("M-dimension mismatch: A has M=")
-           << aShape[0] << " but result has M=" << resShape[0] << ".";
+           << aM << " but result has M=" << resM << ".";
 
   // Verify N dimension match between B and result
-  if (bShape[1] != resShape[1])
+  if (bN != resN)
     return op->emitOpError("N-dimension mismatch: B has N=")
-           << bShape[1] << " but result has N=" << resShape[1] << ".";
+           << bN << " but result has N=" << resN << ".";
 
   return success();
 }
@@ -907,6 +954,9 @@ LogicalResult DpasMxOp::verify() {
   if (failed(verifyDpasDimensions(*this, aShape, bShape, resShape)))
     return failure();
 
+  // Determine batch rank from A operand.
+  int64_t aBatchRank = aShape.size() - 2;
+
   // Validate scale_a if present
   if (getScaleA()) {
     auto scaleAVecType = dyn_cast<VectorType>(getScaleAType());
@@ -914,19 +964,20 @@ LogicalResult DpasMxOp::verify() {
     if (scaleAVecType && scaleAVecType.getRank() > 1) {
       auto scaleAShape = scaleAVecType.getShape();
 
-      if (scaleAVecType.getRank() != 2)
-        return emitOpError("Scale A must be a 2D vector when not a scalar.");
+      if (scaleAVecType.getRank() < 2)
+        return emitOpError("Scale A must be at least a 2D vector when not a "
+                           "scalar.");
 
       // Verify layout distributability for scale_a
       if (failed(verifyLayoutDistributable(*this, getLayoutAScale(),
                                            scaleAShape, "ScaleA")))
         return failure();
 
-      // Validate M dimension: scale_a[0] must match a[0]
-      if (scaleAShape[0] != aShape[0])
+      // Validate M dimension: scale_a's M must match A's M (last-1 dim)
+      if (scaleAShape[scaleAShape.size() - 2] != aShape[aBatchRank])
         return emitOpError("Scale A M dimension [")
-               << scaleAShape[0] << "] must match A M dimension [" << aShape[0]
-               << "].";
+               << scaleAShape[scaleAShape.size() - 2]
+               << "] must match A M dimension [" << aShape[aBatchRank] << "].";
     }
   }
 
@@ -937,19 +988,21 @@ LogicalResult DpasMxOp::verify() {
     if (scaleBVecType && scaleBVecType.getRank() > 1) {
       auto scaleBShape = scaleBVecType.getShape();
 
-      if (scaleBVecType.getRank() != 2)
-        return emitOpError("Scale B must be a 2D vector when not a scalar.");
+      if (scaleBVecType.getRank() < 2)
+        return emitOpError("Scale B must be at least a 2D vector when not a "
+                           "scalar.");
 
       // Verify layout distributability for scale_b
       if (failed(verifyLayoutDistributable(*this, getLayoutBScale(),
                                            scaleBShape, "ScaleB")))
         return failure();
 
-      // Validate N dimension: scale_b[1] must match b[1]
-      if (scaleBShape[1] != bShape[1])
+      // Validate N dimension: scale_b's N (last dim) must match B's N (last
+      // dim)
+      if (scaleBShape.back() != bShape.back())
         return emitOpError("Scale B N dimension [")
-               << scaleBShape[1] << "] must match B N dimension [" << bShape[1]
-               << "].";
+               << scaleBShape.back() << "] must match B N dimension ["
+               << bShape.back() << "].";
     }
   }
 
@@ -964,12 +1017,12 @@ LogicalResult DpasMxOp::verify() {
       auto scaleAShape = scaleAVecType.getShape();
       auto scaleBShape = scaleBVecType.getShape();
 
-      // Validate scale K dimension compatibility: scale_a[1] must match
-      // scale_b[0]
-      if (scaleAShape[1] != scaleBShape[0])
+      // Validate scale K dimension compatibility: scale_a's last dim must
+      // match scale_b's second-to-last dim
+      if (scaleAShape.back() != scaleBShape[scaleBShape.size() - 2])
         return emitOpError("Scale K dimension mismatch: scale_a has K=")
-               << scaleAShape[1] << " but scale_b has K=" << scaleBShape[0]
-               << ".";
+               << scaleAShape.back() << " but scale_b has K="
+               << scaleBShape[scaleBShape.size() - 2] << ".";
     }
   }
 
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index 341427e7e1231..438f675ecc3f4 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -394,7 +394,7 @@ func.func @dpas_vc_1(%a : vector<8x8xf16>, %b: vector<8x16x2xf16>) {
 
 // -----
 func.func @dpas_vc_2(%a : vector<8x8x2xf16>, %b: vector<8x16x2xf16>) {
-  // expected-error at +1 {{op A operand must be a 2D vector}}
+  // expected-error at +1 {{'xegpu.dpas' op Batch dimension rank mismatch among A, B, and result}}
   %1 = xegpu.dpas %a, %b : vector<8x8x2xf16>, vector<8x16x2xf16> -> vector<8x16xf32>
   return
 }
@@ -702,21 +702,21 @@ func.func @dpas_mx_acc_result_type_mismatch(%a : vector<8x16xf8E5M2>, %b: vector
 
 // -----
 func.func @dpas_mx_a_not_2d(%a : vector<128xf8E5M2>, %b: vector<16x16xf8E5M2>) {
-  // expected-error at +1 {{A operand must be a 2D vector.}}
+  // expected-error at +1 {{A operand must be at least a 2D vector.}}
   %1 = xegpu.dpas_mx %a, %b : (vector<128xf8E5M2>, vector<16x16xf8E5M2>) -> vector<8x16xf32>
   return
 }
 
 // -----
 func.func @dpas_mx_b_not_2d(%a : vector<8x16xf8E5M2>, %b: vector<256xf8E5M2>) {
-  // expected-error at +1 {{B operand must be a 2D or 3D vector.}}
+  // expected-error at +1 {{B operand must be at least a 2D vector.}}
   %1 = xegpu.dpas_mx %a, %b : (vector<8x16xf8E5M2>, vector<256xf8E5M2>) -> vector<8x16xf32>
   return
 }
 
 // -----
 func.func @dpas_mx_result_not_2d(%a : vector<8x16xf8E5M2>, %b: vector<16x16xf8E5M2>) {
-  // expected-error at +1 {{Result must be a 2D vector.}}
+  // expected-error at +1 {{Result must be at least a 2D vector.}}
   %1 = xegpu.dpas_mx %a, %b : (vector<8x16xf8E5M2>, vector<16x16xf8E5M2>) -> vector<128xf32>
   return
 }
diff --git a/mlir/test/Dialect/XeGPU/ops.mlir b/mlir/test/Dialect/XeGPU/ops.mlir
index 198dc64d2814b..5cc74cdf7c9e4 100644
--- a/mlir/test/Dialect/XeGPU/ops.mlir
+++ b/mlir/test/Dialect/XeGPU/ops.mlir
@@ -649,4 +649,25 @@ gpu.func @dpas_mx(%a : vector<8x32xf8E5M2>, %b: vector<32x16xf8E5M2>, %acc: vect
   gpu.return
 }
 
+// CHECK-LABEL: gpu.func @dpas_3d_batch
+gpu.func @dpas_3d_batch(%a : vector<4x8x16xf16>, %b: vector<4x16x16xf16>) {
+  // CHECK: %{{.+}} = xegpu.dpas %{{.+}}, %{{.+}} : vector<4x8x16xf16>, vector<4x16x16xf16> -> vector<4x8x16xf32>
+  %1 = xegpu.dpas %a, %b: vector<4x8x16xf16>, vector<4x16x16xf16> -> vector<4x8x16xf32>
+  gpu.return
+}
+
+// CHECK-LABEL: gpu.func @dpas_3d_batch_with_acc
+gpu.func @dpas_3d_batch_with_acc(%a : vector<4x8x16xf16>, %b: vector<4x16x16xf16>, %acc: vector<4x8x16xf32>) {
+  // CHECK: %{{.+}} = xegpu.dpas %{{.+}}, %{{.+}}, %{{.+}} : vector<4x8x16xf16>, vector<4x16x16xf16>, vector<4x8x16xf32> -> vector<4x8x16xf32>
+  %1 = xegpu.dpas %a, %b, %acc : vector<4x8x16xf16>, vector<4x16x16xf16>, vector<4x8x16xf32> -> vector<4x8x16xf32>
+  gpu.return
+}
+
+// CHECK-LABEL: gpu.func @dpas_mx_3d_batch
+gpu.func @dpas_mx_3d_batch(%a : vector<4x8x32xf8E5M2>, %b: vector<4x32x16xf8E5M2>, %acc: vector<4x8x16xbf16>, %a_scale: vector<4x8x1xf8E8M0FNU>, %b_scale: vector<4x1x16xf8E8M0FNU>) {
+  // CHECK: %{{.+}} = xegpu.dpas_mx %{{.+}}, %{{.+}}, %{{.+}} scale_a = %{{.+}} scale_b = %{{.+}} : (vector<4x8x32xf8E5M2>, vector<4x32x16xf8E5M2>, vector<4x8x16xbf16>, vector<4x8x1xf8E8M0FNU>, vector<4x1x16xf8E8M0FNU>) -> vector<4x8x16xbf16>
+  %1 = xegpu.dpas_mx %a, %b, %acc scale_a = %a_scale scale_b = %b_scale : (vector<4x8x32xf8E5M2>, vector<4x32x16xf8E5M2>, vector<4x8x16xbf16>, vector<4x8x1xf8E8M0FNU>, vector<4x1x16xf8E8M0FNU>) -> vector<4x8x16xbf16>
+  gpu.return
+}
+
 }

>From 732a78bcda7bc7294ae5b204cd038b5dbcd9ce81 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Wed, 27 May 2026 02:00:14 +0000
Subject: [PATCH 2/3] [MLIR][XeGPU] Add 3D/4D batch dimension documentation and
 unit tests for dpas/dpas_mx

Enhance the documentation for xegpu.dpas and xegpu.dpas_mx operations to
clearly explain support for 3D and 4D vectors with batch dimensions. Add
unit tests demonstrating 3D batched operations.

Documentation updates:
- Clarify that matrix dimensions (m, n, k) are microarchitecture-specific
  (Xe2/Xe3: n=16, k=8*32/bit_width, m can be 1 to 8)
- Explain that operands can be 2D, 3D, or 4D vectors
- Clarify that leading dimensions are batch dimensions that must match
  across all operands
- Describe that batch dimensions represent independent matrix multiplications
  executed in parallel
- Add example showing 3D batched dpas operation
- Update argument descriptions to mention batch dimension support

Unit tests added (ops.mlir):
- dpas_3d_batch: Basic 3D dpas without accumulator
- dpas_3d_batch_with_acc: 3D dpas with accumulator
- dpas_mx_3d_batch: 3D dpas_mx with vector scales

Co-Authored-By: Claude Sonnet 4.5 <noreply at anthropic.com>
---
 .../include/mlir/Dialect/XeGPU/IR/XeGPUOps.td | 59 ++++++++++++++++---
 1 file changed, 50 insertions(+), 9 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index dcb51352edf13..c000f7637127c 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -919,6 +919,21 @@ def XeGPU_DpasOp : XeGPU_Op<"dpas", [Pure, AllElementTypesMatch<["lhs", "rhs"]>,
     data type, the matrices are `A: vector<8x16xf16>`, `B: vector<16x16xf16>`,
     and `C/D: vector<8x16xf32>`.
 
+    The operands can be 2D, 3D, or 4D vectors. When the vectors have more than 2
+    dimensions, the leading dimensions are treated as **batch dimensions** that
+    must match across all operands (lhs, rhs, acc, and result). The batch
+    dimensions represent independent matrix multiplications that are executed in
+    parallel. For example:
+    - 2D: `A: vector<8x16xf16>`, `B: vector<16x16xf16>` -> `C: vector<8x16xf32>`
+    - 3D: `A: vector<4x8x16xf16>`, `B: vector<4x16x16xf16>` -> `C: vector<4x8x16xf32>`
+      (4 independent 8x16 matrix multiplications)
+    - 4D: `A: vector<2x4x8x16xf16>`, `B: vector<2x4x16x16xf16>` -> `C: vector<2x4x8x16xf32>`
+      (2x4=8 independent 8x16 matrix multiplications)
+
+    The last 2 dimensions are always the core matrix multiplication dimensions
+    (M, K for lhs; K, N for rhs; M, N for result). Note that rhs (B) can have
+    one extra trailing dimension for VNNI packing (e.g., `vector<4x8x16x2xf16>`).
+
     In lane level code, each lane from a subgroup holds a data fragment for A, B, C and the result,
     which are represented as 1D vectors. Please refer to [OpenCL Intel extentions]
     (https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroup_matrix_multiply_accumulate.html)
@@ -930,19 +945,21 @@ def XeGPU_DpasOp : XeGPU_Op<"dpas", [Pure, AllElementTypesMatch<["lhs", "rhs"]>,
     Arguments:
 
     - `lhs`: A vector value representing the left-hand-side matrix tile (A) participating in the
-      matrix multiply.
+      matrix multiply. Can be 1D, 2D, 3D, or 4D where leading dimensions are batch dimensions.
 
-    - `rhs`: A vector value representing the right-hand-side matrix tile (B).
+    - `rhs`: A vector value representing the right-hand-side matrix tile (B). Can be 1D, 2D, 3D, or 4D
+      where leading dimensions are batch dimensions, plus an optional trailing dimension for VNNI packing.
 
     - `acc`: [optional] A vector value representing the accumulator matrix tile (C). When present, the
       result is computed as `lhs * rhs + acc`; otherwise, the accumulator is implicitly assumed to be zero.
+      Must have the same batch dimensions as lhs and rhs.
 
     - `layout_a`, `layout_b`, `layout_cd`: [optional] Attributes that identify this
       operation as anchor for operands A, B, and the accumulator/result, enabling users to assign layouts
       that govern distribution at the subgroup and/or lane level. Only valid at workgroup and subgroup
       level.
 
-    Example 1 (Workgroup level):
+    Example 1 (Workgroup level, 2D):
 
     ```mlir
       %d = xegpu.dpas %a, %b, %c <{
@@ -952,12 +969,20 @@ def XeGPU_DpasOp : XeGPU_Op<"dpas", [Pure, AllElementTypesMatch<["lhs", "rhs"]>,
           : vector<64x128xf16>, vector<128x128xf16>, vector<64x128xf32> -> vector<64x128xf32>
     ```
 
-    Example 2 (Lane level):
+    Example 2 (Lane level, 1D):
 
     ```mlir
       %d = xegpu.dpas %a, %b, %c
             :  vector<8xf16>, vector<16xf16>, vector<8xf32> -> vector<8xf32>
     ```
+
+    Example 3 (Workgroup level, 3D with batch):
+
+    ```mlir
+      // 4 independent 8x16 matrix multiplications
+      %d = xegpu.dpas %a, %b, %c
+          : vector<4x8x16xf16>, vector<4x16x16xf16>, vector<4x8x16xf32> -> vector<4x8x16xf32>
+    ```
   }];
 
   let arguments = (ins
@@ -1428,6 +1453,21 @@ def XeGPU_DpasMxOp : XeGPU_Op<"dpas_mx", [Pure, AttrSizedOperandSegments,
     size, B is of size `kxn`, and accumulate on matrix C of size `mxn` to the same size
     matrix.
 
+    The operands can be 2D, 3D, or 4D vectors. When the vectors have more than 2
+    dimensions, the leading dimensions are treated as **batch dimensions** that
+    must match across all operands (a, b, acc, result, scale_a, scale_b). The batch
+    dimensions represent independent matrix multiplications that are executed in
+    parallel. For example:
+    - 2D: `A: vector<8x32xf8E5M2>`, `B: vector<32x16xf8E5M2>` -> `C: vector<8x16xbf16>`
+    - 3D: `A: vector<4x8x32xf8E5M2>`, `B: vector<4x32x16xf8E5M2>` -> `C: vector<4x8x16xbf16>`
+      (4 independent 8x16 scaled matrix multiplications)
+    - 4D: `A: vector<2x4x8x32xf8E5M2>`, `B: vector<2x4x32x16xf8E5M2>` -> `C: vector<2x4x8x16xbf16>`
+      (2x4=8 independent 8x16 scaled matrix multiplications)
+
+    The last 2 dimensions are always the core matrix multiplication dimensions
+    (M, K for a; K, N for b; M, N for result). The scale vectors also follow the
+    same batch dimension structure.
+
     In lane level code, each lane from a subgroup holds a data fragment for A, B, Acc and the result,
     which are represented as 1D vectors.
 
@@ -1437,18 +1477,19 @@ def XeGPU_DpasMxOp : XeGPU_Op<"dpas_mx", [Pure, AttrSizedOperandSegments,
     Arguments:
 
     - `a`: A vector value representing the left-hand-side matrix tile (A) participating in the
-      matrix multiply.
+      matrix multiply. Can be 1D, 2D, 3D, or 4D where leading dimensions are batch dimensions.
 
-    - `b`: A vector value representing the right-hand-side matrix tile (B).
+    - `b`: A vector value representing the right-hand-side matrix tile (B). Can be 1D, 2D, 3D, or 4D
+      where leading dimensions are batch dimensions.
 
     - `acc`: A vector value representing the accumulator matrix tile (C). The
-      result is computed as `a * b + acc`.
+      result is computed as `a * b + acc`. Must have the same batch dimensions as a and b.
 
     - `scale_a`: A floating point vector/scalar value used to scale `a` for
-      matrix multiplication.
+      matrix multiplication. When a vector, must have matching batch dimensions.
 
     - `scale_b`: A floating point vector/scalar value used to scale `b` for
-      matrix multiplication.
+      matrix multiplication. When a vector, must have matching batch dimensions.
 
     - `layout_a`, `layout_b`, `layout_cd`: [optional] Attributes that identify this
       operation as anchor for operands A, B, and the accumulator/result, enabling users to assign layouts

>From 80814d4088eb656f3dbf04f1ef68e7b9d193366d Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Wed, 27 May 2026 02:08:19 +0000
Subject: [PATCH 3/3] add documentation and fix format issue

---
 mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td | 7 ++++---
 mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp         | 5 +++--
 2 files changed, 7 insertions(+), 5 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index c000f7637127c..d1e024787bc73 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -915,9 +915,10 @@ def XeGPU_DpasOp : XeGPU_Op<"dpas", [Pure, AllElementTypesMatch<["lhs", "rhs"]>,
 
   let description = [{DPAS performs matrix multiplication on matrix A of `mxk`
     size, B of `kxn` size, and accumulate on matrix C of `mxn` to the same size
-    matrix , `m=8`, `n=16` and `k=8 * 32/bit_width_of_elem_type`. So for fp16
-    data type, the matrices are `A: vector<8x16xf16>`, `B: vector<16x16xf16>`,
-    and `C/D: vector<8x16xf32>`.
+    matrix. The dimensions are microarchitecture-specific: for Xe2 and Xe3,
+    `n=16` and `k=8 * 32/bit_width_of_elem_type`, while `m` can be 1 to 8.
+    For example, for fp16 data type on Xe2/Xe3, typical matrices are
+    `A: vector<8x16xf16>`, `B: vector<16x16xf16>`, and `C/D: vector<8x16xf32>`.
 
     The operands can be 2D, 3D, or 4D vectors. When the vectors have more than 2
     dimensions, the leading dimensions are treated as **batch dimensions** that
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index d55a44f2279d9..bebdd22e1c087 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -1021,8 +1021,9 @@ LogicalResult DpasMxOp::verify() {
       // match scale_b's second-to-last dim
       if (scaleAShape.back() != scaleBShape[scaleBShape.size() - 2])
         return emitOpError("Scale K dimension mismatch: scale_a has K=")
-               << scaleAShape.back() << " but scale_b has K="
-               << scaleBShape[scaleBShape.size() - 2] << ".";
+               << scaleAShape.back()
+               << " but scale_b has K=" << scaleBShape[scaleBShape.size() - 2]
+               << ".";
     }
   }
 



More information about the Mlir-commits mailing list