[llvm] [clang] [AMDGPU] Add GFX12 WMMA and SWMMAC instructions (PR #77795)

via cfe-commits cfe-commits at lists.llvm.org
Thu Jan 11 08:28:03 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-ir
@llvm/pr-subscribers-backend-amdgpu

@llvm/pr-subscribers-clang-codegen

Author: Mirko BrkuĊĦanin (mbrkusanin)

<details>
<summary>Changes</summary>



---

Patch is 1002.09 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/77795.diff


62 Files Affected:

- (modified) clang/include/clang/Basic/BuiltinsAMDGPU.def (+61) 
- (modified) clang/lib/CodeGen/CGBuiltin.cpp (+159-13) 
- (added) clang/test/CodeGenOpenCL/builtins-amdgcn-gfx12-wmma-w32.cl (+156) 
- (added) clang/test/CodeGenOpenCL/builtins-amdgcn-gfx12-wmma-w64.cl (+155) 
- (added) clang/test/CodeGenOpenCL/builtins-amdgcn-swmmac-w32.cl (+135) 
- (added) clang/test/CodeGenOpenCL/builtins-amdgcn-swmmac-w64.cl (+134) 
- (modified) clang/test/CodeGenOpenCL/builtins-amdgcn-wmma-w32.cl (+8-9) 
- (modified) clang/test/CodeGenOpenCL/builtins-amdgcn-wmma-w64.cl (+8-10) 
- (added) cross-project-tests/amdgpu/builtins-amdgcn-gfx12-wmma-w32.cl (+107) 
- (added) cross-project-tests/amdgpu/builtins-amdgcn-gfx12-wmma-w64.cl (+104) 
- (added) cross-project-tests/amdgpu/builtins-amdgcn-swmmac-w32.cl (+110) 
- (added) cross-project-tests/amdgpu/builtins-amdgcn-swmmac-w64.cl (+109) 
- (modified) llvm/include/llvm/IR/IntrinsicsAMDGPU.td (+93-28) 
- (modified) llvm/lib/Target/AMDGPU/AMDGPUGISel.td (+27-3) 
- (modified) llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp (+331-1) 
- (modified) llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.h (+11-1) 
- (modified) llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.cpp (+214-1) 
- (modified) llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.h (+14-1) 
- (modified) llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp (+23) 
- (modified) llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp (+16) 
- (modified) llvm/lib/Target/AMDGPU/AMDGPUSearchableTables.td (+16) 
- (modified) llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp (+141-5) 
- (modified) llvm/lib/Target/AMDGPU/Disassembler/AMDGPUDisassembler.cpp (+8) 
- (modified) llvm/lib/Target/AMDGPU/GCNHazardRecognizer.cpp (+5-3) 
- (modified) llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.cpp (+37) 
- (modified) llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.h (+4) 
- (modified) llvm/lib/Target/AMDGPU/SIDefines.h (+3) 
- (modified) llvm/lib/Target/AMDGPU/SIFoldOperands.cpp (+1) 
- (modified) llvm/lib/Target/AMDGPU/SIISelLowering.cpp (+31-1) 
- (modified) llvm/lib/Target/AMDGPU/SIInstrFormats.td (+5) 
- (modified) llvm/lib/Target/AMDGPU/SIInstrInfo.h (+8) 
- (modified) llvm/lib/Target/AMDGPU/SIInstrInfo.td (+12-1) 
- (modified) llvm/lib/Target/AMDGPU/SIRegisterInfo.td (+5) 
- (modified) llvm/lib/Target/AMDGPU/VOP3PInstructions.td (+498-10) 
- (modified) llvm/lib/Target/AMDGPU/VOPInstructions.td (+3) 
- (modified) llvm/test/Analysis/UniformityAnalysis/AMDGPU/intrinsics.ll (+118-18) 
- (added) llvm/test/CodeGen/AMDGPU/GlobalISel/wmma-gfx12-w32-f16-f32-matrix-modifiers.ll (+504) 
- (added) llvm/test/CodeGen/AMDGPU/GlobalISel/wmma-gfx12-w32-imm.ll (+519) 
- (added) llvm/test/CodeGen/AMDGPU/GlobalISel/wmma-gfx12-w32-iu-modifiers.ll (+309) 
- (added) llvm/test/CodeGen/AMDGPU/GlobalISel/wmma-gfx12-w32-swmmac-index_key.ll (+321) 
- (added) llvm/test/CodeGen/AMDGPU/GlobalISel/wmma-gfx12-w32.ll (+370) 
- (added) llvm/test/CodeGen/AMDGPU/GlobalISel/wmma-gfx12-w64-f16-f32-matrix-modifiers.ll (+459) 
- (added) llvm/test/CodeGen/AMDGPU/GlobalISel/wmma-gfx12-w64-imm.ll (+430) 
- (added) llvm/test/CodeGen/AMDGPU/GlobalISel/wmma-gfx12-w64-iu-modifiers.ll (+274) 
- (added) llvm/test/CodeGen/AMDGPU/GlobalISel/wmma-gfx12-w64-swmmac-index_key.ll (+472) 
- (added) llvm/test/CodeGen/AMDGPU/GlobalISel/wmma-gfx12-w64.ll (+333) 
- (added) llvm/test/CodeGen/AMDGPU/wmma-gfx12-w32-f16-f32-matrix-modifiers.ll (+499) 
- (added) llvm/test/CodeGen/AMDGPU/wmma-gfx12-w32-imm.ll (+431) 
- (added) llvm/test/CodeGen/AMDGPU/wmma-gfx12-w32-iu-modifiers.ll (+309) 
- (added) llvm/test/CodeGen/AMDGPU/wmma-gfx12-w32-swmmac-index_key.ll (+321) 
- (added) llvm/test/CodeGen/AMDGPU/wmma-gfx12-w32.ll (+370) 
- (added) llvm/test/CodeGen/AMDGPU/wmma-gfx12-w64-f16-f32-matrix-modifiers.ll (+456) 
- (added) llvm/test/CodeGen/AMDGPU/wmma-gfx12-w64-imm.ll (+373) 
- (added) llvm/test/CodeGen/AMDGPU/wmma-gfx12-w64-iu-modifiers.ll (+274) 
- (added) llvm/test/CodeGen/AMDGPU/wmma-gfx12-w64-swmmac-index_key.ll (+472) 
- (added) llvm/test/CodeGen/AMDGPU/wmma-gfx12-w64.ll (+333) 
- (added) llvm/test/CodeGen/AMDGPU/wmma-hazards-gfx12-w32.mir (+358) 
- (added) llvm/test/CodeGen/AMDGPU/wmma-hazards-gfx12-w64.mir (+359) 
- (added) llvm/test/MC/AMDGPU/gfx12_asm_wmma_w32.s (+1529) 
- (added) llvm/test/MC/AMDGPU/gfx12_asm_wmma_w64.s (+1529) 
- (added) llvm/test/MC/Disassembler/AMDGPU/gfx12_dasm_wmma_w32.txt (+1628) 
- (added) llvm/test/MC/Disassembler/AMDGPU/gfx12_dasm_wmma_w64.txt (+1628) 


``````````diff
diff --git a/clang/include/clang/Basic/BuiltinsAMDGPU.def b/clang/include/clang/Basic/BuiltinsAMDGPU.def
index e562ef04a30194..026c0af65c92bb 100644
--- a/clang/include/clang/Basic/BuiltinsAMDGPU.def
+++ b/clang/include/clang/Basic/BuiltinsAMDGPU.def
@@ -423,6 +423,67 @@ TARGET_BUILTIN(__builtin_amdgcn_s_wakeup_barrier, "vi", "n", "gfx12-insts")
 TARGET_BUILTIN(__builtin_amdgcn_s_barrier_leave, "b", "n", "gfx12-insts")
 TARGET_BUILTIN(__builtin_amdgcn_s_get_barrier_state, "Uii", "n", "gfx12-insts")
 
+//===----------------------------------------------------------------------===//
+// WMMA builtins.
+// Postfix w32 indicates the builtin requires wavefront size of 32.
+// Postfix w64 indicates the builtin requires wavefront size of 64.
+//
+// Some of these are very similar to their GFX11 counterparts, but they don't
+// require replication of the A,B matrices, so they use fewer vector elements.
+// Therefore, we add an "_gfx12" suffix to distinguish them from the existing
+// builtins.
+//===----------------------------------------------------------------------===//
+TARGET_BUILTIN(__builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12, "V8fV8hV8hV8f", "nc", "gfx12-insts,wavefrontsize32")
+TARGET_BUILTIN(__builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12, "V8fV8sV8sV8f", "nc", "gfx12-insts,wavefrontsize32")
+TARGET_BUILTIN(__builtin_amdgcn_wmma_f16_16x16x16_f16_w32_gfx12, "V8hV8hV8hV8h", "nc", "gfx12-insts,wavefrontsize32")
+TARGET_BUILTIN(__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w32_gfx12, "V8sV8sV8sV8s", "nc", "gfx12-insts,wavefrontsize32")
+TARGET_BUILTIN(__builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12, "V8iIbV2iIbV2iV8iIb", "nc", "gfx12-insts,wavefrontsize32")
+TARGET_BUILTIN(__builtin_amdgcn_wmma_i32_16x16x16_iu4_w32_gfx12, "V8iIbiIbiV8iIb", "nc", "gfx12-insts,wavefrontsize32")
+// These are gfx12-only, but for consistency with the other WMMA variants we're
+// keeping the "_gfx12" suffix.
+TARGET_BUILTIN(__builtin_amdgcn_wmma_f32_16x16x16_fp8_fp8_w32_gfx12, "V8fV2iV2iV8f", "nc", "gfx12-insts,wavefrontsize32")
+TARGET_BUILTIN(__builtin_amdgcn_wmma_f32_16x16x16_fp8_bf8_w32_gfx12, "V8fV2iV2iV8f", "nc", "gfx12-insts,wavefrontsize32")
+TARGET_BUILTIN(__builtin_amdgcn_wmma_f32_16x16x16_bf8_fp8_w32_gfx12, "V8fV2iV2iV8f", "nc", "gfx12-insts,wavefrontsize32")
+TARGET_BUILTIN(__builtin_amdgcn_wmma_f32_16x16x16_bf8_bf8_w32_gfx12, "V8fV2iV2iV8f", "nc", "gfx12-insts,wavefrontsize32")
+TARGET_BUILTIN(__builtin_amdgcn_wmma_i32_16x16x32_iu4_w32_gfx12, "V8iIbV2iIbV2iV8iIb", "nc", "gfx12-insts,wavefrontsize32")
+
+TARGET_BUILTIN(__builtin_amdgcn_wmma_f32_16x16x16_f16_w64_gfx12, "V4fV4hV4hV4f", "nc", "gfx12-insts,wavefrontsize64")
+TARGET_BUILTIN(__builtin_amdgcn_wmma_f32_16x16x16_bf16_w64_gfx12, "V4fV4sV4sV4f", "nc", "gfx12-insts,wavefrontsize64")
+TARGET_BUILTIN(__builtin_amdgcn_wmma_f16_16x16x16_f16_w64_gfx12, "V4hV4hV4hV4h", "nc", "gfx12-insts,wavefrontsize64")
+TARGET_BUILTIN(__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w64_gfx12, "V4sV4sV4sV4s", "nc", "gfx12-insts,wavefrontsize64")
+TARGET_BUILTIN(__builtin_amdgcn_wmma_i32_16x16x16_iu8_w64_gfx12, "V4iIbiIbiV4iIb", "nc", "gfx12-insts,wavefrontsize64")
+TARGET_BUILTIN(__builtin_amdgcn_wmma_i32_16x16x16_iu4_w64_gfx12, "V4iIbiIbiV4iIb", "nc", "gfx12-insts,wavefrontsize64")
+// These are gfx12-only, but for consistency with the other WMMA variants we're
+// keeping the "_gfx12" suffix.
+TARGET_BUILTIN(__builtin_amdgcn_wmma_f32_16x16x16_fp8_fp8_w64_gfx12, "V4fiiV4f", "nc", "gfx12-insts,wavefrontsize64")
+TARGET_BUILTIN(__builtin_amdgcn_wmma_f32_16x16x16_fp8_bf8_w64_gfx12, "V4fiiV4f", "nc", "gfx12-insts,wavefrontsize64")
+TARGET_BUILTIN(__builtin_amdgcn_wmma_f32_16x16x16_bf8_fp8_w64_gfx12, "V4fiiV4f", "nc", "gfx12-insts,wavefrontsize64")
+TARGET_BUILTIN(__builtin_amdgcn_wmma_f32_16x16x16_bf8_bf8_w64_gfx12, "V4fiiV4f", "nc", "gfx12-insts,wavefrontsize64")
+TARGET_BUILTIN(__builtin_amdgcn_wmma_i32_16x16x32_iu4_w64_gfx12, "V4iIbiIbiV4iIb", "nc", "gfx12-insts,wavefrontsize64")
+
+TARGET_BUILTIN(__builtin_amdgcn_swmmac_f32_16x16x32_f16_w32, "V8fV8hV16hV8fs", "nc", "gfx12-insts,wavefrontsize32")
+TARGET_BUILTIN(__builtin_amdgcn_swmmac_f32_16x16x32_bf16_w32, "V8fV8sV16sV8fs", "nc", "gfx12-insts,wavefrontsize32")
+TARGET_BUILTIN(__builtin_amdgcn_swmmac_f16_16x16x32_f16_w32, "V8hV8hV16hV8hs", "nc", "gfx12-insts,wavefrontsize32")
+TARGET_BUILTIN(__builtin_amdgcn_swmmac_bf16_16x16x32_bf16_w32, "V8sV8sV16sV8ss", "nc", "gfx12-insts,wavefrontsize32")
+TARGET_BUILTIN(__builtin_amdgcn_swmmac_i32_16x16x32_iu8_w32, "V8iIbV2iIbV4iV8isIb", "nc", "gfx12-insts,wavefrontsize32")
+TARGET_BUILTIN(__builtin_amdgcn_swmmac_i32_16x16x32_iu4_w32, "V8iIbiIbV2iV8isIb", "nc", "gfx12-insts,wavefrontsize32")
+TARGET_BUILTIN(__builtin_amdgcn_swmmac_i32_16x16x64_iu4_w32, "V8iIbV2iIbV4iV8isIb", "nc", "gfx12-insts,wavefrontsize32")
+TARGET_BUILTIN(__builtin_amdgcn_swmmac_f32_16x16x32_fp8_fp8_w32, "V8fV2iV4iV8fs", "nc", "gfx12-insts,wavefrontsize32")
+TARGET_BUILTIN(__builtin_amdgcn_swmmac_f32_16x16x32_fp8_bf8_w32, "V8fV2iV4iV8fs", "nc", "gfx12-insts,wavefrontsize32")
+TARGET_BUILTIN(__builtin_amdgcn_swmmac_f32_16x16x32_bf8_fp8_w32, "V8fV2iV4iV8fs", "nc", "gfx12-insts,wavefrontsize32")
+TARGET_BUILTIN(__builtin_amdgcn_swmmac_f32_16x16x32_bf8_bf8_w32, "V8fV2iV4iV8fs", "nc", "gfx12-insts,wavefrontsize32")
+
+TARGET_BUILTIN(__builtin_amdgcn_swmmac_f32_16x16x32_f16_w64, "V4fV4hV8hV4fs", "nc", "gfx12-insts,wavefrontsize64")
+TARGET_BUILTIN(__builtin_amdgcn_swmmac_f32_16x16x32_bf16_w64, "V4fV4sV8sV4fs", "nc", "gfx12-insts,wavefrontsize64")
+TARGET_BUILTIN(__builtin_amdgcn_swmmac_f16_16x16x32_f16_w64, "V4hV4hV8hV4hs", "nc", "gfx12-insts,wavefrontsize64")
+TARGET_BUILTIN(__builtin_amdgcn_swmmac_bf16_16x16x32_bf16_w64, "V4sV4sV8sV4ss", "nc", "gfx12-insts,wavefrontsize64")
+TARGET_BUILTIN(__builtin_amdgcn_swmmac_i32_16x16x32_iu8_w64, "V4iIbiIbV2iV4isIb", "nc", "gfx12-insts,wavefrontsize64")
+TARGET_BUILTIN(__builtin_amdgcn_swmmac_i32_16x16x32_iu4_w64, "V4iIbiIbiV4isIb", "nc", "gfx12-insts,wavefrontsize64")
+TARGET_BUILTIN(__builtin_amdgcn_swmmac_i32_16x16x64_iu4_w64, "V4iIbiIbV2iV4isIb", "nc", "gfx12-insts,wavefrontsize64")
+TARGET_BUILTIN(__builtin_amdgcn_swmmac_f32_16x16x32_fp8_fp8_w64, "V4fiV2iV4fs", "nc", "gfx12-insts,wavefrontsize64")
+TARGET_BUILTIN(__builtin_amdgcn_swmmac_f32_16x16x32_fp8_bf8_w64, "V4fiV2iV4fs", "nc", "gfx12-insts,wavefrontsize64")
+TARGET_BUILTIN(__builtin_amdgcn_swmmac_f32_16x16x32_bf8_fp8_w64, "V4fiV2iV4fs", "nc", "gfx12-insts,wavefrontsize64")
+TARGET_BUILTIN(__builtin_amdgcn_swmmac_f32_16x16x32_bf8_bf8_w64, "V4fiV2iV4fs", "nc", "gfx12-insts,wavefrontsize64")
 
 #undef BUILTIN
 #undef TARGET_BUILTIN
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 998fcc3af58175..c588b32f698bf5 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -18240,65 +18240,211 @@ Value *CodeGenFunction::EmitAMDGPUBuiltinExpr(unsigned BuiltinID,
   case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu4_w32:
   case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu4_w64:
   case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu8_w32:
-  case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu8_w64: {
+  case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu8_w64:
+  case AMDGPU::BI__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w32_gfx12:
+  case AMDGPU::BI__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w64_gfx12:
+  case AMDGPU::BI__builtin_amdgcn_wmma_f16_16x16x16_f16_w32_gfx12:
+  case AMDGPU::BI__builtin_amdgcn_wmma_f16_16x16x16_f16_w64_gfx12:
+  case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12:
+  case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf16_w64_gfx12:
+  case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12:
+  case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_f16_w64_gfx12:
+  case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu4_w32_gfx12:
+  case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu4_w64_gfx12:
+  case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12:
+  case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu8_w64_gfx12:
+  case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_fp8_fp8_w32_gfx12:
+  case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_fp8_fp8_w64_gfx12:
+  case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_fp8_bf8_w32_gfx12:
+  case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_fp8_bf8_w64_gfx12:
+  case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf8_fp8_w32_gfx12:
+  case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf8_fp8_w64_gfx12:
+  case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf8_bf8_w32_gfx12:
+  case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf8_bf8_w64_gfx12:
+  case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x32_iu4_w32_gfx12:
+  case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x32_iu4_w64_gfx12:
+  case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_f16_w32:
+  case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_f16_w64:
+  case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_bf16_w32:
+  case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_bf16_w64:
+  case AMDGPU::BI__builtin_amdgcn_swmmac_f16_16x16x32_f16_w32:
+  case AMDGPU::BI__builtin_amdgcn_swmmac_f16_16x16x32_f16_w64:
+  case AMDGPU::BI__builtin_amdgcn_swmmac_bf16_16x16x32_bf16_w32:
+  case AMDGPU::BI__builtin_amdgcn_swmmac_bf16_16x16x32_bf16_w64:
+  case AMDGPU::BI__builtin_amdgcn_swmmac_i32_16x16x32_iu8_w32:
+  case AMDGPU::BI__builtin_amdgcn_swmmac_i32_16x16x32_iu8_w64:
+  case AMDGPU::BI__builtin_amdgcn_swmmac_i32_16x16x32_iu4_w32:
+  case AMDGPU::BI__builtin_amdgcn_swmmac_i32_16x16x32_iu4_w64:
+  case AMDGPU::BI__builtin_amdgcn_swmmac_i32_16x16x64_iu4_w32:
+  case AMDGPU::BI__builtin_amdgcn_swmmac_i32_16x16x64_iu4_w64:
+  case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_fp8_fp8_w32:
+  case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_fp8_fp8_w64:
+  case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_fp8_bf8_w32:
+  case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_fp8_bf8_w64:
+  case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_bf8_fp8_w32:
+  case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_bf8_fp8_w64:
+  case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_bf8_bf8_w32:
+  case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_bf8_bf8_w64: {
 
     // These operations perform a matrix multiplication and accumulation of
     // the form:
     //             D = A * B + C
-    // The return type always matches the type of matrix C.
-    unsigned ArgForMatchingRetType;
+    // We need to specify one type for matrices AB and one for matrices CD.
+    SmallVector<unsigned, 2> ArgsForMatchingMatrixTypes;
+    // Some intrinsics expect "false" as an extra bool argument.
+    bool AppendExtraBoolArg = false;
     unsigned BuiltinWMMAOp;
 
     switch (BuiltinID) {
     case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_f16_w32:
     case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_f16_w64:
-      ArgForMatchingRetType = 2;
+    case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12:
+    case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_f16_w64_gfx12:
+      ArgsForMatchingMatrixTypes = {0, 2};
       BuiltinWMMAOp = Intrinsic::amdgcn_wmma_f32_16x16x16_f16;
       break;
     case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf16_w32:
     case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf16_w64:
-      ArgForMatchingRetType = 2;
+    case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12:
+    case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf16_w64_gfx12:
+      ArgsForMatchingMatrixTypes = {0, 2};
       BuiltinWMMAOp = Intrinsic::amdgcn_wmma_f32_16x16x16_bf16;
       break;
+    case AMDGPU::BI__builtin_amdgcn_wmma_f16_16x16x16_f16_w32_gfx12:
+    case AMDGPU::BI__builtin_amdgcn_wmma_f16_16x16x16_f16_w64_gfx12:
+      AppendExtraBoolArg = true;
+      LLVM_FALLTHROUGH;
     case AMDGPU::BI__builtin_amdgcn_wmma_f16_16x16x16_f16_w32:
     case AMDGPU::BI__builtin_amdgcn_wmma_f16_16x16x16_f16_w64:
-      ArgForMatchingRetType = 2;
+      ArgsForMatchingMatrixTypes = {0, 2};
       BuiltinWMMAOp = Intrinsic::amdgcn_wmma_f16_16x16x16_f16;
       break;
+    case AMDGPU::BI__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w32_gfx12:
+    case AMDGPU::BI__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w64_gfx12:
+      AppendExtraBoolArg = true;
+      LLVM_FALLTHROUGH;
     case AMDGPU::BI__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w32:
     case AMDGPU::BI__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w64:
-      ArgForMatchingRetType = 2;
+      ArgsForMatchingMatrixTypes = {0, 2};
       BuiltinWMMAOp = Intrinsic::amdgcn_wmma_bf16_16x16x16_bf16;
       break;
     case AMDGPU::BI__builtin_amdgcn_wmma_f16_16x16x16_f16_tied_w32:
     case AMDGPU::BI__builtin_amdgcn_wmma_f16_16x16x16_f16_tied_w64:
-      ArgForMatchingRetType = 2;
+      ArgsForMatchingMatrixTypes = {0, 2};
       BuiltinWMMAOp = Intrinsic::amdgcn_wmma_f16_16x16x16_f16_tied;
       break;
     case AMDGPU::BI__builtin_amdgcn_wmma_bf16_16x16x16_bf16_tied_w32:
     case AMDGPU::BI__builtin_amdgcn_wmma_bf16_16x16x16_bf16_tied_w64:
-      ArgForMatchingRetType = 2;
+      ArgsForMatchingMatrixTypes = {0, 2};
       BuiltinWMMAOp = Intrinsic::amdgcn_wmma_bf16_16x16x16_bf16_tied;
       break;
     case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu8_w32:
     case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu8_w64:
-      ArgForMatchingRetType = 4;
+    case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12:
+    case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu8_w64_gfx12:
+      ArgsForMatchingMatrixTypes = {1, 4};
       BuiltinWMMAOp = Intrinsic::amdgcn_wmma_i32_16x16x16_iu8;
       break;
     case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu4_w32:
     case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu4_w64:
-      ArgForMatchingRetType = 4;
+    case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu4_w32_gfx12:
+    case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu4_w64_gfx12:
+      ArgsForMatchingMatrixTypes = {1, 4};
       BuiltinWMMAOp = Intrinsic::amdgcn_wmma_i32_16x16x16_iu4;
       break;
+    case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_fp8_fp8_w32_gfx12:
+    case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_fp8_fp8_w64_gfx12:
+      ArgsForMatchingMatrixTypes = {0, 2};
+      BuiltinWMMAOp = Intrinsic::amdgcn_wmma_f32_16x16x16_fp8_fp8;
+      break;
+    case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_fp8_bf8_w32_gfx12:
+    case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_fp8_bf8_w64_gfx12:
+      ArgsForMatchingMatrixTypes = {0, 2};
+      BuiltinWMMAOp = Intrinsic::amdgcn_wmma_f32_16x16x16_fp8_bf8;
+      break;
+    case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf8_fp8_w32_gfx12:
+    case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf8_fp8_w64_gfx12:
+      ArgsForMatchingMatrixTypes = {0, 2};
+      BuiltinWMMAOp = Intrinsic::amdgcn_wmma_f32_16x16x16_bf8_fp8;
+      break;
+    case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf8_bf8_w32_gfx12:
+    case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf8_bf8_w64_gfx12:
+      ArgsForMatchingMatrixTypes = {0, 2};
+      BuiltinWMMAOp = Intrinsic::amdgcn_wmma_f32_16x16x16_bf8_bf8;
+      break;
+    case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x32_iu4_w32_gfx12:
+    case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x32_iu4_w64_gfx12:
+      ArgsForMatchingMatrixTypes = {1, 4};
+      BuiltinWMMAOp = Intrinsic::amdgcn_wmma_i32_16x16x32_iu4;
+      break;
+    case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_f16_w32:
+    case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_f16_w64:
+      ArgsForMatchingMatrixTypes = {0, 1, 2, 3};
+      BuiltinWMMAOp = Intrinsic::amdgcn_swmmac_f32_16x16x32_f16;
+      break;
+    case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_bf16_w32:
+    case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_bf16_w64:
+      ArgsForMatchingMatrixTypes = {0, 1, 2, 3};
+      BuiltinWMMAOp = Intrinsic::amdgcn_swmmac_f32_16x16x32_bf16;
+      break;
+    case AMDGPU::BI__builtin_amdgcn_swmmac_f16_16x16x32_f16_w32:
+    case AMDGPU::BI__builtin_amdgcn_swmmac_f16_16x16x32_f16_w64:
+      ArgsForMatchingMatrixTypes = {0, 1, 2, 3};
+      BuiltinWMMAOp = Intrinsic::amdgcn_swmmac_f16_16x16x32_f16;
+      break;
+    case AMDGPU::BI__builtin_amdgcn_swmmac_bf16_16x16x32_bf16_w32:
+    case AMDGPU::BI__builtin_amdgcn_swmmac_bf16_16x16x32_bf16_w64:
+      ArgsForMatchingMatrixTypes = {0, 1, 2, 3};
+      BuiltinWMMAOp = Intrinsic::amdgcn_swmmac_bf16_16x16x32_bf16;
+      break;
+    case AMDGPU::BI__builtin_amdgcn_swmmac_i32_16x16x32_iu8_w32:
+    case AMDGPU::BI__builtin_amdgcn_swmmac_i32_16x16x32_iu8_w64:
+      ArgsForMatchingMatrixTypes = {1, 3, 4, 5};
+      BuiltinWMMAOp = Intrinsic::amdgcn_swmmac_i32_16x16x32_iu8;
+      break;
+    case AMDGPU::BI__builtin_amdgcn_swmmac_i32_16x16x32_iu4_w32:
+    case AMDGPU::BI__builtin_amdgcn_swmmac_i32_16x16x32_iu4_w64:
+      ArgsForMatchingMatrixTypes = {1, 3, 4, 5};
+      BuiltinWMMAOp = Intrinsic::amdgcn_swmmac_i32_16x16x32_iu4;
+      break;
+    case AMDGPU::BI__builtin_amdgcn_swmmac_i32_16x16x64_iu4_w32:
+    case AMDGPU::BI__builtin_amdgcn_swmmac_i32_16x16x64_iu4_w64:
+      ArgsForMatchingMatrixTypes = {1, 3, 4, 5};
+      BuiltinWMMAOp = Intrinsic::amdgcn_swmmac_i32_16x16x64_iu4;
+      break;
+    case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_fp8_fp8_w32:
+    case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_fp8_fp8_w64:
+      ArgsForMatchingMatrixTypes = {0, 1, 2, 3};
+      BuiltinWMMAOp = Intrinsic::amdgcn_swmmac_f32_16x16x32_fp8_fp8;
+      break;
+    case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_fp8_bf8_w32:
+    case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_fp8_bf8_w64:
+      ArgsForMatchingMatrixTypes = {0, 1, 2, 3};
+      BuiltinWMMAOp = Intrinsic::amdgcn_swmmac_f32_16x16x32_fp8_bf8;
+      break;
+    case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_bf8_fp8_w32:
+    case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_bf8_fp8_w64:
+      ArgsForMatchingMatrixTypes = {0, 1, 2, 3};
+      BuiltinWMMAOp = Intrinsic::amdgcn_swmmac_f32_16x16x32_bf8_fp8;
+      break;
+    case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_bf8_bf8_w32:
+    case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_bf8_bf8_w64:
+      ArgsForMatchingMatrixTypes = {0, 1, 2, 3};
+      BuiltinWMMAOp = Intrinsic::amdgcn_swmmac_f32_16x16x32_bf8_bf8;
+      break;
     }
 
     SmallVector<Value *, 6> Args;
     for (int i = 0, e = E->getNumArgs(); i != e; ++i)
       Args.push_back(EmitScalarExpr(E->getArg(i)));
+    if (AppendExtraBoolArg)
+      Args.push_back(Builder.getFalse());
 
-    Function *F = CGM.getIntrinsic(BuiltinWMMAOp,
-                                   {Args[ArgForMatchingRetType]->getType()});
+    SmallVector<llvm::Type *, 6> ArgTypes;
+    for (auto ArgIdx : ArgsForMatchingMatrixTypes)
+      ArgTypes.push_back(Args[ArgIdx]->getType());
 
+    Function *F = CGM.getIntrinsic(BuiltinWMMAOp, ArgTypes);
     return Builder.CreateCall(F, Args);
   }
 
diff --git a/clang/test/CodeGenOpenCL/builtins-amdgcn-gfx12-wmma-w32.cl b/clang/test/CodeGenOpenCL/builtins-amdgcn-gfx12-wmma-w32.cl
new file mode 100644
index 00000000000000..a5d8bb34a7842d
--- /dev/null
+++ b/clang/test/CodeGenOpenCL/builtins-amdgcn-gfx12-wmma-w32.cl
@@ -0,0 +1,156 @@
+// NOTE: Assertions have been autogenerated by utils/update_cc_test_checks.py
+// REQUIRES: amdgpu-registered-target
+// RUN: %clang_cc1 -triple amdgcn-unknown-unknown -target-cpu gfx1200 -target-feature +wavefrontsize32 -S -emit-llvm -o - %s | FileCheck %s --check-prefix=CHECK-GFX1200
+
+typedef int    v2i   __attribute__((ext_vector_type(2)));
+typedef float  v8f   __attribute__((ext_vector_type(8)));
+typedef half   v8h   __attribute__((ext_vector_type(8)));
+typedef short  v8s   __attribute__((ext_vector_type(8)));
+typedef int    v8i   __attribute__((ext_vector_type(8)));
+
+// Wave32
+
+//
+// amdgcn_wmma_f32_16x16x16_f16
+//
+
+// CHECK-GFX1200-LABEL: @test_amdgcn_wmma_f32_16x16x16_f16_w32(
+// CHECK-GFX1200-NEXT:  entry:
+// CHECK-GFX1200-NEXT:    [[TMP0:%.*]] = tail call <8 x float> @llvm.amdgcn.wmma.f32.16x16x16.f16.v8f16.v8f32(<8 x half> [[A:%.*]], <8 x half> [[B:%.*]], <8 x float> [[C:%.*]])
+// CHECK-GFX1200-NEXT:    store <8 x float> [[TMP0]], ptr addrspace(1) [[OUT:%.*]], align 32, !tbaa [[TBAA4:![0-9]+]]
+// CHECK-GFX1200-NEXT:    ret void
+//
+void test_amdgcn_wmma_f32_16x16x16_f16_w32(global v8f* out, v8h a, v8h b, v8f c)
+{
+  *out = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12(a, b, c);
+}
+
+//
+// amdgcn_wmma_f...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/77795


More information about the cfe-commits mailing list