[Mlir-commits] [mlir] [MLIR][XeVM] Add truncf and mma_mx op. (PR #180055)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Feb 5 14:50:13 PST 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Sang Ik Lee (silee2)

<details>
<summary>Changes</summary>

truncf op converts 16 bit floats to 8 bit or 4 bit floats mma_mx op does cooperative matrix multiply accumulate on 8 or 4 bit float type with 8 bit scale value.

---
Full diff: https://github.com/llvm/llvm-project/pull/180055.diff


3 Files Affected:

- (modified) mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td (+80-5) 
- (modified) mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp (+8) 
- (modified) mlir/test/Dialect/LLVMIR/xevm.mlir (+23) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td
index 388efaaa25117..7d16827405061 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td
@@ -519,11 +519,15 @@ def XeVM_ET_U4 : I32EnumAttrCase<"U4", 13, "u4">;
 def XeVM_ET_TF32 : I32EnumAttrCase<"TF32", 14, "tf32">;
 def XeVM_ET_F32 : I32EnumAttrCase<"F32", 15, "f32">;
 def XeVM_ET_S32 : I32EnumAttrCase<"S32", 16, "s32">;
-
-def XeVM_ElemTypeAttr : I32EnumAttr<"ElemType", "XeVM element type",
-                                    [XeVM_ET_BF16, XeVM_ET_F16, XeVM_ET_S8,
-                                     XeVM_ET_U8, XeVM_ET_S4, XeVM_ET_U4,
-                                     XeVM_ET_TF32, XeVM_ET_F32, XeVM_ET_S32]> {
+def XeVM_ET_E2M1 : I32EnumAttrCase<"E2M1", 17, "e2m1">;
+def XeVM_ET_BF8 : I32EnumAttrCase<"BF8", 18, "bf8">;
+def XeVM_ET_F8 : I32EnumAttrCase<"F8", 19, "f8">;
+
+def XeVM_ElemTypeAttr
+    : I32EnumAttr<"ElemType", "XeVM element type",
+                  [XeVM_ET_BF16, XeVM_ET_F16, XeVM_ET_S8, XeVM_ET_U8,
+                   XeVM_ET_S4, XeVM_ET_U4, XeVM_ET_TF32, XeVM_ET_F32,
+                   XeVM_ET_S32, XeVM_ET_E2M1, XeVM_ET_BF8, XeVM_ET_F8]> {
   let cppNamespace = "::mlir::xevm";
 }
 
@@ -592,6 +596,77 @@ def XeVM_MMAOp
   let hasVerifier = 1;
 }
 
+def XeVM_ConversionTypesAttr
+    : XeVM_Attr<"ConversionTypes", "conversion_types"> {
+  let description = [{
+    This attribute is used to specify source and destination types for conversion operations.
+  }];
+  let parameters = (ins "xevm::ElemType":$src_type, "xevm::ElemType":$dst_type);
+  let assemblyFormat = "`<` struct(params) `>`";
+}
+
+def XeVM_TruncfOp
+    : XeVM_Op<"truncf">,
+      Results<(outs AnyTypeOf<[I8]>:$res)>,
+      Arguments<(ins AnyTypeOf<[I16]>:$arg, XeVM_ConversionTypesAttr:$types)> {
+  let summary = "Floating point truncation from f16/bf16 to f8/bf8/f4";
+  let description = [{
+    The `xevm.truncf` operation truncates a floating point value from
+    f16/bf16 to f8/bf8/f4 format.
+  }];
+
+  let assemblyFormat = [{
+    $arg ` ` `{` `types` `=` $types `}` attr-dict `:` functional-type(operands, results)
+  }];
+}
+
+def XeVM_MMAMxOp
+    : XeVM_Op<"mma_mx">,
+      Results<(outs FixedVectorOfRankAndType<[1], [XeVM_MatrixElemType]>:$d)>,
+      Arguments<(ins FixedVectorOfRankAndType<[1], [XeVM_MatrixElemType]>:$a,
+          FixedVectorOfRankAndType<[1], [XeVM_MatrixElemType]>:$b,
+          FixedVectorOfRankAndType<[1], [I8]>:$scale_a,
+          FixedVectorOfRankAndType<[1], [I8]>:$scale_b,
+          Optional<FixedVectorOfRankAndType<[1], [XeVM_MatrixElemType]>>:$c,
+          XeVM_MMAShapeAttr:$shape, XeVM_MMATypesAttr:$types)> {
+
+  let summary = "Subgroup matrix multiply-add with MxN shape and MX scaling";
+
+  let description = [{
+    The `xevm.mma_mx` is similar to `xevm.mma` has scale operands for A and B matrices.
+    It is a cooperative operation where all threads/lanes in a subgroup participates
+    and carries out matrix multiplication plus accumulation:
+
+      D = C + A x B
+
+      where the A, B, C input matrices and the result D have shapes:
+        - D : MxN
+        - C : MxN
+        - A : MxK
+        - B : KxN
+
+    Parameters:
+      * `a` - vector of matrix A elements.
+      * `b` - vector of matrix B elements.
+      * `scale_a` - vector of scaling factors for matrix A.
+      * `scale_b` - vector of scaling factors for matrix B.
+      * `c` - (optional) vector of matrix C elements.
+      * `shape` - the shape of the matrices, specified as `M`, `N`, and `K` values.
+      * `types` - the data types of the matrices, specified as `D`, `A`, `B`, and optionally `C`.
+
+  }];
+
+  let assemblyFormat = [{
+    $a `,` $b `,` $scale_a `,` $scale_b (`,` $c^)? ` `
+    `{`
+      `shape` `=` $shape `,`
+      `types` `=` $types
+    `}` attr-dict `:` functional-type(operands, results)
+  }];
+
+  let hasVerifier = 1;
+}
+
 //===----------------------------------------------------------------------===//
 // XeVM target attribute.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp
index 04e8836c00359..f202f0b82e935 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp
@@ -349,6 +349,14 @@ LogicalResult MMAOp::verify() {
   return success();
 }
 
+LogicalResult MMAMxOp::verify() {
+  if (getC()) {
+    if (getResult().getType() != getC().getType())
+      return emitOpError("type of C operand must match result type");
+  }
+  return success();
+}
+
 LogicalResult
 XeVMTargetAttr::verify(function_ref<InFlightDiagnostic()> emitError, int O,
                        StringRef triple, StringRef chip, DictionaryAttr flags,
diff --git a/mlir/test/Dialect/LLVMIR/xevm.mlir b/mlir/test/Dialect/LLVMIR/xevm.mlir
index 66fb2949a270f..7f1eae2580d39 100644
--- a/mlir/test/Dialect/LLVMIR/xevm.mlir
+++ b/mlir/test/Dialect/LLVMIR/xevm.mlir
@@ -92,6 +92,29 @@ func.func @mma(%loaded_c_casted: vector<8xf32>, %loaded_a: vector<8xi16>, %loade
   return %c_result : vector<8xf32>
 }
 
+// -----
+// CHECK-LABEL: func.func @mma_mx(
+// CHECK-SAME: %[[ARG0:.*]]: vector<8xf32>, %[[ARG1:.*]]: vector<8xi16>, %[[ARG2:.*]]: vector<8xi32>, %[[ARG3:.*]]: vector<2xi8>, %[[ARG4:.*]]: vector<2xi8>)
+func.func @mma_mx(%loaded_c_casted: vector<8xf32>, %loaded_a: vector<8xi16>, %loaded_b_casted: vector<8xi32>, %scale_a: vector<2xi8>, %scale_b: vector<2xi8>) -> vector<8xf32> {
+  // CHECK: %[[VAR0:.*]] = xevm.mma_mx %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG4]], %[[ARG0]]
+  // CHECK-SAME: {shape = <m = 8, n = 16, k = 64>, types = <d = f32, a = e2m1, b = e2m1, c = f32>}
+  // CHECK-SAME:: (vector<8xi16>, vector<8xi32>, vector<2xi8>, vector<2xi8>, vector<8xf32>) -> vector<8xf32>
+  %c_result = xevm.mma_mx %loaded_a, %loaded_b_casted, %scale_a, %scale_b, %loaded_c_casted { shape=<m=8, n=16, k=64>,
+    types=<d=f32, a=e2m1, b=e2m1, c=f32> } : (vector<8xi16>, vector<8xi32>, vector<2xi8>, vector<2xi8>, vector<8xf32>) -> vector<8xf32>
+  return %c_result : vector<8xf32>
+}
+
+// -----
+// CHECK-LABEL: func.func @truncf
+func.func @truncf() -> i8 {
+  %0 = arith.constant 1.0 : bf16
+  // CHECK: %[[VAR1:.*]] = arith.bitcast %{{.+}} : bf16 to i16
+  %1 = arith.bitcast %0 : bf16 to i16
+  // CHECK: xevm.truncf %[[VAR1]] {types = <src_type = bf16, dst_type = e2m1>} : (i16) -> i8
+  %2 = xevm.truncf %1 { types=<src_type=bf16, dst_type=e2m1> } : (i16) -> i8
+  return %2 : i8
+}
+
 // -----
 // CHECK-LABEL: func.func @memfence()
 func.func @memfence() {

``````````

</details>


https://github.com/llvm/llvm-project/pull/180055


More information about the Mlir-commits mailing list