[Mlir-commits] [mlir] [mlir][nvgpu] Fix 'warpgroup.mma.store' index calculation (PR #78413)
Guray Ozen
llvmlistbot at llvm.org
Fri Jan 19 05:06:27 PST 2024
https://github.com/grypp updated https://github.com/llvm/llvm-project/pull/78413
>From a476ea9ea316e28fe54815dea90ab6b098ee8358 Mon Sep 17 00:00:00 2001
From: Guray Ozen <guray.ozen at gmail.com>
Date: Wed, 17 Jan 2024 10:04:57 +0100
Subject: [PATCH 1/2] [mlir][nvgpu] Fix 'warpgroup.mma.store' index calculation
This PR fixes the 'nvgpu.warpgroup.mma.store' index calculation. When the destionation memref and current accumulator matrix were small, the previous code was reaching out of range.
---
.../Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp | 7 +++--
.../Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir | 28 +++++++++++++++++++
2 files changed, 32 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index 759766275de4a5..9e4ae219eefd60 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -1554,6 +1554,7 @@ struct NVGPUWarpgroupMmaStoreOpLowering
Value lane4Id = b.create<LLVM::UDivOp>(i32, laneId, c4);
Value lane4modId = b.create<LLVM::URemOp>(i32, laneId, c4);
+ auto structType = matrixD.getType().cast<LLVM::LLVMStructType>();
auto makeExtractAndStore = [&](int i, Value wgmmaResult, Value x, Value y,
TypedValue<::mlir::MemRefType> memref) {
Type it = b.getIndexType();
@@ -1570,11 +1571,11 @@ struct NVGPUWarpgroupMmaStoreOpLowering
Value ti = makeAdd(lane4Id, makeMul(warpId, c16));
if (offset)
ti = makeAdd(ti, makeConst(offset));
- for (int i = 0; i < 2; ++i) {
+ for (size_t i = 0; i < 2; ++i) {
Value idx = makeAdd(ti, makeMul(makeConst(i), c8));
- for (int j = 0; j < 16; ++j) {
+ for (size_t j = 0; j < (structType.getBody().size() / 8); ++j) {
Value idy = makeAdd(tj, makeMul(makeConst(j), c8));
- int sIndex = i * 2 + j * 4;
+ size_t sIndex = i * 2 + j * 4;
makeExtractAndStore(sIndex, matrixD, idx, idy, dstMemref);
}
}
diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
index edccd7e80603bd..ce81fd859fd02a 100644
--- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
@@ -1055,6 +1055,34 @@ func.func @warpgroup_mma_store(
return
}
+// CHECK-LABEL: @warpgroup_mma_store_multiplie(
+// CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>, %[[arg1:[a-zA-Z0-9_]+]]: memref<64x128xf32, 3>, %[[arg2:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.accumulator<fragmented = vector<64x32xf32>>, %[[arg3:[a-zA-Z0-9_]+]]: memref<64x32xf32, 3>, %[[arg4:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.accumulator<fragmented = vector<64x64xf32>>, %[[arg5:[a-zA-Z0-9_]+]]: memref<64x64xf32, 3>)
+func.func @warpgroup_mma_store_multiplie(
+ %result128 : !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>,
+ %matrixD128: memref<64x128xf32,3>,
+ %result32 : !nvgpu.warpgroup.accumulator<fragmented = vector<64x32xf32>>,
+ %matrixD32: memref<64x32xf32,3>,
+ %result64 : !nvgpu.warpgroup.accumulator<fragmented = vector<64x64xf32>>,
+ %matrixD64: memref<64x64xf32,3>) {
+
+ // CHECK-COUNT-32: memref.store %{{.*}}, %[[arg1]][%{{.*}}, %{{.*}}] : memref<64x128xf32, 3>
+ nvgpu.warpgroup.mma.store %result128, %matrixD128 :
+ !nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>>
+ to memref<64x128xf32,3>
+
+
+ // CHECK-COUNT-8: memref.store %{{.*}}, %[[arg3]][%{{.*}}, %{{.*}}] : memref<64x32xf32, 3>
+ nvgpu.warpgroup.mma.store %result32, %matrixD32 :
+ !nvgpu.warpgroup.accumulator< fragmented = vector<64x32xf32>>
+ to memref<64x32xf32,3>
+
+ // CHECK-COUNT-16: memref.store %{{.*}}, %[[arg5]][%{{.*}}, %{{.*}}] : memref<64x64xf32, 3>
+ nvgpu.warpgroup.mma.store %result64, %matrixD64 :
+ !nvgpu.warpgroup.accumulator< fragmented = vector<64x64xf32>>
+ to memref<64x64xf32,3>
+ return
+}
+
func.func @warpgroup_mma_init() {
//CHECK: %[[S1:.+]] = llvm.mlir.constant(0.000000e+00 : f32) : f3
//CHECK: %[[S0:.+]] = llvm.mlir.undef : !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)>
>From da9a4dddb6dab808c345a4badfd0de198ea02c95 Mon Sep 17 00:00:00 2001
From: Guray Ozen <guray.ozen at gmail.com>
Date: Fri, 19 Jan 2024 14:06:14 +0100
Subject: [PATCH 2/2] add more test
---
.../Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp | 33 ++--
.../Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir | 152 +++++++++++++++---
2 files changed, 149 insertions(+), 36 deletions(-)
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index 9e4ae219eefd60..ab4dea9d5618d5 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -1548,13 +1548,6 @@ struct NVGPUWarpgroupMmaStoreOpLowering
return b.create<LLVM::AddOp>(lhs.getType(), lhs, rhs);
};
- Value tidx = b.create<NVVM::ThreadIdXOp>(i32);
- Value laneId = b.create<LLVM::URemOp>(i32, tidx, warpSize);
- Value warpId = b.create<LLVM::UDivOp>(i32, tidx, warpSize);
- Value lane4Id = b.create<LLVM::UDivOp>(i32, laneId, c4);
- Value lane4modId = b.create<LLVM::URemOp>(i32, laneId, c4);
-
- auto structType = matrixD.getType().cast<LLVM::LLVMStructType>();
auto makeExtractAndStore = [&](int i, Value wgmmaResult, Value x, Value y,
TypedValue<::mlir::MemRefType> memref) {
Type it = b.getIndexType();
@@ -1567,16 +1560,34 @@ struct NVGPUWarpgroupMmaStoreOpLowering
b.create<memref::StoreOp>(d1, memref, ValueRange{idx, idy1});
};
+ Value tidx = b.create<NVVM::ThreadIdXOp>(i32);
+ Value laneId = b.create<LLVM::URemOp>(i32, tidx, warpSize);
+ Value warpId = b.create<LLVM::UDivOp>(i32, tidx, warpSize);
+ Value lane4Id = b.create<LLVM::UDivOp>(i32, laneId, c4);
+ Value lane4modId = b.create<LLVM::URemOp>(i32, laneId, c4);
+
Value tj = makeMul(lane4modId, c2);
Value ti = makeAdd(lane4Id, makeMul(warpId, c16));
if (offset)
ti = makeAdd(ti, makeConst(offset));
- for (size_t i = 0; i < 2; ++i) {
+
+ auto structType = matrixD.getType().cast<LLVM::LLVMStructType>();
+
+ // Number of 32-bit registers owns per thread
+ constexpr unsigned numAdjacentRegisters = 2;
+ // Number of 8x8 matrices one below another per warp
+ constexpr unsigned numStackedMatrices = 2;
+
+ size_t storeCount = (structType.getBody().size() /
+ (numStackedMatrices * numAdjacentRegisters));
+
+ for (size_t i = 0; i < numStackedMatrices; ++i) {
Value idx = makeAdd(ti, makeMul(makeConst(i), c8));
- for (size_t j = 0; j < (structType.getBody().size() / 8); ++j) {
+ for (size_t j = 0; j < storeCount; ++j) {
Value idy = makeAdd(tj, makeMul(makeConst(j), c8));
- size_t sIndex = i * 2 + j * 4;
- makeExtractAndStore(sIndex, matrixD, idx, idy, dstMemref);
+ size_t structIndex = (i * numAdjacentRegisters) +
+ (j * (numStackedMatrices * numAdjacentRegisters));
+ makeExtractAndStore(structIndex, matrixD, idx, idy, dstMemref);
}
}
}
diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
index ce81fd859fd02a..b495363e228d8f 100644
--- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
@@ -1055,31 +1055,133 @@ func.func @warpgroup_mma_store(
return
}
-// CHECK-LABEL: @warpgroup_mma_store_multiplie(
-// CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>, %[[arg1:[a-zA-Z0-9_]+]]: memref<64x128xf32, 3>, %[[arg2:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.accumulator<fragmented = vector<64x32xf32>>, %[[arg3:[a-zA-Z0-9_]+]]: memref<64x32xf32, 3>, %[[arg4:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.accumulator<fragmented = vector<64x64xf32>>, %[[arg5:[a-zA-Z0-9_]+]]: memref<64x64xf32, 3>)
-func.func @warpgroup_mma_store_multiplie(
- %result128 : !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>,
- %matrixD128: memref<64x128xf32,3>,
- %result32 : !nvgpu.warpgroup.accumulator<fragmented = vector<64x32xf32>>,
- %matrixD32: memref<64x32xf32,3>,
- %result64 : !nvgpu.warpgroup.accumulator<fragmented = vector<64x64xf32>>,
- %matrixD64: memref<64x64xf32,3>) {
-
- // CHECK-COUNT-32: memref.store %{{.*}}, %[[arg1]][%{{.*}}, %{{.*}}] : memref<64x128xf32, 3>
- nvgpu.warpgroup.mma.store %result128, %matrixD128 :
- !nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>>
- to memref<64x128xf32,3>
-
-
- // CHECK-COUNT-8: memref.store %{{.*}}, %[[arg3]][%{{.*}}, %{{.*}}] : memref<64x32xf32, 3>
- nvgpu.warpgroup.mma.store %result32, %matrixD32 :
- !nvgpu.warpgroup.accumulator< fragmented = vector<64x32xf32>>
- to memref<64x32xf32,3>
-
- // CHECK-COUNT-16: memref.store %{{.*}}, %[[arg5]][%{{.*}}, %{{.*}}] : memref<64x64xf32, 3>
- nvgpu.warpgroup.mma.store %result64, %matrixD64 :
- !nvgpu.warpgroup.accumulator< fragmented = vector<64x64xf32>>
- to memref<64x64xf32,3>
+// CHECK-LABEL: @warpgroup_mma_store_multiple
+func.func @warpgroup_mma_store_multiple(
+ %shmem_m64n8k : memref<64x8xf32>,
+ %shmem_m64n16k : memref<64x16xf32>,
+ %shmem_m64n24k : memref<64x24xf32>,
+ %shmem_m64n32k : memref<64x32xf32>,
+ %shmem_m64n40k : memref<64x40xf32>,
+ %shmem_m64n48k : memref<64x48xf32>,
+ %shmem_m64n56k : memref<64x56xf32>,
+ %shmem_m64n64k : memref<64x64xf32>,
+ %shmem_m64n72k : memref<64x72xf32>,
+ %shmem_m64n80k : memref<64x80xf32>,
+ %shmem_m64n88k : memref<64x88xf32>,
+ %shmem_m64n96k : memref<64x96xf32>,
+ %shmem_m64n104k : memref<64x104xf32>,
+ %shmem_m64n112k : memref<64x112xf32>,
+ %shmem_m64n120k : memref<64x120xf32>,
+ %shmem_m64n128k : memref<64x128xf32>,
+ %shmem_m64n136k : memref<64x136xf32>,
+ %shmem_m64n144k : memref<64x144xf32>,
+ %shmem_m64n152k : memref<64x152xf32>,
+ %shmem_m64n160k : memref<64x160xf32>,
+ %shmem_m64n168k : memref<64x168xf32>,
+ %shmem_m64n176k : memref<64x176xf32>,
+ %shmem_m64n184k : memref<64x184xf32>,
+ %shmem_m64n192k : memref<64x192xf32>,
+ %shmem_m64n200k : memref<64x200xf32>,
+ %shmem_m64n208k : memref<64x208xf32>,
+ %shmem_m64n216k : memref<64x216xf32>,
+ %shmem_m64n224k : memref<64x224xf32>,
+ %shmem_m64n232k : memref<64x232xf32>,
+ %shmem_m64n240k : memref<64x240xf32>,
+ %shmem_m64n248k : memref<64x248xf32>,
+ %shmem_m64n256k : memref<64x256xf32>,
+ %res_m64n16k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x16xf32>>,
+ %res_m64n24k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x24xf32>>,
+ %res_m64n32k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x32xf32>>,
+ %res_m64n40k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x40xf32>>,
+ %res_m64n48k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x48xf32>>,
+ %res_m64n56k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x56xf32>>,
+ %res_m64n64k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x64xf32>>,
+ %res_m64n72k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x72xf32>>,
+ %res_m64n80k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x80xf32>>,
+ %res_m64n88k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x88xf32>>,
+ %res_m64n96k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x96xf32>>,
+ %res_m64n104k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x104xf32>>,
+ %res_m64n112k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x112xf32>>,
+ %res_m64n120k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x120xf32>>,
+ %res_m64n128k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>,
+ %res_m64n136k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x136xf32>>,
+ %res_m64n144k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x144xf32>>,
+ %res_m64n152k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x152xf32>>,
+ %res_m64n160k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x160xf32>>,
+ %res_m64n168k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x168xf32>>,
+ %res_m64n176k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x176xf32>>,
+ %res_m64n184k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x184xf32>>,
+ %res_m64n192k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x192xf32>>,
+ %res_m64n200k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x200xf32>>,
+ %res_m64n208k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x208xf32>>,
+ %res_m64n216k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x216xf32>>,
+ %res_m64n224k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x224xf32>>,
+ %res_m64n232k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x232xf32>>,
+ %res_m64n240k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x240xf32>>,
+ %res_m64n248k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x248xf32>>,
+ %res_m64n256k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x256xf32>>) {
+ // CHECK-COUNT-8: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x16xf32>
+ // CHECK-COUNT-12: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x24xf32>
+ // CHECK-COUNT-16: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x32xf32>
+ // CHECK-COUNT-20: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x40xf32>
+ // CHECK-COUNT-24: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x48xf32>
+ // CHECK-COUNT-28: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x56xf32>
+ // CHECK-COUNT-32: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x64xf32>
+ // CHECK-COUNT-36: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x72xf32>
+ // CHECK-COUNT-40: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x80xf32>
+ // CHECK-COUNT-44: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x88xf32>
+ // CHECK-COUNT-48: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x96xf32>
+ // CHECK-COUNT-52: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x104xf32>
+ // CHECK-COUNT-56: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x112xf32>
+ // CHECK-COUNT-60: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x120xf32>
+ // CHECK-COUNT-64: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x128xf32>
+ // CHECK-COUNT-68: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x136xf32>
+ // CHECK-COUNT-72: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x144xf32>
+ // CHECK-COUNT-76: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x152xf32>
+ // CHECK-COUNT-80: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x160xf32>
+ // CHECK-COUNT-84: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x168xf32>
+ // CHECK-COUNT-88: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x176xf32>
+ // CHECK-COUNT-92: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x184xf32>
+ // CHECK-COUNT-96: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x192xf32>
+ // CHECK-COUNT-100: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x200xf32>
+ // CHECK-COUNT-104: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x208xf32>
+ // CHECK-COUNT-108: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x216xf32>
+ // CHECK-COUNT-112: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x224xf32>
+ // CHECK-COUNT-116: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x232xf32>
+ // CHECK-COUNT-120: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x240xf32>
+ // CHECK-COUNT-124: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x248xf32>
+ // CHECK-COUNT-128: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x256xf32>
+ nvgpu.warpgroup.mma.store %res_m64n16k, %shmem_m64n16k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x16xf32>> to memref<64x16xf32>
+ nvgpu.warpgroup.mma.store %res_m64n24k, %shmem_m64n24k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x24xf32>> to memref<64x24xf32>
+ nvgpu.warpgroup.mma.store %res_m64n32k, %shmem_m64n32k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x32xf32>> to memref<64x32xf32>
+ nvgpu.warpgroup.mma.store %res_m64n40k, %shmem_m64n40k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x40xf32>> to memref<64x40xf32>
+ nvgpu.warpgroup.mma.store %res_m64n48k, %shmem_m64n48k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x48xf32>> to memref<64x48xf32>
+ nvgpu.warpgroup.mma.store %res_m64n56k, %shmem_m64n56k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x56xf32>> to memref<64x56xf32>
+ nvgpu.warpgroup.mma.store %res_m64n64k, %shmem_m64n64k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x64xf32>> to memref<64x64xf32>
+ nvgpu.warpgroup.mma.store %res_m64n72k, %shmem_m64n72k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x72xf32>> to memref<64x72xf32>
+ nvgpu.warpgroup.mma.store %res_m64n80k, %shmem_m64n80k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x80xf32>> to memref<64x80xf32>
+ nvgpu.warpgroup.mma.store %res_m64n88k, %shmem_m64n88k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x88xf32>> to memref<64x88xf32>
+ nvgpu.warpgroup.mma.store %res_m64n96k, %shmem_m64n96k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x96xf32>> to memref<64x96xf32>
+ nvgpu.warpgroup.mma.store %res_m64n104k, %shmem_m64n104k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x104xf32>> to memref<64x104xf32>
+ nvgpu.warpgroup.mma.store %res_m64n112k, %shmem_m64n112k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x112xf32>> to memref<64x112xf32>
+ nvgpu.warpgroup.mma.store %res_m64n120k, %shmem_m64n120k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x120xf32>> to memref<64x120xf32>
+ nvgpu.warpgroup.mma.store %res_m64n128k, %shmem_m64n128k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>> to memref<64x128xf32>
+ nvgpu.warpgroup.mma.store %res_m64n136k, %shmem_m64n136k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x136xf32>> to memref<64x136xf32>
+ nvgpu.warpgroup.mma.store %res_m64n144k, %shmem_m64n144k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x144xf32>> to memref<64x144xf32>
+ nvgpu.warpgroup.mma.store %res_m64n152k, %shmem_m64n152k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x152xf32>> to memref<64x152xf32>
+ nvgpu.warpgroup.mma.store %res_m64n160k, %shmem_m64n160k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x160xf32>> to memref<64x160xf32>
+ nvgpu.warpgroup.mma.store %res_m64n168k, %shmem_m64n168k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x168xf32>> to memref<64x168xf32>
+ nvgpu.warpgroup.mma.store %res_m64n176k, %shmem_m64n176k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x176xf32>> to memref<64x176xf32>
+ nvgpu.warpgroup.mma.store %res_m64n184k, %shmem_m64n184k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x184xf32>> to memref<64x184xf32>
+ nvgpu.warpgroup.mma.store %res_m64n192k, %shmem_m64n192k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x192xf32>> to memref<64x192xf32>
+ nvgpu.warpgroup.mma.store %res_m64n200k, %shmem_m64n200k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x200xf32>> to memref<64x200xf32>
+ nvgpu.warpgroup.mma.store %res_m64n208k, %shmem_m64n208k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x208xf32>> to memref<64x208xf32>
+ nvgpu.warpgroup.mma.store %res_m64n216k, %shmem_m64n216k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x216xf32>> to memref<64x216xf32>
+ nvgpu.warpgroup.mma.store %res_m64n224k, %shmem_m64n224k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x224xf32>> to memref<64x224xf32>
+ nvgpu.warpgroup.mma.store %res_m64n232k, %shmem_m64n232k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x232xf32>> to memref<64x232xf32>
+ nvgpu.warpgroup.mma.store %res_m64n240k, %shmem_m64n240k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x240xf32>> to memref<64x240xf32>
+ nvgpu.warpgroup.mma.store %res_m64n248k, %shmem_m64n248k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x248xf32>> to memref<64x248xf32>
+ nvgpu.warpgroup.mma.store %res_m64n256k, %shmem_m64n256k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x256xf32>> to memref<64x256xf32>
return
}
More information about the Mlir-commits
mailing list