[Mlir-commits] [mlir] [mlir][AMDGPU] Add scaled wmma ops for gfx1250 (PR #169854)
Jakub Kuderski
llvmlistbot at llvm.org
Sun Nov 30 06:26:28 PST 2025
================
@@ -442,6 +442,111 @@ LogicalResult WMMAOp::verify() {
return success();
}
+//===----------------------------------------------------------------------===//
+// ScaledWMMAOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult ScaledWMMAOp::verify() {
+ auto sourceAType = cast<VectorType>(getSourceA().getType());
+ auto sourceBType = cast<VectorType>(getSourceB().getType());
+ auto destType = cast<VectorType>(getDestC().getType());
+
+ // Validate output type is F32
+ if (!destType.getElementType().isF32())
+ return emitOpError("destination must have f32 element type");
+
+ // Validate source element types are small floats (fp4/fp6/fp8)
+ Type aElemType = sourceAType.getElementType();
+ Type bElemType = sourceBType.getElementType();
+
+ bool aIsSmallFloat =
+ aElemType.isFloat(4) || aElemType.isFloat(6) || aElemType.isFloat(8);
+ bool bIsSmallFloat =
+ bElemType.isFloat(4) || bElemType.isFloat(6) || bElemType.isFloat(8);
+
+ if (!aIsSmallFloat || !bIsSmallFloat)
+ return emitOpError("source operands must have small float element types "
+ "(fp4/fp6/fp8)");
+
+ // Validate vector lengths based on dimensions
+ int64_t m = getM();
+ int64_t aLen = sourceAType.getNumElements();
+ int64_t bLen = sourceBType.getNumElements();
+ int64_t expectedOutLen = (m == 16) ? 8 : 16;
+
+ if (destType.getNumElements() != expectedOutLen)
+ return emitOpError("expected output vector of length " +
+ Twine(expectedOutLen) + " but got " +
+ Twine(destType.getNumElements()));
+
+ if (m == 16) {
+ // For 16×16×128: both A and B must be 64 elements
----------------
kuhar wrote:
```suggestion
// For 16×16×128: both A and B must be 64 elements.
```
https://github.com/llvm/llvm-project/pull/169854
More information about the Mlir-commits
mailing list