[Mlir-commits] [mlir] [MLIR][XeGPU] Extend op definitions to support 3D+: dpas, dpas_mx (PR #199809)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue May 26 19:03:50 PDT 2026
llvmorg-github-actions[bot] wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Jianhui Li (Jianhui-Li)
<details>
<summary>Changes</summary>
**Summary**
Extend xegpu.dpas and xegpu.dpas_mx operations to support 3D and 4D operands with batch dimensions, enabling batched matrix multiplication workloads like [4, 128, 512] x [4, 512, 128] -> [4, 128, 128].
**Changes**
- Type definitions: Extended XeGPU_DpasOprType and XeGPU_DpasResType to support rank 3-4 (previously 1-3 and 1-2)
- Op definitions: Extended dpas_mx scale operands to support rank 3-4 vectors
- Verifiers: Updated verifyDpasDimensions() to validate batch dimensions across A, B, and result operands; updated DpasMxOp::verify() for batch-aware scale dimension checks
- Documentation: Added comprehensive documentation explaining batch dimensions, microarchitecture-specific matrix sizes, and 3D/4D usage examples
- Tests: Added unit tests for 3D batched dpas and dpas_mx operations
Assisted-by-Claude
---
Full diff: https://github.com/llvm/llvm-project/pull/199809.diff
5 Files Affected:
- (modified) mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td (+52-11)
- (modified) mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td (+2-2)
- (modified) mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp (+85-32)
- (modified) mlir/test/Dialect/XeGPU/invalid.mlir (+4-4)
- (modified) mlir/test/Dialect/XeGPU/ops.mlir (+21)
``````````diff
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index fce14999b4011..c000f7637127c 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -919,6 +919,21 @@ def XeGPU_DpasOp : XeGPU_Op<"dpas", [Pure, AllElementTypesMatch<["lhs", "rhs"]>,
data type, the matrices are `A: vector<8x16xf16>`, `B: vector<16x16xf16>`,
and `C/D: vector<8x16xf32>`.
+ The operands can be 2D, 3D, or 4D vectors. When the vectors have more than 2
+ dimensions, the leading dimensions are treated as **batch dimensions** that
+ must match across all operands (lhs, rhs, acc, and result). The batch
+ dimensions represent independent matrix multiplications that are executed in
+ parallel. For example:
+ - 2D: `A: vector<8x16xf16>`, `B: vector<16x16xf16>` -> `C: vector<8x16xf32>`
+ - 3D: `A: vector<4x8x16xf16>`, `B: vector<4x16x16xf16>` -> `C: vector<4x8x16xf32>`
+ (4 independent 8x16 matrix multiplications)
+ - 4D: `A: vector<2x4x8x16xf16>`, `B: vector<2x4x16x16xf16>` -> `C: vector<2x4x8x16xf32>`
+ (2x4=8 independent 8x16 matrix multiplications)
+
+ The last 2 dimensions are always the core matrix multiplication dimensions
+ (M, K for lhs; K, N for rhs; M, N for result). Note that rhs (B) can have
+ one extra trailing dimension for VNNI packing (e.g., `vector<4x8x16x2xf16>`).
+
In lane level code, each lane from a subgroup holds a data fragment for A, B, C and the result,
which are represented as 1D vectors. Please refer to [OpenCL Intel extentions]
(https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroup_matrix_multiply_accumulate.html)
@@ -930,19 +945,21 @@ def XeGPU_DpasOp : XeGPU_Op<"dpas", [Pure, AllElementTypesMatch<["lhs", "rhs"]>,
Arguments:
- `lhs`: A vector value representing the left-hand-side matrix tile (A) participating in the
- matrix multiply.
+ matrix multiply. Can be 1D, 2D, 3D, or 4D where leading dimensions are batch dimensions.
- - `rhs`: A vector value representing the right-hand-side matrix tile (B).
+ - `rhs`: A vector value representing the right-hand-side matrix tile (B). Can be 1D, 2D, 3D, or 4D
+ where leading dimensions are batch dimensions, plus an optional trailing dimension for VNNI packing.
- `acc`: [optional] A vector value representing the accumulator matrix tile (C). When present, the
result is computed as `lhs * rhs + acc`; otherwise, the accumulator is implicitly assumed to be zero.
+ Must have the same batch dimensions as lhs and rhs.
- `layout_a`, `layout_b`, `layout_cd`: [optional] Attributes that identify this
operation as anchor for operands A, B, and the accumulator/result, enabling users to assign layouts
that govern distribution at the subgroup and/or lane level. Only valid at workgroup and subgroup
level.
- Example 1 (Workgroup level):
+ Example 1 (Workgroup level, 2D):
```mlir
%d = xegpu.dpas %a, %b, %c <{
@@ -952,12 +969,20 @@ def XeGPU_DpasOp : XeGPU_Op<"dpas", [Pure, AllElementTypesMatch<["lhs", "rhs"]>,
: vector<64x128xf16>, vector<128x128xf16>, vector<64x128xf32> -> vector<64x128xf32>
```
- Example 2 (Lane level):
+ Example 2 (Lane level, 1D):
```mlir
%d = xegpu.dpas %a, %b, %c
: vector<8xf16>, vector<16xf16>, vector<8xf32> -> vector<8xf32>
```
+
+ Example 3 (Workgroup level, 3D with batch):
+
+ ```mlir
+ // 4 independent 8x16 matrix multiplications
+ %d = xegpu.dpas %a, %b, %c
+ : vector<4x8x16xf16>, vector<4x16x16xf16>, vector<4x8x16xf32> -> vector<4x8x16xf32>
+ ```
}];
let arguments = (ins
@@ -1428,6 +1453,21 @@ def XeGPU_DpasMxOp : XeGPU_Op<"dpas_mx", [Pure, AttrSizedOperandSegments,
size, B is of size `kxn`, and accumulate on matrix C of size `mxn` to the same size
matrix.
+ The operands can be 2D, 3D, or 4D vectors. When the vectors have more than 2
+ dimensions, the leading dimensions are treated as **batch dimensions** that
+ must match across all operands (a, b, acc, result, scale_a, scale_b). The batch
+ dimensions represent independent matrix multiplications that are executed in
+ parallel. For example:
+ - 2D: `A: vector<8x32xf8E5M2>`, `B: vector<32x16xf8E5M2>` -> `C: vector<8x16xbf16>`
+ - 3D: `A: vector<4x8x32xf8E5M2>`, `B: vector<4x32x16xf8E5M2>` -> `C: vector<4x8x16xbf16>`
+ (4 independent 8x16 scaled matrix multiplications)
+ - 4D: `A: vector<2x4x8x32xf8E5M2>`, `B: vector<2x4x32x16xf8E5M2>` -> `C: vector<2x4x8x16xbf16>`
+ (2x4=8 independent 8x16 scaled matrix multiplications)
+
+ The last 2 dimensions are always the core matrix multiplication dimensions
+ (M, K for a; K, N for b; M, N for result). The scale vectors also follow the
+ same batch dimension structure.
+
In lane level code, each lane from a subgroup holds a data fragment for A, B, Acc and the result,
which are represented as 1D vectors.
@@ -1437,18 +1477,19 @@ def XeGPU_DpasMxOp : XeGPU_Op<"dpas_mx", [Pure, AttrSizedOperandSegments,
Arguments:
- `a`: A vector value representing the left-hand-side matrix tile (A) participating in the
- matrix multiply.
+ matrix multiply. Can be 1D, 2D, 3D, or 4D where leading dimensions are batch dimensions.
- - `b`: A vector value representing the right-hand-side matrix tile (B).
+ - `b`: A vector value representing the right-hand-side matrix tile (B). Can be 1D, 2D, 3D, or 4D
+ where leading dimensions are batch dimensions.
- `acc`: A vector value representing the accumulator matrix tile (C). The
- result is computed as `a * b + acc`.
+ result is computed as `a * b + acc`. Must have the same batch dimensions as a and b.
- `scale_a`: A floating point vector/scalar value used to scale `a` for
- matrix multiplication.
+ matrix multiplication. When a vector, must have matching batch dimensions.
- `scale_b`: A floating point vector/scalar value used to scale `b` for
- matrix multiplication.
+ matrix multiplication. When a vector, must have matching batch dimensions.
- `layout_a`, `layout_b`, `layout_cd`: [optional] Attributes that identify this
operation as anchor for operands A, B, and the accumulator/result, enabling users to assign layouts
@@ -1460,9 +1501,9 @@ def XeGPU_DpasMxOp : XeGPU_Op<"dpas_mx", [Pure, AttrSizedOperandSegments,
let arguments = (ins XeGPU_DpasOprType:$a, XeGPU_DpasOprType:$b,
Optional<XeGPU_DpasResType>:$acc,
Optional<AnyTypeOf<[F8E8M0FNU,
- VectorOfRankAndType<[1, 2], [F8E8M0FNU]>]>>:$scale_a,
+ VectorOfRankAndType<[1, 2, 3, 4], [F8E8M0FNU]>]>>:$scale_a,
Optional<AnyTypeOf<[F8E8M0FNU,
- VectorOfRankAndType<[1, 2], [F8E8M0FNU]>]>>:$scale_b,
+ VectorOfRankAndType<[1, 2, 3, 4], [F8E8M0FNU]>]>>:$scale_b,
OptionalAttr<DistributeLayoutAttr>:$layout_a,
OptionalAttr<DistributeLayoutAttr>:$layout_b,
OptionalAttr<DistributeLayoutAttr>:$layout_cd,
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
index 95a3e22cf803b..0423303c23493 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
@@ -21,8 +21,8 @@ def XeGPU_ScalarType: AnyTypeOf<[XeGPU_IntType, XeGPU_FloatType]>;
def XeGPU_PointerType : AnyTypeOf<[UI64, UI32, I64, I32]>;
def XeGPU_BaseAddrType
: AnyTypeOf<[Non0RankedMemRefOf<[XeGPU_ScalarType]>, XeGPU_PointerType]>;
-def XeGPU_DpasOprType: FixedVectorOfRankAndType<[1, 2, 3], [XeGPU_ScalarType]>;
-def XeGPU_DpasResType: FixedVectorOfRankAndType<[1, 2], [XeGPU_ScalarType]>;
+def XeGPU_DpasOprType: FixedVectorOfRankAndType<[1, 2, 3, 4], [XeGPU_ScalarType]>;
+def XeGPU_DpasResType: FixedVectorOfRankAndType<[1, 2, 3, 4], [XeGPU_ScalarType]>;
def XeGPU_OffsetType: FixedVectorOfNonZeroRankOf<[Index]>;
def XeGPU_MaskType: FixedVectorOfNonZeroRankOf<[I1]>;
def XeGPU_ValueType: VectorOfRankAndType<[1,2,3,4,5,6,7,8], [XeGPU_ScalarType]>;
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 56d340141ee8f..d55a44f2279d9 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -720,31 +720,78 @@ static LogicalResult verifyDpasDimensions(Operation *op,
if (aRank == 1 && bRank == 1 && resRank == 1)
return success();
- // Validate A and B are 2D
- if (aRank != 2)
- return op->emitOpError("A operand must be a 2D vector.");
- if (bRank < 2 || bRank > 3)
- return op->emitOpError("B operand must be a 2D or 3D vector.");
- if (resRank != 2)
- return op->emitOpError("Result must be a 2D vector.");
+ // A must be at least 2D, B must be 2D or 3D (innermost dims), result at
+ // least 2D.
+ if (aRank < 2)
+ return op->emitOpError("A operand must be at least a 2D vector.");
+ if (bRank < 2)
+ return op->emitOpError("B operand must be at least a 2D vector.");
+ if (resRank < 2)
+ return op->emitOpError("Result must be at least a 2D vector.");
+
+ // Determine batch dimensions. For A[batch..., M, K], B[batch..., K, N] (or
+ // B[batch..., K/vnni, N, vnni]), result[batch..., M, N].
+ // B may have one extra trailing dim for VNNI packing (3D innermost).
+ // Determine how many trailing dims are the "core" matmul dims.
+ // A core dims: last 2 (M, K)
+ // B core dims: last 2 (K, N) or last 3 (K/vnni, N, vnni) for packed
+ // Result core dims: last 2 (M, N)
+ int64_t aBatchRank = aRank - 2;
+ int64_t resBatchRank = resRank - 2;
+
+ // B can have an extra trailing dim for VNNI packing. Determine B's batch
+ // rank: if bRank > aRank, the extra dim is the VNNI packing dim.
+ bool bPacked = (bRank == aRank + 1);
+ int64_t bBatchRank = bPacked ? bRank - 3 : bRank - 2;
+
+ // Batch ranks must match across A, B, and result.
+ if (aBatchRank != bBatchRank || aBatchRank != resBatchRank)
+ return op->emitOpError("Batch dimension rank mismatch among A, B, and "
+ "result.");
+
+ // Verify batch dimensions match.
+ for (int64_t i = 0; i < aBatchRank; ++i) {
+ if (aShape[i] != resShape[i])
+ return op->emitOpError("Batch dimension mismatch at dim ")
+ << i << ": A has " << aShape[i] << " but result has "
+ << resShape[i] << ".";
+ if (aShape[i] != bShape[i])
+ return op->emitOpError("Batch dimension mismatch at dim ")
+ << i << ": A has " << aShape[i] << " but B has " << bShape[i]
+ << ".";
+ }
- // Calculate effective K dimension for B (handle 3D packed case)
- int64_t bK = bRank == 3 ? bShape[0] * bShape[2] : bShape[0];
+ // Now verify the core matmul dimensions (last 2 of A and result, last 2 or
+ // 3 of B).
+ int64_t aM = aShape[aBatchRank];
+ int64_t aK = aShape[aBatchRank + 1];
+ int64_t resM = resShape[resBatchRank];
+ int64_t resN = resShape[resBatchRank + 1];
+
+ // Calculate effective K dimension for B (handle packed case)
+ int64_t bK, bN;
+ if (bPacked) {
+ bK = bShape[bBatchRank] * bShape[bBatchRank + 2];
+ bN = bShape[bBatchRank + 1];
+ } else {
+ bK = bShape[bBatchRank];
+ bN = bShape[bBatchRank + 1];
+ }
// Verify K dimension match between A and B
- if (bK != aShape[1])
+ if (bK != aK)
return op->emitOpError("K-dimension mismatch: A has K=")
- << aShape[1] << " but B has K=" << bK << ".";
+ << aK << " but B has K=" << bK << ".";
// Verify M dimension match between A and result
- if (aShape[0] != resShape[0])
+ if (aM != resM)
return op->emitOpError("M-dimension mismatch: A has M=")
- << aShape[0] << " but result has M=" << resShape[0] << ".";
+ << aM << " but result has M=" << resM << ".";
// Verify N dimension match between B and result
- if (bShape[1] != resShape[1])
+ if (bN != resN)
return op->emitOpError("N-dimension mismatch: B has N=")
- << bShape[1] << " but result has N=" << resShape[1] << ".";
+ << bN << " but result has N=" << resN << ".";
return success();
}
@@ -907,6 +954,9 @@ LogicalResult DpasMxOp::verify() {
if (failed(verifyDpasDimensions(*this, aShape, bShape, resShape)))
return failure();
+ // Determine batch rank from A operand.
+ int64_t aBatchRank = aShape.size() - 2;
+
// Validate scale_a if present
if (getScaleA()) {
auto scaleAVecType = dyn_cast<VectorType>(getScaleAType());
@@ -914,19 +964,20 @@ LogicalResult DpasMxOp::verify() {
if (scaleAVecType && scaleAVecType.getRank() > 1) {
auto scaleAShape = scaleAVecType.getShape();
- if (scaleAVecType.getRank() != 2)
- return emitOpError("Scale A must be a 2D vector when not a scalar.");
+ if (scaleAVecType.getRank() < 2)
+ return emitOpError("Scale A must be at least a 2D vector when not a "
+ "scalar.");
// Verify layout distributability for scale_a
if (failed(verifyLayoutDistributable(*this, getLayoutAScale(),
scaleAShape, "ScaleA")))
return failure();
- // Validate M dimension: scale_a[0] must match a[0]
- if (scaleAShape[0] != aShape[0])
+ // Validate M dimension: scale_a's M must match A's M (last-1 dim)
+ if (scaleAShape[scaleAShape.size() - 2] != aShape[aBatchRank])
return emitOpError("Scale A M dimension [")
- << scaleAShape[0] << "] must match A M dimension [" << aShape[0]
- << "].";
+ << scaleAShape[scaleAShape.size() - 2]
+ << "] must match A M dimension [" << aShape[aBatchRank] << "].";
}
}
@@ -937,19 +988,21 @@ LogicalResult DpasMxOp::verify() {
if (scaleBVecType && scaleBVecType.getRank() > 1) {
auto scaleBShape = scaleBVecType.getShape();
- if (scaleBVecType.getRank() != 2)
- return emitOpError("Scale B must be a 2D vector when not a scalar.");
+ if (scaleBVecType.getRank() < 2)
+ return emitOpError("Scale B must be at least a 2D vector when not a "
+ "scalar.");
// Verify layout distributability for scale_b
if (failed(verifyLayoutDistributable(*this, getLayoutBScale(),
scaleBShape, "ScaleB")))
return failure();
- // Validate N dimension: scale_b[1] must match b[1]
- if (scaleBShape[1] != bShape[1])
+ // Validate N dimension: scale_b's N (last dim) must match B's N (last
+ // dim)
+ if (scaleBShape.back() != bShape.back())
return emitOpError("Scale B N dimension [")
- << scaleBShape[1] << "] must match B N dimension [" << bShape[1]
- << "].";
+ << scaleBShape.back() << "] must match B N dimension ["
+ << bShape.back() << "].";
}
}
@@ -964,12 +1017,12 @@ LogicalResult DpasMxOp::verify() {
auto scaleAShape = scaleAVecType.getShape();
auto scaleBShape = scaleBVecType.getShape();
- // Validate scale K dimension compatibility: scale_a[1] must match
- // scale_b[0]
- if (scaleAShape[1] != scaleBShape[0])
+ // Validate scale K dimension compatibility: scale_a's last dim must
+ // match scale_b's second-to-last dim
+ if (scaleAShape.back() != scaleBShape[scaleBShape.size() - 2])
return emitOpError("Scale K dimension mismatch: scale_a has K=")
- << scaleAShape[1] << " but scale_b has K=" << scaleBShape[0]
- << ".";
+ << scaleAShape.back() << " but scale_b has K="
+ << scaleBShape[scaleBShape.size() - 2] << ".";
}
}
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index 341427e7e1231..438f675ecc3f4 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -394,7 +394,7 @@ func.func @dpas_vc_1(%a : vector<8x8xf16>, %b: vector<8x16x2xf16>) {
// -----
func.func @dpas_vc_2(%a : vector<8x8x2xf16>, %b: vector<8x16x2xf16>) {
- // expected-error at +1 {{op A operand must be a 2D vector}}
+ // expected-error at +1 {{'xegpu.dpas' op Batch dimension rank mismatch among A, B, and result}}
%1 = xegpu.dpas %a, %b : vector<8x8x2xf16>, vector<8x16x2xf16> -> vector<8x16xf32>
return
}
@@ -702,21 +702,21 @@ func.func @dpas_mx_acc_result_type_mismatch(%a : vector<8x16xf8E5M2>, %b: vector
// -----
func.func @dpas_mx_a_not_2d(%a : vector<128xf8E5M2>, %b: vector<16x16xf8E5M2>) {
- // expected-error at +1 {{A operand must be a 2D vector.}}
+ // expected-error at +1 {{A operand must be at least a 2D vector.}}
%1 = xegpu.dpas_mx %a, %b : (vector<128xf8E5M2>, vector<16x16xf8E5M2>) -> vector<8x16xf32>
return
}
// -----
func.func @dpas_mx_b_not_2d(%a : vector<8x16xf8E5M2>, %b: vector<256xf8E5M2>) {
- // expected-error at +1 {{B operand must be a 2D or 3D vector.}}
+ // expected-error at +1 {{B operand must be at least a 2D vector.}}
%1 = xegpu.dpas_mx %a, %b : (vector<8x16xf8E5M2>, vector<256xf8E5M2>) -> vector<8x16xf32>
return
}
// -----
func.func @dpas_mx_result_not_2d(%a : vector<8x16xf8E5M2>, %b: vector<16x16xf8E5M2>) {
- // expected-error at +1 {{Result must be a 2D vector.}}
+ // expected-error at +1 {{Result must be at least a 2D vector.}}
%1 = xegpu.dpas_mx %a, %b : (vector<8x16xf8E5M2>, vector<16x16xf8E5M2>) -> vector<128xf32>
return
}
diff --git a/mlir/test/Dialect/XeGPU/ops.mlir b/mlir/test/Dialect/XeGPU/ops.mlir
index 198dc64d2814b..5cc74cdf7c9e4 100644
--- a/mlir/test/Dialect/XeGPU/ops.mlir
+++ b/mlir/test/Dialect/XeGPU/ops.mlir
@@ -649,4 +649,25 @@ gpu.func @dpas_mx(%a : vector<8x32xf8E5M2>, %b: vector<32x16xf8E5M2>, %acc: vect
gpu.return
}
+// CHECK-LABEL: gpu.func @dpas_3d_batch
+gpu.func @dpas_3d_batch(%a : vector<4x8x16xf16>, %b: vector<4x16x16xf16>) {
+ // CHECK: %{{.+}} = xegpu.dpas %{{.+}}, %{{.+}} : vector<4x8x16xf16>, vector<4x16x16xf16> -> vector<4x8x16xf32>
+ %1 = xegpu.dpas %a, %b: vector<4x8x16xf16>, vector<4x16x16xf16> -> vector<4x8x16xf32>
+ gpu.return
+}
+
+// CHECK-LABEL: gpu.func @dpas_3d_batch_with_acc
+gpu.func @dpas_3d_batch_with_acc(%a : vector<4x8x16xf16>, %b: vector<4x16x16xf16>, %acc: vector<4x8x16xf32>) {
+ // CHECK: %{{.+}} = xegpu.dpas %{{.+}}, %{{.+}}, %{{.+}} : vector<4x8x16xf16>, vector<4x16x16xf16>, vector<4x8x16xf32> -> vector<4x8x16xf32>
+ %1 = xegpu.dpas %a, %b, %acc : vector<4x8x16xf16>, vector<4x16x16xf16>, vector<4x8x16xf32> -> vector<4x8x16xf32>
+ gpu.return
+}
+
+// CHECK-LABEL: gpu.func @dpas_mx_3d_batch
+gpu.func @dpas_mx_3d_batch(%a : vector<4x8x32xf8E5M2>, %b: vector<4x32x16xf8E5M2>, %acc: vector<4x8x16xbf16>, %a_scale: vector<4x8x1xf8E8M0FNU>, %b_scale: vector<4x1x16xf8E8M0FNU>) {
+ // CHECK: %{{.+}} = xegpu.dpas_mx %{{.+}}, %{{.+}}, %{{.+}} scale_a = %{{.+}} scale_b = %{{.+}} : (vector<4x8x32xf8E5M2>, vector<4x32x16xf8E5M2>, vector<4x8x16xbf16>, vector<4x8x1xf8E8M0FNU>, vector<4x1x16xf8E8M0FNU>) -> vector<4x8x16xbf16>
+ %1 = xegpu.dpas_mx %a, %b, %acc scale_a = %a_scale scale_b = %b_scale : (vector<4x8x32xf8E5M2>, vector<4x32x16xf8E5M2>, vector<4x8x16xbf16>, vector<4x8x1xf8E8M0FNU>, vector<4x1x16xf8E8M0FNU>) -> vector<4x8x16xbf16>
+ gpu.return
+}
+
}
``````````
</details>
https://github.com/llvm/llvm-project/pull/199809
More information about the Mlir-commits
mailing list