[Mlir-commits] [mlir] [MLIR][XeGPU] Add truncf and dpas_mx ops (PR #180059)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Feb 5 14:54:29 PST 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-gpu
Author: Sang Ik Lee (silee2)
<details>
<summary>Changes</summary>
Add truncf and dpas_mx ops to support low precision float mma compute.
---
Full diff: https://github.com/llvm/llvm-project/pull/180059.diff
4 Files Affected:
- (modified) mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td (+138)
- (modified) mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp (+26)
- (modified) mlir/test/Dialect/XeGPU/invalid.mlir (+14)
- (modified) mlir/test/Dialect/XeGPU/ops.mlir (+14)
``````````diff
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index 2cbec50772b98..8cbe6ee047c61 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -1719,4 +1719,142 @@ def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>,
let hasVerifier = 1;
}
+def XeGPU_TruncfOp
+ : XeGPU_Op<"truncf", [AllRanksMatch<["source", "result"]>,
+ AllShapesMatch<["source", "result"]>, Pure]> {
+ let summary = "It performs floating point truncation from higher precision "
+ "to lower precision.";
+
+ let description =
+ [{The `xegpu.truncf` operation truncates floating point values from a higher
+ precision type to a lower precision type.
+ Converts `f16` and `bf16` to microscaling float types.
+ Rounding mode defaults to round to nearest even.
+
+ Example:
+ ```mlir
+ %res = xegpu.truncf %src : vector<16xf16> -> vector<16xf8E5M2>
+ ```
+
+ }];
+
+ let arguments = (ins FixedVectorOfNonZeroRankOf<[XeGPU_FloatType]>:$source);
+ let results = (outs FixedVectorOfNonZeroRankOf<[XeGPU_FloatType]>:$result);
+
+ let extraClassDeclaration = [{
+
+ Type getSourceType() {
+ return getSource().getType();
+ }
+
+ Type getResultType() {
+ return getResult().getType();
+ }
+
+ }];
+
+ let assemblyFormat = [{
+ $source attr-dict `:` type(operands) `->` type(results)
+ }];
+
+ let hasVerifier = 1;
+}
+
+def XeGPU_DpasMxOp : XeGPU_Op<"dpas_mx", [Pure, AttrSizedOperandSegments,
+ AllElementTypesMatch<["a", "b"]>,
+ AnchorLayoutInterface]> {
+ let summary = "It performs scaled mma computation";
+
+ let description =
+ [{DPAS MX performs matrix multiplication on matrix A and Matrix B
+ of low precision data type. A is of size`mxk`
+ size, B is of size `kxn`, and accumulate on matrix C of size `mxn` to the same size
+ matrix.
+
+ In lane level code, each lane from a subgroup holds a data fragment for A, B, Acc and the result,
+ which are represented as 1D vectors.
+
+ This operation serves as an anchor through which users assign a layout attribute
+ to govern computation distribution.
+
+ Arguments:
+
+ - `a`: A vector value representing the left-hand-side matrix tile (A) participating in the
+ matrix multiply.
+
+ - `b`: A vector value representing the right-hand-side matrix tile (B).
+
+ - `acc`: A vector value representing the accumulator matrix tile (C). The
+ result is computed as `a * b + acc`.
+
+ - `scale_a`: A floating point vector/scalar value used to scale `a` for
+ matrix multiplication.
+
+ - `scale_b`: A floating point vector/scalar value used to scale `b` for
+ matrix multiplication.
+
+ - `layout_a`, `layout_b`, `layout_cd`: [optional] Attributes that identify this
+ operation as anchor for operands A, B, and the accumulator/result, enabling users to assign layouts
+ that govern distribution at the subgroup and/or lane level. Only valid at workgroup and subgroup
+ level.
+
+ }];
+
+ let arguments = (ins XeGPU_DpasOprType:$a, XeGPU_DpasOprType:$b,
+ Optional<XeGPU_DpasResType>:$acc,
+ Optional<AnyTypeOf<[F8E8M0FNU,
+ VectorOfRankAndType<[1, 2], [F8E8M0FNU]>]>>:$scale_a,
+ Optional<AnyTypeOf<[F8E8M0FNU,
+ VectorOfRankAndType<[1, 2], [F8E8M0FNU]>]>>:$scale_b,
+ OptionalAttr<DistributeLayoutAttr>:$layout_a,
+ OptionalAttr<DistributeLayoutAttr>:$layout_b,
+ OptionalAttr<DistributeLayoutAttr>:$layout_cd);
+ let results = (outs XeGPU_DpasResType:$result);
+ let extraClassDeclaration = [{
+
+ xegpu::DistributeLayoutAttr getAnchorLayout() {
+ return getLayoutCd().value_or(nullptr);
+ }
+
+ void setAnchorLayout(xegpu::DistributeLayoutAttr anchorLayout) {
+ setLayoutCdAttr(anchorLayout);
+ }
+
+ VectorType getAType() {
+ return getA().getType();
+ }
+
+ VectorType getBType() {
+ return getB().getType();
+ }
+
+ VectorType getAccType() {
+ return getAcc().getType();
+ }
+
+ Type getScaleAType() {
+ return getScaleA().getType();
+ }
+
+ Type getScaleBType() {
+ return getScaleB().getType();
+ }
+
+ VectorType getResultType() {
+ return getResult().getType();
+ }
+
+ }];
+ let assemblyFormat = [{
+ $a `,` $b (`,` $acc^)?
+ (`scale_a` `=` $scale_a^)?
+ (`scale_b` `=` $scale_b^)?
+ attr-dict `:` type($a)`,` type($b)
+ (`,` type($acc)^)?
+ (`,` type($scale_a)^)?
+ (`,` type($scale_b)^)? `->` type($result)
+ }];
+ let hasVerifier = 1;
+}
+
#endif // MLIR_DIALECT_XEGPU_IR_XEGPUOPS_TD
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 91ba07a8e0256..38a591c627479 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -1187,6 +1187,32 @@ LogicalResult StoreMatrixOp::verify() {
getLayoutAttr(), [&]() { return emitError(); });
}
+//===----------------------------------------------------------------------===//
+// XeGPU_TruncfOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult TruncfOp::verify() {
+ auto sourceVecType = dyn_cast<VectorType>(getSource().getType());
+ auto resultVecType = dyn_cast<VectorType>(getResult().getType());
+
+ if (sourceVecType.getElementTypeBitWidth() <=
+ resultVecType.getElementTypeBitWidth())
+ return emitOpError("input type must be wider than result type.");
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// XeGPU_DpasMxOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult DpasMxOp::verify() {
+ if (getAcc() && getAcc().getType() != getResultType())
+ return emitOpError("Expecting the acc type to be the same as result.");
+
+ return success();
+}
+
namespace mlir {
#include <mlir/Dialect/XeGPU/IR/XeGPUAttrInterface.cpp.inc>
} // namespace mlir
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index f2011ab86e9e9..36b914fa578bf 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -901,3 +901,17 @@ func.func @simt_store_matrix_vector_noncoalesced(%arg0: !xegpu.mem_desc<32x32xf3
vector<16x2xf32>, !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<stride = [32, 1], block = [1, 17]>>
return
}
+
+// -----
+func.func @truncf_invalid_result_size(%a: vector<8x16xf16>) {
+ // expected-error at +1 {{op input type must be wider than result type}}
+ %1 = xegpu.truncf %a : vector<8x16xf16> -> vector<8x16xf32>
+ return
+}
+
+// -----
+func.func @dpas_mx_acc_result_type_mismatch(%a : vector<8x16xf8E5M2>, %b: vector<16x16xf8E5M2>, %acc: vector<8x16xbf16>) {
+ // expected-error at +1 {{Expecting the acc type to be the same as result.}}
+ %1 = xegpu.dpas_mx %a, %b, %acc : vector<8x16xf8E5M2>, vector<16x16xf8E5M2>, vector<8x16xbf16> -> vector<8x16xf32>
+ return
+}
diff --git a/mlir/test/Dialect/XeGPU/ops.mlir b/mlir/test/Dialect/XeGPU/ops.mlir
index 1e9738f44bb66..520061925f92c 100644
--- a/mlir/test/Dialect/XeGPU/ops.mlir
+++ b/mlir/test/Dialect/XeGPU/ops.mlir
@@ -925,4 +925,18 @@ gpu.func @simt_store_matrix_vector(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_
gpu.return
}
+// CHECK-LABEL: gpu.func @truncf
+gpu.func @truncf(%a: vector<8x16xf16>) {
+ // CHECK: %{{.+}} = xegpu.truncf %{{.+}} : vector<8x16xf16> -> vector<8x16xf8E5M2>
+ %1 = xegpu.truncf %a : vector<8x16xf16> -> vector<8x16xf8E5M2>
+ gpu.return
+}
+
+// CHECK-LABEL: gpu.func @dpas_mx
+gpu.func @dpas_mx(%a : vector<8x16xf8E5M2>, %b: vector<16x16xf8E5M2>, %acc: vector<8x16xbf16>) {
+ // CHECK: %{{.+}} = xegpu.dpas_mx %{{.+}}, %{{.+}}, %{{.+}} : vector<8x16xf8E5M2>, vector<16x16xf8E5M2>, vector<8x16xbf16> -> vector<8x16xbf16>
+ %1 = xegpu.dpas_mx %a, %b, %acc : vector<8x16xf8E5M2>, vector<16x16xf8E5M2>, vector<8x16xbf16> -> vector<8x16xbf16>
+ gpu.return
+}
+
}
``````````
</details>
https://github.com/llvm/llvm-project/pull/180059
More information about the Mlir-commits
mailing list