[Mlir-commits] [mlir] 58cb887 - [mlir][rocdl] Add xdlops intrinsics to rocdl dialect

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon May 11 07:13:22 PDT 2020


Author: rtayl
Date: 2020-05-11T10:08:58-04:00
New Revision: 58cb88733f01f29508aeaa4ba749016432bc544b

URL: https://github.com/llvm/llvm-project/commit/58cb88733f01f29508aeaa4ba749016432bc544b
DIFF: https://github.com/llvm/llvm-project/commit/58cb88733f01f29508aeaa4ba749016432bc544b.diff

LOG: [mlir][rocdl] Add xdlops intrinsics to rocdl dialect

Summary: This adds xdlops (mfma) to the rocdl dialect and also tests the translation to llvm ir.

Reviewers: ftynse

Subscribers: mehdi_amini, rriddle, jpienaar, shauheen, antiagainst, nicolasvasilache, arpith-jacob, mgester, lucyrfox, liufengdb, stephenneuendorffer, Joonsoo, grosul1, frgossen, Kayjukh, llvm-commits

Tags: #llvm #mlir

Differential Revision: https://reviews.llvm.org/D79642

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
    mlir/test/Dialect/LLVMIR/rocdl.mlir
    mlir/test/Target/rocdl.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index 58ee259c5c49..7f600243f5d7 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -102,4 +102,38 @@ def ROCDL_BarrierOp : ROCDL_Op<"barrier"> {
   let assemblyFormat = "attr-dict";
 }
 
+//===---------------------------------------------------------------------===//
+// Xdlops intrinsics
+
+class ROCDL_Mfma_IntrOp<string mnemonic, list<OpTrait> traits = []> :
+  LLVM_IntrOpBase<ROCDL_Dialect, mnemonic,
+                  "amdgcn_" # !subst(".","_", mnemonic),
+                  [], [], traits, 1>,
+  Arguments<(ins Variadic<LLVM_Type>:$args)> {
+  let assemblyFormat =
+    "$args attr-dict `:` functional-type($args, $res)";
+}
+
+def ROCDL_mfma_f32_32x32x1f32 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x1f32">;
+def ROCDL_mfma_f32_32x32x2f32 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x2f32">;
+def ROCDL_mfma_f32_16x16x4f32 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x4f32">;
+def ROCDL_mfma_f32_16x16x1f32 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x1f32">;
+def ROCDL_mfma_f32_32x32x4f16 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x4f16">;
+def ROCDL_mfma_f32_32x32x8f16 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x8f16">;
+def ROCDL_mfma_f32_16x16x4f16 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x4f16">;
+def ROCDL_mfma_f32_16x16x16f16 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x16f16">;
+def ROCDL_mfma_f32_32x32x2bf16 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x2bf16">;
+def ROCDL_mfma_f32_32x32x4bf16 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x4bf16">;
+def ROCDL_mfma_f32_16x16x8bf16 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x8bf16">;
+def ROCDL_mfma_f32_16x16x2bf16 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x2bf16">;
+def ROCDL_mfma_f32_4x4x2bf16 : ROCDL_Mfma_IntrOp<"mfma.f32.4x4x2bf16">;
+def ROCDL_mfma_f32_4x4x1f32 : ROCDL_Mfma_IntrOp<"mfma.f32.4x4x1f32">;
+def ROCDL_mfma_f32_4x4x4f16 : ROCDL_Mfma_IntrOp<"mfma.f32.4x4x4f16">;
+def ROCDL_mfma_i32_32x32x4i8 : ROCDL_Mfma_IntrOp<"mfma.i32.32x32x4i8">;
+def ROCDL_mfma_i32_16x16x4i8 : ROCDL_Mfma_IntrOp<"mfma.i32.16x16x4i8">;
+def ROCDL_mfma_i32_4x4x4i8 : ROCDL_Mfma_IntrOp<"mfma.i32.4x4x4i8">;
+def ROCDL_mfma_i32_32x32x8i8 : ROCDL_Mfma_IntrOp<"mfma.i32.32x32x8i8">;
+def ROCDL_mfma_i32_16x16x16i8 : ROCDL_Mfma_IntrOp<"mfma.i32.16x16x16i8">;
+
+
 #endif // ROCDLIR_OPS

diff  --git a/mlir/test/Dialect/LLVMIR/rocdl.mlir b/mlir/test/Dialect/LLVMIR/rocdl.mlir
index 2a6f10008f9e..4f634e187af1 100644
--- a/mlir/test/Dialect/LLVMIR/rocdl.mlir
+++ b/mlir/test/Dialect/LLVMIR/rocdl.mlir
@@ -34,3 +34,113 @@ func @rocdl.barrier() {
   rocdl.barrier
   llvm.return
 }
+
+func @rocdl.xdlops(%arg0 : !llvm.float, %arg1 : !llvm.float,
+                   %arg2 : !llvm<"<32 x float>">, %arg3 : !llvm.i32,
+                   %arg4 : !llvm<"<16 x float>">, %arg5 : !llvm<"<4 x float>">,
+                   %arg6 : !llvm<"<4 x half>">, %arg7 : !llvm<"<32 x i32>">,
+                   %arg8 : !llvm<"<16 x i32>">, %arg9 : !llvm<"<4 x i32>">,
+                   %arg10 : !llvm<"<2 x i16>">) -> !llvm<"<32 x float>"> {
+  // CHECK-LABEL: rocdl.xdlops
+  // CHECK: rocdl.mfma.f32.32x32x1f32 {{.*}} : (!llvm.float, !llvm.float, !llvm<"<32 x float>">, !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<32 x float>">
+  %r0 = rocdl.mfma.f32.32x32x1f32 %arg0, %arg1, %arg2, %arg3, %arg3, %arg3 :
+                            (!llvm.float, !llvm.float, !llvm<"<32 x float>">,
+                            !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<32 x float>">
+
+  // CHECK: rocdl.mfma.f32.16x16x1f32 {{.*}} : (!llvm.float, !llvm.float, !llvm<"<16 x float>">, !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<16 x float>">
+  %r1 = rocdl.mfma.f32.16x16x1f32 %arg0, %arg1, %arg4, %arg3, %arg3, %arg3 :
+                            (!llvm.float, !llvm.float, !llvm<"<16 x float>">,
+                            !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<16 x float>">
+
+  // CHECK: rocdl.mfma.f32.16x16x4f32 {{.*}} : (!llvm.float, !llvm.float, !llvm<"<4 x float>">, !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<4 x float>">
+  %r2 = rocdl.mfma.f32.16x16x4f32 %arg0, %arg1, %arg5, %arg3, %arg3, %arg3 :
+                            (!llvm.float, !llvm.float, !llvm<"<4 x float>">,
+                            !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<4 x float>">
+
+  // CHECK: rocdl.mfma.f32.4x4x1f32 {{.*}} : (!llvm.float, !llvm.float, !llvm<"<4 x float>">, !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<4 x float>">
+  %r3 = rocdl.mfma.f32.4x4x1f32 %arg0, %arg1, %arg5, %arg3, %arg3, %arg3 :
+                            (!llvm.float, !llvm.float, !llvm<"<4 x float>">,
+                            !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<4 x float>">
+
+  // CHECK: rocdl.mfma.f32.32x32x2f32 {{.*}} : (!llvm.float, !llvm.float, !llvm<"<16 x float>">, !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<16 x float>">
+  %r4= rocdl.mfma.f32.32x32x2f32 %arg0, %arg1, %arg4, %arg3, %arg3, %arg3 :
+                            (!llvm.float, !llvm.float, !llvm<"<16 x float>">,
+                            !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<16 x float>">
+
+  // CHECK: rocdl.mfma.f32.32x32x4f16 {{.*}} : (!llvm<"<4 x half>">, !llvm<"<4 x half>">, !llvm<"<32 x float>">, !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<32 x float>">
+  %r5 = rocdl.mfma.f32.32x32x4f16 %arg6, %arg6, %arg2, %arg3, %arg3, %arg3 :
+                            (!llvm<"<4 x half>">, !llvm<"<4 x half>">, !llvm<"<32 x float>">,
+                            !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<32 x float>">
+
+  // CHECK: rocdl.mfma.f32.16x16x4f16 {{.*}} : (!llvm<"<4 x half>">, !llvm<"<4 x half>">, !llvm<"<16 x float>">, !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<16 x float>">
+  %r6 = rocdl.mfma.f32.16x16x4f16 %arg6, %arg6, %arg4, %arg3, %arg3, %arg3 :
+                            (!llvm<"<4 x half>">, !llvm<"<4 x half>">, !llvm<"<16 x float>">,
+                            !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<16 x float>">
+
+  // CHECK: rocdl.mfma.f32.4x4x4f16 {{.*}} : (!llvm<"<4 x half>">, !llvm<"<4 x half>">, !llvm<"<4 x float>">, !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<4 x float>">
+  %r7 = rocdl.mfma.f32.4x4x4f16 %arg6, %arg6, %arg5, %arg3, %arg3, %arg3 :
+                            (!llvm<"<4 x half>">, !llvm<"<4 x half>">, !llvm<"<4 x float>">,
+                            !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<4 x float>">
+
+  // CHECK: rocdl.mfma.f32.32x32x8f16 {{.*}} : (!llvm<"<4 x half>">, !llvm<"<4 x half>">, !llvm<"<16 x float>">, !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<16 x float>">
+  %r8 = rocdl.mfma.f32.32x32x8f16 %arg6, %arg6, %arg4, %arg3, %arg3, %arg3 :
+                            (!llvm<"<4 x half>">, !llvm<"<4 x half>">, !llvm<"<16 x float>">,
+                            !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<16 x float>">
+
+  // CHECK: rocdl.mfma.f32.16x16x16f16 {{.*}} : (!llvm<"<4 x half>">, !llvm<"<4 x half>">, !llvm<"<4 x float>">, !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<4 x float>">
+  %r9 = rocdl.mfma.f32.16x16x16f16 %arg6, %arg6, %arg5, %arg3, %arg3, %arg3 :
+                            (!llvm<"<4 x half>">, !llvm<"<4 x half>">, !llvm<"<4 x float>">,
+                            !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<4 x float>">
+
+  // CHECK: rocdl.mfma.i32.32x32x4i8 {{.*}} : (!llvm.i32, !llvm.i32, !llvm<"<32 x i32>">, !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<32 x i32>">
+  %r10 = rocdl.mfma.i32.32x32x4i8 %arg3, %arg3, %arg7, %arg3, %arg3, %arg3 :
+                            (!llvm.i32, !llvm.i32, !llvm<"<32 x i32>">,
+                            !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<32 x i32>">
+
+  // CHECK: rocdl.mfma.i32.16x16x4i8 {{.*}} : (!llvm.i32, !llvm.i32, !llvm<"<16 x i32>">, !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<16 x i32>">
+  %r11 = rocdl.mfma.i32.16x16x4i8 %arg3, %arg3, %arg8, %arg3, %arg3, %arg3 :
+                            (!llvm.i32, !llvm.i32, !llvm<"<16 x i32>">,
+                            !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<16 x i32>">
+
+  // CHECK: rocdl.mfma.i32.4x4x4i8 {{.*}} : (!llvm.i32, !llvm.i32, !llvm<"<4 x i32>">, !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<4 x i32>">
+  %r12 = rocdl.mfma.i32.4x4x4i8 %arg3, %arg3, %arg9, %arg3, %arg3, %arg3 :
+                            (!llvm.i32, !llvm.i32, !llvm<"<4 x i32>">,
+                            !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<4 x i32>">
+
+  // CHECK: rocdl.mfma.i32.32x32x8i8 {{.*}} : (!llvm.i32, !llvm.i32, !llvm<"<16 x i32>">, !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<16 x i32>">
+  %r13 = rocdl.mfma.i32.32x32x8i8 %arg3, %arg3, %arg8, %arg3, %arg3, %arg3 :
+                            (!llvm.i32, !llvm.i32, !llvm<"<16 x i32>">,
+                            !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<16 x i32>">
+
+  // CHECK: rocdl.mfma.i32.16x16x16i8 {{.*}} : (!llvm.i32, !llvm.i32, !llvm<"<4 x i32>">, !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<4 x i32>">
+  %r14 = rocdl.mfma.i32.16x16x16i8 %arg3, %arg3, %arg9, %arg3, %arg3, %arg3 :
+                            (!llvm.i32, !llvm.i32, !llvm<"<4 x i32>">,
+                            !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<4 x i32>">
+
+  // CHECK: rocdl.mfma.f32.32x32x2bf16 {{.*}} : (!llvm<"<2 x i16>">, !llvm<"<2 x i16>">, !llvm<"<32 x float>">, !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<32 x float>">
+  %r15 = rocdl.mfma.f32.32x32x2bf16 %arg10, %arg10, %arg2, %arg3, %arg3, %arg3 :
+                            (!llvm<"<2 x i16>">, !llvm<"<2 x i16>">, !llvm<"<32 x float>">,
+                            !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<32 x float>">
+
+  // CHECK: rocdl.mfma.f32.16x16x2bf16 {{.*}} : (!llvm<"<2 x i16>">, !llvm<"<2 x i16>">, !llvm<"<16 x float>">, !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<16 x float>">
+  %r16 = rocdl.mfma.f32.16x16x2bf16 %arg10, %arg10, %arg4, %arg3, %arg3, %arg3 :
+                            (!llvm<"<2 x i16>">, !llvm<"<2 x i16>">, !llvm<"<16 x float>">,
+                            !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<16 x float>">
+
+  // CHECK: rocdl.mfma.f32.4x4x2bf16 {{.*}} : (!llvm<"<2 x i16>">, !llvm<"<2 x i16>">, !llvm<"<4 x float>">, !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<4 x float>">
+  %r17 = rocdl.mfma.f32.4x4x2bf16 %arg10, %arg10, %arg5, %arg3, %arg3, %arg3 :
+                            (!llvm<"<2 x i16>">, !llvm<"<2 x i16>">, !llvm<"<4 x float>">,
+                            !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<4 x float>">
+
+  // CHECK: rocdl.mfma.f32.32x32x4bf16 {{.*}} : (!llvm<"<2 x i16>">, !llvm<"<2 x i16>">, !llvm<"<16 x float>">, !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<16 x float>">
+  %r18 = rocdl.mfma.f32.32x32x4bf16 %arg10, %arg10, %arg4, %arg3, %arg3, %arg3 :
+                            (!llvm<"<2 x i16>">, !llvm<"<2 x i16>">, !llvm<"<16 x float>">,
+                            !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<16 x float>">
+
+  // CHECK: rocdl.mfma.f32.16x16x8bf16 {{.*}} : (!llvm<"<2 x i16>">, !llvm<"<2 x i16>">, !llvm<"<4 x float>">, !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<4 x float>">
+  %r19 = rocdl.mfma.f32.16x16x8bf16 %arg10, %arg10, %arg5, %arg3, %arg3, %arg3 :
+                            (!llvm<"<2 x i16>">, !llvm<"<2 x i16>">, !llvm<"<4 x float>">,
+                            !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<4 x float>">
+
+  llvm.return %r0 : !llvm<"<32 x float>">
+}

diff  --git a/mlir/test/Target/rocdl.mlir b/mlir/test/Target/rocdl.mlir
index 773c4e3928ee..676865fe1a93 100644
--- a/mlir/test/Target/rocdl.mlir
+++ b/mlir/test/Target/rocdl.mlir
@@ -41,3 +41,113 @@ llvm.func @rocdl.barrier() {
   rocdl.barrier
   llvm.return
 }
+
+llvm.func @rocdl.xdlops(%arg0 : !llvm.float, %arg1 : !llvm.float,
+                   %arg2 : !llvm<"<32 x float>">, %arg3 : !llvm.i32,
+                   %arg4 : !llvm<"<16 x float>">, %arg5 : !llvm<"<4 x float>">,
+                   %arg6 : !llvm<"<4 x half>">, %arg7 : !llvm<"<32 x i32>">,
+                   %arg8 : !llvm<"<16 x i32>">, %arg9 : !llvm<"<4 x i32>">,
+                   %arg10 : !llvm<"<2 x i16>">) -> !llvm<"<32 x float>"> {
+  // CHECK-LABEL: rocdl.xdlops
+  // CHECK: call <32 x float> @llvm.amdgcn.mfma.f32.32x32x1f32(float %{{.*}}, float %{{.*}}, <32 x float> %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+  %r0 = rocdl.mfma.f32.32x32x1f32 %arg0, %arg1, %arg2, %arg3, %arg3, %arg3 :
+                            (!llvm.float, !llvm.float, !llvm<"<32 x float>">,
+                            !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<32 x float>">
+
+  // CHECK: call <16 x float> @llvm.amdgcn.mfma.f32.16x16x1f32(float %{{.*}}, float %{{.*}}, <16 x float> %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+  %r1 = rocdl.mfma.f32.16x16x1f32 %arg0, %arg1, %arg4, %arg3, %arg3, %arg3 :
+                            (!llvm.float, !llvm.float, !llvm<"<16 x float>">,
+                            !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<16 x float>">
+
+  // CHECK: call <4 x float> @llvm.amdgcn.mfma.f32.16x16x4f32(float %{{.*}}, float %{{.*}}, <4 x float> %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+  %r2 = rocdl.mfma.f32.16x16x4f32 %arg0, %arg1, %arg5, %arg3, %arg3, %arg3 :
+                            (!llvm.float, !llvm.float, !llvm<"<4 x float>">,
+                            !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<4 x float>">
+
+  // CHECK: call <4 x float> @llvm.amdgcn.mfma.f32.4x4x1f32(float %{{.*}}, float %{{.*}}, <4 x float> %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+  %r3 = rocdl.mfma.f32.4x4x1f32 %arg0, %arg1, %arg5, %arg3, %arg3, %arg3 :
+                            (!llvm.float, !llvm.float, !llvm<"<4 x float>">,
+                            !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<4 x float>">
+
+  // CHECK: call <16 x float> @llvm.amdgcn.mfma.f32.32x32x2f32(float %{{.*}}, float %{{.*}}, <16 x float> %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+  %r4= rocdl.mfma.f32.32x32x2f32 %arg0, %arg1, %arg4, %arg3, %arg3, %arg3 :
+                            (!llvm.float, !llvm.float, !llvm<"<16 x float>">,
+                            !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<16 x float>">
+
+  // CHECK: call <32 x float> @llvm.amdgcn.mfma.f32.32x32x4f16(<4 x half> %{{.*}}, <4 x half> %{{.*}}, <32 x float> %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+  %r5 = rocdl.mfma.f32.32x32x4f16 %arg6, %arg6, %arg2, %arg3, %arg3, %arg3 :
+                            (!llvm<"<4 x half>">, !llvm<"<4 x half>">, !llvm<"<32 x float>">,
+                            !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<32 x float>">
+
+  // CHECK: call <16 x float> @llvm.amdgcn.mfma.f32.16x16x4f16(<4 x half> %{{.*}}, <4 x half> %{{.*}}, <16 x float> %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+  %r6 = rocdl.mfma.f32.16x16x4f16 %arg6, %arg6, %arg4, %arg3, %arg3, %arg3 :
+                            (!llvm<"<4 x half>">, !llvm<"<4 x half>">, !llvm<"<16 x float>">,
+                            !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<16 x float>">
+
+  // CHECK: call <4 x float> @llvm.amdgcn.mfma.f32.4x4x4f16(<4 x half> %{{.*}}, <4 x half> %{{.*}}, <4 x float> %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+  %r7 = rocdl.mfma.f32.4x4x4f16 %arg6, %arg6, %arg5, %arg3, %arg3, %arg3 :
+                            (!llvm<"<4 x half>">, !llvm<"<4 x half>">, !llvm<"<4 x float>">,
+                            !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<4 x float>">
+
+  // CHECK: call <16 x float> @llvm.amdgcn.mfma.f32.32x32x8f16(<4 x half> %{{.*}}, <4 x half> %{{.*}}, <16 x float> %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+  %r8 = rocdl.mfma.f32.32x32x8f16 %arg6, %arg6, %arg4, %arg3, %arg3, %arg3 :
+                            (!llvm<"<4 x half>">, !llvm<"<4 x half>">, !llvm<"<16 x float>">,
+                            !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<16 x float>">
+
+  // CHECK: call <4 x float> @llvm.amdgcn.mfma.f32.16x16x16f16(<4 x half> %{{.*}}, <4 x half> %{{.*}}, <4 x float> %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+  %r9 = rocdl.mfma.f32.16x16x16f16 %arg6, %arg6, %arg5, %arg3, %arg3, %arg3 :
+                            (!llvm<"<4 x half>">, !llvm<"<4 x half>">, !llvm<"<4 x float>">,
+                            !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<4 x float>">
+
+  // CHECK: call <32 x i32> @llvm.amdgcn.mfma.i32.32x32x4i8(i32 %{{.*}}, i32 %{{.*}}, <32 x i32> %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+  %r10 = rocdl.mfma.i32.32x32x4i8 %arg3, %arg3, %arg7, %arg3, %arg3, %arg3 :
+                            (!llvm.i32, !llvm.i32, !llvm<"<32 x i32>">,
+                            !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<32 x i32>">
+
+  // CHECK: call <16 x i32> @llvm.amdgcn.mfma.i32.16x16x4i8(i32 %{{.*}}, i32 %{{.*}}, <16 x i32> %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+  %r11 = rocdl.mfma.i32.16x16x4i8 %arg3, %arg3, %arg8, %arg3, %arg3, %arg3 :
+                            (!llvm.i32, !llvm.i32, !llvm<"<16 x i32>">,
+                            !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<16 x i32>">
+
+  // CHECK: call <4 x i32> @llvm.amdgcn.mfma.i32.4x4x4i8(i32 %{{.*}}, i32 %{{.*}}, <4 x i32> %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+  %r12 = rocdl.mfma.i32.4x4x4i8 %arg3, %arg3, %arg9, %arg3, %arg3, %arg3 :
+                            (!llvm.i32, !llvm.i32, !llvm<"<4 x i32>">,
+                            !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<4 x i32>">
+
+  // CHECK: call <16 x i32> @llvm.amdgcn.mfma.i32.32x32x8i8(i32 %{{.*}}, i32 %{{.*}}, <16 x i32> %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+  %r13 = rocdl.mfma.i32.32x32x8i8 %arg3, %arg3, %arg8, %arg3, %arg3, %arg3 :
+                            (!llvm.i32, !llvm.i32, !llvm<"<16 x i32>">,
+                            !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<16 x i32>">
+
+  // CHECK: call <4 x i32> @llvm.amdgcn.mfma.i32.16x16x16i8(i32 %{{.*}}, i32 %{{.*}}, <4 x i32> %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+  %r14 = rocdl.mfma.i32.16x16x16i8 %arg3, %arg3, %arg9, %arg3, %arg3, %arg3 :
+                            (!llvm.i32, !llvm.i32, !llvm<"<4 x i32>">,
+                            !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<4 x i32>">
+
+  // CHECK: call <32 x float> @llvm.amdgcn.mfma.f32.32x32x2bf16(<2 x i16> %{{.*}}, <2 x i16> %{{.*}}, <32 x float> %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+  %r15 = rocdl.mfma.f32.32x32x2bf16 %arg10, %arg10, %arg2, %arg3, %arg3, %arg3 :
+                            (!llvm<"<2 x i16>">, !llvm<"<2 x i16>">, !llvm<"<32 x float>">,
+                            !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<32 x float>">
+
+  // CHECK: call <16 x float> @llvm.amdgcn.mfma.f32.16x16x2bf16(<2 x i16> %{{.*}}, <2 x i16> %{{.*}}, <16 x float> %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+  %r16 = rocdl.mfma.f32.16x16x2bf16 %arg10, %arg10, %arg4, %arg3, %arg3, %arg3 :
+                            (!llvm<"<2 x i16>">, !llvm<"<2 x i16>">, !llvm<"<16 x float>">,
+                            !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<16 x float>">
+
+  // CHECK: call <4 x float> @llvm.amdgcn.mfma.f32.4x4x2bf16(<2 x i16> %{{.*}}, <2 x i16> %{{.*}}, <4 x float> %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+  %r17 = rocdl.mfma.f32.4x4x2bf16 %arg10, %arg10, %arg5, %arg3, %arg3, %arg3 :
+                            (!llvm<"<2 x i16>">, !llvm<"<2 x i16>">, !llvm<"<4 x float>">,
+                            !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<4 x float>">
+
+  // CHECK: call <16 x float> @llvm.amdgcn.mfma.f32.32x32x4bf16(<2 x i16> %{{.*}}, <2 x i16> %{{.*}}, <16 x float> %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+  %r18 = rocdl.mfma.f32.32x32x4bf16 %arg10, %arg10, %arg4, %arg3, %arg3, %arg3 :
+                            (!llvm<"<2 x i16>">, !llvm<"<2 x i16>">, !llvm<"<16 x float>">,
+                            !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<16 x float>">
+
+  // CHECK: call <4 x float> @llvm.amdgcn.mfma.f32.16x16x8bf16(<2 x i16> %{{.*}}, <2 x i16> %{{.*}}, <4 x float> %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+  %r19 = rocdl.mfma.f32.16x16x8bf16 %arg10, %arg10, %arg5, %arg3, %arg3, %arg3 :
+                            (!llvm<"<2 x i16>">, !llvm<"<2 x i16>">, !llvm<"<4 x float>">,
+                            !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<4 x float>">
+
+  llvm.return %r0 : !llvm<"<32 x float>">
+}


        


More information about the Mlir-commits mailing list