[Mlir-commits] [mlir] [mlir][amdgpu][rocdl] Add gfx1250 wmma ops (PR #165064)

Krzysztof Drewniak llvmlistbot at llvm.org
Fri Oct 24 21:51:20 PDT 2025


================
@@ -1019,34 +1024,124 @@ static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
         return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
     }
   }
-  if (chipset.majorVersion < 12)
+  if (isRDNA3)
     return std::nullopt;
 
+  using fp8 = Float8E4M3FNType;
+  using bf8 = Float8E5M2Type;
+
   // gfx12+
   if (k == 16) {
-    if (isa<Float8E4M3FNType>(elemSourceType) &&
-        isa<Float8E4M3FNType>(elemBSourceType) && elemDestType.isF32())
+    if (!isRDNA4) // gfx1250 does not have any wmma ops with k=16.
+      return std::nullopt;
+
+    if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType) &&
+        elemDestType.isF32())
       return ROCDL::wmma_f32_16x16x16_fp8_fp8::getOperationName();
-    if (isa<Float8E4M3FNType>(elemSourceType) &&
-        isa<Float8E5M2Type>(elemBSourceType) && elemDestType.isF32())
+    if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType) &&
+        elemDestType.isF32())
       return ROCDL::wmma_f32_16x16x16_fp8_bf8::getOperationName();
-    if (isa<Float8E5M2Type>(elemSourceType) &&
-        isa<Float8E5M2Type>(elemBSourceType) && elemDestType.isF32())
+    if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType) &&
+        elemDestType.isF32())
       return ROCDL::wmma_f32_16x16x16_bf8_bf8::getOperationName();
-    if (isa<Float8E5M2Type>(elemSourceType) &&
-        isa<Float8E4M3FNType>(elemBSourceType) && elemDestType.isF32())
+    if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType) &&
+        elemDestType.isF32())
       return ROCDL::wmma_f32_16x16x16_bf8_fp8::getOperationName();
+
     if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
       return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
 
     return std::nullopt;
   }
   if (k == 32) {
-    if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
-      return ROCDL::wmma_i32_16x16x32_iu4::getOperationName();
+    if (isRDNA4) {
+      if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
+        return ROCDL::wmma_i32_16x16x32_iu4::getOperationName();
+      return std::nullopt;
+    }
+
+    // gfx1250
+    if (elemSourceType.isF16() && elemDestType.isF32())
----------------
krzysz00 wrote:

We might want another if statement here so we're not doing gfx1250 as a default case, unless I'm reading this code wrong?

https://github.com/llvm/llvm-project/pull/165064


More information about the Mlir-commits mailing list