[Mlir-commits] [mlir] [MLIR][XeVM] Add truncf and mma_mx op. (PR #180055)

Sang Ik Lee llvmlistbot at llvm.org
Wed Mar 11 09:06:10 PDT 2026


https://github.com/silee2 updated https://github.com/llvm/llvm-project/pull/180055

>From 0f9ae360e30b7a12a86e5abad812ebc78d863e5b Mon Sep 17 00:00:00 2001
From: "Lee, Sang Ik" <sang.ik.lee at intel.com>
Date: Fri, 23 Jan 2026 23:22:18 +0000
Subject: [PATCH 1/3] [MLIR][XeVM] Add truncf and mma_mx op. truncf op converts
 16 bit floats to 8 bit or 4 bit floats mma_mx op does cooperative matrix
 multiply accumulate on 8 or 4 bit float type with 8 bit scale value.

---
 mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td | 85 +++++++++++++++++++--
 mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp  |  8 ++
 mlir/test/Dialect/LLVMIR/xevm.mlir          | 23 ++++++
 3 files changed, 111 insertions(+), 5 deletions(-)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td
index 388efaaa25117..7d16827405061 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td
@@ -519,11 +519,15 @@ def XeVM_ET_U4 : I32EnumAttrCase<"U4", 13, "u4">;
 def XeVM_ET_TF32 : I32EnumAttrCase<"TF32", 14, "tf32">;
 def XeVM_ET_F32 : I32EnumAttrCase<"F32", 15, "f32">;
 def XeVM_ET_S32 : I32EnumAttrCase<"S32", 16, "s32">;
-
-def XeVM_ElemTypeAttr : I32EnumAttr<"ElemType", "XeVM element type",
-                                    [XeVM_ET_BF16, XeVM_ET_F16, XeVM_ET_S8,
-                                     XeVM_ET_U8, XeVM_ET_S4, XeVM_ET_U4,
-                                     XeVM_ET_TF32, XeVM_ET_F32, XeVM_ET_S32]> {
+def XeVM_ET_E2M1 : I32EnumAttrCase<"E2M1", 17, "e2m1">;
+def XeVM_ET_BF8 : I32EnumAttrCase<"BF8", 18, "bf8">;
+def XeVM_ET_F8 : I32EnumAttrCase<"F8", 19, "f8">;
+
+def XeVM_ElemTypeAttr
+    : I32EnumAttr<"ElemType", "XeVM element type",
+                  [XeVM_ET_BF16, XeVM_ET_F16, XeVM_ET_S8, XeVM_ET_U8,
+                   XeVM_ET_S4, XeVM_ET_U4, XeVM_ET_TF32, XeVM_ET_F32,
+                   XeVM_ET_S32, XeVM_ET_E2M1, XeVM_ET_BF8, XeVM_ET_F8]> {
   let cppNamespace = "::mlir::xevm";
 }
 
@@ -592,6 +596,77 @@ def XeVM_MMAOp
   let hasVerifier = 1;
 }
 
+def XeVM_ConversionTypesAttr
+    : XeVM_Attr<"ConversionTypes", "conversion_types"> {
+  let description = [{
+    This attribute is used to specify source and destination types for conversion operations.
+  }];
+  let parameters = (ins "xevm::ElemType":$src_type, "xevm::ElemType":$dst_type);
+  let assemblyFormat = "`<` struct(params) `>`";
+}
+
+def XeVM_TruncfOp
+    : XeVM_Op<"truncf">,
+      Results<(outs AnyTypeOf<[I8]>:$res)>,
+      Arguments<(ins AnyTypeOf<[I16]>:$arg, XeVM_ConversionTypesAttr:$types)> {
+  let summary = "Floating point truncation from f16/bf16 to f8/bf8/f4";
+  let description = [{
+    The `xevm.truncf` operation truncates a floating point value from
+    f16/bf16 to f8/bf8/f4 format.
+  }];
+
+  let assemblyFormat = [{
+    $arg ` ` `{` `types` `=` $types `}` attr-dict `:` functional-type(operands, results)
+  }];
+}
+
+def XeVM_MMAMxOp
+    : XeVM_Op<"mma_mx">,
+      Results<(outs FixedVectorOfRankAndType<[1], [XeVM_MatrixElemType]>:$d)>,
+      Arguments<(ins FixedVectorOfRankAndType<[1], [XeVM_MatrixElemType]>:$a,
+          FixedVectorOfRankAndType<[1], [XeVM_MatrixElemType]>:$b,
+          FixedVectorOfRankAndType<[1], [I8]>:$scale_a,
+          FixedVectorOfRankAndType<[1], [I8]>:$scale_b,
+          Optional<FixedVectorOfRankAndType<[1], [XeVM_MatrixElemType]>>:$c,
+          XeVM_MMAShapeAttr:$shape, XeVM_MMATypesAttr:$types)> {
+
+  let summary = "Subgroup matrix multiply-add with MxN shape and MX scaling";
+
+  let description = [{
+    The `xevm.mma_mx` is similar to `xevm.mma` has scale operands for A and B matrices.
+    It is a cooperative operation where all threads/lanes in a subgroup participates
+    and carries out matrix multiplication plus accumulation:
+
+      D = C + A x B
+
+      where the A, B, C input matrices and the result D have shapes:
+        - D : MxN
+        - C : MxN
+        - A : MxK
+        - B : KxN
+
+    Parameters:
+      * `a` - vector of matrix A elements.
+      * `b` - vector of matrix B elements.
+      * `scale_a` - vector of scaling factors for matrix A.
+      * `scale_b` - vector of scaling factors for matrix B.
+      * `c` - (optional) vector of matrix C elements.
+      * `shape` - the shape of the matrices, specified as `M`, `N`, and `K` values.
+      * `types` - the data types of the matrices, specified as `D`, `A`, `B`, and optionally `C`.
+
+  }];
+
+  let assemblyFormat = [{
+    $a `,` $b `,` $scale_a `,` $scale_b (`,` $c^)? ` `
+    `{`
+      `shape` `=` $shape `,`
+      `types` `=` $types
+    `}` attr-dict `:` functional-type(operands, results)
+  }];
+
+  let hasVerifier = 1;
+}
+
 //===----------------------------------------------------------------------===//
 // XeVM target attribute.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp
index 04e8836c00359..f202f0b82e935 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp
@@ -349,6 +349,14 @@ LogicalResult MMAOp::verify() {
   return success();
 }
 
+LogicalResult MMAMxOp::verify() {
+  if (getC()) {
+    if (getResult().getType() != getC().getType())
+      return emitOpError("type of C operand must match result type");
+  }
+  return success();
+}
+
 LogicalResult
 XeVMTargetAttr::verify(function_ref<InFlightDiagnostic()> emitError, int O,
                        StringRef triple, StringRef chip, DictionaryAttr flags,
diff --git a/mlir/test/Dialect/LLVMIR/xevm.mlir b/mlir/test/Dialect/LLVMIR/xevm.mlir
index 66fb2949a270f..7f1eae2580d39 100644
--- a/mlir/test/Dialect/LLVMIR/xevm.mlir
+++ b/mlir/test/Dialect/LLVMIR/xevm.mlir
@@ -92,6 +92,29 @@ func.func @mma(%loaded_c_casted: vector<8xf32>, %loaded_a: vector<8xi16>, %loade
   return %c_result : vector<8xf32>
 }
 
+// -----
+// CHECK-LABEL: func.func @mma_mx(
+// CHECK-SAME: %[[ARG0:.*]]: vector<8xf32>, %[[ARG1:.*]]: vector<8xi16>, %[[ARG2:.*]]: vector<8xi32>, %[[ARG3:.*]]: vector<2xi8>, %[[ARG4:.*]]: vector<2xi8>)
+func.func @mma_mx(%loaded_c_casted: vector<8xf32>, %loaded_a: vector<8xi16>, %loaded_b_casted: vector<8xi32>, %scale_a: vector<2xi8>, %scale_b: vector<2xi8>) -> vector<8xf32> {
+  // CHECK: %[[VAR0:.*]] = xevm.mma_mx %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG4]], %[[ARG0]]
+  // CHECK-SAME: {shape = <m = 8, n = 16, k = 64>, types = <d = f32, a = e2m1, b = e2m1, c = f32>}
+  // CHECK-SAME:: (vector<8xi16>, vector<8xi32>, vector<2xi8>, vector<2xi8>, vector<8xf32>) -> vector<8xf32>
+  %c_result = xevm.mma_mx %loaded_a, %loaded_b_casted, %scale_a, %scale_b, %loaded_c_casted { shape=<m=8, n=16, k=64>,
+    types=<d=f32, a=e2m1, b=e2m1, c=f32> } : (vector<8xi16>, vector<8xi32>, vector<2xi8>, vector<2xi8>, vector<8xf32>) -> vector<8xf32>
+  return %c_result : vector<8xf32>
+}
+
+// -----
+// CHECK-LABEL: func.func @truncf
+func.func @truncf() -> i8 {
+  %0 = arith.constant 1.0 : bf16
+  // CHECK: %[[VAR1:.*]] = arith.bitcast %{{.+}} : bf16 to i16
+  %1 = arith.bitcast %0 : bf16 to i16
+  // CHECK: xevm.truncf %[[VAR1]] {types = <src_type = bf16, dst_type = e2m1>} : (i16) -> i8
+  %2 = xevm.truncf %1 { types=<src_type=bf16, dst_type=e2m1> } : (i16) -> i8
+  return %2 : i8
+}
+
 // -----
 // CHECK-LABEL: func.func @memfence()
 func.func @memfence() {

>From 1e85083ab31b798e0d9a8e17d263c5bd03e932e6 Mon Sep 17 00:00:00 2001
From: "Lee, Sang Ik" <sang.ik.lee at intel.com>
Date: Tue, 10 Mar 2026 20:38:04 +0000
Subject: [PATCH 2/3] Extend truncf to small vector operands.

---
 mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td | 26 +++++++++++++--------
 mlir/test/Dialect/LLVMIR/xevm.mlir          | 21 ++++++++++++-----
 2 files changed, 31 insertions(+), 16 deletions(-)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td
index 7d16827405061..9586cad5a35d1 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td
@@ -596,19 +596,25 @@ def XeVM_MMAOp
   let hasVerifier = 1;
 }
 
-def XeVM_ConversionTypesAttr
-    : XeVM_Attr<"ConversionTypes", "conversion_types"> {
-  let description = [{
-    This attribute is used to specify source and destination types for conversion operations.
-  }];
-  let parameters = (ins "xevm::ElemType":$src_type, "xevm::ElemType":$dst_type);
-  let assemblyFormat = "`<` struct(params) `>`";
+def XeVM_TruncfDstElemTypes
+    : I32EnumAttr<"TruncfDstElemTypes",
+                  "Destination element type for xevm.truncf",
+                  [XeVM_ET_BF8, XeVM_ET_F8, XeVM_ET_E2M1]> {
+  let cppNamespace = "::mlir::xevm";
+}
+
+def XeVM_TruncfDstElemTypeAttr : XeVM_Attr<"TruncfDstElemType", "dst_etype"> {
+  let parameters = (ins "xevm::TruncfDstElemTypes":$etype);
+  let assemblyFormat = "`etype` `=` $etype";
 }
 
 def XeVM_TruncfOp
     : XeVM_Op<"truncf">,
-      Results<(outs AnyTypeOf<[I8]>:$res)>,
-      Arguments<(ins AnyTypeOf<[I16]>:$arg, XeVM_ConversionTypesAttr:$types)> {
+      Results<(outs AnyTypeOf<[VectorOfRankAndType<[1], [I8, I<4>]>, I8,
+                               I<4>]>:$dst)>,
+      Arguments<(ins AnyTypeOf<[VectorOfRankAndType<[1], [F16, BF16]>, F16,
+                                BF16]>:$arg,
+          XeVM_TruncfDstElemTypeAttr:$dst_etype)> {
   let summary = "Floating point truncation from f16/bf16 to f8/bf8/f4";
   let description = [{
     The `xevm.truncf` operation truncates a floating point value from
@@ -616,7 +622,7 @@ def XeVM_TruncfOp
   }];
 
   let assemblyFormat = [{
-    $arg ` ` `{` `types` `=` $types `}` attr-dict `:` functional-type(operands, results)
+    $arg ` ` `{` $dst_etype `}` attr-dict `:` functional-type(operands, results)
   }];
 }
 
diff --git a/mlir/test/Dialect/LLVMIR/xevm.mlir b/mlir/test/Dialect/LLVMIR/xevm.mlir
index 7f1eae2580d39..15f93e6f697cc 100644
--- a/mlir/test/Dialect/LLVMIR/xevm.mlir
+++ b/mlir/test/Dialect/LLVMIR/xevm.mlir
@@ -105,16 +105,25 @@ func.func @mma_mx(%loaded_c_casted: vector<8xf32>, %loaded_a: vector<8xi16>, %lo
 }
 
 // -----
-// CHECK-LABEL: func.func @truncf
-func.func @truncf() -> i8 {
+// CHECK-LABEL: func.func @truncf_scalar
+func.func @truncf_scalar() -> i8 {
+  // CHECK: %[[VAR0:.*]] = arith.constant 1.0
   %0 = arith.constant 1.0 : bf16
-  // CHECK: %[[VAR1:.*]] = arith.bitcast %{{.+}} : bf16 to i16
-  %1 = arith.bitcast %0 : bf16 to i16
-  // CHECK: xevm.truncf %[[VAR1]] {types = <src_type = bf16, dst_type = e2m1>} : (i16) -> i8
-  %2 = xevm.truncf %1 { types=<src_type=bf16, dst_type=e2m1> } : (i16) -> i8
+  // CHECK: xevm.truncf %[[VAR0]] {etype = f8} : (bf16) -> i8
+  %2 = xevm.truncf %0 { etype=f8 } : (bf16) -> i8
   return %2 : i8
 }
 
+// -----
+// CHECK-LABEL: func.func @truncf_vector
+func.func @truncf_vector() -> vector<8xi4> {
+  // CHECK: %[[VAR0:.*]] = arith.constant
+  %0 = arith.constant dense<1.0> : vector<8xbf16>
+  // CHECK: xevm.truncf %[[VAR0]] {etype = e2m1} : (vector<8xbf16>) -> vector<8xi4>
+  %2 = xevm.truncf %0 { etype=e2m1 } : (vector<8xbf16>) -> vector<8xi4>
+  return %2 : vector<8xi4>
+}
+
 // -----
 // CHECK-LABEL: func.func @memfence()
 func.func @memfence() {

>From 467dabfaa085bb8d42967d28781a6bd95ea2a0ec Mon Sep 17 00:00:00 2001
From: "Lee, Sang Ik" <sang.ik.lee at intel.com>
Date: Wed, 11 Mar 2026 15:51:09 +0000
Subject: [PATCH 3/3] TruncfOp: Add verifier.

---
 mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td |  8 ++++---
 mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp  | 16 +++++++++++++
 mlir/test/Dialect/LLVMIR/invalid.mlir       | 25 +++++++++++++++++++++
 3 files changed, 46 insertions(+), 3 deletions(-)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td
index 9586cad5a35d1..73fcf7d614638 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td
@@ -613,8 +613,8 @@ def XeVM_TruncfOp
       Results<(outs AnyTypeOf<[VectorOfRankAndType<[1], [I8, I<4>]>, I8,
                                I<4>]>:$dst)>,
       Arguments<(ins AnyTypeOf<[VectorOfRankAndType<[1], [F16, BF16]>, F16,
-                                BF16]>:$arg,
-          XeVM_TruncfDstElemTypeAttr:$dst_etype)> {
+                                BF16]>:$src,
+          XeVM_TruncfDstElemTypeAttr:$etype)> {
   let summary = "Floating point truncation from f16/bf16 to f8/bf8/f4";
   let description = [{
     The `xevm.truncf` operation truncates a floating point value from
@@ -622,8 +622,10 @@ def XeVM_TruncfOp
   }];
 
   let assemblyFormat = [{
-    $arg ` ` `{` $dst_etype `}` attr-dict `:` functional-type(operands, results)
+    $src ` ` `{` $etype `}` attr-dict `:` functional-type(operands, results)
   }];
+
+  let hasVerifier = 1;
 }
 
 def XeVM_MMAMxOp
diff --git a/mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp
index f202f0b82e935..f40290807de9d 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp
@@ -357,6 +357,22 @@ LogicalResult MMAMxOp::verify() {
   return success();
 }
 
+LogicalResult TruncfOp::verify() {
+  Type srcTy = getSrc().getType();
+  Type dstTy = getDst().getType();
+  if (isa<VectorType>(srcTy) && !isa<VectorType>(dstTy))
+    return emitOpError("both src and dst should be vector types or both should "
+                       "be scalar types");
+  if (isa<VectorType>(srcTy)) {
+    VectorType srcVecTy = dyn_cast<VectorType>(srcTy);
+    VectorType dstVecTy = dyn_cast<VectorType>(dstTy);
+    if (srcVecTy.getNumElements() != dstVecTy.getNumElements())
+      return emitOpError(
+          "src and dst vector types should have the same number of elements");
+  }
+  return success();
+}
+
 LogicalResult
 XeVMTargetAttr::verify(function_ref<InFlightDiagnostic()> emitError, int O,
                        StringRef triple, StringRef chip, DictionaryAttr flags,
diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index 5068ddc42e1e5..47aa96a6991d9 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -2019,6 +2019,31 @@ llvm.func @invalid_xevm_matrix_3(%a: !llvm.ptr<1>, %base_width_a: i32, %base_hei
 
 // -----
 
+llvm.func @invalid_xevm_truncf_1(%arg0: vector<8xf16>) {
+  // expected-error at +1 {{op both src and dst should be vector types or both}}
+  %0 = xevm.truncf %arg0 { etype = bf8 } : (vector<8xf16>) -> i8
+  llvm.return
+}
+
+// -----
+
+llvm.func @invalid_xevm_truncf_1(%arg0: vector<8xf16>) {
+  // expected-error at +1 {{op src and dst vector types should have the same number of elements}}
+  %0 = xevm.truncf %arg0 { etype = bf8 } : (vector<8xf16>) -> vector<4xi8>
+  llvm.return
+}
+
+// -----
+
+llvm.func @invalid_xevm_mma_mx(%loaded_c_casted: vector<4xf32>, %loaded_a: vector<8xi16>, %loaded_b_casted: vector<8xi32>, %scale_a: vector<2xi8>, %scale_b: vector<2xi8>) -> vector<8xf32> {
+  // expected-error at +1 {{op type of C operand must match result type}}
+  %c_result = xevm.mma_mx %loaded_a, %loaded_b_casted, %scale_a, %scale_b, %loaded_c_casted { shape=<m=8, n=16, k=64>,
+    types=<d=f32, a=e2m1, b=e2m1, c=f32> } : (vector<8xi16>, vector<8xi32>, vector<2xi8>, vector<2xi8>, vector<4xf32>) -> vector<8xf32>
+  llvm.return %c_result : vector<8xf32>
+}
+
+// -----
+
 llvm.func external @resolve_foo() -> !llvm.ptr attributes {dso_local}
 // expected-error at +1 {{'llvm.mlir.ifunc' op resolver must be a definition}}
 llvm.mlir.ifunc external @foo : !llvm.func<void (ptr, i32)>, !llvm.ptr @resolve_foo {dso_local}



More information about the Mlir-commits mailing list