[Mlir-commits] [mlir] [MLIR][NVGPU] Introduce `warpgroup.init.accumulator` Op (PR #67530)
Guray Ozen
llvmlistbot at llvm.org
Thu Oct 5 06:29:10 PDT 2023
================
@@ -727,4 +727,15 @@ def NVGPU_WarpgroupMmaOp : NVGPU_Op<"warpgroup.mma"> {
let hasVerifier = 1;
}
+def NVGPU_WarpgroupMmaInitAccumulatorOp : NVGPU_Op<"warpgroup.mma.init.accumulator"> {
+ let summary = "Initialize accumulator matrix for `warppgroup.mma`";
+
+ let description = [{
+ This Op generates and initilizes the accumulator matrix for
+ `nvgpu.warpgroup.mma` op to perform matrix-multiply-and-accumulate (mma).
+ }];
+ let results = (outs Variadic<NVGPU_WarpgroupAccumulator>:$matrixC);
----------------
grypp wrote:
You are right there are trade-offs. I have 3 versions in my mind, let me put here so we can discuss.
1) Current version with the variadic is below. All Ops has variadic, we genereate single Op for each of them.
```
// Init
%matrixC1, %matrixC2 = nvgpu.wargroup.mma.init.accumulator ->
!nvgpu.wargroup.accumulator<fragmented = vector<64x128xf32>>,
!nvgpu.wargroup.accumulator<fragmented = vector<64x128xf32>>
// GEMM
%matrixD1, %matrixD2 = nvgpu.wargroup.mma %descA, %descB, %matrixC1, %matrixC2 ...
// Epilogue
nvgpu.wargroup.mma.store [%matrixD1, %matrixD2] to %sharedMemoryBuffer
: !nvgpu.wargroup.accumulator<fragmented = vector<64x128xf32>>,
!nvgpu.wargroup.accumulator<fragmented = vector<64x128xf32>>
into memref<128x128xf32,3>
```
2) We can get rid of variadic in `nvgpu.wargroup.mma.init.accumulator`, like below:
```
// Init
%matrixC1 = nvgpu.wargroup.mma.init.accumulator -> !nvgpu.wargroup.accumulator<fragmented = vector<64x128xf32>>
%matrixC2 = nvgpu.wargroup.mma.init.accumulator -> !nvgpu.wargroup.accumulator<fragmented = vector<64x128xf32>>
// GEMM
%matrixD1, %matrixD2 = nvgpu.wargroup.mma %descA, %descB, %matrixC1, %matrixC2 : ....
// Epilogue
nvgpu.wargroup.mma.store [%matrixD1, %matrixD2] to %sharedMemoryBuffer
: !nvgpu.wargroup.accumulator<fragmented = vector<64x128xf32>>,
!nvgpu.wargroup.accumulator<fragmented = vector<64x128xf32>>
into memref<128x128xf32,3>
```
3) One can do `nvgpu.wargroup.mma.store` without variadic as well:
```
// Epilogue
%s1 = memref.subview %sharedMemoryBuffer ....
nvgpu.wargroup.mma.store %matrixD1 to %s1
: !nvgpu.wargroup.accumulator<fragmented = vector<64x128xf32>> into memref<128x128xf32,3>
%s2 = memref.subview %sharedMemoryBuffer ....
nvgpu.wargroup.mma.store %matrixD2 to %s2
: !nvgpu.wargroup.accumulator<fragmented = vector<64x128xf32>> into memref<128x128xf32,3>
```
I have implemented **1st** Option. I guess you are advocating the **2nd** one. Is that right?
https://github.com/llvm/llvm-project/pull/67530
More information about the Mlir-commits
mailing list