[Mlir-commits] [mlir] dd080c7 - [mlir][nvvm] Add NVVMToLLVM Pass

Guray Ozen llvmlistbot at llvm.org
Tue Jul 11 03:14:29 PDT 2023


Author: Guray Ozen
Date: 2023-07-11T12:14:24+02:00
New Revision: dd080c7579e0c0e1d41dafb5dbea01343d6e0dc1

URL: https://github.com/llvm/llvm-project/commit/dd080c7579e0c0e1d41dafb5dbea01343d6e0dc1
DIFF: https://github.com/llvm/llvm-project/commit/dd080c7579e0c0e1d41dafb5dbea01343d6e0dc1.diff

LOG: [mlir][nvvm] Add NVVMToLLVM Pass

It introduces an NVVMToLLVM Pass and a `BasicPtxBuilderOpInterface` interface. The Pass performs pattern matching on all the NVVM Ops that implement the BasicPtxBuilderOpInterface interface to generate LLVM Inline Assembly Ops.

The BasicPtxBuilderOpInterface interface is utilized in the convert-nvvm-to-llvm pass, which lowers Ops that support this interface to inline assembly Ops. The interface provides several methods that are used for this lowering.

The `getPtx` method returns PTX code. The `hasSideEffect` method is used to determine whether the op has any side effects on the memory. The `hasIntrinsic` method indicates whether the operation has intrinsic support in LLVM. This is particularly useful for Ops that don't have intrinsic support for each case. The `getAsmValues` method returns the arguments to be passed to the PTX code. The order of arguments starts with the results and they are used for write operations, followed by the operands and attributes.

Example:

If we have the following Op definition that returns PTX code through getPtx:
```tablegen
def NVVM_MBarrierArriveExpectTxOp : NVVM_Op<\"mbarrier.arrive.expect_tx\",
                    [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,
  Results<(outs LLVM_Type:$res)>, Arguments<(ins LLVM_i64ptr_any:$addr, I32:$txcount)> {
  ...
  let extraClassDefinition = [{
    const char* $cppClass::getPtx() { return \"mbarrier.arrive.expect_tx.b64 %0, [%1], %2;\"; }
  }\];
}
```

The NVVM Op will look like below:
```mlir
  %0 = nvvm.mbarrier.arrive.expect_tx %barrier, %txcount : !llvm.ptr, i32 -> i32
```

The `convert-nvvm-to-llvm` Pass generates the following PTX code, while keeping the order of arguments the same. The read/write modifiers are set based on the input and result types.
```mlir
  %0 = llvm.inline_asm has_side_effects asm_dialect = att "mbarrier.arrive.expect_tx.b64 %0, [%1], %2;", "=r,l,r" %arg0, %arg1 : (!llvm.ptr, i32) -> i32
```

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D154060

Added: 
    mlir/include/mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h
    mlir/lib/Conversion/NVVMToLLVM/CMakeLists.txt
    mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp
    mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir

Modified: 
    mlir/include/mlir/Conversion/Passes.h
    mlir/include/mlir/Conversion/Passes.td
    mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt
    mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h
    mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
    mlir/lib/Conversion/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h b/mlir/include/mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h
new file mode 100644
index 00000000000000..00c33dfd776548
--- /dev/null
+++ b/mlir/include/mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h
@@ -0,0 +1,24 @@
+//===- NVVMTOLLVMPass.h - Convert NVVM to LLVM dialect --------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#ifndef MLIR_CONVERSION_NVVMTOLLVM_NVVMTOLLVMPASS_H_
+#define MLIR_CONVERSION_NVVMTOLLVM_NVVMTOLLVMPASS_H_
+
+#include <memory>
+
+namespace mlir {
+
+class LLVMTypeConverter;
+class RewritePatternSet;
+class Pass;
+
+#define GEN_PASS_DECL_CONVERTNVVMTOLLVMPASS
+#include "mlir/Conversion/Passes.h.inc"
+
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_NVVMTOLLVM_NVVMTOLLVMPASS_H_

diff  --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index ae2fc8464821c2..b15a60cfd005fb 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -40,6 +40,7 @@
 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
 #include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.h"
 #include "mlir/Conversion/NVGPUToNVVM/NVGPUToNVVM.h"
+#include "mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h"
 #include "mlir/Conversion/OpenACCToSCF/ConvertOpenACCToSCF.h"
 #include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h"
 #include "mlir/Conversion/PDLToPDLInterp/PDLToPDLInterp.h"

diff  --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 287386411f935f..fd648d838d29b3 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -709,6 +709,21 @@ def ConvertMemRefToSPIRV : Pass<"convert-memref-to-spirv"> {
   ];
 }
 
+//===----------------------------------------------------------------------===//
+// NVVMToLLVM
+//===----------------------------------------------------------------------===//
+
+def ConvertNVVMToLLVMPass : Pass<"convert-nvvm-to-llvm"> {
+  let summary = "Convert NVVM dialect to LLVM dialect";
+  let description = [{
+    This pass generates inline assembly for the NVVM ops which is not 
+    implemented in LLVM core.
+  }];
+  let dependentDialects = [
+    "NVVM::NVVMDialect",
+  ];  
+}
+
 //===----------------------------------------------------------------------===//
 // NVGPUToNVVM
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt b/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt
index 8a2710464bc32a..2f65c1a3d6bcde 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt
@@ -52,6 +52,8 @@ set(LLVM_TARGET_DEFINITIONS NVVMOps.td)
 mlir_tablegen(NVVMConversions.inc -gen-llvmir-conversions)
 mlir_tablegen(NVVMOpsEnums.h.inc -gen-enum-decls)
 mlir_tablegen(NVVMOpsEnums.cpp.inc -gen-enum-defs)
+mlir_tablegen(NVVMOpsInterface.h.inc -gen-op-interface-decls)
+mlir_tablegen(NVVMOpsInterface.cpp.inc -gen-op-interface-defs)
 mlir_tablegen(NVVMOpsAttributes.h.inc -gen-attrdef-decls -attrdefs-dialect=nvvm)
 mlir_tablegen(NVVMOpsAttributes.cpp.inc -gen-attrdef-defs -attrdefs-dialect=nvvm)
 add_public_tablegen_target(MLIRNVVMConversionsIncGen)

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h
index 2376bcad418534..1644d0029380ce 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h
@@ -26,6 +26,8 @@
 namespace mlir {
 namespace NVVM {
 
+#include "mlir/Dialect/LLVMIR/NVVMOpsInterface.h.inc"
+
 /// NVVM memory space identifiers.
 enum NVVMMemorySpace {
   /// Global memory space identifier.

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 7669ff1957afe1..01294225a64d29 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -81,6 +81,128 @@ class NVVM_Attr<string attrName, string attrMnemonic, list<Trait> traits = []>
   let mnemonic = attrMnemonic;
 }
 
+//===----------------------------------------------------------------------===//
+// Basic PTX Builder Interface
+//===----------------------------------------------------------------------===//
+
+// https://docs.nvidia.com/cuda/inline-ptx-assembly/index.html#parameters
+def Read : I32EnumAttrCase<"Read", 0, "read">;
+def Write : I32EnumAttrCase<"Write", 2, "write">;
+def ReadWrite : I32EnumAttrCase<"ReadWrite", 1, "readwrite">;
+
+def PTXRegisterMod : I32EnumAttr<"PTXRegisterMod", 
+  "Register read/write modifier to build cosntraint string for PTX inline",
+  [Read, Write, ReadWrite]> {
+  let cppNamespace = "::mlir::NVVM";
+}
+
+
+def BasicPtxBuilderOpInterface : OpInterface<"BasicPtxBuilderInterface"> {
+  let description = [{
+    Interface to generate inline assembly with PTX for basic operations. 
+
+    Interface is used in `convert-nvvm-to-llvm` pass that lowers Ops supports this interface to inline assembly Op. Interface has several methods and they are used for this lowering. 
+
+    `getPtx` method returns PTX code. 
+
+    `hasSideEffect` is used to set whether the op has any side effect on the memory.
+
+    `hasIntrinsic` returns whether the operation has intrinsic support in LLVM. This is useful for the Ops that don't have intrinsic support for each case.
+
+    `getAsmValues` returns arguments to pass PTX code. The order of arguments is started from the results and they are used as write, followed by the operands and attributes.
+
+    Example:
+    
+    If we have following Op definition that returns PTX code by `getPtx`. 
+    
+    ```tablegen
+    def NVVM_MBarrierArriveExpectTxOp : NVVM_Op<\"mbarrier.arrive.expect_tx\",
+                        [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,
+      Results<(outs LLVM_Type:$res)>, Arguments<(ins LLVM_i64ptr_any:$addr, I32:$txcount)> {
+      ...
+      let extraClassDefinition = [{
+        const char* $cppClass::getPtx() { return \"mbarrier.arrive.expect_tx.b64 %0, [%1], %2;\"; }
+      }\];
+    }
+    ```
+
+    The NVVM Op will look like below:
+    ```mlir
+      %0 = nvvm.mbarrier.arrive.expect_tx %barrier, %txcount : !llvm.ptr, i32 -> i32
+    ```
+
+    The `convert-nvvm-to-llvm` Pass returns the PTX code below. The order of 
+    arguments are kept the same. The read/write modifiers are set based on the
+    input and result types.
+    ```mlir
+      %0 = llvm.inline_asm has_side_effects asm_dialect = att "mbarrier.arrive.expect_tx.b64 %0, [%1], %2;", "=r,l,r" %arg0, %arg1 : (!llvm.ptr, i32) -> i32
+    ```
+
+  }];
+  let methods = [
+    InterfaceMethod<
+        /*desc=*/[{
+          Returns whether the operation has intrinsic support in LLVM.
+        }],
+        /*retType=*/"bool",
+        /*methodName=*/"hasIntrinsic",
+        /*args=*/(ins),
+        /*methodBody=*/"",
+        /*defaultImplementation=*/"return false;"
+      >,
+    InterfaceMethod<
+        /*desc=*/[{ Return whether the operation has memory side effects. }],
+        /*retType=*/"bool",
+        /*methodName=*/"hasSideEffect",
+        /*args=*/(ins),
+        /*methodBody=*/"",
+        /*defaultImplementation=*/"return true;"
+      >,
+    InterfaceMethod<
+        /*desc=*/[{ Returns PTX code. }],
+        /*retType=*/"const char*",
+        /*methodName=*/"getPtx"
+      >,
+    InterfaceMethod<
+        /*desc=*/[{Generate constant value.}],
+        /*retType=*/"::mlir::Value",
+        /*methodName=*/"makeConstantI32",
+        /*args=*/(ins "::mlir::RewriterBase &":$rewriter, "unsigned" : $val),
+        /*methodBody=*/"",
+        /*defaultImpl=*/ [{
+            mlir::Operation* op = $_op;
+            return rewriter.create<LLVM::ConstantOp>(
+              op->getLoc(), rewriter.getIntegerType(32), val);
+        }]
+     >,
+     InterfaceMethod<
+         /*desc=*/[{ 
+            Returns arguments to pass PTX code.
+            The order of arguments is started from the results and they are 
+            used as write, followed by the operands and attributes.
+          }],
+         /*retType=*/"void",
+         /*methodName=*/"getAsmValues",
+         /*args=*/(ins "::mlir::RewriterBase &":$rewriter, 
+         "llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>>&" : $asmValues),
+         /*methodBody=*/"",
+         /*defaultImpl=*/ [{         
+           mlir::Operation* op = $_op;
+           for (auto val : op->getResults()) 
+            asmValues.push_back({val, mlir::NVVM::PTXRegisterMod::Write});
+           for (auto val : op->getOperands()) 
+            asmValues.push_back({val, mlir::NVVM::PTXRegisterMod::Read});
+           for (auto attr : op->getAttrs()) {
+            if (auto intAttr = dyn_cast<mlir::IntegerAttr>(attr.getValue())) {
+             Value val = makeConstantI32(rewriter, intAttr.getInt());
+             asmValues.push_back({val ,mlir::NVVM::PTXRegisterMod::Read});
+             }
+           }
+         }]
+       >
+  ];
+}
+
 //===----------------------------------------------------------------------===//
 // NVVM intrinsic operations
 //===----------------------------------------------------------------------===//
@@ -249,6 +371,58 @@ def NVVM_MBarrierArriveNocompleteSharedOp : NVVM_Op<"mbarrier.arrive.nocomplete.
   let assemblyFormat = "$addr `,` $count attr-dict `:` type(operands) `->` type($res)";
 }
 
+def NVVM_MBarrierArriveExpectTxOp : NVVM_Op<"mbarrier.arrive.expect_tx",
+                    [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,
+  Results<(outs LLVM_Type:$res)>,
+  Arguments<(ins LLVM_i64ptr_any:$addr, I32:$txcount)> {
+  let assemblyFormat = "$addr `,` $txcount attr-dict `:` type(operands) `->` type($res)";
+  let extraClassDefinition = [{
+    const char* $cppClass::getPtx() { return "mbarrier.arrive.expect_tx.b64 %0, [%1], %2;"; }
+  }];
+}
+
+def NVVM_MBarrierArriveExpectTxSharedOp : NVVM_Op<"mbarrier.arrive.expect_tx.shared", 
+                    [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,
+  Results<(outs LLVM_Type:$res)>,
+  Arguments<(ins LLVM_i64ptr_shared:$addr, I32:$txcount)> {    
+  let assemblyFormat = "$addr `,` $txcount attr-dict `:` type(operands) `->` type($res)";
+  let extraClassDefinition = [{
+    const char* $cppClass::getPtx() { return "mbarrier.arrive.expect_tx.shared.b64 %0, [%1], %2;"; }
+  }];
+}
+
+def NVVM_MBarrierTryWaitParityOp : NVVM_Op<"mbarrier.try_wait.parity", 
+                    [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,
+  Results<(outs LLVM_Type:$res)>,
+  Arguments<(ins LLVM_i64ptr_any:$addr, LLVM_Type:$token)> {
+  let assemblyFormat = "$addr `,` $token attr-dict `:` type(operands) `->` type($res)";
+  let extraClassDefinition = [{
+    const char* $cppClass::getPtx() {
+      return "{\n\t"
+              ".reg .pred P1; \n\t"
+              "mbarrier.try_wait.parity.b64 P1, [%1], %2; \n\t"
+              "selp.b32 %0, 1, 0, P1; \n\t"
+              "}"; 
+    }
+  }];
+}
+
+def NVVM_MBarrierTryWaitParitySharedOp : NVVM_Op<"mbarrier.try_wait.parity.shared", 
+                    [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,
+  Results<(outs LLVM_Type:$res)>,
+  Arguments<(ins LLVM_i64ptr_shared:$addr, LLVM_Type:$token)> {  
+  let assemblyFormat = "$addr `,` $token attr-dict `:` type(operands) `->` type($res)";
+  let extraClassDefinition = [{
+    const char* $cppClass::getPtx() {
+      return "{\n\t"
+              ".reg .pred P1; \n\t"
+              "mbarrier.try_wait.parity.shared.b64 P1, [%1], %2; \n\t"
+              "selp.b32 %0, 1, 0, P1; \n\t"
+              "}"; 
+    }
+  }];
+}
+
 def NVVM_MBarrierTestWaitOp : NVVM_Op<"mbarrier.test.wait">,
   Results<(outs LLVM_Type:$res)>,
   Arguments<(ins LLVM_i64ptr_any:$addr, LLVM_Type:$token)> {

diff  --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index 33efa59a6c0664..a1c58f53e59862 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -30,6 +30,7 @@ add_subdirectory(MathToSPIRV)
 add_subdirectory(MemRefToLLVM)
 add_subdirectory(MemRefToSPIRV)
 add_subdirectory(NVGPUToNVVM)
+add_subdirectory(NVVMToLLVM)
 add_subdirectory(OpenACCToSCF)
 add_subdirectory(OpenMPToLLVM)
 add_subdirectory(PDLToPDLInterp)

diff  --git a/mlir/lib/Conversion/NVVMToLLVM/CMakeLists.txt b/mlir/lib/Conversion/NVVMToLLVM/CMakeLists.txt
new file mode 100644
index 00000000000000..2afff1a4e5f169
--- /dev/null
+++ b/mlir/lib/Conversion/NVVMToLLVM/CMakeLists.txt
@@ -0,0 +1,21 @@
+add_mlir_conversion_library(MLIRNVVMToLLVM
+  NVVMToLLVM.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/NVVMToLLVM
+
+  DEPENDS
+  MLIRConversionPassIncGen
+
+  LINK_COMPONENTS
+  Core
+
+  LINK_LIBS PUBLIC
+  MLIRGPUDialect
+  MLIRLLVMCommonConversion
+  MLIRLLVMDialect
+  MLIRNVVMDialect
+  MLIRNVGPUDialect
+  MLIRPass
+  MLIRTransforms
+  )

diff  --git a/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp b/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp
new file mode 100644
index 00000000000000..4564e325b0d456
--- /dev/null
+++ b/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp
@@ -0,0 +1,187 @@
+//===- NVVMToLLVM.cpp - NVVM to LLVM dialect conversion -----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a translation NVVM ops which is not supported in LLVM
+// core.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h"
+
+#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
+#include "mlir/Conversion/LLVMCommon/Pattern.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
+#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeUtilities.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Support/LogicalResult.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/Support/raw_ostream.h"
+#include <string>
+
+#define DEBUG_TYPE "nvvm-to-llvm"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
+#define DBGSNL() (llvm::dbgs() << "\n")
+
+namespace mlir {
+#define GEN_PASS_DEF_CONVERTNVVMTOLLVMPASS
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+using namespace NVVM;
+
+#include "mlir/Dialect/LLVMIR/NVVMOpsInterface.cpp.inc"
+namespace {
+
+class PtxBuilder {
+  Operation *op;
+  PatternRewriter &rewriter;
+  const char *asmStr;
+  SmallVector<Value> asmVals;
+  std::string asmConstraints;
+  bool sideEffects;
+  bool hasResult = false;
+
+  // https://docs.nvidia.com/cuda/inline-ptx-assembly/index.html#constraints
+  char getRegisterType(Value v) {
+    if (v.getDefiningOp<LLVM::ConstantOp>())
+      return 'n';
+    if (v.getType().isInteger(16))
+      return 'h';
+    if (v.getType().isInteger(32))
+      return 'r';
+    if (v.getType().isInteger(64))
+      return 'l';
+    if (v.getType().isF32())
+      return 'f';
+    if (v.getType().isF64())
+      return 'd';
+    if (auto ptr = v.getType().dyn_cast<LLVM::LLVMPointerType>()) {
+      // Shared address spaces is addressed with 32-bit pointers.
+      if (ptr.getAddressSpace() == NVVM::kSharedMemorySpace) {
+        return 'r';
+      }
+      return 'l';
+    }
+    assert(false && "Register type is not handled yet");
+    return ' ';
+  }
+
+public:
+  PtxBuilder(Operation *op, PatternRewriter &rewriter, const char *ptxAsm,
+             bool sideEffects = false)
+      : op(op), rewriter(rewriter), asmStr(ptxAsm), sideEffects(sideEffects) {}
+
+  void insertValue(Value v, PTXRegisterMod itype = PTXRegisterMod::Read) {
+    llvm::raw_string_ostream ss(asmConstraints);
+    if (itype == PTXRegisterMod::Read) {
+      asmVals.push_back(v);
+    } else if (itype == PTXRegisterMod::ReadWrite) {
+      asmVals.push_back(v);
+      ss << "+";
+      hasResult = true;
+    } else if (itype == PTXRegisterMod::Write) {
+      ss << "=";
+      hasResult = true;
+    }
+    ss << getRegisterType(v) << ",";
+    ss.flush();
+  }
+
+  LLVM::InlineAsmOp build() {
+    auto asmDialectAttr =
+        LLVM::AsmDialectAttr::get(op->getContext(), LLVM::AsmDialect::AD_ATT);
+    Type resultType = hasResult ? op->getResult(0).getType()
+                                : LLVM::LLVMVoidType::get(op->getContext());
+
+    // Remove the last comma from the constraints string.
+    if (asmConstraints[asmConstraints.size() - 1] == ',')
+      asmConstraints.pop_back();
+
+    return rewriter.create<LLVM::InlineAsmOp>(
+        op->getLoc(), resultType,
+        /*operands=*/asmVals,
+        /*asm_string=*/asmStr,
+        /*constraints=*/asmConstraints.data(),
+        /*has_side_effects=*/sideEffects,
+        /*is_align_stack=*/false,
+        /*asm_dialect=*/asmDialectAttr,
+        /*operand_attrs=*/ArrayAttr());
+  }
+
+  void buildAndReplaceOp() {
+    LLVM::InlineAsmOp inlineAsmOp = build();
+    LLVM_DEBUG(DBGS() << "\n Generated PTX \n\t" << inlineAsmOp << "\n");
+    if (inlineAsmOp->getNumResults() == op->getNumResults())
+      rewriter.replaceOp(op, inlineAsmOp);
+    else
+      rewriter.eraseOp(op);
+  }
+};
+
+struct PtxLowering
+    : public OpInterfaceRewritePattern<NVVM::BasicPtxBuilderInterface> {
+  using OpInterfaceRewritePattern<
+      NVVM::BasicPtxBuilderInterface>::OpInterfaceRewritePattern;
+
+  PtxLowering(MLIRContext *context, PatternBenefit benefit = 2)
+      : OpInterfaceRewritePattern(context, benefit) {}
+
+  LogicalResult matchAndRewrite(NVVM::BasicPtxBuilderInterface op,
+                                PatternRewriter &rewriter) const override {
+    if (op.hasIntrinsic()) {
+      LLVM_DEBUG(DBGS() << "Ptx Builder does not lower \n\t" << op << "\n");
+      return failure();
+    }
+
+    SmallVector<std::pair<Value, PTXRegisterMod>> asmValues;
+    PtxBuilder generator(op, rewriter, op.getPtx(), op.hasSideEffect());
+
+    op.getAsmValues(rewriter, asmValues);
+    for (auto &[asmValue, modifier] : asmValues) {
+      LLVM_DEBUG(DBGSNL() << asmValue << "\t Modifier : " << modifier);
+      generator.insertValue(asmValue, modifier);
+    }
+
+    generator.buildAndReplaceOp();
+    return success();
+  }
+};
+
+struct ConvertNVVMToLLVMPass
+    : public impl::ConvertNVVMToLLVMPassBase<ConvertNVVMToLLVMPass> {
+  using Base::Base;
+
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<LLVM::LLVMDialect, NVVM::NVVMDialect>();
+  }
+
+  void runOnOperation() override {
+    ConversionTarget target(getContext());
+    target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
+    RewritePatternSet pattern(&getContext());
+    pattern.add<PtxLowering>(pattern.getContext());
+    if (failed(
+            applyPartialConversion(getOperation(), target, std::move(pattern))))
+      signalPassFailure();
+  }
+};
+
+} // namespace

diff  --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
new file mode 100644
index 00000000000000..3d863efd44e769
--- /dev/null
+++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
@@ -0,0 +1,29 @@
+// RUN: mlir-opt --convert-nvvm-to-llvm --split-input-file %s | FileCheck %s
+
+// CHECK-LABEL : @init_mbarrier_arrive_expect_tx
+llvm.func @init_mbarrier_arrive_expect_tx(%barrier : !llvm.ptr<3>, %txcount : i32) -> i32{
+  //CHECK : llvm.inline_asm has_side_effects asm_dialect = att "mbarrier.arrive.expect_tx.shared.b64 %0, [%1], %2;", "=r,r,r" %arg0, %arg1 : (!llvm.ptr<3>, i32) -> i32
+  %res = nvvm.mbarrier.arrive.expect_tx.shared %barrier, %txcount : !llvm.ptr<3>, i32 -> i32
+  llvm.return %res : i32
+}
+
+// CHECK-LABEL : @init_mbarrier_arrive_expect_tx_generic
+llvm.func @init_mbarrier_arrive_expect_tx_generic(%barrier : !llvm.ptr, %txcount : i32)-> i32 {
+  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "mbarrier.arrive.expect_tx.b64 %0, [%1], %2;", "=r,l,r" %arg0, %arg1 : (!llvm.ptr, i32) -> i32
+  %res = nvvm.mbarrier.arrive.expect_tx %barrier, %txcount : !llvm.ptr, i32 -> i32
+  llvm.return %res : i32
+}
+
+// CHECK-LABEL : @init_mbarrier_try_wait.parity.shared
+llvm.func @init_mbarrier_try_wait.parity.shared(%barrier : !llvm.ptr<3>, %token : i32) -> i32 {
+  // CHECK : llvm.inline_asm has_side_effects asm_dialect = att "{\0A\09.reg .pred P1; \0A\09mbarrier.try_wait.parity.shared.b64 P1, [%1], %2; \0A\09selp.b32 %0, 1, 0, P1; \0A\09}", "=r,r,r" %arg0, %arg1 : (!llvm.ptr<3>, i32) -> i32
+  %res = nvvm.mbarrier.try_wait.parity.shared %barrier, %token : !llvm.ptr<3>, i32 -> i32
+  llvm.return %res : i32
+}
+
+// CHECK-LABEL : @init_mbarrier_try_wait.parity
+llvm.func @init_mbarrier_try_wait.parity(%barrier : !llvm.ptr, %token : i32) -> i32{
+  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "{\0A\09.reg .pred P1; \0A\09mbarrier.try_wait.parity.b64 P1, [%1], %2; \0A\09selp.b32 %0, 1, 0, P1; \0A\09}", "=r,l,r" %arg0, %arg1 : (!llvm.ptr, i32) -> i32
+  %res = nvvm.mbarrier.try_wait.parity %barrier, %token : !llvm.ptr, i32 -> i32
+  llvm.return %res : i32
+}


        


More information about the Mlir-commits mailing list