[Mlir-commits] [mlir] 58cb887 - [mlir][rocdl] Add xdlops intrinsics to rocdl dialect
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon May 11 07:13:22 PDT 2020
Author: rtayl
Date: 2020-05-11T10:08:58-04:00
New Revision: 58cb88733f01f29508aeaa4ba749016432bc544b
URL: https://github.com/llvm/llvm-project/commit/58cb88733f01f29508aeaa4ba749016432bc544b
DIFF: https://github.com/llvm/llvm-project/commit/58cb88733f01f29508aeaa4ba749016432bc544b.diff
LOG: [mlir][rocdl] Add xdlops intrinsics to rocdl dialect
Summary: This adds xdlops (mfma) to the rocdl dialect and also tests the translation to llvm ir.
Reviewers: ftynse
Subscribers: mehdi_amini, rriddle, jpienaar, shauheen, antiagainst, nicolasvasilache, arpith-jacob, mgester, lucyrfox, liufengdb, stephenneuendorffer, Joonsoo, grosul1, frgossen, Kayjukh, llvm-commits
Tags: #llvm #mlir
Differential Revision: https://reviews.llvm.org/D79642
Added:
Modified:
mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
mlir/test/Dialect/LLVMIR/rocdl.mlir
mlir/test/Target/rocdl.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index 58ee259c5c49..7f600243f5d7 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -102,4 +102,38 @@ def ROCDL_BarrierOp : ROCDL_Op<"barrier"> {
let assemblyFormat = "attr-dict";
}
+//===---------------------------------------------------------------------===//
+// Xdlops intrinsics
+
+class ROCDL_Mfma_IntrOp<string mnemonic, list<OpTrait> traits = []> :
+ LLVM_IntrOpBase<ROCDL_Dialect, mnemonic,
+ "amdgcn_" # !subst(".","_", mnemonic),
+ [], [], traits, 1>,
+ Arguments<(ins Variadic<LLVM_Type>:$args)> {
+ let assemblyFormat =
+ "$args attr-dict `:` functional-type($args, $res)";
+}
+
+def ROCDL_mfma_f32_32x32x1f32 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x1f32">;
+def ROCDL_mfma_f32_32x32x2f32 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x2f32">;
+def ROCDL_mfma_f32_16x16x4f32 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x4f32">;
+def ROCDL_mfma_f32_16x16x1f32 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x1f32">;
+def ROCDL_mfma_f32_32x32x4f16 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x4f16">;
+def ROCDL_mfma_f32_32x32x8f16 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x8f16">;
+def ROCDL_mfma_f32_16x16x4f16 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x4f16">;
+def ROCDL_mfma_f32_16x16x16f16 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x16f16">;
+def ROCDL_mfma_f32_32x32x2bf16 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x2bf16">;
+def ROCDL_mfma_f32_32x32x4bf16 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x4bf16">;
+def ROCDL_mfma_f32_16x16x8bf16 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x8bf16">;
+def ROCDL_mfma_f32_16x16x2bf16 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x2bf16">;
+def ROCDL_mfma_f32_4x4x2bf16 : ROCDL_Mfma_IntrOp<"mfma.f32.4x4x2bf16">;
+def ROCDL_mfma_f32_4x4x1f32 : ROCDL_Mfma_IntrOp<"mfma.f32.4x4x1f32">;
+def ROCDL_mfma_f32_4x4x4f16 : ROCDL_Mfma_IntrOp<"mfma.f32.4x4x4f16">;
+def ROCDL_mfma_i32_32x32x4i8 : ROCDL_Mfma_IntrOp<"mfma.i32.32x32x4i8">;
+def ROCDL_mfma_i32_16x16x4i8 : ROCDL_Mfma_IntrOp<"mfma.i32.16x16x4i8">;
+def ROCDL_mfma_i32_4x4x4i8 : ROCDL_Mfma_IntrOp<"mfma.i32.4x4x4i8">;
+def ROCDL_mfma_i32_32x32x8i8 : ROCDL_Mfma_IntrOp<"mfma.i32.32x32x8i8">;
+def ROCDL_mfma_i32_16x16x16i8 : ROCDL_Mfma_IntrOp<"mfma.i32.16x16x16i8">;
+
+
#endif // ROCDLIR_OPS
diff --git a/mlir/test/Dialect/LLVMIR/rocdl.mlir b/mlir/test/Dialect/LLVMIR/rocdl.mlir
index 2a6f10008f9e..4f634e187af1 100644
--- a/mlir/test/Dialect/LLVMIR/rocdl.mlir
+++ b/mlir/test/Dialect/LLVMIR/rocdl.mlir
@@ -34,3 +34,113 @@ func @rocdl.barrier() {
rocdl.barrier
llvm.return
}
+
+func @rocdl.xdlops(%arg0 : !llvm.float, %arg1 : !llvm.float,
+ %arg2 : !llvm<"<32 x float>">, %arg3 : !llvm.i32,
+ %arg4 : !llvm<"<16 x float>">, %arg5 : !llvm<"<4 x float>">,
+ %arg6 : !llvm<"<4 x half>">, %arg7 : !llvm<"<32 x i32>">,
+ %arg8 : !llvm<"<16 x i32>">, %arg9 : !llvm<"<4 x i32>">,
+ %arg10 : !llvm<"<2 x i16>">) -> !llvm<"<32 x float>"> {
+ // CHECK-LABEL: rocdl.xdlops
+ // CHECK: rocdl.mfma.f32.32x32x1f32 {{.*}} : (!llvm.float, !llvm.float, !llvm<"<32 x float>">, !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<32 x float>">
+ %r0 = rocdl.mfma.f32.32x32x1f32 %arg0, %arg1, %arg2, %arg3, %arg3, %arg3 :
+ (!llvm.float, !llvm.float, !llvm<"<32 x float>">,
+ !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<32 x float>">
+
+ // CHECK: rocdl.mfma.f32.16x16x1f32 {{.*}} : (!llvm.float, !llvm.float, !llvm<"<16 x float>">, !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<16 x float>">
+ %r1 = rocdl.mfma.f32.16x16x1f32 %arg0, %arg1, %arg4, %arg3, %arg3, %arg3 :
+ (!llvm.float, !llvm.float, !llvm<"<16 x float>">,
+ !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<16 x float>">
+
+ // CHECK: rocdl.mfma.f32.16x16x4f32 {{.*}} : (!llvm.float, !llvm.float, !llvm<"<4 x float>">, !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<4 x float>">
+ %r2 = rocdl.mfma.f32.16x16x4f32 %arg0, %arg1, %arg5, %arg3, %arg3, %arg3 :
+ (!llvm.float, !llvm.float, !llvm<"<4 x float>">,
+ !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<4 x float>">
+
+ // CHECK: rocdl.mfma.f32.4x4x1f32 {{.*}} : (!llvm.float, !llvm.float, !llvm<"<4 x float>">, !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<4 x float>">
+ %r3 = rocdl.mfma.f32.4x4x1f32 %arg0, %arg1, %arg5, %arg3, %arg3, %arg3 :
+ (!llvm.float, !llvm.float, !llvm<"<4 x float>">,
+ !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<4 x float>">
+
+ // CHECK: rocdl.mfma.f32.32x32x2f32 {{.*}} : (!llvm.float, !llvm.float, !llvm<"<16 x float>">, !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<16 x float>">
+ %r4= rocdl.mfma.f32.32x32x2f32 %arg0, %arg1, %arg4, %arg3, %arg3, %arg3 :
+ (!llvm.float, !llvm.float, !llvm<"<16 x float>">,
+ !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<16 x float>">
+
+ // CHECK: rocdl.mfma.f32.32x32x4f16 {{.*}} : (!llvm<"<4 x half>">, !llvm<"<4 x half>">, !llvm<"<32 x float>">, !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<32 x float>">
+ %r5 = rocdl.mfma.f32.32x32x4f16 %arg6, %arg6, %arg2, %arg3, %arg3, %arg3 :
+ (!llvm<"<4 x half>">, !llvm<"<4 x half>">, !llvm<"<32 x float>">,
+ !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<32 x float>">
+
+ // CHECK: rocdl.mfma.f32.16x16x4f16 {{.*}} : (!llvm<"<4 x half>">, !llvm<"<4 x half>">, !llvm<"<16 x float>">, !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<16 x float>">
+ %r6 = rocdl.mfma.f32.16x16x4f16 %arg6, %arg6, %arg4, %arg3, %arg3, %arg3 :
+ (!llvm<"<4 x half>">, !llvm<"<4 x half>">, !llvm<"<16 x float>">,
+ !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<16 x float>">
+
+ // CHECK: rocdl.mfma.f32.4x4x4f16 {{.*}} : (!llvm<"<4 x half>">, !llvm<"<4 x half>">, !llvm<"<4 x float>">, !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<4 x float>">
+ %r7 = rocdl.mfma.f32.4x4x4f16 %arg6, %arg6, %arg5, %arg3, %arg3, %arg3 :
+ (!llvm<"<4 x half>">, !llvm<"<4 x half>">, !llvm<"<4 x float>">,
+ !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<4 x float>">
+
+ // CHECK: rocdl.mfma.f32.32x32x8f16 {{.*}} : (!llvm<"<4 x half>">, !llvm<"<4 x half>">, !llvm<"<16 x float>">, !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<16 x float>">
+ %r8 = rocdl.mfma.f32.32x32x8f16 %arg6, %arg6, %arg4, %arg3, %arg3, %arg3 :
+ (!llvm<"<4 x half>">, !llvm<"<4 x half>">, !llvm<"<16 x float>">,
+ !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<16 x float>">
+
+ // CHECK: rocdl.mfma.f32.16x16x16f16 {{.*}} : (!llvm<"<4 x half>">, !llvm<"<4 x half>">, !llvm<"<4 x float>">, !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<4 x float>">
+ %r9 = rocdl.mfma.f32.16x16x16f16 %arg6, %arg6, %arg5, %arg3, %arg3, %arg3 :
+ (!llvm<"<4 x half>">, !llvm<"<4 x half>">, !llvm<"<4 x float>">,
+ !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<4 x float>">
+
+ // CHECK: rocdl.mfma.i32.32x32x4i8 {{.*}} : (!llvm.i32, !llvm.i32, !llvm<"<32 x i32>">, !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<32 x i32>">
+ %r10 = rocdl.mfma.i32.32x32x4i8 %arg3, %arg3, %arg7, %arg3, %arg3, %arg3 :
+ (!llvm.i32, !llvm.i32, !llvm<"<32 x i32>">,
+ !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<32 x i32>">
+
+ // CHECK: rocdl.mfma.i32.16x16x4i8 {{.*}} : (!llvm.i32, !llvm.i32, !llvm<"<16 x i32>">, !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<16 x i32>">
+ %r11 = rocdl.mfma.i32.16x16x4i8 %arg3, %arg3, %arg8, %arg3, %arg3, %arg3 :
+ (!llvm.i32, !llvm.i32, !llvm<"<16 x i32>">,
+ !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<16 x i32>">
+
+ // CHECK: rocdl.mfma.i32.4x4x4i8 {{.*}} : (!llvm.i32, !llvm.i32, !llvm<"<4 x i32>">, !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<4 x i32>">
+ %r12 = rocdl.mfma.i32.4x4x4i8 %arg3, %arg3, %arg9, %arg3, %arg3, %arg3 :
+ (!llvm.i32, !llvm.i32, !llvm<"<4 x i32>">,
+ !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<4 x i32>">
+
+ // CHECK: rocdl.mfma.i32.32x32x8i8 {{.*}} : (!llvm.i32, !llvm.i32, !llvm<"<16 x i32>">, !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<16 x i32>">
+ %r13 = rocdl.mfma.i32.32x32x8i8 %arg3, %arg3, %arg8, %arg3, %arg3, %arg3 :
+ (!llvm.i32, !llvm.i32, !llvm<"<16 x i32>">,
+ !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<16 x i32>">
+
+ // CHECK: rocdl.mfma.i32.16x16x16i8 {{.*}} : (!llvm.i32, !llvm.i32, !llvm<"<4 x i32>">, !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<4 x i32>">
+ %r14 = rocdl.mfma.i32.16x16x16i8 %arg3, %arg3, %arg9, %arg3, %arg3, %arg3 :
+ (!llvm.i32, !llvm.i32, !llvm<"<4 x i32>">,
+ !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<4 x i32>">
+
+ // CHECK: rocdl.mfma.f32.32x32x2bf16 {{.*}} : (!llvm<"<2 x i16>">, !llvm<"<2 x i16>">, !llvm<"<32 x float>">, !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<32 x float>">
+ %r15 = rocdl.mfma.f32.32x32x2bf16 %arg10, %arg10, %arg2, %arg3, %arg3, %arg3 :
+ (!llvm<"<2 x i16>">, !llvm<"<2 x i16>">, !llvm<"<32 x float>">,
+ !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<32 x float>">
+
+ // CHECK: rocdl.mfma.f32.16x16x2bf16 {{.*}} : (!llvm<"<2 x i16>">, !llvm<"<2 x i16>">, !llvm<"<16 x float>">, !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<16 x float>">
+ %r16 = rocdl.mfma.f32.16x16x2bf16 %arg10, %arg10, %arg4, %arg3, %arg3, %arg3 :
+ (!llvm<"<2 x i16>">, !llvm<"<2 x i16>">, !llvm<"<16 x float>">,
+ !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<16 x float>">
+
+ // CHECK: rocdl.mfma.f32.4x4x2bf16 {{.*}} : (!llvm<"<2 x i16>">, !llvm<"<2 x i16>">, !llvm<"<4 x float>">, !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<4 x float>">
+ %r17 = rocdl.mfma.f32.4x4x2bf16 %arg10, %arg10, %arg5, %arg3, %arg3, %arg3 :
+ (!llvm<"<2 x i16>">, !llvm<"<2 x i16>">, !llvm<"<4 x float>">,
+ !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<4 x float>">
+
+ // CHECK: rocdl.mfma.f32.32x32x4bf16 {{.*}} : (!llvm<"<2 x i16>">, !llvm<"<2 x i16>">, !llvm<"<16 x float>">, !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<16 x float>">
+ %r18 = rocdl.mfma.f32.32x32x4bf16 %arg10, %arg10, %arg4, %arg3, %arg3, %arg3 :
+ (!llvm<"<2 x i16>">, !llvm<"<2 x i16>">, !llvm<"<16 x float>">,
+ !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<16 x float>">
+
+ // CHECK: rocdl.mfma.f32.16x16x8bf16 {{.*}} : (!llvm<"<2 x i16>">, !llvm<"<2 x i16>">, !llvm<"<4 x float>">, !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<4 x float>">
+ %r19 = rocdl.mfma.f32.16x16x8bf16 %arg10, %arg10, %arg5, %arg3, %arg3, %arg3 :
+ (!llvm<"<2 x i16>">, !llvm<"<2 x i16>">, !llvm<"<4 x float>">,
+ !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<4 x float>">
+
+ llvm.return %r0 : !llvm<"<32 x float>">
+}
diff --git a/mlir/test/Target/rocdl.mlir b/mlir/test/Target/rocdl.mlir
index 773c4e3928ee..676865fe1a93 100644
--- a/mlir/test/Target/rocdl.mlir
+++ b/mlir/test/Target/rocdl.mlir
@@ -41,3 +41,113 @@ llvm.func @rocdl.barrier() {
rocdl.barrier
llvm.return
}
+
+llvm.func @rocdl.xdlops(%arg0 : !llvm.float, %arg1 : !llvm.float,
+ %arg2 : !llvm<"<32 x float>">, %arg3 : !llvm.i32,
+ %arg4 : !llvm<"<16 x float>">, %arg5 : !llvm<"<4 x float>">,
+ %arg6 : !llvm<"<4 x half>">, %arg7 : !llvm<"<32 x i32>">,
+ %arg8 : !llvm<"<16 x i32>">, %arg9 : !llvm<"<4 x i32>">,
+ %arg10 : !llvm<"<2 x i16>">) -> !llvm<"<32 x float>"> {
+ // CHECK-LABEL: rocdl.xdlops
+ // CHECK: call <32 x float> @llvm.amdgcn.mfma.f32.32x32x1f32(float %{{.*}}, float %{{.*}}, <32 x float> %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+ %r0 = rocdl.mfma.f32.32x32x1f32 %arg0, %arg1, %arg2, %arg3, %arg3, %arg3 :
+ (!llvm.float, !llvm.float, !llvm<"<32 x float>">,
+ !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<32 x float>">
+
+ // CHECK: call <16 x float> @llvm.amdgcn.mfma.f32.16x16x1f32(float %{{.*}}, float %{{.*}}, <16 x float> %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+ %r1 = rocdl.mfma.f32.16x16x1f32 %arg0, %arg1, %arg4, %arg3, %arg3, %arg3 :
+ (!llvm.float, !llvm.float, !llvm<"<16 x float>">,
+ !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<16 x float>">
+
+ // CHECK: call <4 x float> @llvm.amdgcn.mfma.f32.16x16x4f32(float %{{.*}}, float %{{.*}}, <4 x float> %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+ %r2 = rocdl.mfma.f32.16x16x4f32 %arg0, %arg1, %arg5, %arg3, %arg3, %arg3 :
+ (!llvm.float, !llvm.float, !llvm<"<4 x float>">,
+ !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<4 x float>">
+
+ // CHECK: call <4 x float> @llvm.amdgcn.mfma.f32.4x4x1f32(float %{{.*}}, float %{{.*}}, <4 x float> %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+ %r3 = rocdl.mfma.f32.4x4x1f32 %arg0, %arg1, %arg5, %arg3, %arg3, %arg3 :
+ (!llvm.float, !llvm.float, !llvm<"<4 x float>">,
+ !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<4 x float>">
+
+ // CHECK: call <16 x float> @llvm.amdgcn.mfma.f32.32x32x2f32(float %{{.*}}, float %{{.*}}, <16 x float> %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+ %r4= rocdl.mfma.f32.32x32x2f32 %arg0, %arg1, %arg4, %arg3, %arg3, %arg3 :
+ (!llvm.float, !llvm.float, !llvm<"<16 x float>">,
+ !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<16 x float>">
+
+ // CHECK: call <32 x float> @llvm.amdgcn.mfma.f32.32x32x4f16(<4 x half> %{{.*}}, <4 x half> %{{.*}}, <32 x float> %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+ %r5 = rocdl.mfma.f32.32x32x4f16 %arg6, %arg6, %arg2, %arg3, %arg3, %arg3 :
+ (!llvm<"<4 x half>">, !llvm<"<4 x half>">, !llvm<"<32 x float>">,
+ !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<32 x float>">
+
+ // CHECK: call <16 x float> @llvm.amdgcn.mfma.f32.16x16x4f16(<4 x half> %{{.*}}, <4 x half> %{{.*}}, <16 x float> %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+ %r6 = rocdl.mfma.f32.16x16x4f16 %arg6, %arg6, %arg4, %arg3, %arg3, %arg3 :
+ (!llvm<"<4 x half>">, !llvm<"<4 x half>">, !llvm<"<16 x float>">,
+ !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<16 x float>">
+
+ // CHECK: call <4 x float> @llvm.amdgcn.mfma.f32.4x4x4f16(<4 x half> %{{.*}}, <4 x half> %{{.*}}, <4 x float> %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+ %r7 = rocdl.mfma.f32.4x4x4f16 %arg6, %arg6, %arg5, %arg3, %arg3, %arg3 :
+ (!llvm<"<4 x half>">, !llvm<"<4 x half>">, !llvm<"<4 x float>">,
+ !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<4 x float>">
+
+ // CHECK: call <16 x float> @llvm.amdgcn.mfma.f32.32x32x8f16(<4 x half> %{{.*}}, <4 x half> %{{.*}}, <16 x float> %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+ %r8 = rocdl.mfma.f32.32x32x8f16 %arg6, %arg6, %arg4, %arg3, %arg3, %arg3 :
+ (!llvm<"<4 x half>">, !llvm<"<4 x half>">, !llvm<"<16 x float>">,
+ !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<16 x float>">
+
+ // CHECK: call <4 x float> @llvm.amdgcn.mfma.f32.16x16x16f16(<4 x half> %{{.*}}, <4 x half> %{{.*}}, <4 x float> %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+ %r9 = rocdl.mfma.f32.16x16x16f16 %arg6, %arg6, %arg5, %arg3, %arg3, %arg3 :
+ (!llvm<"<4 x half>">, !llvm<"<4 x half>">, !llvm<"<4 x float>">,
+ !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<4 x float>">
+
+ // CHECK: call <32 x i32> @llvm.amdgcn.mfma.i32.32x32x4i8(i32 %{{.*}}, i32 %{{.*}}, <32 x i32> %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+ %r10 = rocdl.mfma.i32.32x32x4i8 %arg3, %arg3, %arg7, %arg3, %arg3, %arg3 :
+ (!llvm.i32, !llvm.i32, !llvm<"<32 x i32>">,
+ !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<32 x i32>">
+
+ // CHECK: call <16 x i32> @llvm.amdgcn.mfma.i32.16x16x4i8(i32 %{{.*}}, i32 %{{.*}}, <16 x i32> %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+ %r11 = rocdl.mfma.i32.16x16x4i8 %arg3, %arg3, %arg8, %arg3, %arg3, %arg3 :
+ (!llvm.i32, !llvm.i32, !llvm<"<16 x i32>">,
+ !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<16 x i32>">
+
+ // CHECK: call <4 x i32> @llvm.amdgcn.mfma.i32.4x4x4i8(i32 %{{.*}}, i32 %{{.*}}, <4 x i32> %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+ %r12 = rocdl.mfma.i32.4x4x4i8 %arg3, %arg3, %arg9, %arg3, %arg3, %arg3 :
+ (!llvm.i32, !llvm.i32, !llvm<"<4 x i32>">,
+ !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<4 x i32>">
+
+ // CHECK: call <16 x i32> @llvm.amdgcn.mfma.i32.32x32x8i8(i32 %{{.*}}, i32 %{{.*}}, <16 x i32> %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+ %r13 = rocdl.mfma.i32.32x32x8i8 %arg3, %arg3, %arg8, %arg3, %arg3, %arg3 :
+ (!llvm.i32, !llvm.i32, !llvm<"<16 x i32>">,
+ !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<16 x i32>">
+
+ // CHECK: call <4 x i32> @llvm.amdgcn.mfma.i32.16x16x16i8(i32 %{{.*}}, i32 %{{.*}}, <4 x i32> %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+ %r14 = rocdl.mfma.i32.16x16x16i8 %arg3, %arg3, %arg9, %arg3, %arg3, %arg3 :
+ (!llvm.i32, !llvm.i32, !llvm<"<4 x i32>">,
+ !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<4 x i32>">
+
+ // CHECK: call <32 x float> @llvm.amdgcn.mfma.f32.32x32x2bf16(<2 x i16> %{{.*}}, <2 x i16> %{{.*}}, <32 x float> %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+ %r15 = rocdl.mfma.f32.32x32x2bf16 %arg10, %arg10, %arg2, %arg3, %arg3, %arg3 :
+ (!llvm<"<2 x i16>">, !llvm<"<2 x i16>">, !llvm<"<32 x float>">,
+ !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<32 x float>">
+
+ // CHECK: call <16 x float> @llvm.amdgcn.mfma.f32.16x16x2bf16(<2 x i16> %{{.*}}, <2 x i16> %{{.*}}, <16 x float> %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+ %r16 = rocdl.mfma.f32.16x16x2bf16 %arg10, %arg10, %arg4, %arg3, %arg3, %arg3 :
+ (!llvm<"<2 x i16>">, !llvm<"<2 x i16>">, !llvm<"<16 x float>">,
+ !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<16 x float>">
+
+ // CHECK: call <4 x float> @llvm.amdgcn.mfma.f32.4x4x2bf16(<2 x i16> %{{.*}}, <2 x i16> %{{.*}}, <4 x float> %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+ %r17 = rocdl.mfma.f32.4x4x2bf16 %arg10, %arg10, %arg5, %arg3, %arg3, %arg3 :
+ (!llvm<"<2 x i16>">, !llvm<"<2 x i16>">, !llvm<"<4 x float>">,
+ !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<4 x float>">
+
+ // CHECK: call <16 x float> @llvm.amdgcn.mfma.f32.32x32x4bf16(<2 x i16> %{{.*}}, <2 x i16> %{{.*}}, <16 x float> %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+ %r18 = rocdl.mfma.f32.32x32x4bf16 %arg10, %arg10, %arg4, %arg3, %arg3, %arg3 :
+ (!llvm<"<2 x i16>">, !llvm<"<2 x i16>">, !llvm<"<16 x float>">,
+ !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<16 x float>">
+
+ // CHECK: call <4 x float> @llvm.amdgcn.mfma.f32.16x16x8bf16(<2 x i16> %{{.*}}, <2 x i16> %{{.*}}, <4 x float> %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+ %r19 = rocdl.mfma.f32.16x16x8bf16 %arg10, %arg10, %arg5, %arg3, %arg3, %arg3 :
+ (!llvm<"<2 x i16>">, !llvm<"<2 x i16>">, !llvm<"<4 x float>">,
+ !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<4 x float>">
+
+ llvm.return %r0 : !llvm<"<32 x float>">
+}
More information about the Mlir-commits
mailing list