[Mlir-commits] [mlir] [MLIR][NVGPU] Introduce `nvgpu.wargroup.mma.store` Op for Hopper GPUs (PR #65441)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Sep 29 09:04:38 PDT 2023
================
@@ -1299,6 +1301,82 @@ struct NVGPUWarpgroupMmaOpLowering
}
};
+struct NVGPUWarpgroupMmaStoreOpLowering
+ : public ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaStoreOp> {
+ using ConvertOpToLLVMPattern<
+ nvgpu::WarpgroupMmaStoreOp>::ConvertOpToLLVMPattern;
+
+ void storeFragmentedMatrix(Value wgmmaResult, nvgpu::WarpgroupMmaStoreOp op,
+ OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter,
+ int offset) const {
+ Location loc = op->getLoc();
+ Type i32 = rewriter.getI32Type();
+
+ auto makeConst = [&](int32_t index) -> Value {
+ return rewriter.create<LLVM::ConstantOp>(
+ loc, i32, rewriter.getI32IntegerAttr(index));
+ };
+ Value c4 = makeConst(4);
+ Value c32 = makeConst(kWarpSize);
+ Value c8 = makeConst(8);
+ Value c2 = makeConst(2);
+ Value c1 = makeConst(1);
+ Value c16 = makeConst(16);
+
+ auto makeMul = [&](Value lhs, Value rhs) -> Value {
+ return rewriter.create<LLVM::MulOp>(loc, lhs.getType(), lhs, rhs);
+ };
+ auto makeAdd = [&](Value lhs, Value rhs) -> Value {
+ return rewriter.create<LLVM::AddOp>(loc, lhs.getType(), lhs, rhs);
+ };
+
+ Value tidx = rewriter.create<NVVM::ThreadIdXOp>(loc, i32);
+ Value laneId = rewriter.create<LLVM::URemOp>(loc, i32, tidx, c32);
+ Value warpId = rewriter.create<LLVM::UDivOp>(loc, i32, tidx, c32);
+ Value lane4Id = rewriter.create<LLVM::UDivOp>(loc, i32, laneId, c4);
+ Value lane4modId = rewriter.create<LLVM::URemOp>(loc, i32, laneId, c4);
+
+ auto makeExtractAndStore = [&](int i, Value wgmmaResult, Value x, Value y,
+ TypedValue<::mlir::MemRefType> memref) {
+ Type it = rewriter.getIndexType();
+ Value idx = rewriter.create<arith::IndexCastOp>(loc, it, x);
+ Value idy0 = rewriter.create<arith::IndexCastOp>(loc, it, y);
+ Value idy1 = rewriter.create<arith::IndexCastOp>(loc, it, makeAdd(y, c1));
+ Value d0 = rewriter.create<LLVM::ExtractValueOp>(loc, wgmmaResult, i);
+ Value d1 = rewriter.create<LLVM::ExtractValueOp>(loc, wgmmaResult, i + 1);
+ rewriter.create<memref::StoreOp>(loc, d0, memref, ValueRange{idx, idy0});
+ rewriter.create<memref::StoreOp>(loc, d1, memref, ValueRange{idx, idy1});
+ };
+
+ Value tj = makeMul(lane4modId, c2);
----------------
qcolombet wrote:
I'm a little bit lost with the math here.
Maybe it'll get easier to digest (well when it's not Friday ;)) after you add some comments for the whole function as already mentioned in a previous comment, but if that comment doesn't touch on that, a brief addition around what we're building here would be nice.
https://github.com/llvm/llvm-project/pull/65441
More information about the Mlir-commits
mailing list