[Mlir-commits] [mlir] [mlir][AMDGPU] Add int4 intrinsics, mixed-type fp8 to handle gfx12 (PR #128963)

Daniel Hernandez-Juarez llvmlistbot at llvm.org
Thu Feb 27 09:05:14 PST 2025


================
@@ -631,10 +649,33 @@ static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
     return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName();
   if (elemSourceType.isInteger(8) && elemDestType.isInteger(32))
     return ROCDL::wmma_i32_16x16x16_iu8::getOperationName();
-  if (isa<Float8E4M3FNType>(elemSourceType) && elemDestType.isF32())
-    return ROCDL::wmma_f32_16x16x16_fp8::getOperationName();
-  if (isa<Float8E5M2Type>(elemSourceType) && elemDestType.isF32())
-    return ROCDL::wmma_f32_16x16x16_bf8::getOperationName();
+  if (chipset.majorVersion == 11) {
+    if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
+      return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
+  }
+  if (chipset.majorVersion >= 12) {
+    if (isa<Float8E4M3FNType>(elemSourceType) &&
+        isa<Float8E4M3FNType>(elemBSourceType) && elemDestType.isF32())
+      return ROCDL::wmma_f32_16x16x16_fp8_fp8::getOperationName();
+    if (isa<Float8E4M3FNType>(elemSourceType) &&
+        isa<Float8E5M2Type>(elemBSourceType) && elemDestType.isF32())
+      return ROCDL::wmma_f32_16x16x16_fp8_bf8::getOperationName();
+    if (isa<Float8E5M2Type>(elemSourceType) &&
+        isa<Float8E5M2Type>(elemBSourceType) && elemDestType.isF32())
+      return ROCDL::wmma_f32_16x16x16_bf8_bf8::getOperationName();
+    if (isa<Float8E5M2Type>(elemSourceType) &&
+        isa<Float8E4M3FNType>(elemBSourceType) && elemDestType.isF32())
+      return ROCDL::wmma_f32_16x16x16_bf8_fp8::getOperationName();
+    if (elemSourceType.isInteger(4) && elemDestType.isInteger(32)) {
+      bool isWave64 = destVectorType.getNumElements() == 4;
+      // This is the ambiguous case. 8 inputs to the wave64 version means that
+      // we want the 16x16x32 version, but for wave32 they mean the short form.
+      bool has8Inputs = sourceVectorType.getNumElements() == 8;
+      if ((isWave64 && has8Inputs) || (!isWave64 && !has8Inputs))
----------------
dhernandez0 wrote:

nit: if(isWave64 == has8Inputs)

https://github.com/llvm/llvm-project/pull/128963


More information about the Mlir-commits mailing list