[Mlir-commits] [mlir] [mlir][nvgpu] Fix packing accumlator matrix (PR #69316)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Oct 17 03:34:52 PDT 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-gpu
Author: Guray Ozen (grypp)
<details>
<summary>Changes</summary>
The #<!-- -->68728 significantly simplified the accumulator matrix type, making it easier to work with the nvgpu dialect without worrying about the number of required structs, as this information is abstracted away in the nvgpu-to-nvvm transformation.
However, we forgot packing the structs after initialization, causing the accumulator matrix to hold undefined values, which is wrong. This PR addresses that.
---
Full diff: https://github.com/llvm/llvm-project/pull/69316.diff
1 Files Affected:
- (modified) mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp (+18-11)
``````````diff
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index 2d43230938526b9..91b6a25c6dfc03b 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -1576,27 +1576,34 @@ struct NVGPUWarpgroupMmaInitAccumulatorOpLowering
matchAndRewrite(nvgpu::WarpgroupMmaInitAccumulatorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
- LLVM::LLVMStructType structType =
+ LLVM::LLVMStructType packStructType =
getTypeConverter()
->convertType(op.getMatrixC().getType())
.cast<LLVM::LLVMStructType>();
- Type elemType = structType.getBody()
+ Type elemType = packStructType.getBody()
.front()
.cast<LLVM::LLVMStructType>()
.getBody()
.front();
Value zero = b.create<LLVM::ConstantOp>(elemType, b.getZeroAttr(elemType));
- Value structValue = b.create<LLVM::UndefOp>(structType);
- for (auto [idx, s] : llvm::enumerate(structType.getBody())) {
- auto innerStructType = s.cast<LLVM::LLVMStructType>();
- int ii = idx;
- Value innerStructValue = b.create<LLVM::ExtractValueOp>(structValue, ii);
- for (unsigned i = 0; i < innerStructType.getBody().size(); ++i) {
- innerStructValue = b.create<LLVM::InsertValueOp>(
- innerStructType, innerStructValue, zero, ArrayRef<int64_t>({i}));
+ Value packStruct = b.create<LLVM::UndefOp>(packStructType);
+ SmallVector<Value> innerStructs;
+ // Unpack the structs and set all values to zero
+ for (auto [idx, s] : llvm::enumerate(packStructType.getBody())) {
+ auto structType = s.cast<LLVM::LLVMStructType>();
+ Value structValue = b.create<LLVM::ExtractValueOp>(packStruct, idx);
+ for (unsigned i = 0; i < structType.getBody().size(); ++i) {
+ structValue = b.create<LLVM::InsertValueOp>(
+ structType, structValue, zero, ArrayRef<int64_t>({i}));
}
+ innerStructs.push_back(structValue);
}
- rewriter.replaceOp(op, structValue);
+ // Pack the inner structs into a single struct
+ for (auto [idx, matrix] : llvm::enumerate(innerStructs)) {
+ packStruct = b.create<LLVM::InsertValueOp>(packStruct.getType(),
+ packStruct, matrix, idx);
+ }
+ rewriter.replaceOp(op, packStruct);
return success();
}
};
``````````
</details>
https://github.com/llvm/llvm-project/pull/69316
More information about the Mlir-commits
mailing list