[Mlir-commits] [mlir] c4ba84d - [mlir][nvgpu] Fix packing accumlator matrix (#69316)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Oct 17 03:46:14 PDT 2023
Author: Guray Ozen
Date: 2023-10-17T12:46:10+02:00
New Revision: c4ba84d6555148fb7469fd44412a49d9d66eb4cf
URL: https://github.com/llvm/llvm-project/commit/c4ba84d6555148fb7469fd44412a49d9d66eb4cf
DIFF: https://github.com/llvm/llvm-project/commit/c4ba84d6555148fb7469fd44412a49d9d66eb4cf.diff
LOG: [mlir][nvgpu] Fix packing accumlator matrix (#69316)
The #68728 significantly simplified the accumulator matrix type, making
it easier to work with the nvgpu dialect without worrying about the
number of required structs, as this information is abstracted away in
the nvgpu-to-nvvm transformation.
However, we forgot packing the structs after initialization, causing the
accumulator matrix to hold undefined values, which is wrong. This PR
addresses that.
Added:
Modified:
mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index 00baf7b3c741565..029659a2f855416 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -1578,27 +1578,34 @@ struct NVGPUWarpgroupMmaInitAccumulatorOpLowering
matchAndRewrite(nvgpu::WarpgroupMmaInitAccumulatorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
- LLVM::LLVMStructType structType =
+ LLVM::LLVMStructType packStructType =
getTypeConverter()
->convertType(op.getMatrixC().getType())
.cast<LLVM::LLVMStructType>();
- Type elemType = structType.getBody()
+ Type elemType = packStructType.getBody()
.front()
.cast<LLVM::LLVMStructType>()
.getBody()
.front();
Value zero = b.create<LLVM::ConstantOp>(elemType, b.getZeroAttr(elemType));
- Value structValue = b.create<LLVM::UndefOp>(structType);
- for (auto [idx, s] : llvm::enumerate(structType.getBody())) {
- auto innerStructType = s.cast<LLVM::LLVMStructType>();
- int ii = idx;
- Value innerStructValue = b.create<LLVM::ExtractValueOp>(structValue, ii);
- for (unsigned i = 0; i < innerStructType.getBody().size(); ++i) {
- innerStructValue = b.create<LLVM::InsertValueOp>(
- innerStructType, innerStructValue, zero, ArrayRef<int64_t>({i}));
+ Value packStruct = b.create<LLVM::UndefOp>(packStructType);
+ SmallVector<Value> innerStructs;
+ // Unpack the structs and set all values to zero
+ for (auto [idx, s] : llvm::enumerate(packStructType.getBody())) {
+ auto structType = s.cast<LLVM::LLVMStructType>();
+ Value structValue = b.create<LLVM::ExtractValueOp>(packStruct, idx);
+ for (unsigned i = 0; i < structType.getBody().size(); ++i) {
+ structValue = b.create<LLVM::InsertValueOp>(
+ structType, structValue, zero, ArrayRef<int64_t>({i}));
}
+ innerStructs.push_back(structValue);
}
- rewriter.replaceOp(op, structValue);
+ // Pack the inner structs into a single struct
+ for (auto [idx, matrix] : llvm::enumerate(innerStructs)) {
+ packStruct = b.create<LLVM::InsertValueOp>(packStruct.getType(),
+ packStruct, matrix, idx);
+ }
+ rewriter.replaceOp(op, packStruct);
return success();
}
};
More information about the Mlir-commits
mailing list