[Mlir-commits] [mlir] [mlir][nvgpu] Fix packing accumlator matrix (PR #69316)

Guray Ozen llvmlistbot at llvm.org
Tue Oct 17 03:33:39 PDT 2023


https://github.com/grypp created https://github.com/llvm/llvm-project/pull/69316

The #68728 significantly simplified the accumulator matrix type, making it easier to work with the nvgpu dialect without worrying about the number of required structs, as this information is abstracted away in the nvgpu-to-nvvm transformation.

However, we forgot packing the structs after initialization, causing the accumulator matrix to hold undefined values, which is wrong. This PR addresses that.

>From c31b40d2aae827e548c343dc782e0407e401eb91 Mon Sep 17 00:00:00 2001
From: Guray Ozen <guray.ozen at gmail.com>
Date: Tue, 17 Oct 2023 12:28:58 +0200
Subject: [PATCH] [mlir][nvgpu] Fix packing accumlator matrix

#68728 simplified significantly the accumulator matrix. But we forget packing the struct after initilization. This PR fixes that.
---
 .../Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp    | 29 ++++++++++++-------
 1 file changed, 18 insertions(+), 11 deletions(-)

diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index 2d43230938526b9..91b6a25c6dfc03b 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -1576,27 +1576,34 @@ struct NVGPUWarpgroupMmaInitAccumulatorOpLowering
   matchAndRewrite(nvgpu::WarpgroupMmaInitAccumulatorOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     ImplicitLocOpBuilder b(op->getLoc(), rewriter);
-    LLVM::LLVMStructType structType =
+    LLVM::LLVMStructType packStructType =
         getTypeConverter()
             ->convertType(op.getMatrixC().getType())
             .cast<LLVM::LLVMStructType>();
-    Type elemType = structType.getBody()
+    Type elemType = packStructType.getBody()
                         .front()
                         .cast<LLVM::LLVMStructType>()
                         .getBody()
                         .front();
     Value zero = b.create<LLVM::ConstantOp>(elemType, b.getZeroAttr(elemType));
-    Value structValue = b.create<LLVM::UndefOp>(structType);
-    for (auto [idx, s] : llvm::enumerate(structType.getBody())) {
-      auto innerStructType = s.cast<LLVM::LLVMStructType>();
-      int ii = idx;
-      Value innerStructValue = b.create<LLVM::ExtractValueOp>(structValue, ii);
-      for (unsigned i = 0; i < innerStructType.getBody().size(); ++i) {
-        innerStructValue = b.create<LLVM::InsertValueOp>(
-            innerStructType, innerStructValue, zero, ArrayRef<int64_t>({i}));
+    Value packStruct = b.create<LLVM::UndefOp>(packStructType);
+    SmallVector<Value> innerStructs;
+    // Unpack the structs and set all values to zero
+    for (auto [idx, s] : llvm::enumerate(packStructType.getBody())) {
+      auto structType = s.cast<LLVM::LLVMStructType>();
+      Value structValue = b.create<LLVM::ExtractValueOp>(packStruct, idx);
+      for (unsigned i = 0; i < structType.getBody().size(); ++i) {
+        structValue = b.create<LLVM::InsertValueOp>(
+            structType, structValue, zero, ArrayRef<int64_t>({i}));
       }
+      innerStructs.push_back(structValue);
     }
-    rewriter.replaceOp(op, structValue);
+    // Pack the inner structs into a single struct
+    for (auto [idx, matrix] : llvm::enumerate(innerStructs)) {
+      packStruct = b.create<LLVM::InsertValueOp>(packStruct.getType(),
+                                                 packStruct, matrix, idx);
+    }
+    rewriter.replaceOp(op, packStruct);
     return success();
   }
 };



More information about the Mlir-commits mailing list