[Mlir-commits] [mlir] [mlir][nvgpu] Fix wgmma store offset (PR #154581)
lonely eagle
llvmlistbot at llvm.org
Thu Aug 21 08:47:19 PDT 2025
https://github.com/linuxlonelyeagle updated https://github.com/llvm/llvm-project/pull/154581
>From d62a27ea20e4e4242e9c53d10df3aead1d7e4428 Mon Sep 17 00:00:00 2001
From: linuxlonelyeagle <2020382038 at qq.com>
Date: Wed, 20 Aug 2025 17:28:24 +0000
Subject: [PATCH 1/2] Fix wgmma store offset
---
mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp | 3 +--
1 file changed, 1 insertion(+), 2 deletions(-)
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index c6c5ab356f256..fffcb2aedafee 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -1623,11 +1623,10 @@ struct NVGPUWarpgroupMmaStoreOpLowering
Value matriDValue = adaptor.getMatrixD();
auto stype = cast<LLVM::LLVMStructType>(matriDValue.getType());
for (auto [idx, matrixD] : llvm::enumerate(stype.getBody())) {
- auto structType = cast<LLVM::LLVMStructType>(matrixD);
Value innerStructValue =
LLVM::ExtractValueOp::create(b, matriDValue, idx);
storeFragmentedMatrix(b, innerStructValue, op.getDstMemref(), offset);
- offset += structType.getBody().size();
+ offset += kWgmmaSizeM;
}
rewriter.eraseOp(op);
return success();
>From 04ef66b67847bfa8b158217061bf64c04d61c46f Mon Sep 17 00:00:00 2001
From: linuxlonelyeagle <2020382038 at qq.com>
Date: Thu, 21 Aug 2025 15:47:05 +0000
Subject: [PATCH 2/2] add test
---
.../Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir | 103 ++++++++++++++++++
1 file changed, 103 insertions(+)
diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
index c4cf4f7337d81..cc820b87f5dc1 100644
--- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
@@ -1089,6 +1089,109 @@ func.func @warpgroup_mma_store(
return
}
+// CHECK-LABEL: func.func @warpgroup_mma_store_row_offset(
+func.func @warpgroup_mma_store_row_offset(
+ %result : !nvgpu.warpgroup.accumulator<fragmented = vector<128x8xf32>>,
+ %matrixD: memref<128x8xf32,3>) {
+ nvgpu.warpgroup.mma.store %result, %matrixD :
+ !nvgpu.warpgroup.accumulator< fragmented = vector<128x8xf32>>
+ to memref<128x8xf32,3>
+ // CHECK-SAME: %[[ARG0:.*]]: !nvgpu.warpgroup.accumulator<fragmented = vector<128x8xf32>>,
+ // CHECK-SAME: %[[ARG1:.*]]: memref<128x8xf32, 3>) {
+ // CHECK: %[[S0:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : !nvgpu.warpgroup.accumulator<fragmented = vector<128x8xf32>> to !llvm.struct<(struct<(f32, f32, f32, f32)>, struct<(f32, f32, f32, f32)>)>
+ // CHECK: %[[S1:.*]] = llvm.extractvalue %[[S0]][0] : !llvm.struct<(struct<(f32, f32, f32, f32)>, struct<(f32, f32, f32, f32)>)>
+ // CHECK: %[[S2:.*]] = llvm.mlir.constant(1 : i32) : i32
+ // CHECK: %[[S3:.*]] = llvm.mlir.constant(2 : i32) : i32
+ // CHECK: %[[S4:.*]] = llvm.mlir.constant(4 : i32) : i32
+ // CHECK: %[[S5:.*]] = llvm.mlir.constant(8 : i32) : i32
+ // CHECK: %[[S6:.*]] = llvm.mlir.constant(16 : i32) : i32
+ // CHECK: %[[S7:.*]] = llvm.mlir.constant(32 : i32) : i32
+ // CHECK: %[[S8:.*]] = nvvm.read.ptx.sreg.tid.x : i32
+ // CHECK: %[[S9:.*]] = llvm.urem %[[S8]], %[[S7]] : i32
+ // CHECK: %[[S10:.*]] = llvm.udiv %[[S8]], %[[S7]] : i32
+ // CHECK: %[[S11:.*]] = llvm.udiv %[[S9]], %[[S4]] : i32
+ // CHECK: %[[S12:.*]] = llvm.urem %[[S9]], %[[S4]] : i32
+ // CHECK: %[[S13:.*]] = llvm.mul %[[S12]], %[[S3]] : i32
+ // CHECK: %[[S14:.*]] = llvm.mul %[[S10]], %[[S6]] : i32
+ // CHECK: %[[S15:.*]] = llvm.add %[[S11]], %[[S14]] : i32
+ // CHECK: %[[S16:.*]] = llvm.mlir.constant(0 : i32) : i32
+ // CHECK: %[[S17:.*]] = llvm.mul %[[S16]], %[[S5]] : i32
+ // CHECK: %[[S18:.*]] = llvm.add %[[S15]], %[[S17]] : i32
+ // CHECK: %[[S19:.*]] = llvm.mlir.constant(0 : i32) : i32
+ // CHECK: %[[S20:.*]] = llvm.mul %[[S19]], %[[S5]] : i32
+ // CHECK: %[[S21:.*]] = llvm.add %[[S13]], %[[S20]] : i32
+ // CHECK: %[[S22:.*]] = arith.index_cast %[[S18]] : i32 to index
+ // CHECK: %[[S23:.*]] = arith.index_cast %[[S21]] : i32 to index
+ // CHECK: %[[S24:.*]] = llvm.add %[[S21]], %[[S2]] : i32
+ // CHECK: %[[S25:.*]] = arith.index_cast %[[S24]] : i32 to index
+ // CHECK: %[[S26:.*]] = llvm.extractvalue %[[S1]][0] : !llvm.struct<(f32, f32, f32, f32)>
+ // CHECK: %[[S27:.*]] = llvm.extractvalue %[[S1]][1] : !llvm.struct<(f32, f32, f32, f32)>
+ // CHECK: memref.store %[[S26]], %[[ARG1]]{{\[}}%[[S22]], %[[S23]]] : memref<128x8xf32, 3>
+ // CHECK: memref.store %[[S27]], %[[ARG1]]{{\[}}%[[S22]], %[[S25]]] : memref<128x8xf32, 3>
+ // CHECK: %[[S28:.*]] = llvm.mlir.constant(1 : i32) : i32
+ // CHECK: %[[S29:.*]] = llvm.mul %[[S28]], %[[S5]] : i32
+ // CHECK: %[[S30:.*]] = llvm.add %[[S15]], %[[S29]] : i32
+ // CHECK: %[[S31:.*]] = llvm.mlir.constant(0 : i32) : i32
+ // CHECK: %[[S32:.*]] = llvm.mul %[[S31]], %[[S5]] : i32
+ // CHECK: %[[S33:.*]] = llvm.add %[[S13]], %[[S32]] : i32
+ // CHECK: %[[S34:.*]] = arith.index_cast %[[S30]] : i32 to index
+ // CHECK: %[[S35:.*]] = arith.index_cast %[[S33]] : i32 to index
+ // CHECK: %[[S36:.*]] = llvm.add %[[S33]], %[[S2]] : i32
+ // CHECK: %[[S37:.*]] = arith.index_cast %[[S36]] : i32 to index
+ // CHECK: %[[S38:.*]] = llvm.extractvalue %[[S1]][2] : !llvm.struct<(f32, f32, f32, f32)>
+ // CHECK: %[[S39:.*]] = llvm.extractvalue %[[S1]][3] : !llvm.struct<(f32, f32, f32, f32)>
+ // CHECK: memref.store %[[S38]], %[[ARG1]]{{\[}}%[[S34]], %[[S35]]] : memref<128x8xf32, 3>
+ // CHECK: memref.store %[[S39]], %[[ARG1]]{{\[}}%[[S34]], %[[S37]]] : memref<128x8xf32, 3>
+ // CHECK: %[[S40:.*]] = llvm.extractvalue %[[S0]][1] : !llvm.struct<(struct<(f32, f32, f32, f32)>, struct<(f32, f32, f32, f32)>)>
+ // CHECK: %[[S41:.*]] = llvm.mlir.constant(1 : i32) : i32
+ // CHECK: %[[S42:.*]] = llvm.mlir.constant(2 : i32) : i32
+ // CHECK: %[[S43:.*]] = llvm.mlir.constant(4 : i32) : i32
+ // CHECK: %[[S44:.*]] = llvm.mlir.constant(8 : i32) : i32
+ // CHECK: %[[S45:.*]] = llvm.mlir.constant(16 : i32) : i32
+ // CHECK: %[[S46:.*]] = llvm.mlir.constant(32 : i32) : i32
+ // CHECK: %[[S47:.*]] = nvvm.read.ptx.sreg.tid.x : i32
+ // CHECK: %[[S48:.*]] = llvm.urem %[[S47]], %[[S46]] : i32
+ // CHECK: %[[S49:.*]] = llvm.udiv %[[S47]], %[[S46]] : i32
+ // CHECK: %[[S50:.*]] = llvm.udiv %[[S48]], %[[S43]] : i32
+ // CHECK: %[[S51:.*]] = llvm.urem %[[S48]], %[[S43]] : i32
+ // CHECK: %[[S52:.*]] = llvm.mul %[[S51]], %[[S42]] : i32
+ // CHECK: %[[S53:.*]] = llvm.mul %[[S49]], %[[S45]] : i32
+ // CHECK: %[[S54:.*]] = llvm.add %[[S50]], %[[S53]] : i32
+ // CHECK: %[[S55:.*]] = llvm.mlir.constant(64 : i32) : i32
+ // CHECK: %[[S56:.*]] = llvm.add %[[S54]], %[[S55]] : i32
+ // CHECK: %[[S57:.*]] = llvm.mlir.constant(0 : i32) : i32
+ // CHECK: %[[S58:.*]] = llvm.mul %[[S57]], %[[S44]] : i32
+ // CHECK: %[[S59:.*]] = llvm.add %[[S56]], %[[S58]] : i32
+ // CHECK: %[[S60:.*]] = llvm.mlir.constant(0 : i32) : i32
+ // CHECK: %[[S61:.*]] = llvm.mul %[[S60]], %[[S44]] : i32
+ // CHECK: %[[S62:.*]] = llvm.add %[[S52]], %[[S61]] : i32
+ // CHECK: %[[S63:.*]] = arith.index_cast %[[S59]] : i32 to index
+ // CHECK: %[[S64:.*]] = arith.index_cast %[[S62]] : i32 to index
+ // CHECK: %[[S65:.*]] = llvm.add %[[S62]], %[[S41]] : i32
+ // CHECK: %[[S66:.*]] = arith.index_cast %[[S65]] : i32 to index
+ // CHECK: %[[S67:.*]] = llvm.extractvalue %[[S40]][0] : !llvm.struct<(f32, f32, f32, f32)>
+ // CHECK: %[[S68:.*]] = llvm.extractvalue %[[S40]][1] : !llvm.struct<(f32, f32, f32, f32)>
+ // CHECK: memref.store %[[S67]], %[[ARG1]]{{\[}}%[[S63]], %[[S64]]] : memref<128x8xf32, 3>
+ // CHECK: memref.store %[[S68]], %[[ARG1]]{{\[}}%[[S63]], %[[S66]]] : memref<128x8xf32, 3>
+ // CHECK: %[[S69:.*]] = llvm.mlir.constant(1 : i32) : i32
+ // CHECK: %[[S70:.*]] = llvm.mul %[[S69]], %[[S44]] : i32
+ // CHECK: %[[S71:.*]] = llvm.add %[[S56]], %[[S70]] : i32
+ // CHECK: %[[S72:.*]] = llvm.mlir.constant(0 : i32) : i32
+ // CHECK: %[[S73:.*]] = llvm.mul %[[S72]], %[[S44]] : i32
+ // CHECK: %[[S74:.*]] = llvm.add %[[S52]], %[[S73]] : i32
+ // CHECK: %[[S75:.*]] = arith.index_cast %[[S71]] : i32 to index
+ // CHECK: %[[S76:.*]] = arith.index_cast %[[S74]] : i32 to index
+ // CHECK: %[[S77:.*]] = llvm.add %[[S74]], %[[S41]] : i32
+ // CHECK: %[[S78:.*]] = arith.index_cast %[[S77]] : i32 to index
+ // CHECK: %[[S79:.*]] = llvm.extractvalue %[[S40]][2] : !llvm.struct<(f32, f32, f32, f32)>
+ // CHECK: %[[S80:.*]] = llvm.extractvalue %[[S40]][3] : !llvm.struct<(f32, f32, f32, f32)>
+ // CHECK: memref.store %[[S79]], %[[ARG1]]{{\[}}%[[S75]], %[[S76]]] : memref<128x8xf32, 3>
+ // CHECK: memref.store %[[S80]], %[[ARG1]]{{\[}}%[[S75]], %[[S78]]] : memref<128x8xf32, 3>
+ // CHECK: return
+ // CHECK: }
+ return
+}
+
// CHECK-LABEL: @warpgroup_mma_store_multiple
func.func @warpgroup_mma_store_multiple(
%shmem_m64n8k : memref<64x8xf32>,
More information about the Mlir-commits
mailing list