[Mlir-commits] [mlir] [mlir][nvgpu] Fix `transposeB` in `nvgpu.warpgroup.mma` (PR #79271)

Guray Ozen llvmlistbot at llvm.org
Wed Jan 24 03:34:27 PST 2024


https://github.com/grypp updated https://github.com/llvm/llvm-project/pull/79271

>From 19865938501e42493ae8efbe11ef1ceba872c4b4 Mon Sep 17 00:00:00 2001
From: Guray Ozen <guray.ozen at gmail.com>
Date: Wed, 24 Jan 2024 11:44:28 +0100
Subject: [PATCH 1/2] [mlir][nvgpu] Fix `transposeB` in `nvgpu.warpgroup.mma`

The #76150 fixed meaning of `transposeB` in NVVM dialect which was initially implemented with opposite meaning.

This PR fixes the lowering of `nvgpu.warpgroup.mma` to NVVM dialect.

This will fix two integration tests:
gemm_f32_f16_f16_128x128x128.mlir
gemm_pred_f32_f16_f16_128x128x128.mlir
---
 mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index 43d05b872a4fbc8..5080956a4589828 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -1407,7 +1407,7 @@ struct NVGPUWarpgroupMmaOpLowering
       NVVM::WGMMAScaleOutAttr scaleOut = generateScaleOut();
       NVVM::WGMMAScaleInAttr scaleIn = generateScaleIn();
       NVVM::MMALayoutAttr layoutA = generateWgmmaLayout(op.getTransposeA());
-      NVVM::MMALayoutAttr layoutB = generateWgmmaLayout(op.getTransposeB());
+      NVVM::MMALayoutAttr layoutB = generateWgmmaLayout(!op.getTransposeB());
 
       auto overflow = NVVM::MMAIntOverflowAttr::get(
           op->getContext(), NVVM::MMAIntOverflow::wrapped);

>From 7de3c17872b0db77c12b0e5f1e46f911a213e8c0 Mon Sep 17 00:00:00 2001
From: Guray Ozen <guray.ozen at gmail.com>
Date: Wed, 24 Jan 2024 12:33:56 +0100
Subject: [PATCH 2/2] fix test

---
 mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
index 0ac7331e1f69872..e97ac73c6f0374a 100644
--- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
+++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
@@ -297,7 +297,7 @@ func.func @wgmma_f32_f16_f16(%descA : i64, %descB : i64) -> !mat64f32{
   // CHECK: %[[A2:.*]] = llvm.mlir.constant(-1 : i32) : i32
   // CHECK: %[[A3:.*]] = llvm.mlir.constant(-1 : i32) : i32
   // CHECK: %[[A4:.*]] = llvm.mlir.constant(1 : i32) : i32
-  // CHECK: %[[A5:.*]] = llvm.mlir.constant(0 : i32) : i32
+  // CHECK: %[[A5:.*]] = llvm.mlir.constant(1 : i32) : i32
   // CHECK: %[[V0:.*]] = llvm.extractvalue %[[RES]][0] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> 
   // CHECK: %[[V4:.*]] = llvm.extractvalue %[[RES]][4] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> 
   // CHECK: %[[V11:.*]] = llvm.extractvalue %[[RES]][11] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>  



More information about the Mlir-commits mailing list