[Mlir-commits] [mlir] 4c56494 - [mlir][nvgpu] Add NVGPU dialect (architectural specific gpu dialect)

Thomas Raoux llvmlistbot at llvm.org
Thu Apr 14 10:03:33 PDT 2022


Author: Thomas Raoux
Date: 2022-04-14T16:33:46Z
New Revision: 4c564940a14f55d2315d2676b10fea0660ea814a

URL: https://github.com/llvm/llvm-project/commit/4c564940a14f55d2315d2676b10fea0660ea814a
DIFF: https://github.com/llvm/llvm-project/commit/4c564940a14f55d2315d2676b10fea0660ea814a.diff

LOG: [mlir][nvgpu] Add NVGPU dialect (architectural specific gpu dialect)

This introduce a new dialect for vendro specific ptx operations. This
also adds the first operation ldmatrix as an example. More operations
will be added in follow up patches.
This new dialect is meant to be a bridge between GPU and Vector
dialectis and NVVM dialect.

This is based on the RFC proposed here:
https://discourse.llvm.org/t/rfc-add-nv-gpu-dialect-hw-specific-extension-of-gpu-dialect-for-nvidia-gpus/61466/8

Differential Revision: https://reviews.llvm.org/D123266

Added: 
    mlir/include/mlir/Dialect/NVGPU/CMakeLists.txt
    mlir/include/mlir/Dialect/NVGPU/NVGPU.td
    mlir/include/mlir/Dialect/NVGPU/NVGPUDialect.h
    mlir/lib/Dialect/NVGPU/CMakeLists.txt
    mlir/lib/Dialect/NVGPU/IR/CMakeLists.txt
    mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
    mlir/test/Dialect/NVGPU/roundtrip.mlir

Modified: 
    mlir/include/mlir/Dialect/CMakeLists.txt
    mlir/include/mlir/InitAllDialects.h
    mlir/lib/Dialect/CMakeLists.txt
    mlir/test/mlir-opt/commandline.mlir
    utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/CMakeLists.txt b/mlir/include/mlir/Dialect/CMakeLists.txt
index 2db29357000ce..9d51442f1980f 100644
--- a/mlir/include/mlir/Dialect/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/CMakeLists.txt
@@ -16,6 +16,7 @@ add_subdirectory(Linalg)
 add_subdirectory(LLVMIR)
 add_subdirectory(MemRef)
 add_subdirectory(MLProgram)
+add_subdirectory(NVGPU)
 add_subdirectory(OpenACC)
 add_subdirectory(OpenMP)
 add_subdirectory(PDL)

diff  --git a/mlir/include/mlir/Dialect/NVGPU/CMakeLists.txt b/mlir/include/mlir/Dialect/NVGPU/CMakeLists.txt
new file mode 100644
index 0000000000000..9901492fa2b54
--- /dev/null
+++ b/mlir/include/mlir/Dialect/NVGPU/CMakeLists.txt
@@ -0,0 +1,4 @@
+add_mlir_dialect(NVGPU nvgpu)
+add_mlir_doc(NVGPU -gen-dialect-doc NVGPU Dialects/)
+
+set(LLVM_TARGET_DEFINITIONS NVGPU.td)

diff  --git a/mlir/include/mlir/Dialect/NVGPU/NVGPU.td b/mlir/include/mlir/Dialect/NVGPU/NVGPU.td
new file mode 100644
index 0000000000000..9ed34ace009b6
--- /dev/null
+++ b/mlir/include/mlir/Dialect/NVGPU/NVGPU.td
@@ -0,0 +1,72 @@
+//===-- NVGPU.td - NVGPU dialect operation definitions *- tablegen -*------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the basic operations for the NVGPU dialect.
+//
+// This NVGPU provides a bridge between the target agnostic GPU and Vector
+// dialects and lower level NVVM dialect. This allow representing PTX specific
+// operations while using MLIR high level concepts like memref and 2-D vector.
+//
+// Ops semantic are going to be based on vendor specific PTX defintion:
+// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef NVGPU
+#define NVGPU
+
+include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/IR/OpBase.td"
+
+def NVGPU_Dialect : Dialect {
+  let name = "nvgpu";
+  let cppNamespace = "::mlir::nvgpu";
+  let description = [{
+    This `NVGPU` dialect provides a bridge between the target agnostic GPU and
+    Vector dialects and the lower level LLVM IR based NVVM dialect. This allow
+    representing PTX specific operations while using MLIR high level concepts
+    like memref and 2-D vector.
+  }];
+}
+
+//===----------------------------------------------------------------------===//
+// NVGPU Op definitions
+//===----------------------------------------------------------------------===//
+
+class NVGPU_Op<string mnemonic, list<Trait> traits = []> :
+  Op<NVGPU_Dialect, mnemonic, traits> {}
+
+def NVGPU_LdMatrixOp : NVGPU_Op<"ldmatrix",
+                                [MemoryEffects<[MemRead]>]> {
+  let description = [{
+  The `nvgpu.ldmatrix` op represents loading a matrix fragment from
+  memory. The load source and result type must be compatible with lowering
+  to the `nvvm.ldmatrix` instruction. This op is meant to represent
+  the distributed version of a `vector.transfer_read` as an intermediate
+  step between lowering from `vector.transfer_read` to `nvvm.ldmatrix`.
+
+  This operation is meant to follow the semantic of described here:
+  https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-ldmatrix
+
+  Example:
+  ```mlir
+  %0 = nvgpu.ldmatrix %sm[%c0, %c0] {numTiles = 4 : i32, transpose = false} :
+    memref<?x?xf16, 3> -> vector<4x2xf16>
+  ```
+  }];
+
+  let arguments = (ins Arg<AnyMemRef, "", [MemRead]>:$srcMemref,
+                           Variadic<Index>:$indices, BoolAttr:$transpose,
+                           I32Attr:$numTiles);
+  let results = (outs AnyVector:$res);
+  let assemblyFormat = [{
+    $srcMemref`[` $indices `]` attr-dict `:` type($srcMemref) `->` type($res)
+  }];
+}
+
+#endif // NVGPU

diff  --git a/mlir/include/mlir/Dialect/NVGPU/NVGPUDialect.h b/mlir/include/mlir/Dialect/NVGPU/NVGPUDialect.h
new file mode 100644
index 0000000000000..efa14433ccf24
--- /dev/null
+++ b/mlir/include/mlir/Dialect/NVGPU/NVGPUDialect.h
@@ -0,0 +1,26 @@
+//===- NVGPUDialect.h - MLIR Dialect for NVGPU ------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares the Target dialect for NVGPU in MLIR.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_NVGPU_NVGPUDIALECT_H_
+#define MLIR_DIALECT_NVGPU_NVGPUDIALECT_H_
+
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
+
+#include "mlir/Dialect/NVGPU/NVGPUDialect.h.inc"
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/NVGPU/NVGPU.h.inc"
+
+#endif // MLIR_DIALECT_NVGPU_NVGPUDIALECT_H_

diff  --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index d2cc62241f0e6..e43ccc173bdf5 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -36,6 +36,7 @@
 #include "mlir/Dialect/MLProgram/IR/MLProgram.h"
 #include "mlir/Dialect/Math/IR/Math.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/NVGPU/NVGPUDialect.h"
 #include "mlir/Dialect/OpenACC/OpenACC.h"
 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
 #include "mlir/Dialect/PDL/IR/PDL.h"
@@ -80,6 +81,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
                   math::MathDialect,
                   memref::MemRefDialect,
                   ml_program::MLProgramDialect,
+                  nvgpu::NVGPUDialect,
                   scf::SCFDialect,
                   omp::OpenMPDialect,
                   pdl::PDLDialect,

diff  --git a/mlir/lib/Dialect/CMakeLists.txt b/mlir/lib/Dialect/CMakeLists.txt
index df5d9d22aae15..78b513f08e78b 100644
--- a/mlir/lib/Dialect/CMakeLists.txt
+++ b/mlir/lib/Dialect/CMakeLists.txt
@@ -16,6 +16,7 @@ add_subdirectory(LLVMIR)
 add_subdirectory(Math)
 add_subdirectory(MemRef)
 add_subdirectory(MLProgram)
+add_subdirectory(NVGPU)
 add_subdirectory(OpenACC)
 add_subdirectory(OpenMP)
 add_subdirectory(PDL)

diff  --git a/mlir/lib/Dialect/NVGPU/CMakeLists.txt b/mlir/lib/Dialect/NVGPU/CMakeLists.txt
new file mode 100644
index 0000000000000..f33061b2d87cf
--- /dev/null
+++ b/mlir/lib/Dialect/NVGPU/CMakeLists.txt
@@ -0,0 +1 @@
+add_subdirectory(IR)

diff  --git a/mlir/lib/Dialect/NVGPU/IR/CMakeLists.txt b/mlir/lib/Dialect/NVGPU/IR/CMakeLists.txt
new file mode 100644
index 0000000000000..d65c2dfd1fa49
--- /dev/null
+++ b/mlir/lib/Dialect/NVGPU/IR/CMakeLists.txt
@@ -0,0 +1,13 @@
+add_mlir_dialect_library(MLIRNVGPU
+  NVGPUDialect.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/NVGPU
+
+  DEPENDS
+  MLIRNVGPUIncGen
+
+  LINK_LIBS PUBLIC
+  MLIRIR
+  MLIRSideEffectInterfaces
+  )

diff  --git a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
new file mode 100644
index 0000000000000..6c4318f4d4967
--- /dev/null
+++ b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
@@ -0,0 +1,30 @@
+//===- NVGPUDialect.cpp - MLIR NVGPU ops implementation -------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the NVGPU dialect and its operations.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/NVGPU/NVGPUDialect.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/TypeUtilities.h"
+
+using namespace mlir;
+
+#include "mlir/Dialect/NVGPU/NVGPUDialect.cpp.inc"
+
+void nvgpu::NVGPUDialect::initialize() {
+  addOperations<
+#define GET_OP_LIST
+#include "mlir/Dialect/NVGPU/NVGPU.cpp.inc"
+      >();
+}
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/NVGPU/NVGPU.cpp.inc"

diff  --git a/mlir/test/Dialect/NVGPU/roundtrip.mlir b/mlir/test/Dialect/NVGPU/roundtrip.mlir
new file mode 100644
index 0000000000000..8a52180676445
--- /dev/null
+++ b/mlir/test/Dialect/NVGPU/roundtrip.mlir
@@ -0,0 +1,10 @@
+// RUN: mlir-opt %s | mlir-opt | FileCheck %s
+
+// CHECK-LABEL: func @ldmatrix(
+func @ldmatrix(%arg0: memref<?x?xf16, 3>, %x: index, %y: index) {
+//      CHECK: nvgpu.ldmatrix %{{.*}}[%{{.*}}, %{{.*}}]
+// CHECK-SAME: {numTiles = 4 : i32, transpose = false} : memref<?x?xf16, 3> -> vector<4x2xf16>
+  %l = nvgpu.ldmatrix %arg0[%x, %y] {numTiles = 4 : i32, transpose = false} :
+    memref<?x?xf16, 3> -> vector<4x2xf16>
+  return
+}

diff  --git a/mlir/test/mlir-opt/commandline.mlir b/mlir/test/mlir-opt/commandline.mlir
index 1f606fdcc2029..371d3bb0d68f6 100644
--- a/mlir/test/mlir-opt/commandline.mlir
+++ b/mlir/test/mlir-opt/commandline.mlir
@@ -20,6 +20,7 @@
 // CHECK-NEXT: math
 // CHECK-NEXT: memref
 // CHECK-NEXT: ml_program
+// CHECK-NEXT: nvgpu
 // CHECK-NEXT: nvvm
 // CHECK-NEXT: omp
 // CHECK-NEXT: pdl

diff  --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index ee15429333664..656a089082ab3 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -1996,6 +1996,69 @@ cc_library(
     ],
 )
 
+##---------------------------------------------------------------------------##
+# NVGPU dialect.
+##---------------------------------------------------------------------------##
+
+td_library(
+    name = "NVGPUTdFiles",
+    srcs = ["include/mlir/Dialect/NVGPU/NVGPU.td"],
+    includes = ["include"],
+    deps = [
+        ":SideEffectInterfacesTdFiles",
+    ],
+)
+
+gentbl_cc_library(
+    name = "NVGPUIncGen",
+    strip_include_prefix = "include",
+    tbl_outs = [
+        (
+            [
+                "-gen-dialect-decls",
+                "-dialect=nvgpu",
+            ],
+            "include/mlir/Dialect/NVGPU/NVGPUDialect.h.inc",
+        ),
+        (
+            [
+                "-gen-dialect-defs",
+                "-dialect=nvgpu",
+            ],
+            "include/mlir/Dialect/NVGPU/NVGPUDialect.cpp.inc",
+        ),
+        (
+            ["-gen-op-decls"],
+            "include/mlir/Dialect/NVGPU/NVGPU.h.inc",
+        ),
+        (
+            ["-gen-op-defs"],
+            "include/mlir/Dialect/NVGPU/NVGPU.cpp.inc",
+        ),
+        (
+            ["-gen-op-doc"],
+            "g3doc/Dialects/NVGPU/NVGPU.md",
+        ),
+    ],
+    tblgen = ":mlir-tblgen",
+    td_file = "include/mlir/Dialect/NVGPU/NVGPU.td",
+    deps = [":NVGPUTdFiles"],
+)
+
+cc_library(
+    name = "NVGPU",
+    srcs = ["lib/Dialect/NVGPU/IR/NVGPUDialect.cpp"],
+    hdrs = ["include/mlir/Dialect/NVGPU/NVGPUDialect.h"],
+    includes = ["include"],
+    deps = [
+        ":IR",
+        ":NVGPUIncGen",
+        ":SideEffectInterfaces",
+        "//llvm:Core",
+        "//llvm:Support",
+    ],
+)
+
 td_library(
     name = "FuncTdFiles",
     srcs = [
@@ -5985,6 +6048,7 @@ cc_library(
         ":MemRefToLLVM",
         ":MemRefToSPIRV",
         ":MemRefTransforms",
+        ":NVGPU",
         ":NVVMDialect",
         ":OpenACCDialect",
         ":OpenMPDialect",


        


More information about the Mlir-commits mailing list