[llvm] [NFC][AMDGPU] Refactor the multiclass for WMMA_F8F6F4 instructions (PR #172245)

via llvm-commits llvm-commits at lists.llvm.org
Sun Dec 14 19:53:01 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-amdgpu

Author: Shilei Tian (shiltian)

<details>
<summary>Changes</summary>



---
Full diff: https://github.com/llvm/llvm-project/pull/172245.diff


1 Files Affected:

- (modified) llvm/lib/Target/AMDGPU/VOP3PInstructions.td (+36-15) 


``````````diff
diff --git a/llvm/lib/Target/AMDGPU/VOP3PInstructions.td b/llvm/lib/Target/AMDGPU/VOP3PInstructions.td
index 2dfa905848a34..410e56d83331b 100644
--- a/llvm/lib/Target/AMDGPU/VOP3PInstructions.td
+++ b/llvm/lib/Target/AMDGPU/VOP3PInstructions.td
@@ -1814,21 +1814,42 @@ def F32_FP8BF8X128_SWMMAC_w32    : VOP3PWMMA_Profile<[v8f32, v8i32,  v16i32, v8f
 def F16_FP8BF8X128_SWMMAC_w32    : VOP3PWMMA_Profile<[v8f16, v8i32,  v16i32, v8f16], 1, 32, 0, 1, 1, 0, 0, 0, 1>;
 def I32_IU8X128_SWMMAC_w32       : VOP3PWMMA_Profile<[v8i32, v8i32,  v16i32, v8i32], 1, 32, 1, 0, 1, 0, 0, 0, 1>;
 
-multiclass WMMA_F8F6F4_Profiles<bit HasMatrixScale, bit Scale16, bit HasMatrixReuse> {
-  def _f8_f8_w32 : VOP3PWMMA_Profile<[v8f32, v16i32, v16i32, v8f32], 0, 0, 0, 1, 1, 1, HasMatrixScale, Scale16, HasMatrixReuse>;
-  def _f8_f6_w32 : VOP3PWMMA_Profile<[v8f32, v16i32, v12i32, v8f32], 0, 0, 0, 1, 1, 1, HasMatrixScale, Scale16, HasMatrixReuse>;
-  def _f8_f4_w32 : VOP3PWMMA_Profile<[v8f32, v16i32, v8i32,  v8f32], 0, 0, 0, 1, 1, 1, HasMatrixScale, Scale16, HasMatrixReuse>;
-  def _f6_f8_w32 : VOP3PWMMA_Profile<[v8f32, v12i32, v16i32, v8f32], 0, 0, 0, 1, 1, 1, HasMatrixScale, Scale16, HasMatrixReuse>;
-  def _f6_f6_w32 : VOP3PWMMA_Profile<[v8f32, v12i32, v12i32, v8f32], 0, 0, 0, 1, 1, 1, HasMatrixScale, Scale16, HasMatrixReuse>;
-  def _f6_f4_w32 : VOP3PWMMA_Profile<[v8f32, v12i32, v8i32,  v8f32], 0, 0, 0, 1, 1, 1, HasMatrixScale, Scale16, HasMatrixReuse>;
-  def _f4_f8_w32 : VOP3PWMMA_Profile<[v8f32, v8i32,  v16i32, v8f32], 0, 0, 0, 1, 1, 1, HasMatrixScale, Scale16, HasMatrixReuse>;
-  def _f4_f6_w32 : VOP3PWMMA_Profile<[v8f32, v8i32,  v12i32, v8f32], 0, 0, 0, 1, 1, 1, HasMatrixScale, Scale16, HasMatrixReuse>;
-  def _f4_f4_w32 : VOP3PWMMA_Profile<[v8f32, v8i32,  v8i32,  v8f32], 0, 0, 0, 1, 1, 1, HasMatrixScale, Scale16, HasMatrixReuse>;
-}
-
-defm F32_16X16X128_F8F6F4         : WMMA_F8F6F4_Profiles<0, 0, 0>;
-defm F32_16X16X128_F8F6F4_SCALE   : WMMA_F8F6F4_Profiles<1, 0, 1>;
-defm F32_16X16X128_F8F6F4_SCALE16 : WMMA_F8F6F4_Profiles<1, 1, 1>;
+// Helper class to compute the destination vector type of WMMA_F8F6F4 instructions based on element type and dimensions.
+class getWMMAF8F6F4DstVTy<ValueType DstEltTy, int M, int N> {
+  // Size in bits = (M * N / 32) * element_size_in_bits
+  defvar Size = !mul(!div(!mul(M, N), 32), DstEltTy.Size);
+  ValueType ret = !cond(!eq(Size, 256)  : v8f32,
+                        !eq(Size, 1024) : v64f16);
+}
+
+// Helper class to compute the type of matrix A and B of WMMA_F8F6F4 instructions based on format and dimensions.
+class getWMMAF8F6F4ABVTy<string Fmt, int D1, int D2> {
+  defvar FmtBits = !cond(!eq(Fmt, "f8") : 8,
+                         !eq(Fmt, "f6") : 6,
+                         !eq(Fmt, "f4") : 4);
+  // TypeSize in bits = (D1 * D2 / 32) * format_bits
+  defvar TypeSize = !mul(!div(!mul(D1, D2), 32), FmtBits);
+  ValueType ret = !cond(!eq(TypeSize, 256)  : v8i32,
+                        !eq(TypeSize, 384)  : v12i32,
+                        !eq(TypeSize, 512)  : v16i32,
+                        !eq(TypeSize, 1024) : v32i32);
+}
+
+multiclass WMMA_F8F6F4_Profiles<ValueType DstEltTy, int M, int N, int K,
+                                bit HasMatrixScale, bit Scale16, bit HasMatrixReuse> {
+  defvar DstTy = getWMMAF8F6F4DstVTy<DstEltTy, M, N>.ret;
+  foreach ATy = ["f8", "f6", "f4"] in {
+    foreach BTy = ["f8", "f6", "f4"] in {
+      def _#ATy#_#BTy#_w32 : VOP3PWMMA_Profile<
+        [DstTy, getWMMAF8F6F4ABVTy<ATy, M, K>.ret, getWMMAF8F6F4ABVTy<BTy, K, N>.ret, DstTy],
+        0, 0, 0, 1, 1, 1, HasMatrixScale, Scale16, HasMatrixReuse>;
+    }
+  }
+}
+
+defm F32_16X16X128_F8F6F4         : WMMA_F8F6F4_Profiles<f32, /*M=*/16, /*N=*/16, /*K=*/128, /*HasMatrixScale=*/0, /*Scale16=*/0, /*HasMatrixReuse=*/0>;
+defm F32_16X16X128_F8F6F4_SCALE   : WMMA_F8F6F4_Profiles<f32, /*M=*/16, /*N=*/16, /*K=*/128, /*HasMatrixScale=*/1, /*Scale16=*/0, /*HasMatrixReuse=*/1>;
+defm F32_16X16X128_F8F6F4_SCALE16 : WMMA_F8F6F4_Profiles<f32, /*M=*/16, /*N=*/16, /*K=*/128, /*HasMatrixScale=*/1, /*Scale16=*/1, /*HasMatrixReuse=*/1>;
 
 class VOP_WMMA_LD_SCALE<ValueType vt, RegisterOperand RC> : VOP3P_Profile<VOPProfile<[untyped, vt, vt, untyped]>> {
   let HasMatrixScale = 1;

``````````

</details>


https://github.com/llvm/llvm-project/pull/172245


More information about the llvm-commits mailing list