[Mlir-commits] [mlir] [MLIR][NVVM] Explicit Data Type for Output in `wgmma.mma_async` (PR #78713)

Durgadoss R llvmlistbot at llvm.org
Fri Jan 19 10:15:34 PST 2024


================
@@ -821,27 +833,29 @@ LogicalResult NVVM::WgmmaMmaAsyncOp::verify() {
   auto stype = dyn_cast<LLVM::LLVMStructType>(outValue.getType());
   if (!stype)
     return emitOpError() << "expected results to be struct";
-  Type outputType = stype.getBody().front();
   int outputSize = stype.getBody().size();
+  WGMMATypes typeD = getTypeD();
+  WGMMATypes typeA = getTypeA();
+  WGMMATypes typeB = getTypeB();
+
   for (Type t : stype.getBody()) {
-    if (t != outputType)
+    if (t != stype.getBody().front())
       return emitOpError()
              << "all elements in struct must be same type but there is " << t;
   }
 
-  if (!outputType.isF32() && !outputType.isInteger(32) && !outputType.isF16()) {
+  if (typeD != WGMMATypes::f32 && typeD != WGMMATypes::f16 &&
+      typeD != WGMMATypes::s32) {
     return emitOpError() << "does not support the given output type "
-                         << outputType;
+                         << NVVM::stringifyWGMMATypes(typeD);
   }
-  if (outputType.isInteger(32) && (getScaleA() == NVVM::WGMMAScaleIn::neg ||
-                                   getScaleB() == NVVM::WGMMAScaleIn::neg)) {
-    return emitOpError() << "has s32 output, scaleA and scaleB cannot be neg";
+  if (typeD == WGMMATypes::s32 &&
+      (getScaleA() == WGMMAScaleIn::neg || getScaleB() == WGMMAScaleIn::neg)) {
+    return emitOpError() << "has s32 output, scaleA and scaleB cannot be neg ";
----------------
durga4github wrote:

nit: the space at the end (after neg) seems unintended.

https://github.com/llvm/llvm-project/pull/78713


More information about the Mlir-commits mailing list