[PATCH] D74655: [MLIR] change NVVM.mma.sync to the most useful variant.

Tim Shen via Phabricator via llvm-commits llvm-commits at lists.llvm.org
Tue Feb 18 18:03:22 PST 2020


This revision was automatically updated to reflect the committed changes.
Closed by commit rGb762bbd4c868: [MLIR] change NVVM.mma.sync to the most useful variant. (authored by timshen).

Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D74655/new/

https://reviews.llvm.org/D74655

Files:
  mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
  mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
  mlir/test/Dialect/LLVMIR/invalid.mlir
  mlir/test/Dialect/LLVMIR/nvvm.mlir
  mlir/test/Target/nvvmir.mlir


Index: mlir/test/Target/nvvmir.mlir
===================================================================
--- mlir/test/Target/nvvmir.mlir
+++ mlir/test/Target/nvvmir.mlir
@@ -68,8 +68,8 @@
                     %b0 : !llvm<"<2 x half>">, %b1 : !llvm<"<2 x half>">,
                     %c0 : !llvm.float, %c1 : !llvm.float, %c2 : !llvm.float, %c3 : !llvm.float,
                     %c4 : !llvm.float, %c5 : !llvm.float, %c6 : !llvm.float, %c7 : !llvm.float) {
-  // CHECK: call { float, float, float, float, float, float, float, float } @llvm.nvvm.mma.m8n8k4.row.row.f32.f32
-  %0 = nvvm.mma.sync %a0, %a1, %b0, %b1, %c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7 {alayout="row", blayout="row"} : (!llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float) -> !llvm<"{ float, float, float, float, float, float, float, float }">
+  // CHECK: call { float, float, float, float, float, float, float, float } @llvm.nvvm.mma.m8n8k4.row.col.f32.f32
+  %0 = nvvm.mma.sync %a0, %a1, %b0, %b1, %c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7 {alayout="row", blayout="col"} : (!llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float) -> !llvm<"{ float, float, float, float, float, float, float, float }">
   llvm.return %0 : !llvm<"{ float, float, float, float, float, float, float, float }">
 }
 
Index: mlir/test/Dialect/LLVMIR/nvvm.mlir
===================================================================
--- mlir/test/Dialect/LLVMIR/nvvm.mlir
+++ mlir/test/Dialect/LLVMIR/nvvm.mlir
@@ -64,7 +64,7 @@
                %b0 : !llvm<"<2 x half>">, %b1 : !llvm<"<2 x half>">,
                %c0 : !llvm.float, %c1 : !llvm.float, %c2 : !llvm.float, %c3 : !llvm.float,
                %c4 : !llvm.float, %c5 : !llvm.float, %c6 : !llvm.float, %c7 : !llvm.float) {
-  // CHECK: nvvm.mma.sync {{.*}} {alayout = "row", blayout = "row"} : (!llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float) -> !llvm<"{ float, float, float, float, float, float, float, float }">
-  %0 = nvvm.mma.sync %a0, %a1, %b0, %b1, %c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7 {alayout="row", blayout="row"} : (!llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float) -> !llvm<"{ float, float, float, float, float, float, float, float }">
+  // CHECK: nvvm.mma.sync {{.*}} {alayout = "row", blayout = "col"} : (!llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float) -> !llvm<"{ float, float, float, float, float, float, float, float }">
+  %0 = nvvm.mma.sync %a0, %a1, %b0, %b1, %c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7 {alayout="row", blayout="col"} : (!llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float) -> !llvm<"{ float, float, float, float, float, float, float, float }">
   llvm.return %0 : !llvm<"{ float, float, float, float, float, float, float, float }">
 }
Index: mlir/test/Dialect/LLVMIR/invalid.mlir
===================================================================
--- mlir/test/Dialect/LLVMIR/invalid.mlir
+++ mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -295,7 +295,7 @@
                          %c0 : !llvm.float, %c1 : !llvm.float, %c2 : !llvm.float, %c3 : !llvm.float,
                          %c4 : !llvm.float, %c5 : !llvm.float, %c6 : !llvm.float, %c7 : !llvm.float) {
   // expected-error at +1 {{expected operands to be 4 <halfx2>s followed by either 4 <halfx2>s or 8 floats}}
-  %0 = nvvm.mma.sync %a0, %a1, %b0, %b1, %c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7 {alayout="row", blayout="row"} : (!llvm.half, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float) -> !llvm<"{ float, float, float, float, float, float, float, float }">
+  %0 = nvvm.mma.sync %a0, %a1, %b0, %b1, %c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7 {alayout="row", blayout="col"} : (!llvm.half, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float) -> !llvm<"{ float, float, float, float, float, float, float, float }">
   llvm.return %0 : !llvm<"{ float, float, float, float, float, float, float, float }">
 }
 
@@ -307,7 +307,7 @@
                          %c0 : !llvm.float, %c1 : !llvm.float, %c2 : !llvm.float, %c3 : !llvm.float,
                          %c4 : !llvm.float, %c5 : !llvm.float, %c6 : !llvm.float, %c7 : !llvm.float) {
   // expected-error at +1 {{expected result type to be a struct of either 4 <halfx2>s or 8 floats}}
-  %0 = nvvm.mma.sync %a0, %a1, %b0, %b1, %c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7 {alayout="row", blayout="row"} : (!llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float) -> !llvm<"{ float, float, float, float, float, float, float, half }">
+  %0 = nvvm.mma.sync %a0, %a1, %b0, %b1, %c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7 {alayout="row", blayout="col"} : (!llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float) -> !llvm<"{ float, float, float, float, float, float, float, half }">
   llvm.return %0 : !llvm<"{ float, float, float, float, float, float, float, half }">
 }
 
@@ -331,7 +331,7 @@
                          %c0 : !llvm<"<2 x half>">, %c1 : !llvm<"<2 x half>">,
                          %c2 : !llvm<"<2 x half>">, %c3 : !llvm<"<2 x half>">) {
   // expected-error at +1 {{unimplemented mma.sync variant}}
-  %0 = nvvm.mma.sync %a0, %a1, %b0, %b1, %c0, %c1, %c2, %c3 {alayout="row", blayout="row"} : (!llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">) -> !llvm<"{ float, float, float, float, float, float, float, float }">
+  %0 = nvvm.mma.sync %a0, %a1, %b0, %b1, %c0, %c1, %c2, %c3 {alayout="row", blayout="col"} : (!llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">) -> !llvm<"{ float, float, float, float, float, float, float, float }">
   llvm.return %0 : !llvm<"{ float, float, float, float, float, float, float, float }">
 }
 
@@ -343,7 +343,7 @@
                          %c0 : !llvm.float, %c1 : !llvm.float, %c2 : !llvm.float, %c3 : !llvm.float,
                          %c4 : !llvm.float, %c5 : !llvm.float, %c6 : !llvm.float, %c7 : !llvm.float) {
   // expected-error at +1 {{unimplemented mma.sync variant}}
-  %0 = nvvm.mma.sync %a0, %a1, %b0, %b1, %c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7 {alayout="row", blayout="row"} : (!llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float) -> !llvm<"{<2 x half>, <2 x half>, <2 x half>, <2 x half>}">
+  %0 = nvvm.mma.sync %a0, %a1, %b0, %b1, %c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7 {alayout="row", blayout="col"} : (!llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float) -> !llvm<"{<2 x half>, <2 x half>, <2 x half>, <2 x half>}">
   llvm.return %0 : !llvm<"{<2 x half>, <2 x half>, <2 x half>, <2 x half>}">
 }
 
Index: mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
===================================================================
--- mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -131,7 +131,7 @@
                                              f32Ty, f32Ty, f32Ty, f32Ty, f32Ty,
                                              f32Ty, f32Ty, f32Ty} &&
       op.getType() == f32x8StructTy && alayout.getValue() == "row" &&
-      blayout.getValue() == "row") {
+      blayout.getValue() == "col") {
     return success();
   }
   return op.emitOpError("unimplemented mma.sync variant");
Index: mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
===================================================================
--- mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -125,7 +125,7 @@
   Arguments<(ins Variadic<LLVM_Type>:$args)> {
   string llvmBuilder = [{
     $res = createIntrinsicCall(
-        builder, llvm::Intrinsic::nvvm_mma_m8n8k4_row_row_f32_f32, $args);
+        builder, llvm::Intrinsic::nvvm_mma_m8n8k4_row_col_f32_f32, $args);
   }];
   let assemblyFormat = "$args attr-dict `:` functional-type($args, $res)";
   let verifier = [{ return ::verify(*this); }];


-------------- next part --------------
A non-text attachment was scrubbed...
Name: D74655.245310.patch
Type: text/x-patch
Size: 9351 bytes
Desc: not available
URL: <http://lists.llvm.org/pipermail/llvm-commits/attachments/20200219/43adaaa1/attachment.bin>


More information about the llvm-commits mailing list