[Mlir-commits] [mlir] 85b2327 - [mlir][nvvm] Fix the PTX lowering of wgmma.mma_async (#76150)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Dec 22 05:46:37 PST 2023


Author: Adam Paszke
Date: 2023-12-22T14:46:34+01:00
New Revision: 85b23271928c48f87cd950b55a434fc11a212306

URL: https://github.com/llvm/llvm-project/commit/85b23271928c48f87cd950b55a434fc11a212306
DIFF: https://github.com/llvm/llvm-project/commit/85b23271928c48f87cd950b55a434fc11a212306.diff

LOG: [mlir][nvvm] Fix the PTX lowering of wgmma.mma_async (#76150)

Added: 
    

Modified: 
    mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
    mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 4f5d71e10f68c1..a4de89d928e1be 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -1003,7 +1003,7 @@ void NVVM::WgmmaMmaAsyncOp::getAsmValues(
         {makeConstantI32(rewriter, static_cast<int>(getLayoutA())),
          mlir::NVVM::PTXRegisterMod::Read});
     asmValues.push_back(
-        {makeConstantI32(rewriter, static_cast<int>(getLayoutB())),
+        {makeConstantI32(rewriter, 1 - static_cast<int>(getLayoutB())),
          mlir::NVVM::PTXRegisterMod::Read});
   }
 }

diff  --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
index 43de50f3dc8de0..74186138c3a985 100644
--- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
+++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
@@ -297,7 +297,7 @@ func.func @wgmma_f32_f16_f16(%descA : i64, %descB : i64) -> !mat64f32{
   // CHECK: %[[A2:.*]] = llvm.mlir.constant(-1 : i32) : i32
   // CHECK: %[[A3:.*]] = llvm.mlir.constant(-1 : i32) : i32
   // CHECK: %[[A4:.*]] = llvm.mlir.constant(1 : i32) : i32
-  // CHECK: %[[A5:.*]] = llvm.mlir.constant(1 : i32) : i32
+  // CHECK: %[[A5:.*]] = llvm.mlir.constant(0 : i32) : i32
   // CHECK: %[[V0:.*]] = llvm.extractvalue %[[RES]][0] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> 
   // CHECK: %[[V4:.*]] = llvm.extractvalue %[[RES]][4] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> 
   // CHECK: %[[V11:.*]] = llvm.extractvalue %[[RES]][11] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>  


        


More information about the Mlir-commits mailing list