[Mlir-commits] [mlir] [mlir][nvvm] Fix the PTX lowering of wgmma.mma_async (PR #76150)
Adam Paszke
llvmlistbot at llvm.org
Thu Dec 21 10:17:39 PST 2023
https://github.com/apaszke updated https://github.com/llvm/llvm-project/pull/76150
>From 3c222095bcbf9e80462ce67ca3157a269c4b86fa Mon Sep 17 00:00:00 2001
From: Adam Paszke <apaszke at google.com>
Date: Thu, 21 Dec 2023 13:05:49 +0000
Subject: [PATCH] [mlir][nvvm] Fix the PTX lowering of wgmma.mma_async
The default layout of A and B matrices is row- and column-major
respectively, meaning that the transpose flags have opposite meanings
between those two operands.
---
mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 2 +-
mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir | 2 +-
2 files changed, 2 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 4f5d71e10f68c1..a4de89d928e1be 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -1003,7 +1003,7 @@ void NVVM::WgmmaMmaAsyncOp::getAsmValues(
{makeConstantI32(rewriter, static_cast<int>(getLayoutA())),
mlir::NVVM::PTXRegisterMod::Read});
asmValues.push_back(
- {makeConstantI32(rewriter, static_cast<int>(getLayoutB())),
+ {makeConstantI32(rewriter, 1 - static_cast<int>(getLayoutB())),
mlir::NVVM::PTXRegisterMod::Read});
}
}
diff --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
index 43de50f3dc8de0..74186138c3a985 100644
--- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
+++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
@@ -297,7 +297,7 @@ func.func @wgmma_f32_f16_f16(%descA : i64, %descB : i64) -> !mat64f32{
// CHECK: %[[A2:.*]] = llvm.mlir.constant(-1 : i32) : i32
// CHECK: %[[A3:.*]] = llvm.mlir.constant(-1 : i32) : i32
// CHECK: %[[A4:.*]] = llvm.mlir.constant(1 : i32) : i32
- // CHECK: %[[A5:.*]] = llvm.mlir.constant(1 : i32) : i32
+ // CHECK: %[[A5:.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: %[[V0:.*]] = llvm.extractvalue %[[RES]][0] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
// CHECK: %[[V4:.*]] = llvm.extractvalue %[[RES]][4] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
// CHECK: %[[V11:.*]] = llvm.extractvalue %[[RES]][11] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
More information about the Mlir-commits
mailing list