[Mlir-commits] [mlir] [mlir][AMDGPU] Add scaled wmma ops for gfx1250 (PR #169854)

Muzammiluddin Syed llvmlistbot at llvm.org
Fri Nov 28 12:17:23 PST 2025


================
@@ -442,6 +442,111 @@ LogicalResult WMMAOp::verify() {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// ScaledWMMAOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult ScaledWMMAOp::verify() {
+  auto sourceAType = cast<VectorType>(getSourceA().getType());
+  auto sourceBType = cast<VectorType>(getSourceB().getType());
+  auto destType = cast<VectorType>(getDestC().getType());
+
+  // Validate output type is F32
+  if (!destType.getElementType().isF32())
+    return emitOpError("destination must have f32 element type");
+
+  // Validate source element types are small floats (fp4/fp6/fp8)
+  Type aElemType = sourceAType.getElementType();
+  Type bElemType = sourceBType.getElementType();
+
+  bool aIsSmallFloat =
+      aElemType.isFloat(4) || aElemType.isFloat(6) || aElemType.isFloat(8);
+  bool bIsSmallFloat =
+      bElemType.isFloat(4) || bElemType.isFloat(6) || bElemType.isFloat(8);
+
+  if (!aIsSmallFloat || !bIsSmallFloat)
+    return emitOpError("source operands must have small float element types "
+                       "(fp4/fp6/fp8)");
+
+  // Validate vector lengths based on dimensions
+  int64_t m = getM();
+  int64_t aLen = sourceAType.getNumElements();
+  int64_t bLen = sourceBType.getNumElements();
+  int64_t expectedOutLen = (m == 16) ? 8 : 16;
+
+  if (destType.getNumElements() != expectedOutLen)
+    return emitOpError("expected output vector of length " +
+                       Twine(expectedOutLen) + " but got " +
+                       Twine(destType.getNumElements()));
+
+  if (m == 16) {
+    // For 16×16×128: both A and B must be 64 elements
+    if (aLen != 64)
+      return emitOpError(
+          "for 16x16x128, sourceA must have 64 elements but got " +
+          Twine(aLen));
+    if (bLen != 64)
+      return emitOpError(
+          "for 16x16x128, sourceB must have 64 elements but got " +
+          Twine(bLen));
+  } else { // m == 32
+    // For 32×16×128: only fp4 is supported, A is 128, B is 64
+    if (!aElemType.isFloat(4))
+      return emitOpError("32x16x128 only supports fp4 element types");
+
+    if (aLen != 128)
+      return emitOpError(
+          "for 32x16x128, sourceA must have 128 elements but got " +
+          Twine(aLen));
+    if (bLen != 64)
+      return emitOpError(
+          "for 32x16x128, sourceB must have 64 elements but got " +
+          Twine(bLen));
+  }
+
+  // Validate scale types and their compatibility with matrix element types
+  auto scaleAType = cast<VectorType>(getScaleA().getType());
+  auto scaleBType = cast<VectorType>(getScaleB().getType());
+  Type scaleAElemType = scaleAType.getElementType();
+  Type scaleBElemType = scaleBType.getElementType();
+
+  // Validate scale element types are valid f8 types
+  if (!scaleAElemType.isFloat(8) || !scaleBElemType.isFloat(8))
+    return emitOpError("scale operands must have f8 element types");
+
+  // Helper functions for scale type classification
+  auto isE8M0 = [](Type t) { return isa<Float8E8M0FNUType>(t); };
+  auto isE4M3 = [](Type t) {
+    return isa<Float8E4M3FNType, Float8E4M3FNUZType>(t);
+  };
+
+  bool aIsF4 = aElemType.isFloat(4);
+  bool bIsF4 = bElemType.isFloat(4);
+  bool aIsF8F6 = aElemType.isFloat(8) || aElemType.isFloat(6);
+  bool bIsF8F6 = bElemType.isFloat(8) || bElemType.isFloat(6);
+
+  // Any matrices A/B (fp8|fp6|fp4) with E8M0 scales for matrix A/B are valid
+  if (isE8M0(scaleAElemType) && isE8M0(scaleBElemType))
+    return success();
+
+  // Matrix A (F8|F6) x Matrix B (F4) with Scale A (E8M0), Scale B (E5M2|E4M3)
+  if (aIsF8F6 && isE8M0(scaleAElemType) && bIsF4 && (isE4M3(scaleBElemType)))
+    return success();
----------------
Muzammiluddin-Syed-ECE wrote:

Is there a reason we're not explicitly checking for B scale type E5M3 in this combination of legal A and B matrix data?

https://github.com/llvm/llvm-project/pull/169854


More information about the Mlir-commits mailing list