[Mlir-commits] [mlir] [MLIR][XeVM] Add truncf and mma_mx op. (PR #180055)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Feb 5 14:50:13 PST 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Sang Ik Lee (silee2)
<details>
<summary>Changes</summary>
truncf op converts 16 bit floats to 8 bit or 4 bit floats mma_mx op does cooperative matrix multiply accumulate on 8 or 4 bit float type with 8 bit scale value.
---
Full diff: https://github.com/llvm/llvm-project/pull/180055.diff
3 Files Affected:
- (modified) mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td (+80-5)
- (modified) mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp (+8)
- (modified) mlir/test/Dialect/LLVMIR/xevm.mlir (+23)
``````````diff
diff --git a/mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td
index 388efaaa25117..7d16827405061 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td
@@ -519,11 +519,15 @@ def XeVM_ET_U4 : I32EnumAttrCase<"U4", 13, "u4">;
def XeVM_ET_TF32 : I32EnumAttrCase<"TF32", 14, "tf32">;
def XeVM_ET_F32 : I32EnumAttrCase<"F32", 15, "f32">;
def XeVM_ET_S32 : I32EnumAttrCase<"S32", 16, "s32">;
-
-def XeVM_ElemTypeAttr : I32EnumAttr<"ElemType", "XeVM element type",
- [XeVM_ET_BF16, XeVM_ET_F16, XeVM_ET_S8,
- XeVM_ET_U8, XeVM_ET_S4, XeVM_ET_U4,
- XeVM_ET_TF32, XeVM_ET_F32, XeVM_ET_S32]> {
+def XeVM_ET_E2M1 : I32EnumAttrCase<"E2M1", 17, "e2m1">;
+def XeVM_ET_BF8 : I32EnumAttrCase<"BF8", 18, "bf8">;
+def XeVM_ET_F8 : I32EnumAttrCase<"F8", 19, "f8">;
+
+def XeVM_ElemTypeAttr
+ : I32EnumAttr<"ElemType", "XeVM element type",
+ [XeVM_ET_BF16, XeVM_ET_F16, XeVM_ET_S8, XeVM_ET_U8,
+ XeVM_ET_S4, XeVM_ET_U4, XeVM_ET_TF32, XeVM_ET_F32,
+ XeVM_ET_S32, XeVM_ET_E2M1, XeVM_ET_BF8, XeVM_ET_F8]> {
let cppNamespace = "::mlir::xevm";
}
@@ -592,6 +596,77 @@ def XeVM_MMAOp
let hasVerifier = 1;
}
+def XeVM_ConversionTypesAttr
+ : XeVM_Attr<"ConversionTypes", "conversion_types"> {
+ let description = [{
+ This attribute is used to specify source and destination types for conversion operations.
+ }];
+ let parameters = (ins "xevm::ElemType":$src_type, "xevm::ElemType":$dst_type);
+ let assemblyFormat = "`<` struct(params) `>`";
+}
+
+def XeVM_TruncfOp
+ : XeVM_Op<"truncf">,
+ Results<(outs AnyTypeOf<[I8]>:$res)>,
+ Arguments<(ins AnyTypeOf<[I16]>:$arg, XeVM_ConversionTypesAttr:$types)> {
+ let summary = "Floating point truncation from f16/bf16 to f8/bf8/f4";
+ let description = [{
+ The `xevm.truncf` operation truncates a floating point value from
+ f16/bf16 to f8/bf8/f4 format.
+ }];
+
+ let assemblyFormat = [{
+ $arg ` ` `{` `types` `=` $types `}` attr-dict `:` functional-type(operands, results)
+ }];
+}
+
+def XeVM_MMAMxOp
+ : XeVM_Op<"mma_mx">,
+ Results<(outs FixedVectorOfRankAndType<[1], [XeVM_MatrixElemType]>:$d)>,
+ Arguments<(ins FixedVectorOfRankAndType<[1], [XeVM_MatrixElemType]>:$a,
+ FixedVectorOfRankAndType<[1], [XeVM_MatrixElemType]>:$b,
+ FixedVectorOfRankAndType<[1], [I8]>:$scale_a,
+ FixedVectorOfRankAndType<[1], [I8]>:$scale_b,
+ Optional<FixedVectorOfRankAndType<[1], [XeVM_MatrixElemType]>>:$c,
+ XeVM_MMAShapeAttr:$shape, XeVM_MMATypesAttr:$types)> {
+
+ let summary = "Subgroup matrix multiply-add with MxN shape and MX scaling";
+
+ let description = [{
+ The `xevm.mma_mx` is similar to `xevm.mma` has scale operands for A and B matrices.
+ It is a cooperative operation where all threads/lanes in a subgroup participates
+ and carries out matrix multiplication plus accumulation:
+
+ D = C + A x B
+
+ where the A, B, C input matrices and the result D have shapes:
+ - D : MxN
+ - C : MxN
+ - A : MxK
+ - B : KxN
+
+ Parameters:
+ * `a` - vector of matrix A elements.
+ * `b` - vector of matrix B elements.
+ * `scale_a` - vector of scaling factors for matrix A.
+ * `scale_b` - vector of scaling factors for matrix B.
+ * `c` - (optional) vector of matrix C elements.
+ * `shape` - the shape of the matrices, specified as `M`, `N`, and `K` values.
+ * `types` - the data types of the matrices, specified as `D`, `A`, `B`, and optionally `C`.
+
+ }];
+
+ let assemblyFormat = [{
+ $a `,` $b `,` $scale_a `,` $scale_b (`,` $c^)? ` `
+ `{`
+ `shape` `=` $shape `,`
+ `types` `=` $types
+ `}` attr-dict `:` functional-type(operands, results)
+ }];
+
+ let hasVerifier = 1;
+}
+
//===----------------------------------------------------------------------===//
// XeVM target attribute.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp
index 04e8836c00359..f202f0b82e935 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp
@@ -349,6 +349,14 @@ LogicalResult MMAOp::verify() {
return success();
}
+LogicalResult MMAMxOp::verify() {
+ if (getC()) {
+ if (getResult().getType() != getC().getType())
+ return emitOpError("type of C operand must match result type");
+ }
+ return success();
+}
+
LogicalResult
XeVMTargetAttr::verify(function_ref<InFlightDiagnostic()> emitError, int O,
StringRef triple, StringRef chip, DictionaryAttr flags,
diff --git a/mlir/test/Dialect/LLVMIR/xevm.mlir b/mlir/test/Dialect/LLVMIR/xevm.mlir
index 66fb2949a270f..7f1eae2580d39 100644
--- a/mlir/test/Dialect/LLVMIR/xevm.mlir
+++ b/mlir/test/Dialect/LLVMIR/xevm.mlir
@@ -92,6 +92,29 @@ func.func @mma(%loaded_c_casted: vector<8xf32>, %loaded_a: vector<8xi16>, %loade
return %c_result : vector<8xf32>
}
+// -----
+// CHECK-LABEL: func.func @mma_mx(
+// CHECK-SAME: %[[ARG0:.*]]: vector<8xf32>, %[[ARG1:.*]]: vector<8xi16>, %[[ARG2:.*]]: vector<8xi32>, %[[ARG3:.*]]: vector<2xi8>, %[[ARG4:.*]]: vector<2xi8>)
+func.func @mma_mx(%loaded_c_casted: vector<8xf32>, %loaded_a: vector<8xi16>, %loaded_b_casted: vector<8xi32>, %scale_a: vector<2xi8>, %scale_b: vector<2xi8>) -> vector<8xf32> {
+ // CHECK: %[[VAR0:.*]] = xevm.mma_mx %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG4]], %[[ARG0]]
+ // CHECK-SAME: {shape = <m = 8, n = 16, k = 64>, types = <d = f32, a = e2m1, b = e2m1, c = f32>}
+ // CHECK-SAME:: (vector<8xi16>, vector<8xi32>, vector<2xi8>, vector<2xi8>, vector<8xf32>) -> vector<8xf32>
+ %c_result = xevm.mma_mx %loaded_a, %loaded_b_casted, %scale_a, %scale_b, %loaded_c_casted { shape=<m=8, n=16, k=64>,
+ types=<d=f32, a=e2m1, b=e2m1, c=f32> } : (vector<8xi16>, vector<8xi32>, vector<2xi8>, vector<2xi8>, vector<8xf32>) -> vector<8xf32>
+ return %c_result : vector<8xf32>
+}
+
+// -----
+// CHECK-LABEL: func.func @truncf
+func.func @truncf() -> i8 {
+ %0 = arith.constant 1.0 : bf16
+ // CHECK: %[[VAR1:.*]] = arith.bitcast %{{.+}} : bf16 to i16
+ %1 = arith.bitcast %0 : bf16 to i16
+ // CHECK: xevm.truncf %[[VAR1]] {types = <src_type = bf16, dst_type = e2m1>} : (i16) -> i8
+ %2 = xevm.truncf %1 { types=<src_type=bf16, dst_type=e2m1> } : (i16) -> i8
+ return %2 : i8
+}
+
// -----
// CHECK-LABEL: func.func @memfence()
func.func @memfence() {
``````````
</details>
https://github.com/llvm/llvm-project/pull/180055
More information about the Mlir-commits
mailing list