[Mlir-commits] [mlir] b762bbd - [MLIR] change NVVM.mma.sync to the most useful variant.

Tim Shen llvmlistbot at llvm.org
Tue Feb 18 17:58:15 PST 2020


Author: Tim Shen
Date: 2020-02-18T17:57:04-08:00
New Revision: b762bbd4c86806095a11dbe4d594059bd3fd5bc5

URL: https://github.com/llvm/llvm-project/commit/b762bbd4c86806095a11dbe4d594059bd3fd5bc5
DIFF: https://github.com/llvm/llvm-project/commit/b762bbd4c86806095a11dbe4d594059bd3fd5bc5.diff

LOG: [MLIR] change NVVM.mma.sync to the most useful variant.

Summary:
the .row.col variant turns out to be the popular one, contrary to what I
thought as .row.row. Since .row.col is so prevailing (as I inspect
cuDNN's behavior), I'm going to remove the .row.row support here, which
makes the patch a little bit easier.

Reviewers: ftynse

Subscribers: jholewinski, bixia, sanjoy.google, mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, nicolasvasilache, arpith-jacob, mgester, lucyrfox, liufengdb, Joonsoo, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D74655

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
    mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
    mlir/test/Dialect/LLVMIR/invalid.mlir
    mlir/test/Dialect/LLVMIR/nvvm.mlir
    mlir/test/Target/nvvmir.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 0b18ef75897f..c875d1cac7dd 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -125,7 +125,7 @@ def NVVM_MmaOp :
   Arguments<(ins Variadic<LLVM_Type>:$args)> {
   string llvmBuilder = [{
     $res = createIntrinsicCall(
-        builder, llvm::Intrinsic::nvvm_mma_m8n8k4_row_row_f32_f32, $args);
+        builder, llvm::Intrinsic::nvvm_mma_m8n8k4_row_col_f32_f32, $args);
   }];
   let assemblyFormat = "$args attr-dict `:` functional-type($args, $res)";
   let verifier = [{ return ::verify(*this); }];

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 210507a61fc3..dab441fc26ff 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -131,7 +131,7 @@ static LogicalResult verify(MmaOp op) {
                                              f32Ty, f32Ty, f32Ty, f32Ty, f32Ty,
                                              f32Ty, f32Ty, f32Ty} &&
       op.getType() == f32x8StructTy && alayout.getValue() == "row" &&
-      blayout.getValue() == "row") {
+      blayout.getValue() == "col") {
     return success();
   }
   return op.emitOpError("unimplemented mma.sync variant");

diff  --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index 8d44c8487436..eaee62ef8324 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -295,7 +295,7 @@ func @nvvm_invalid_mma_0(%a0 : !llvm.half, %a1 : !llvm<"<2 x half>">,
                          %c0 : !llvm.float, %c1 : !llvm.float, %c2 : !llvm.float, %c3 : !llvm.float,
                          %c4 : !llvm.float, %c5 : !llvm.float, %c6 : !llvm.float, %c7 : !llvm.float) {
   // expected-error at +1 {{expected operands to be 4 <halfx2>s followed by either 4 <halfx2>s or 8 floats}}
-  %0 = nvvm.mma.sync %a0, %a1, %b0, %b1, %c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7 {alayout="row", blayout="row"} : (!llvm.half, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float) -> !llvm<"{ float, float, float, float, float, float, float, float }">
+  %0 = nvvm.mma.sync %a0, %a1, %b0, %b1, %c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7 {alayout="row", blayout="col"} : (!llvm.half, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float) -> !llvm<"{ float, float, float, float, float, float, float, float }">
   llvm.return %0 : !llvm<"{ float, float, float, float, float, float, float, float }">
 }
 
@@ -307,7 +307,7 @@ func @nvvm_invalid_mma_1(%a0 : !llvm<"<2 x half>">, %a1 : !llvm<"<2 x half>">,
                          %c0 : !llvm.float, %c1 : !llvm.float, %c2 : !llvm.float, %c3 : !llvm.float,
                          %c4 : !llvm.float, %c5 : !llvm.float, %c6 : !llvm.float, %c7 : !llvm.float) {
   // expected-error at +1 {{expected result type to be a struct of either 4 <halfx2>s or 8 floats}}
-  %0 = nvvm.mma.sync %a0, %a1, %b0, %b1, %c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7 {alayout="row", blayout="row"} : (!llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float) -> !llvm<"{ float, float, float, float, float, float, float, half }">
+  %0 = nvvm.mma.sync %a0, %a1, %b0, %b1, %c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7 {alayout="row", blayout="col"} : (!llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float) -> !llvm<"{ float, float, float, float, float, float, float, half }">
   llvm.return %0 : !llvm<"{ float, float, float, float, float, float, float, half }">
 }
 
@@ -331,7 +331,7 @@ func @nvvm_invalid_mma_3(%a0 : !llvm<"<2 x half>">, %a1 : !llvm<"<2 x half>">,
                          %c0 : !llvm<"<2 x half>">, %c1 : !llvm<"<2 x half>">,
                          %c2 : !llvm<"<2 x half>">, %c3 : !llvm<"<2 x half>">) {
   // expected-error at +1 {{unimplemented mma.sync variant}}
-  %0 = nvvm.mma.sync %a0, %a1, %b0, %b1, %c0, %c1, %c2, %c3 {alayout="row", blayout="row"} : (!llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">) -> !llvm<"{ float, float, float, float, float, float, float, float }">
+  %0 = nvvm.mma.sync %a0, %a1, %b0, %b1, %c0, %c1, %c2, %c3 {alayout="row", blayout="col"} : (!llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">) -> !llvm<"{ float, float, float, float, float, float, float, float }">
   llvm.return %0 : !llvm<"{ float, float, float, float, float, float, float, float }">
 }
 
@@ -343,7 +343,7 @@ func @nvvm_invalid_mma_4(%a0 : !llvm<"<2 x half>">, %a1 : !llvm<"<2 x half>">,
                          %c0 : !llvm.float, %c1 : !llvm.float, %c2 : !llvm.float, %c3 : !llvm.float,
                          %c4 : !llvm.float, %c5 : !llvm.float, %c6 : !llvm.float, %c7 : !llvm.float) {
   // expected-error at +1 {{unimplemented mma.sync variant}}
-  %0 = nvvm.mma.sync %a0, %a1, %b0, %b1, %c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7 {alayout="row", blayout="row"} : (!llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float) -> !llvm<"{<2 x half>, <2 x half>, <2 x half>, <2 x half>}">
+  %0 = nvvm.mma.sync %a0, %a1, %b0, %b1, %c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7 {alayout="row", blayout="col"} : (!llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float) -> !llvm<"{<2 x half>, <2 x half>, <2 x half>, <2 x half>}">
   llvm.return %0 : !llvm<"{<2 x half>, <2 x half>, <2 x half>, <2 x half>}">
 }
 

diff  --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir
index 3858ace530d5..a55b35907db0 100644
--- a/mlir/test/Dialect/LLVMIR/nvvm.mlir
+++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir
@@ -64,7 +64,7 @@ func @nvvm_mma(%a0 : !llvm<"<2 x half>">, %a1 : !llvm<"<2 x half>">,
                %b0 : !llvm<"<2 x half>">, %b1 : !llvm<"<2 x half>">,
                %c0 : !llvm.float, %c1 : !llvm.float, %c2 : !llvm.float, %c3 : !llvm.float,
                %c4 : !llvm.float, %c5 : !llvm.float, %c6 : !llvm.float, %c7 : !llvm.float) {
-  // CHECK: nvvm.mma.sync {{.*}} {alayout = "row", blayout = "row"} : (!llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float) -> !llvm<"{ float, float, float, float, float, float, float, float }">
-  %0 = nvvm.mma.sync %a0, %a1, %b0, %b1, %c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7 {alayout="row", blayout="row"} : (!llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float) -> !llvm<"{ float, float, float, float, float, float, float, float }">
+  // CHECK: nvvm.mma.sync {{.*}} {alayout = "row", blayout = "col"} : (!llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float) -> !llvm<"{ float, float, float, float, float, float, float, float }">
+  %0 = nvvm.mma.sync %a0, %a1, %b0, %b1, %c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7 {alayout="row", blayout="col"} : (!llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float) -> !llvm<"{ float, float, float, float, float, float, float, float }">
   llvm.return %0 : !llvm<"{ float, float, float, float, float, float, float, float }">
 }

diff  --git a/mlir/test/Target/nvvmir.mlir b/mlir/test/Target/nvvmir.mlir
index 2e63ecd68bc3..7e8cfb6c0a38 100644
--- a/mlir/test/Target/nvvmir.mlir
+++ b/mlir/test/Target/nvvmir.mlir
@@ -68,8 +68,8 @@ llvm.func @nvvm_mma(%a0 : !llvm<"<2 x half>">, %a1 : !llvm<"<2 x half>">,
                     %b0 : !llvm<"<2 x half>">, %b1 : !llvm<"<2 x half>">,
                     %c0 : !llvm.float, %c1 : !llvm.float, %c2 : !llvm.float, %c3 : !llvm.float,
                     %c4 : !llvm.float, %c5 : !llvm.float, %c6 : !llvm.float, %c7 : !llvm.float) {
-  // CHECK: call { float, float, float, float, float, float, float, float } @llvm.nvvm.mma.m8n8k4.row.row.f32.f32
-  %0 = nvvm.mma.sync %a0, %a1, %b0, %b1, %c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7 {alayout="row", blayout="row"} : (!llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float) -> !llvm<"{ float, float, float, float, float, float, float, float }">
+  // CHECK: call { float, float, float, float, float, float, float, float } @llvm.nvvm.mma.m8n8k4.row.col.f32.f32
+  %0 = nvvm.mma.sync %a0, %a1, %b0, %b1, %c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7 {alayout="row", blayout="col"} : (!llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float) -> !llvm<"{ float, float, float, float, float, float, float, float }">
   llvm.return %0 : !llvm<"{ float, float, float, float, float, float, float, float }">
 }
 


        


More information about the Mlir-commits mailing list