[Mlir-commits] [mlir] 948862b - [mlir][nvvm] Fix the verifier of `wgmma.mma_async` wrt transposed layouts (#97538)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Jul 3 23:43:56 PDT 2024


Author: bangyu shen
Date: 2024-07-04T08:43:53+02:00
New Revision: 948862b24d209ddcf5a93845e1ce327d108761ce

URL: https://github.com/llvm/llvm-project/commit/948862b24d209ddcf5a93845e1ce327d108761ce
DIFF: https://github.com/llvm/llvm-project/commit/948862b24d209ddcf5a93845e1ce327d108761ce.diff

LOG: [mlir][nvvm] Fix the verifier of `wgmma.mma_async` wrt transposed layouts  (#97538)

the WGMMA expect layouts for A/B are row/col, the transposed version
should be col/row. when checking other datatypes cannot use transposed
layout, it should reject col-major for A and row-major for B

Added: 
    

Modified: 
    mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
    mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 036a9a15af838c..4d1896551101ed 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -878,9 +878,12 @@ LogicalResult NVVM::WgmmaMmaAsyncOp::verify() {
   }
 
   // Check transpose (only available for f16/bf16)
+  // Matrices A should be stored in row-major and B in column-major.
+  // Only f16/bf16 matrices can be stored in either column-major or row-major
+  // by setting the tranpose value(imm-trans-a,imm-trans-b) in PTX code.
   if ((typeA != WGMMATypes::f16 && typeA != WGMMATypes::bf16) &&
       (getLayoutA() == mlir::NVVM::MMALayout::col ||
-       getLayoutB() == mlir::NVVM::MMALayout::col)) {
+       getLayoutB() == mlir::NVVM::MMALayout::row)) {
     return emitOpError()
            << "given layouts layout_a = " << stringifyMMALayout(getLayoutA())
            << " and layout_b = " << stringifyMMALayout(getLayoutB())

diff  --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
index 21947c242461e7..375e2951a037cd 100644
--- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
+++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
@@ -397,19 +397,19 @@ func.func @wgmma_s32_s8_s8_satfinite(%descA : i64, %descB : i64) -> !mat16i32{
       #nvvm.shape<m = 64, n = 8, k = 32>, 
       D [<s32>, #nvvm.wgmma_scale_out<one>, <satfinite>],
       A [<s8>, #nvvm.wgmma_scale_in<one>, <row>], 
-      B [<s8>, #nvvm.wgmma_scale_in<one>, <row>]
+      B [<s8>, #nvvm.wgmma_scale_in<one>, <col>]
       : !mat16i32 -> !mat16i32
   %result2 = nvvm.wgmma.mma_async %descA, %descB, %result1, 
       #nvvm.shape<m = 64, n = 8, k = 32>, 
       D [<s32>, #nvvm.wgmma_scale_out<one>, <satfinite>],
       A [<s8>, #nvvm.wgmma_scale_in<one>, <row>], 
-      B [<s8>, #nvvm.wgmma_scale_in<one>, <row>]
+      B [<s8>, #nvvm.wgmma_scale_in<one>, <col>]
       : !mat16i32 -> !mat16i32
   %result3 = nvvm.wgmma.mma_async %descA, %descB, %result2, 
       #nvvm.shape<m = 64, n = 8, k = 32>, 
       D [<s32>, #nvvm.wgmma_scale_out<one>, <satfinite>],
       A [<s8>, #nvvm.wgmma_scale_in<one>, <row>], 
-      B [<s8>, #nvvm.wgmma_scale_in<one>, <row>]
+      B [<s8>, #nvvm.wgmma_scale_in<one>, <col>]
       : !mat16i32 -> !mat16i32
   return %result3 : !mat16i32
 }
@@ -458,19 +458,19 @@ func.func @wgmma_s32_u8_u8(%descA : i64, %descB : i64) -> !mat16i32 {
       #nvvm.shape<m = 64, n = 8, k = 32>, 
       D [<s32>, #nvvm.wgmma_scale_out<one>],
       A [<u8>, #nvvm.wgmma_scale_in<one>, <row>], 
-      B [<u8>, #nvvm.wgmma_scale_in<one>, <row>]
+      B [<u8>, #nvvm.wgmma_scale_in<one>, <col>]
       : !mat16i32 -> !mat16i32
   %result2 = nvvm.wgmma.mma_async %descA, %descB, %result1,
       #nvvm.shape<m = 64, n = 8, k = 32>, 
       D [<s32>, #nvvm.wgmma_scale_out<one>],
       A [<u8>, #nvvm.wgmma_scale_in<one>, <row>], 
-      B [<u8>, #nvvm.wgmma_scale_in<one>, <row>]
+      B [<u8>, #nvvm.wgmma_scale_in<one>, <col>]
       : !mat16i32 -> !mat16i32
   %result3 = nvvm.wgmma.mma_async %descA, %descB, %result2,
       #nvvm.shape<m = 64, n = 8, k = 32>, 
       D [<s32>, #nvvm.wgmma_scale_out<one>],
       A [<u8>, #nvvm.wgmma_scale_in<one>, <row>], 
-      B [<u8>, #nvvm.wgmma_scale_in<one>, <row>]
+      B [<u8>, #nvvm.wgmma_scale_in<one>, <col>]
       : !mat16i32 -> !mat16i32
   return %result3 : !mat16i32
 }
@@ -500,13 +500,13 @@ func.func @wgmma_f32_tf32_tf32(%descA : i64, %descB : i64) -> !mat32f32 {
       #nvvm.shape<m = 64, n = 64, k = 8>, 
       D [#nvvm.wgmma_type<f32>, #nvvm.wgmma_scale_out<one>],
       A [#nvvm.wgmma_type<tf32>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>], 
-      B [#nvvm.wgmma_type<tf32>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>]
+      B [#nvvm.wgmma_type<tf32>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<col>]
        : !mat32f32 -> !mat32f32
   %result2 = nvvm.wgmma.mma_async %descA, %descB, %result1,
       #nvvm.shape<m = 64, n = 64, k = 8>, 
       D [#nvvm.wgmma_type<f32>, #nvvm.wgmma_scale_out<one>],
       A [#nvvm.wgmma_type<tf32>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>], 
-      B [#nvvm.wgmma_type<tf32>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>]
+      B [#nvvm.wgmma_type<tf32>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<col>]
       : !mat32f32 -> !mat32f32
   return %result2 : !mat32f32
 }
@@ -533,13 +533,13 @@ func.func @wgmma_f32_e4m3_e4m3(%descA : i64, %descB : i64) -> !mat32f32 {
       #nvvm.shape<m = 64, n = 64, k = 32>, 
       D [#nvvm.wgmma_type<f32>, #nvvm.wgmma_scale_out<one>],
       A [#nvvm.wgmma_type<e4m3>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>], 
-      B [#nvvm.wgmma_type<e4m3>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>]
+      B [#nvvm.wgmma_type<e4m3>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<col>]
        : !mat32f32 -> !mat32f32
   %result2 = nvvm.wgmma.mma_async %descA, %descB, %result1,
       #nvvm.shape<m = 64, n = 64, k = 32>, 
       D [#nvvm.wgmma_type<f32>, #nvvm.wgmma_scale_out<one>],
       A [#nvvm.wgmma_type<e4m3>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>], 
-      B [#nvvm.wgmma_type<e4m3>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>]
+      B [#nvvm.wgmma_type<e4m3>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<col>]
       : !mat32f32 -> !mat32f32
   return %result2 : !mat32f32
 }
@@ -565,13 +565,13 @@ func.func @wgmma_f32_e5m2_e4m3(%descA : i64, %descB : i64) -> !mat32f32 {
       #nvvm.shape<m = 64, n = 64, k = 32>, 
       D [#nvvm.wgmma_type<f32>, #nvvm.wgmma_scale_out<one>],
       A [#nvvm.wgmma_type<e5m2>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>], 
-      B [#nvvm.wgmma_type<e4m3>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>]
+      B [#nvvm.wgmma_type<e4m3>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<col>]
        : !mat32f32 -> !mat32f32
   %result2 = nvvm.wgmma.mma_async %descA, %descB, %result1,
       #nvvm.shape<m = 64, n = 64, k = 32>, 
       D [#nvvm.wgmma_type<f32>, #nvvm.wgmma_scale_out<one>],
       A [#nvvm.wgmma_type<e5m2>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>], 
-      B [#nvvm.wgmma_type<e4m3>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>]
+      B [#nvvm.wgmma_type<e4m3>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<col>]
       : !mat32f32 -> !mat32f32
   return %result2 : !mat32f32
 }


        


More information about the Mlir-commits mailing list