[Mlir-commits] [mlir] [MLIR][NVGPU] Adding `nvgpu.warpgroup.mma` Op for Hopper GPUs (PR #65440)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Sep 13 07:06:27 PDT 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-nvgpu
            
<details>
<summary>Changes</summary>
This work introduces a new operation called `warpgroup.mma` to the NVGPU dialect of MLIR. The purpose of this operation is to facilitate warpgroup-level matrix multiply and accumulate (WGMMA) operations on Hopper GPUs with sm_90a architecture.

Previously, the `nvvm.wgmma.mma_async` operation was introduced to support warpgroup-level matrix operations in NVVM dialect. This op is used multiple instances of `nvvm.wgmma.mma_async` to achieve the desired shape. The new `nvgpu.warpgroup.mma` operation abstracts this complexity and provides a higher-level interface for performing warpgroup-level matrix operations.

The `nvgpu.warpgroup.mma` does followings:
1) Corresponds multiple `wgmma` instructions.
2) Iterates input matrix descriptors to achieve the desired computation shape. 3) Groups and runs `wgmma` instructions asynchronously, and eventually waits them. This are done by `wgmma.fence.aligned`, `wgmma.commit.group.sync.aligned`, and `wgmma.wait.group.sync.aligned` 4) Results fragmented matrices

Here's an example usage of the `nvgpu.warpgroup.mma` operation:
```
%wgmmaResult, %wgmmaResult2 = nvgpu.warpgroup.mma %descA, %descB, %acc, group = 1 {transposeB}:
      !nvgpu.wgmma.descriptor<tensor = memref<128x64xf16, 3>>,
      !nvgpu.wgmma.descriptor<tensor = memref<64x128xf16, 3>>,
      vector<128x128xf32>
      -> !nvgpu.warpgroup.result<tensor = !llvm.struct<...>,
         !nvgpu.warpgroup.result<tensor = !llvm.struct<...>>
```


--

Patch is 36.34 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/65440.diff

7 Files Affected:

- (modified) mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td (+56) 
- (modified) mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h (+2) 
- (modified) mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp (+163-3) 
- (modified) mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp (+131-4) 
- (modified) mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp (+15) 
- (modified) mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir (+61-1) 
- (modified) mlir/test/Dialect/NVGPU/invalid.mlir (+44) 


<pre>
diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
index a3245bf9196eed1..90381648dac6acc 100644
--- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
+++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
@@ -192,6 +192,19 @@ def NVGPU_WarpgroupMatrixDescriptor : NVGPU_Type&lt;&quot;WarpgroupMatrixDescriptor&quot;, &quot;w
   let assemblyFormat = &quot;`&lt;` struct(params) `&gt;`&quot;;
 }
 
+def NVGPU_WarpgroupAccumulator : NVGPU_Type&lt;&quot;WarpgroupAccumulator&quot;, &quot;warpgroup.accumulator&quot;, []&gt; {
+  let parameters = (ins &quot;VectorType&quot;:$fragmented);
+  let assemblyFormat = &quot;`&lt;` struct(params) `&gt;`&quot;;
+  let description = [{
+    This type represents the result matrix obtained from `nvgpu.warpgroup.mma`. 
+    The `$fragmented` type signifies the distributed or fragmented result 
+    vector that is collectively owned by all the threads in the warp-group 
+    that executed `nvgpu.warpgroup.mma`.
+    [See the details of register fragment layout for accumulator matrix D]
+    (https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#wgmma-64n16-d) 
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // NVGPU Op Definitions
 //===----------------------------------------------------------------------===//
@@ -664,5 +677,48 @@ def NVGPU_GenerateGmmaDescriptorOp : NVGPU_Op&lt;&quot;wgmma.generate.descriptor&quot;, []&gt; {
   let hasVerifier = 1;
 }
 
+def NVGPU_WarpgroupMmaOp : NVGPU_Op&lt;&quot;warpgroup.mma&quot;&gt; {
+  let description = [{
+    The `nvgpu.warpgroup.mma` op performs the warpgroup-level (4 warps) 
+    matrix-multiply-and-accumulate (mma) operation that results in 
+    `nvvm.wgmma.mma_async`. 
+    
+    The operands are `descriptorA` and `descriptorB` that are wgmma matrix 
+    descriptors that shows the properties of the matrix in shared memory. The 
+    results are thread-level ownership to the warpgroup-level mma operation 
+    shape. The shape is deduced from the descriptor types and output vector.
+
+    The Op corresponds multiple `nvvm.wgmma.mma_async` operations to complete the 
+    given shape. As the instruction `nvvm.wgmma.async` is an asynchronous, 
+    this Op groups the `nvvm.wgmma.async` and surrounds them between 
+    `wgmma.fence.aligned` and `wgmma.commit.group.sync.aligned`, 
+    `wgmma.wait.group.sync.aligned` Ops.
+
+    Example:
+    ```mlir
+      %r1,%r2 = nvgpu.warpgroup.mma %wgmmaDescA, %wgmmaDescB, %acc1, %acc2: 
+                 !nvgpu.wgmma.descriptor&lt;tensor = memref&lt;128x64xf16, 3&gt;&gt;, 
+                 !nvgpu.wgmma.descriptor&lt;tensor = memref&lt;64x128xf16, 3&gt;&gt;, 
+                 !nvgpu.warpgroup.accumulator&lt;fragmented = vector&lt;64x128xf32&gt;&gt;,
+                 !nvgpu.warpgroup.accumulator&lt;fragmented = vector&lt;64x128xf32&gt;&gt;
+                 -&gt; 
+                 !nvgpu.warpgroup.accumulator&lt;fragmented = vector&lt;64x128xf32&gt;&gt;,
+                 !nvgpu.warpgroup.accumulator&lt;fragmented = vector&lt;64x128xf32&gt;&gt;
+    ```
+  }];
+
+  let arguments = (ins NVGPU_WarpgroupMatrixDescriptor:$descriptorA, 
+                       NVGPU_WarpgroupMatrixDescriptor:$descriptorB,                                               
+                       DefaultValuedOptionalAttr&lt;I32Attr, &quot;1&quot;&gt;:$waitGroup,
+                       OptionalAttr&lt;UnitAttr&gt;:$transposeA,
+                       OptionalAttr&lt;UnitAttr&gt;:$transposeB,
+                       Variadic&lt;NVGPU_WarpgroupAccumulator&gt;:$matrixC);
+  let results = (outs Variadic&lt;NVGPU_WarpgroupAccumulator&gt;:$matrixD);
+  let assemblyFormat = [{    
+    $descriptorA`,` $descriptorB`,` $matrixC attr-dict
+    `:` type($descriptorA) `,` type($descriptorB) `,` type($matrixC) `-&gt;` type($matrixD)
+  }];
+  let hasVerifier = 1;
+}
 
 #endif // NVGPU
diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h
index 192afcb2dba7913..96af26842dafea2 100644
--- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h
+++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h
@@ -21,6 +21,8 @@
 
 #include &quot;mlir/Dialect/NVGPU/IR/NVGPUEnums.h.inc&quot;
 
+constexpr int kWarpSize = 32;
+
 #define GET_ATTRDEF_CLASSES
 #include &quot;mlir/Dialect/NVGPU/IR/NVGPUAttrDefs.h.inc&quot;
 
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index b045089244ff1a7..046727e4ea9ab83 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -17,10 +17,12 @@
 #include &quot;mlir/Dialect/LLVMIR/NVVMDialect.h&quot;
 #include &quot;mlir/Dialect/MemRef/IR/MemRef.h&quot;
 #include &quot;mlir/Dialect/NVGPU/IR/NVGPUDialect.h&quot;
+#include &quot;mlir/Dialect/SCF/Transforms/Patterns.h&quot;
 #include &quot;mlir/IR/PatternMatch.h&quot;
 #include &quot;mlir/IR/TypeUtilities.h&quot;
 #include &quot;mlir/Pass/Pass.h&quot;
 #include &quot;llvm/Support/Debug.h&quot;
+#include &quot;llvm/Support/ErrorHandling.h&quot;
 #include &quot;llvm/Support/raw_ostream.h&quot;
 
 #define DEBUG_TYPE &quot;nvgpu-to-nvvm&quot;
@@ -34,6 +36,10 @@ namespace mlir {
 
 using namespace mlir;
 
+/// Number of bits that needs to excluded when building matrix descriptor for
+/// wgmma operations.
+constexpr int exclude4LSB = 4;
+
 /// GPU has 32 bit registers, this function truncates values when larger width
 /// is not needed.
 static Value truncToI32(ConversionPatternRewriter &amp;rewriter, Location loc,
@@ -419,6 +425,15 @@ struct ConvertNVGPUToNVVMPass
     converter.addConversion([&amp;](nvgpu::DeviceAsyncTokenType type) -&gt; Type {
       return converter.convertType(IntegerType::get(type.getContext(), 32));
     });
+    converter.addConversion([&amp;](nvgpu::WarpgroupAccumulatorType type) -&gt; Type {
+      VectorType vtype = type.getFragmented();
+      SmallVector&lt;Type&gt; structBody;
+      for (unsigned i = 0; i &lt; vtype.getDimSize(0); i++)
+        structBody.push_back(vtype.getElementType());
+      auto convertedType =
+          LLVM::LLVMStructType::getLiteral(type.getContext(), structBody);
+      return converter.convertType(convertedType);
+    });
     converter.addConversion([&amp;](nvgpu::MBarrierTokenType type) -&gt; Type {
       return converter.convertType(IntegerType::get(type.getContext(), 64));
     });
@@ -438,6 +453,8 @@ struct ConvertNVGPUToNVVMPass
     target.addLegalDialect&lt;::mlir::LLVM::LLVMDialect&gt;();
     target.addLegalDialect&lt;::mlir::memref::MemRefDialect&gt;();
     target.addLegalDialect&lt;::mlir::NVVM::NVVMDialect&gt;();
+    mlir::scf::populateSCFStructuralTypeConversionsAndLegality(
+        converter, patterns, target);
     if (failed(applyPartialConversion(getOperation(), target,
                                       std::move(patterns))))
       signalPassFailure();
@@ -984,10 +1001,9 @@ struct NVGPUGenerateGmmaDescriptorLowering
                                          shiftLeft(val, startBit));
     };
 
-    int ex4LSB = 4;
     int64_t sizeN = op.getTensorMap().getType().getTensor().getDimSize(0);
-    uint64_t strideDimVal = (layout &lt;&lt; 3) &gt;&gt; ex4LSB;
-    uint64_t leadDimVal = (sizeN * layout) &gt;&gt; ex4LSB;
+    uint64_t strideDimVal = (layout &lt;&lt; 3) &gt;&gt; exclude4LSB;
+    uint64_t leadDimVal = (sizeN * layout) &gt;&gt; exclude4LSB;
     uint64_t offsetVal = 0;
 
     Value strideDim = makeConst(strideDimVal);
@@ -1141,6 +1157,149 @@ struct NVGPUTmaCreateDescriptorOpLowering
   }
 };
 
+struct NVGPUWarpgroupMmaOpLowering
+    : public ConvertOpToLLVMPattern&lt;nvgpu::WarpgroupMmaOp&gt; {
+  using ConvertOpToLLVMPattern&lt;nvgpu::WarpgroupMmaOp&gt;::ConvertOpToLLVMPattern;
+
+  LogicalResult getWgmmaShape(int64_t sizeM, int64_t sizeN, Type inputElemType,
+                              int &amp;wgmmaShapeM, int &amp;wgmmaShapeN,
+                              int &amp;wgmmaShapeK) const {
+    wgmmaShapeM = 64;
+    wgmmaShapeN = sizeN;
+    if (inputElemType.isTF32()) {
+      wgmmaShapeK = 8;
+    } else if (inputElemType.isF16() || inputElemType.isBF16()) {
+      wgmmaShapeK = 16;
+    } else if (inputElemType.isFloat8E4M3FN() || inputElemType.isFloat8E5M2() ||
+               inputElemType.isInteger(16)) {
+      wgmmaShapeK = 32;
+    } else if (inputElemType.isInteger(1)) {
+      wgmmaShapeK = 256;
+    } else {
+      llvm_unreachable(&quot;msg: not supported K shape&quot;);
+    }
+    LLVM_DEBUG(DBGS() &lt;&lt; &quot;Generating wgmma.mma.async shape[m = &quot; &lt;&lt; wgmmaShapeM
+                      &lt;&lt; &quot;, n = &quot; &lt;&lt; wgmmaShapeN &lt;&lt; &quot;, k = &quot; &lt;&lt; wgmmaShapeK
+                      &lt;&lt; &quot;]\n&quot;);
+    return success();
+  }
+
+  Value generateNVVMWgmmaOp(MLIRContext *ctx,
+                            ConversionPatternRewriter &amp;rewriter, Location loc,
+                            int m, int n, int k, Type resultStructType,
+                            Value inout, Value descriptorA,
+                            Value descriptorB) const {
+    TypeRange resultTypes = {resultStructType};
+    auto shape = NVVM::MMAShapeAttr::get(ctx, m, n, k);
+    auto scaleOut = NVVM::WGMMAScaleOutAttr::get(ctx, NVVM::WGMMAScaleOut::one);
+    auto scaleIn = NVVM::WGMMAScaleInAttr::get(ctx, NVVM::WGMMAScaleIn::one);
+    auto layoutA = NVVM::MMALayoutAttr::get(ctx, NVVM::MMALayout::row);
+    auto layoutB = NVVM::MMALayoutAttr::get(ctx, NVVM::MMALayout::col);
+    // todo input type
+    auto itype = NVVM::WGMMATypesAttr::get(ctx, NVVM::WGMMATypes::f16);
+    auto overflow =
+        NVVM::MMAIntOverflowAttr::get(ctx, NVVM::MMAIntOverflow::wrapped);
+    Value res = rewriter.create&lt;NVVM::WgmmaMmaAsyncOp&gt;(
+        loc, resultTypes, inout, descriptorA, descriptorB, shape, itype, itype,
+        scaleOut, scaleIn, scaleIn, layoutA, layoutB, overflow);
+    return res;
+  }
+
+  LogicalResult
+  matchAndRewrite(nvgpu::WarpgroupMmaOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &amp;rewriter) const override {
+    int64_t sizeM = op.getDescriptorA().getType().getTensor().getDimSize(0);
+    int64_t sizeN = op.getDescriptorB().getType().getTensor().getDimSize(1);
+    int64_t sizeK = op.getDescriptorA().getType().getTensor().getDimSize(1);
+
+    LLVM_DEBUG(DBGS() &lt;&lt; &quot;===--- GEMM D[&quot; &lt;&lt; sizeM &lt;&lt; &quot;][&quot; &lt;&lt; sizeN &lt;&lt; &quot;] += A[&quot;
+                      &lt;&lt; sizeM &lt;&lt; &quot;][&quot; &lt;&lt; sizeK &lt;&lt; &quot;] * B[&quot; &lt;&lt; sizeK &lt;&lt; &quot;][&quot;
+                      &lt;&lt; sizeN &lt;&lt; &quot;] ---===\n&quot;);
+
+    int wgmmaShapeM, wgmmaShapeN, wgmmaShapeK;
+    if (failed(getWgmmaShape(sizeM, sizeN, rewriter.getF16Type(), wgmmaShapeM,
+                             wgmmaShapeN, wgmmaShapeK))) {
+      return failure();
+    }
+
+    Value descriptorA = adaptor.getDescriptorA();
+    Value descriptorB = adaptor.getDescriptorB();
+
+    //  Generate wgmma group
+
+    auto loc = op-&gt;getLoc();
+    MemRefType typeTensorA = op.getDescriptorA().getType().getTensor();
+    MemRefType typeTensorB = op.getDescriptorB().getType().getTensor();
+
+    auto makeAdd = [&amp;](Value lhs, Value rhs) -&gt; Value {
+      return rewriter.create&lt;LLVM::AddOp&gt;(loc, lhs.getType(), lhs, rhs);
+    };
+
+    auto iterateDescA = [&amp;](Value desc, int iterM, int iterN,
+                            int iterK) -&gt; Value {
+      // todo : Handle column major
+      int byte = typeTensorA.getElementTypeBitWidth() / 8;
+      int tileShapeA = typeTensorA.getDimSize(1);
+      int incrementVal =
+          ((wgmmaShapeK * iterK) + (sizeK * tileShapeA * iterM)) * byte;
+      incrementVal = incrementVal &gt;&gt; exclude4LSB;
+      LLVM_DEBUG(DBGS() &lt;&lt; &quot;\t\t[m: &quot; &lt;&lt; iterM &lt;&lt; &quot; n: &quot; &lt;&lt; iterN &lt;&lt; &quot; k: &quot;
+                        &lt;&lt; iterK &lt;&lt; &quot;] [wgmma descriptors] Descriptor A + &quot;
+                        &lt;&lt; incrementVal &lt;&lt; &quot; | \t &quot;);
+      if (!incrementVal)
+        return desc;
+      return makeAdd(desc, makeI64Const(rewriter, op, incrementVal));
+    };
+
+    auto iterateDescB = [&amp;](Value desc, int iterM, int iterN,
+                            int iterK) -&gt; Value {
+      // todo : Handle row major
+      int byte = typeTensorB.getElementTypeBitWidth() / 8;
+      int incrementVal = typeTensorB.getDimSize(0) * wgmmaShapeK * iterK * byte;
+      incrementVal = incrementVal &gt;&gt; exclude4LSB;
+      LLVM_DEBUG(DBGSE() &lt;&lt; &quot;Descriptor B + &quot; &lt;&lt; incrementVal &lt;&lt; &quot;\n&quot;);
+      if (!incrementVal)
+        return desc;
+      return makeAdd(desc, makeI64Const(rewriter, op, incrementVal));
+    };
+
+    rewriter.create&lt;NVVM::WgmmaFenceAlignedOp&gt;(loc);
+
+    SmallVector&lt;Value&gt; wgmmaResults;
+    for (int iterM = 0; iterM &lt; (sizeM / wgmmaShapeM); iterM++) {
+      Value matrixC = adaptor.getMatrixC()[iterM];
+      Value matrixD = op.getMatrixD()[iterM];
+      Type structType = getTypeConverter()-&gt;convertType(matrixD.getType());
+      LLVM_DEBUG(DBGS() &lt;&lt; &quot; D[&quot; &lt;&lt; (iterM * wgmmaShapeM) &lt;&lt; &quot;:&quot;
+                        &lt;&lt; (iterM * wgmmaShapeM) + wgmmaShapeM &lt;&lt; &quot;][&quot; &lt;&lt; 0
+                        &lt;&lt; &quot;:&quot; &lt;&lt; wgmmaShapeN &lt;&lt; &quot;] += \n&quot;);
+      for (int iterK = 0; iterK &lt; (sizeK / wgmmaShapeK); iterK++) {
+        Value descA = iterateDescA(descriptorA, iterM, 0, iterK);
+        Value descB = iterateDescB(descriptorB, iterM, 0, iterK);
+        LLVM_DEBUG(DBGS() &lt;&lt; &quot;\t wgmma.&quot;
+                          &lt;&lt; &quot;m&quot; &lt;&lt; wgmmaShapeM &lt;&lt; &quot;n&quot; &lt;&lt; wgmmaShapeN &lt;&lt; &quot;k&quot;
+                          &lt;&lt; wgmmaShapeK &lt;&lt; &quot;(A[&quot; &lt;&lt; (iterM * wgmmaShapeM)
+                          &lt;&lt; &quot;:&quot; &lt;&lt; (iterM * wgmmaShapeM) + wgmmaShapeM &lt;&lt; &quot;][&quot;
+                          &lt;&lt; (iterK * wgmmaShapeK) &lt;&lt; &quot;:&quot;
+                          &lt;&lt; (iterK * wgmmaShapeK + wgmmaShapeK) &lt;&lt; &quot;] * &quot;
+                          &lt;&lt; &quot; B[&quot; &lt;&lt; (iterK * wgmmaShapeK) &lt;&lt; &quot;:&quot;
+                          &lt;&lt; (iterK * wgmmaShapeK + wgmmaShapeK) &lt;&lt; &quot;][&quot; &lt;&lt; 0
+                          &lt;&lt; &quot;:&quot; &lt;&lt; wgmmaShapeN &lt;&lt; &quot;])\n&quot;);
+        matrixC = generateNVVMWgmmaOp(op-&gt;getContext(), rewriter, loc,
+                                      wgmmaShapeM, wgmmaShapeN, wgmmaShapeK,
+                                      structType, matrixC, descA, descB);
+      }
+      wgmmaResults.push_back(matrixC);
+    }
+    rewriter.create&lt;NVVM::WgmmaGroupSyncAlignedOp&gt;(loc);
+    rewriter.create&lt;NVVM::WgmmaWaitGroupSyncOp&gt;(loc, op.getWaitGroup());
+
+    ValueRange myres(wgmmaResults);
+    rewriter.replaceOp(op, myres);
+    return success();
+  }
+};
+
 } // namespace
 
 void mlir::populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &amp;converter,
@@ -1156,6 +1315,7 @@ void mlir::populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &amp;converter,
       NVGPUTmaCreateDescriptorOpLowering,    // nvgpu.tma.create.descriptor
       NVGPUMBarrierArriveExpectTxLowering,   // nvgpu.mbarrier.arrive.expect_tx
       NVGPUGenerateGmmaDescriptorLowering,   // nvgpu.wgmma.generate.descriptor
+      NVGPUWarpgroupMmaOpLowering,           // nvgpu.warpgroup.mma
       MmaSyncOptoNVVM, MmaLdMatrixOpToNVVM, NVGPUAsyncCopyLowering,
       NVGPUAsyncCreateGroupLowering, NVGPUAsyncWaitLowering,
       NVGPUMmaSparseSyncLowering&gt;(converter);
diff --git a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
index d832a983a132d61..d96ed69982870b4 100644
--- a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
+++ b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
@@ -22,6 +22,7 @@
 #include &quot;mlir/IR/PatternMatch.h&quot;
 #include &quot;mlir/IR/TypeUtilities.h&quot;
 #include &quot;mlir/IR/Verifier.h&quot;
+#include &quot;llvm/ADT/STLExtras.h&quot;
 #include &quot;llvm/ADT/StringExtras.h&quot;
 #include &quot;llvm/ADT/TypeSwitch.h&quot;
 
@@ -151,7 +152,6 @@ static LogicalResult verifyMmaSyncOp(Operation *op,
   //  - For F32 (TF32), F16, S8, and S4 data
   //    types the fundamental tensor core operation is of shape 8-by-8-by-128b.
   //  - F64 is an exception and is of shape 8-by-8-by-256b.
-  constexpr int kThreads = 32; // 32 threads per warp
   int64_t shapeM = 8;
   int64_t shapeN = 8;
   int64_t shapeK; // set based on data type (128b for all data types except F64)
@@ -206,17 +206,17 @@ static LogicalResult verifyMmaSyncOp(Operation *op,
 
   // verify warp-wide size for vector a
   int64_t sparseFactor = sparse ? 2 : 1;
-  if (aShape[0] * aShape[1] * kThreads != m * k / sparseFactor)
+  if (aShape[0] * aShape[1] * kWarpSize != m * k / sparseFactor)
     return op-&gt;emitOpError()
            &lt;&lt; &quot;expected &quot; &lt;&lt; m * k &lt;&lt; &quot; warp-wide matrix A elements&quot;;
 
   // verify warp-wide size for vector b
-  if (bShape[0] * bShape[1] * kThreads != k * n)
+  if (bShape[0] * bShape[1] * kWarpSize != k * n)
     return op-&gt;emitOpError()
            &lt;&lt; &quot;expected &quot; &lt;&lt; k * n &lt;&lt; &quot; warp-wide matrix B elements&quot;;
 
   // verify warp-wide size for vector c
-  if (cShape[0] * cShape[1] * kThreads != m * n)
+  if (cShape[0] * cShape[1] * kWarpSize != m * n)
     return op-&gt;emitOpError()
            &lt;&lt; &quot;expected &quot; &lt;&lt; m * n &lt;&lt; &quot; warp-wide matrix C elements&quot;;
 
@@ -402,6 +402,133 @@ LogicalResult GenerateGmmaDescriptorOp::verify() {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// WarpgroupMmaOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult isAllowedWGMMADataType(Type typeD, Type typeA, Type typeB) {
+  // F32 += F16 + F16
+  // F16 += F16 + F16
+  if (typeA.isF16() &amp;&amp; typeB.isF16() &amp;&amp; (typeD.isF32() || typeD.isF16()))
+    return success();
+  // F32 += TF32 + TF32
+  if (typeA.isTF32() &amp;&amp; typeD.isF32() &amp;&amp; typeB.isTF32())
+    return success();
+  // s32 += i8 + i8
+  if (typeA.isInteger(16) &amp;&amp; typeB.isInteger(16) &amp;&amp; typeD.isInteger(32))
+    return success();
+  // s32 += i1 + i1
+  if (typeA.isInteger(1) &amp;&amp; typeB.isInteger(1) &amp;&amp; typeD.isInteger(32))
+    return success();
+  // F32 += BF16 + BF16
+  // F16 += BF16 + BF16
+  if (typeA.isBF16() &amp;&amp; typeB.isBF16() &amp;&amp; (typeD.isF32() || typeD.isF16()))
+    return success();
+  // F16 += f8 + f8
+  // F32 += f8 + f8
+  if ((typeA.isFloat8E5M2() || typeA.isFloat8E4M3FN()) &amp;&amp;
+      (typeB.isFloat8E5M2() || typeB.isFloat8E4M3FN()) &amp;&amp;
+      (typeD.isF32() || typeD.isF16()))
+    return success();
+
+  return failure();
+}
+
+LogicalResult isAllowedSizeN(int sizeN, Type typeA) {
+  SmallVector&lt;int&gt; allowedN = {8,   16,  24,  32,  40,  48,  56,  64,
+                               72,  80,  88,  96,  104, 112, 120, 128,
+                               136, 144, 152, 160, 168, 176, 184, 192,
+                               200, 208, 216, 224, 232, 240, 248, 256};
+  SmallVector&lt;int&gt; allowedNshort = {8,   16,  24,  32,  48,  64,
+                                    80,  96,  112, 128, 144, 160,
+                                    176, 192, 208, 224, 240, 256};
+  if (typeA.isBF16() || typeA.isF16() || typeA.isTF32() ||
+      typeA.isFloat8E4M3FN() || typeA.isFloat8E5M2())
+    if (llvm::any_of(allo...
<truncated>
</pre>
</details>


https://github.com/llvm/llvm-project/pull/65440


More information about the Mlir-commits mailing list