[Mlir-commits] [mlir] 948862b - [mlir][nvvm] Fix the verifier of `wgmma.mma_async` wrt transposed layouts (#97538)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Jul 3 23:43:56 PDT 2024
Author: bangyu shen
Date: 2024-07-04T08:43:53+02:00
New Revision: 948862b24d209ddcf5a93845e1ce327d108761ce
URL: https://github.com/llvm/llvm-project/commit/948862b24d209ddcf5a93845e1ce327d108761ce
DIFF: https://github.com/llvm/llvm-project/commit/948862b24d209ddcf5a93845e1ce327d108761ce.diff
LOG: [mlir][nvvm] Fix the verifier of `wgmma.mma_async` wrt transposed layouts (#97538)
the WGMMA expect layouts for A/B are row/col, the transposed version
should be col/row. when checking other datatypes cannot use transposed
layout, it should reject col-major for A and row-major for B
Added:
Modified:
mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 036a9a15af838c..4d1896551101ed 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -878,9 +878,12 @@ LogicalResult NVVM::WgmmaMmaAsyncOp::verify() {
}
// Check transpose (only available for f16/bf16)
+ // Matrices A should be stored in row-major and B in column-major.
+ // Only f16/bf16 matrices can be stored in either column-major or row-major
+ // by setting the tranpose value(imm-trans-a,imm-trans-b) in PTX code.
if ((typeA != WGMMATypes::f16 && typeA != WGMMATypes::bf16) &&
(getLayoutA() == mlir::NVVM::MMALayout::col ||
- getLayoutB() == mlir::NVVM::MMALayout::col)) {
+ getLayoutB() == mlir::NVVM::MMALayout::row)) {
return emitOpError()
<< "given layouts layout_a = " << stringifyMMALayout(getLayoutA())
<< " and layout_b = " << stringifyMMALayout(getLayoutB())
diff --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
index 21947c242461e7..375e2951a037cd 100644
--- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
+++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
@@ -397,19 +397,19 @@ func.func @wgmma_s32_s8_s8_satfinite(%descA : i64, %descB : i64) -> !mat16i32{
#nvvm.shape<m = 64, n = 8, k = 32>,
D [<s32>, #nvvm.wgmma_scale_out<one>, <satfinite>],
A [<s8>, #nvvm.wgmma_scale_in<one>, <row>],
- B [<s8>, #nvvm.wgmma_scale_in<one>, <row>]
+ B [<s8>, #nvvm.wgmma_scale_in<one>, <col>]
: !mat16i32 -> !mat16i32
%result2 = nvvm.wgmma.mma_async %descA, %descB, %result1,
#nvvm.shape<m = 64, n = 8, k = 32>,
D [<s32>, #nvvm.wgmma_scale_out<one>, <satfinite>],
A [<s8>, #nvvm.wgmma_scale_in<one>, <row>],
- B [<s8>, #nvvm.wgmma_scale_in<one>, <row>]
+ B [<s8>, #nvvm.wgmma_scale_in<one>, <col>]
: !mat16i32 -> !mat16i32
%result3 = nvvm.wgmma.mma_async %descA, %descB, %result2,
#nvvm.shape<m = 64, n = 8, k = 32>,
D [<s32>, #nvvm.wgmma_scale_out<one>, <satfinite>],
A [<s8>, #nvvm.wgmma_scale_in<one>, <row>],
- B [<s8>, #nvvm.wgmma_scale_in<one>, <row>]
+ B [<s8>, #nvvm.wgmma_scale_in<one>, <col>]
: !mat16i32 -> !mat16i32
return %result3 : !mat16i32
}
@@ -458,19 +458,19 @@ func.func @wgmma_s32_u8_u8(%descA : i64, %descB : i64) -> !mat16i32 {
#nvvm.shape<m = 64, n = 8, k = 32>,
D [<s32>, #nvvm.wgmma_scale_out<one>],
A [<u8>, #nvvm.wgmma_scale_in<one>, <row>],
- B [<u8>, #nvvm.wgmma_scale_in<one>, <row>]
+ B [<u8>, #nvvm.wgmma_scale_in<one>, <col>]
: !mat16i32 -> !mat16i32
%result2 = nvvm.wgmma.mma_async %descA, %descB, %result1,
#nvvm.shape<m = 64, n = 8, k = 32>,
D [<s32>, #nvvm.wgmma_scale_out<one>],
A [<u8>, #nvvm.wgmma_scale_in<one>, <row>],
- B [<u8>, #nvvm.wgmma_scale_in<one>, <row>]
+ B [<u8>, #nvvm.wgmma_scale_in<one>, <col>]
: !mat16i32 -> !mat16i32
%result3 = nvvm.wgmma.mma_async %descA, %descB, %result2,
#nvvm.shape<m = 64, n = 8, k = 32>,
D [<s32>, #nvvm.wgmma_scale_out<one>],
A [<u8>, #nvvm.wgmma_scale_in<one>, <row>],
- B [<u8>, #nvvm.wgmma_scale_in<one>, <row>]
+ B [<u8>, #nvvm.wgmma_scale_in<one>, <col>]
: !mat16i32 -> !mat16i32
return %result3 : !mat16i32
}
@@ -500,13 +500,13 @@ func.func @wgmma_f32_tf32_tf32(%descA : i64, %descB : i64) -> !mat32f32 {
#nvvm.shape<m = 64, n = 64, k = 8>,
D [#nvvm.wgmma_type<f32>, #nvvm.wgmma_scale_out<one>],
A [#nvvm.wgmma_type<tf32>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>],
- B [#nvvm.wgmma_type<tf32>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>]
+ B [#nvvm.wgmma_type<tf32>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<col>]
: !mat32f32 -> !mat32f32
%result2 = nvvm.wgmma.mma_async %descA, %descB, %result1,
#nvvm.shape<m = 64, n = 64, k = 8>,
D [#nvvm.wgmma_type<f32>, #nvvm.wgmma_scale_out<one>],
A [#nvvm.wgmma_type<tf32>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>],
- B [#nvvm.wgmma_type<tf32>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>]
+ B [#nvvm.wgmma_type<tf32>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<col>]
: !mat32f32 -> !mat32f32
return %result2 : !mat32f32
}
@@ -533,13 +533,13 @@ func.func @wgmma_f32_e4m3_e4m3(%descA : i64, %descB : i64) -> !mat32f32 {
#nvvm.shape<m = 64, n = 64, k = 32>,
D [#nvvm.wgmma_type<f32>, #nvvm.wgmma_scale_out<one>],
A [#nvvm.wgmma_type<e4m3>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>],
- B [#nvvm.wgmma_type<e4m3>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>]
+ B [#nvvm.wgmma_type<e4m3>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<col>]
: !mat32f32 -> !mat32f32
%result2 = nvvm.wgmma.mma_async %descA, %descB, %result1,
#nvvm.shape<m = 64, n = 64, k = 32>,
D [#nvvm.wgmma_type<f32>, #nvvm.wgmma_scale_out<one>],
A [#nvvm.wgmma_type<e4m3>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>],
- B [#nvvm.wgmma_type<e4m3>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>]
+ B [#nvvm.wgmma_type<e4m3>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<col>]
: !mat32f32 -> !mat32f32
return %result2 : !mat32f32
}
@@ -565,13 +565,13 @@ func.func @wgmma_f32_e5m2_e4m3(%descA : i64, %descB : i64) -> !mat32f32 {
#nvvm.shape<m = 64, n = 64, k = 32>,
D [#nvvm.wgmma_type<f32>, #nvvm.wgmma_scale_out<one>],
A [#nvvm.wgmma_type<e5m2>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>],
- B [#nvvm.wgmma_type<e4m3>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>]
+ B [#nvvm.wgmma_type<e4m3>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<col>]
: !mat32f32 -> !mat32f32
%result2 = nvvm.wgmma.mma_async %descA, %descB, %result1,
#nvvm.shape<m = 64, n = 64, k = 32>,
D [#nvvm.wgmma_type<f32>, #nvvm.wgmma_scale_out<one>],
A [#nvvm.wgmma_type<e5m2>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>],
- B [#nvvm.wgmma_type<e4m3>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>]
+ B [#nvvm.wgmma_type<e4m3>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<col>]
: !mat32f32 -> !mat32f32
return %result2 : !mat32f32
}
More information about the Mlir-commits
mailing list