[Mlir-commits] [mlir] [MLIR][XeGPU] Add truncf and dpas_mx ops (PR #180059)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Feb 5 14:54:29 PST 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-gpu

Author: Sang Ik Lee (silee2)

<details>
<summary>Changes</summary>

Add truncf and dpas_mx ops to support low precision float mma compute.

---
Full diff: https://github.com/llvm/llvm-project/pull/180059.diff


4 Files Affected:

- (modified) mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td (+138) 
- (modified) mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp (+26) 
- (modified) mlir/test/Dialect/XeGPU/invalid.mlir (+14) 
- (modified) mlir/test/Dialect/XeGPU/ops.mlir (+14) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index 2cbec50772b98..8cbe6ee047c61 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -1719,4 +1719,142 @@ def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>,
   let hasVerifier = 1;
 }
 
+def XeGPU_TruncfOp
+    : XeGPU_Op<"truncf", [AllRanksMatch<["source", "result"]>,
+                          AllShapesMatch<["source", "result"]>, Pure]> {
+  let summary = "It performs floating point truncation from higher precision "
+                "to lower precision.";
+
+  let description =
+      [{The `xegpu.truncf` operation truncates floating point values from a higher
+    precision type to a lower precision type.
+    Converts `f16` and `bf16` to microscaling float types.
+    Rounding mode defaults to round to nearest even.
+
+    Example:
+    ```mlir
+      %res = xegpu.truncf %src : vector<16xf16> -> vector<16xf8E5M2>
+    ```
+
+  }];
+
+  let arguments = (ins FixedVectorOfNonZeroRankOf<[XeGPU_FloatType]>:$source);
+  let results = (outs FixedVectorOfNonZeroRankOf<[XeGPU_FloatType]>:$result);
+
+  let extraClassDeclaration = [{
+
+    Type getSourceType() {
+      return getSource().getType();
+    }
+
+    Type getResultType() {
+      return getResult().getType();
+    }
+
+  }];
+
+  let assemblyFormat = [{
+    $source attr-dict `:` type(operands) `->` type(results)
+  }];
+
+  let hasVerifier = 1;
+}
+
+def XeGPU_DpasMxOp : XeGPU_Op<"dpas_mx", [Pure, AttrSizedOperandSegments,
+                                          AllElementTypesMatch<["a", "b"]>,
+                                          AnchorLayoutInterface]> {
+  let summary = "It performs scaled mma computation";
+
+  let description =
+      [{DPAS MX performs matrix multiplication on matrix A and Matrix B
+    of low precision data type. A is of size`mxk`
+    size, B is of size `kxn`, and accumulate on matrix C of size `mxn` to the same size
+    matrix.
+
+    In lane level code, each lane from a subgroup holds a data fragment for A, B, Acc and the result,
+    which are represented as 1D vectors.
+
+    This operation serves as an anchor through which users assign a layout attribute
+    to govern computation distribution.
+
+    Arguments:
+
+    - `a`: A vector value representing the left-hand-side matrix tile (A) participating in the
+      matrix multiply.
+
+    - `b`: A vector value representing the right-hand-side matrix tile (B).
+
+    - `acc`: A vector value representing the accumulator matrix tile (C). The
+      result is computed as `a * b + acc`.
+
+    - `scale_a`: A floating point vector/scalar value used to scale `a` for
+      matrix multiplication.
+
+    - `scale_b`: A floating point vector/scalar value used to scale `b` for
+      matrix multiplication.
+
+    - `layout_a`, `layout_b`, `layout_cd`: [optional] Attributes that identify this
+      operation as anchor for operands A, B, and the accumulator/result, enabling users to assign layouts
+      that govern distribution at the subgroup and/or lane level. Only valid at workgroup and subgroup
+      level.
+
+  }];
+
+  let arguments = (ins XeGPU_DpasOprType:$a, XeGPU_DpasOprType:$b,
+      Optional<XeGPU_DpasResType>:$acc,
+      Optional<AnyTypeOf<[F8E8M0FNU,
+                          VectorOfRankAndType<[1, 2], [F8E8M0FNU]>]>>:$scale_a,
+      Optional<AnyTypeOf<[F8E8M0FNU,
+                          VectorOfRankAndType<[1, 2], [F8E8M0FNU]>]>>:$scale_b,
+      OptionalAttr<DistributeLayoutAttr>:$layout_a,
+      OptionalAttr<DistributeLayoutAttr>:$layout_b,
+      OptionalAttr<DistributeLayoutAttr>:$layout_cd);
+  let results = (outs XeGPU_DpasResType:$result);
+  let extraClassDeclaration = [{
+
+    xegpu::DistributeLayoutAttr getAnchorLayout() {
+      return getLayoutCd().value_or(nullptr);
+    }
+
+    void setAnchorLayout(xegpu::DistributeLayoutAttr anchorLayout) {
+      setLayoutCdAttr(anchorLayout);
+    }
+
+    VectorType getAType() {
+      return getA().getType();
+    }
+
+    VectorType getBType() {
+      return getB().getType();
+    }
+
+    VectorType getAccType() {
+      return getAcc().getType();
+    }
+
+    Type getScaleAType() {
+      return getScaleA().getType();
+    }
+
+    Type getScaleBType() {
+      return getScaleB().getType();
+    }
+
+    VectorType getResultType() {
+      return getResult().getType();
+    }
+
+  }];
+  let assemblyFormat = [{
+    $a `,` $b (`,` $acc^)?
+    (`scale_a` `=` $scale_a^)?
+    (`scale_b` `=` $scale_b^)?
+    attr-dict `:` type($a)`,` type($b)
+    (`,` type($acc)^)?
+    (`,` type($scale_a)^)?
+    (`,` type($scale_b)^)? `->` type($result)
+  }];
+  let hasVerifier = 1;
+}
+
 #endif // MLIR_DIALECT_XEGPU_IR_XEGPUOPS_TD
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 91ba07a8e0256..38a591c627479 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -1187,6 +1187,32 @@ LogicalResult StoreMatrixOp::verify() {
                                getLayoutAttr(), [&]() { return emitError(); });
 }
 
+//===----------------------------------------------------------------------===//
+// XeGPU_TruncfOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult TruncfOp::verify() {
+  auto sourceVecType = dyn_cast<VectorType>(getSource().getType());
+  auto resultVecType = dyn_cast<VectorType>(getResult().getType());
+
+  if (sourceVecType.getElementTypeBitWidth() <=
+      resultVecType.getElementTypeBitWidth())
+    return emitOpError("input type must be wider than result type.");
+
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// XeGPU_DpasMxOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult DpasMxOp::verify() {
+  if (getAcc() && getAcc().getType() != getResultType())
+    return emitOpError("Expecting the acc type to be the same as result.");
+
+  return success();
+}
+
 namespace mlir {
 #include <mlir/Dialect/XeGPU/IR/XeGPUAttrInterface.cpp.inc>
 } // namespace mlir
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index f2011ab86e9e9..36b914fa578bf 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -901,3 +901,17 @@ func.func @simt_store_matrix_vector_noncoalesced(%arg0: !xegpu.mem_desc<32x32xf3
         vector<16x2xf32>, !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<stride = [32, 1], block = [1, 17]>>
   return
 }
+
+// -----
+func.func @truncf_invalid_result_size(%a: vector<8x16xf16>) {
+  // expected-error at +1 {{op input type must be wider than result type}}
+  %1 = xegpu.truncf %a : vector<8x16xf16> -> vector<8x16xf32>
+  return
+}
+
+// -----
+func.func @dpas_mx_acc_result_type_mismatch(%a : vector<8x16xf8E5M2>, %b: vector<16x16xf8E5M2>, %acc: vector<8x16xbf16>) {
+  // expected-error at +1 {{Expecting the acc type to be the same as result.}}
+  %1 = xegpu.dpas_mx %a, %b, %acc : vector<8x16xf8E5M2>, vector<16x16xf8E5M2>, vector<8x16xbf16> -> vector<8x16xf32>
+  return
+}
diff --git a/mlir/test/Dialect/XeGPU/ops.mlir b/mlir/test/Dialect/XeGPU/ops.mlir
index 1e9738f44bb66..520061925f92c 100644
--- a/mlir/test/Dialect/XeGPU/ops.mlir
+++ b/mlir/test/Dialect/XeGPU/ops.mlir
@@ -925,4 +925,18 @@ gpu.func @simt_store_matrix_vector(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_
   gpu.return
 }
 
+// CHECK-LABEL: gpu.func @truncf
+gpu.func @truncf(%a: vector<8x16xf16>) {
+  // CHECK: %{{.+}} = xegpu.truncf %{{.+}} : vector<8x16xf16> -> vector<8x16xf8E5M2>
+  %1 = xegpu.truncf %a : vector<8x16xf16> -> vector<8x16xf8E5M2>
+  gpu.return
+}
+
+// CHECK-LABEL: gpu.func @dpas_mx
+gpu.func @dpas_mx(%a : vector<8x16xf8E5M2>, %b: vector<16x16xf8E5M2>, %acc: vector<8x16xbf16>) {
+  // CHECK: %{{.+}} = xegpu.dpas_mx %{{.+}}, %{{.+}}, %{{.+}} : vector<8x16xf8E5M2>, vector<16x16xf8E5M2>, vector<8x16xbf16> -> vector<8x16xbf16>
+  %1 = xegpu.dpas_mx %a, %b, %acc : vector<8x16xf8E5M2>, vector<16x16xf8E5M2>, vector<8x16xbf16> -> vector<8x16xbf16>
+  gpu.return
+}
+
 }

``````````

</details>


https://github.com/llvm/llvm-project/pull/180059


More information about the Mlir-commits mailing list