[Mlir-commits] [mlir] [mlir][nvvm] Move `BasicPtxBuilder` Interface to Its Own File (NFC) (PR #68095)

Guray Ozen llvmlistbot at llvm.org
Tue Oct 3 05:18:45 PDT 2023


https://github.com/grypp created https://github.com/llvm/llvm-project/pull/68095

The `BasicPtxBuilder` interface plays a crucial role in generating PTX assembly from NVVM Ops. Previously, it was situated within `NVVM.td` and `NVVMToLLVM.cpp`. For the sake of code readability, this PR moves it into its own dedicated file. Additionally, it includes comprehensive documentation for the classes and interface.

>From 1937f5746ce1cc0dc7cf79f7a84372c78a7d8237 Mon Sep 17 00:00:00 2001
From: Guray Ozen <guray.ozen at gmail.com>
Date: Tue, 3 Oct 2023 14:17:46 +0200
Subject: [PATCH] [mlir][nvvm] Move `BasicPtxBuilder` Interface to Its Own File
 (NFC)

The `BasicPtxBuilder` interface plays a crucial role in generating PTX assembly from NVVM Ops. Previously, it was situated within `NVVM.td` and `NVVMToLLVM.cpp`. For the sake of code readability, this PR moves it into its own dedicated file. Additionally, it includes comprehensive documentation for the classes and interface.
---
 mlir/include/mlir/Conversion/Passes.td        |   6 +-
 .../Dialect/LLVMIR/BasicPtxBuilderInterface.h |  84 ++++++++++
 .../LLVMIR/BasicPtxBuilderInterface.td        | 139 +++++++++++++++++
 .../mlir/Dialect/LLVMIR/CMakeLists.txt        |   8 +-
 .../include/mlir/Dialect/LLVMIR/NVVMDialect.h |   1 +
 mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td   | 130 +---------------
 mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp | 145 +----------------
 mlir/lib/Dialect/LLVMIR/CMakeLists.txt        |   2 +
 .../LLVMIR/IR/BasicPtxBuilderInterface.cpp    | 146 ++++++++++++++++++
 9 files changed, 388 insertions(+), 273 deletions(-)
 create mode 100644 mlir/include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h
 create mode 100644 mlir/include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.td
 create mode 100644 mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp

diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 38b05c792d405ad..afaeb16f4c2e27c 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -792,10 +792,10 @@ def ConvertMemRefToSPIRV : Pass<"convert-memref-to-spirv"> {
 //===----------------------------------------------------------------------===//
 
 def ConvertNVVMToLLVMPass : Pass<"convert-nvvm-to-llvm"> {
-  let summary = "Convert NVVM dialect to LLVM dialect";
+  let summary = "Convert NVVM to PTX with Inline Assembly in LLVM dialect";
   let description = [{
-    This pass generates inline assembly for the NVVM ops which is not 
-    implemented in LLVM core.
+    This pass generates PTX instructions using inline assembly for NVVM 
+    operations implements `BasicPtxBuilderInterface`.
   }];
   let dependentDialects = [
     "NVVM::NVVMDialect",
diff --git a/mlir/include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h b/mlir/include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h
new file mode 100644
index 000000000000000..677cb802c05628b
--- /dev/null
+++ b/mlir/include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h
@@ -0,0 +1,84 @@
+//===- BasicPtxBuilderInterface.td - PTX builder interface -*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Defines the interface to build PTX (Parallel Thread Execution) from NVVM Ops
+// automatically. It is used by NVVM to LLVM pass.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef NVVM_DIALECT_NVVM_IR_BASICPTXBUILDERINTERFACE_H_
+#define NVVM_DIALECT_NVVM_IR_BASICPTXBUILDERINTERFACE_H_
+
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/Value.h"
+#include <sys/types.h>
+
+#include <utility>
+
+namespace mlir {
+namespace NVVM {
+/// Register read/write modifier to build constraint string for PTX inline
+/// https://docs.nvidia.com/cuda/inline-ptx-assembly/index.html#parameters
+enum class PTXRegisterMod : u_int32_t {
+  /// Read register with no modifier
+  Read = 0,
+  /// Read register with '+' modifier
+  Write = 2,
+  /// Read register with '=' modifier.
+  /// Note that, this is not natively supported by LLVM, but it is possible to
+  /// set read and write for the same operand.
+  ReadWrite = 1,
+};
+} // namespace NVVM
+} // namespace mlir
+
+/// Include the generated interface declarations.
+#include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h.inc"
+
+namespace mlir {
+
+namespace NVVM {
+
+/// A class to build PTX assembly automatically. It is used by
+/// BasicPtxBuilderInterface.
+class PtxBuilder {
+  // The interface op that is used to build the PTX.
+  BasicPtxBuilderInterface interfaceOp;
+  // Rewriter to create new operations.
+  PatternRewriter &rewriter;
+  // The operands for the PTX instruction
+  SmallVector<Value> ptxOperands;
+  // Register constraints (read, write, readwrite) and register data types
+  std::string registerConstraints;
+
+  bool hasResult = false;
+
+public:
+  /// Single constructor that only initializes members.
+  PtxBuilder(Operation *op, PatternRewriter &rewriter)
+      : interfaceOp(op), rewriter(rewriter) {}
+
+  /// Add an operand with the read/write input type.
+  void insertValue(Value v, PTXRegisterMod itype = PTXRegisterMod::Read);
+
+  /// Builds the inline assembly Op and returns it. The `insertValue` needs to
+  /// be called to pass operands before building the PTX.
+  LLVM::InlineAsmOp build();
+
+  /// Shortcut to build the inline assembly Op and replace or erase the original
+  /// op with
+  void buildAndReplaceOp();
+};
+
+} // namespace NVVM
+} // namespace mlir
+
+#endif // NVVM_DIALECT_NVVM_IR_BASICPTXBUILDERINTERFACE_H_
diff --git a/mlir/include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.td b/mlir/include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.td
new file mode 100644
index 000000000000000..6f27c8eb47175e6
--- /dev/null
+++ b/mlir/include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.td
@@ -0,0 +1,139 @@
+//===- BasicPtxBuilderInterface.td - PTX builder interface -*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Defines the interface to build PTX (Parallel Thread Execution) from NVVM Ops 
+// automatically. It is used by NVVM to LLVM pass. 
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef BASICPTXBUILDER_OP_INTERFACE
+#define BASICPTXBUILDER_OP_INTERFACE
+
+include "mlir/IR/EnumAttr.td"
+include "mlir/Dialect/GPU/IR/CompilationAttrInterfaces.td"
+include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
+
+//===----------------------------------------------------------------------===//
+// Basic PTX Builder Interface
+//===----------------------------------------------------------------------===//
+
+def BasicPtxBuilderOpInterface : OpInterface<"BasicPtxBuilderInterface"> {
+  let description = [{
+    This interface is used to generate inline assembly with PTX for basic 
+    operations. It's utilized in the `convert-nvvm-to-llvm pass` to lower 
+    NVVM Ops that implement this interface to PTX (parallel thread execution) 
+    using inline assembly Ops. Interface methods play a crucial role in this 
+    lowering process.
+
+    Here's an example of an Op with the `BasicPtxBuilderOpInterface`:    
+    ```tablegen
+      def NVVM_SpecialOp : NVVM_Op<"special.op",
+          [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,  
+        Results<(outs LLVM_Type:$res)>,
+        Arguments<(ins LLVM_i64ptr_any:$op1, I32:$op2)> {
+        ...
+        let extraClassDefinition = [{
+          std::string $cppClass::getPtx() { 
+            return std::string("special.op %0, %1, %2;"); 
+          }
+     } ];
+    ```
+
+    In the above NVVM Op example:
+    ```mlir
+      %0 = nvvm.special.op %1, %2 : !llvm.ptr, i32 -> i32
+    ```
+
+    The `convert-nvvm-to-llvm` pass generates the inline assembly like below. 
+    The order of arguments is retained, and the read and write modifiers are 
+    set based on the input and result types:
+    ```mlir
+      %0 = llvm.inline_asm 
+                has_side_effects 
+                asm_dialect = 
+                att "special.op %0, %1, %2;", "=r,l,r" %arg0, %arg1 
+                : (!llvm.ptr, i32) -> i32
+    ```
+  }];
+  let cppNamespace = "::mlir::NVVM";
+  let methods = [
+    InterfaceMethod<
+        /*desc=*/[{ Returns PTX assembly with operand number. }],
+        /*retType=*/"std::string",
+        /*methodName=*/"getPtx"
+      >,
+    InterfaceMethod<
+        /*desc=*/[{
+          This function indicates whether the operation is supported by LLVM 
+          intrinsics. It's particularly useful for operations that have 
+          specific cases with LLVM intrinsic support.
+        }],
+        /*retType=*/"bool",
+        /*methodName=*/"hasIntrinsic",
+        /*args=*/(ins),
+        /*methodBody=*/"",
+        /*defaultImplementation=*/"return false;"
+      >,
+    InterfaceMethod<
+        /*desc=*/[{Return whether the operation has memory side effects.}],
+        /*retType=*/"bool",
+        /*methodName=*/"hasSideEffect",
+        /*args=*/(ins),
+        /*methodBody=*/"",
+        /*defaultImplementation=*/"return true;"
+      >,
+    
+    InterfaceMethod<
+        /*desc=*/[{Helper function to generate i32 constant value.}],
+        /*retType=*/"::mlir::Value",
+        /*methodName=*/"makeConstantI32",
+        /*args=*/(ins "::mlir::RewriterBase &":$rewriter, "int" : $val),
+        /*methodBody=*/"",
+        /*defaultImpl=*/ [{
+            mlir::Operation* op = $_op;
+            return rewriter.create<LLVM::ConstantOp>(
+              op->getLoc(), rewriter.getIntegerType(32), val);
+        }]
+     >,
+     InterfaceMethod<
+         /*desc=*/[{ 
+            This function supplies the necessary arguments for passing PTX code,
+            following this order:
+             1) Adds results 
+             2) Adds operands 
+             3) Adds attributes             
+          }],
+         /*retType=*/"void",
+         /*methodName=*/"getAsmValues",
+         /*args=*/(ins "::mlir::RewriterBase &":$rewriter, 
+         "llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>>&" : $asmValues),
+         /*methodBody=*/"",
+         /*defaultImpl=*/ [{         
+           mlir::Operation* op = $_op;
+           
+           // Step 1. Add results
+           for (auto val : op->getResults()) 
+            asmValues.push_back({val, mlir::NVVM::PTXRegisterMod::Write});
+
+           // Step 2. Add operands
+           for (auto val : op->getOperands()) 
+            asmValues.push_back({val, mlir::NVVM::PTXRegisterMod::Read});
+           
+           // Step 3. Add attributes
+           for (auto attr : op->getAttrs()) {
+            if (auto intAttr = dyn_cast<mlir::IntegerAttr>(attr.getValue())) {
+             ::mlir::Value val = makeConstantI32(rewriter, intAttr.getInt());
+             asmValues.push_back({val, mlir::NVVM::PTXRegisterMod::Read});
+             }
+           }
+         }]
+       >
+  ];
+}
+
+#endif // BASICPTXBUILDER_OP_INTERFACE
\ No newline at end of file
diff --git a/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt b/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt
index aaf64e63321f204..64de028c7fe4061 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt
@@ -46,14 +46,18 @@ mlir_tablegen(LLVMIntrinsicFromLLVMIRConversions.inc -gen-intr-from-llvmir-conve
 mlir_tablegen(LLVMConvertibleLLVMIRIntrinsics.inc -gen-convertible-llvmir-intrinsics)
 add_public_tablegen_target(MLIRLLVMIntrinsicConversionsIncGen)
 
+set(LLVM_TARGET_DEFINITIONS BasicPtxBuilderInterface.td)
+mlir_tablegen(BasicPtxBuilderInterface.h.inc -gen-op-interface-decls)
+mlir_tablegen(BasicPtxBuilderInterface.cpp.inc -gen-op-interface-defs)
+add_public_tablegen_target(MLIRBasicPtxBuilderInterfaceIncGen)
+add_dependencies(mlir-headers MLIRBasicPtxBuilderInterfaceIncGen)
+
 add_mlir_dialect(NVVMOps nvvm)
 add_mlir_doc(NVVMOps NVVMDialect Dialects/ -gen-dialect-doc -dialect=nvvm)
 set(LLVM_TARGET_DEFINITIONS NVVMOps.td)
 mlir_tablegen(NVVMConversions.inc -gen-llvmir-conversions)
 mlir_tablegen(NVVMOpsEnums.h.inc -gen-enum-decls)
 mlir_tablegen(NVVMOpsEnums.cpp.inc -gen-enum-defs)
-mlir_tablegen(NVVMOpsInterface.h.inc -gen-op-interface-decls)
-mlir_tablegen(NVVMOpsInterface.cpp.inc -gen-op-interface-defs)
 mlir_tablegen(NVVMOpsAttributes.h.inc -gen-attrdef-decls -attrdefs-dialect=nvvm)
 mlir_tablegen(NVVMOpsAttributes.cpp.inc -gen-attrdef-defs -attrdefs-dialect=nvvm)
 add_public_tablegen_target(MLIRNVVMConversionsIncGen)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h
index 1644d0029380cec..d5ffa64fefa2609 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h
@@ -15,6 +15,7 @@
 #define MLIR_DIALECT_LLVMIR_NVVMDIALECT_H_
 
 #include "mlir/Bytecode/BytecodeOpInterface.h"
+#include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/OpDefinition.h"
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 0d4d734edd2b69b..5e7c168d45b2b45 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -17,6 +17,7 @@ include "mlir/IR/EnumAttr.td"
 include "mlir/Dialect/GPU/IR/CompilationAttrInterfaces.td"
 include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.td"
 
 def LLVM_i8Ptr_global : LLVM_IntPtrBase<8, 1>;
 def LLVM_i8Ptr_shared : LLVM_IntPtrBase<8, 3>;
@@ -82,135 +83,6 @@ class NVVM_Attr<string attrName, string attrMnemonic, list<Trait> traits = []>
   let mnemonic = attrMnemonic;
 }
 
-//===----------------------------------------------------------------------===//
-// Basic PTX Builder Interface
-//===----------------------------------------------------------------------===//
-
-// https://docs.nvidia.com/cuda/inline-ptx-assembly/index.html#parameters
-def Read : I32EnumAttrCase<"Read", 0, "read">;
-def Write : I32EnumAttrCase<"Write", 2, "write">;
-def ReadWrite : I32EnumAttrCase<"ReadWrite", 1, "readwrite">;
-
-def PTXRegisterMod : I32EnumAttr<"PTXRegisterMod", 
-  "Register read/write modifier to build cosntraint string for PTX inline",
-  [Read, Write, ReadWrite]> {
-  let cppNamespace = "::mlir::NVVM";
-}
-
-
-def BasicPtxBuilderOpInterface : OpInterface<"BasicPtxBuilderInterface"> {
-  let description = [{
-    Interface to generate inline assembly with PTX for basic operations. 
-
-    Interface is used in `convert-nvvm-to-llvm` pass that lowers Ops supports 
-    this interface to inline assembly Op. Interface has several methods and 
-    they are used for this lowering. 
-
-    `getPtx` method returns PTX code. 
-
-    `hasSideEffect` is used to set whether the op has any side effect on the 
-    memory.
-
-    `hasIntrinsic` returns whether the operation has intrinsic support in LLVM. 
-    This is useful for the Ops that don't have intrinsic support for each case.
-
-    `getAsmValues` returns arguments to pass PTX code. The order of arguments 
-    is started from the results and they are used as write, followed by the 
-    operands and attributes.
-
-    Example:
-    If we have following Op definition that returns PTX code by `getPtx`. 
-    
-    ```tablegen
-      def NVVM_MyOp : NVVM_Op<"myop",
-          [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,  
-        Results<(outs LLVM_Type:$res)>,
-        Arguments<(ins LLVM_i64ptr_any:$op1, I32:$op2)> {
-        ...
-        let extraClassDefinition = [{
-          std::string $cppClass::getPtx() { 
-            return std::string("my.ptx.code %0, %1, %2;"); 
-          }
-      } ];
-    ```
-
-    The NVVM Op will look like below:
-    ```mlir
-      %0 = my.ptx.code %1, %2 : !llvm.ptr, i32 -> i32
-    ```
-
-    The `convert-nvvm-to-llvm` Pass generates the PTX code below. The order of 
-    arguments are kept the same. The read and write modifiers are set based on
-    the input and result types.
-    ```mlir
-      %0 = llvm.inline_asm has_side_effects asm_dialect = att "my.ptx.code %0, %1, %2;", "=r,l,r" %arg0, %arg1 : (!llvm.ptr, i32) -> i32
-    ```
-
-  }];
-  let methods = [
-    InterfaceMethod<
-        /*desc=*/[{
-          Returns whether the operation has intrinsic support in LLVM.
-        }],
-        /*retType=*/"bool",
-        /*methodName=*/"hasIntrinsic",
-        /*args=*/(ins),
-        /*methodBody=*/"",
-        /*defaultImplementation=*/"return false;"
-      >,
-    InterfaceMethod<
-        /*desc=*/[{ Return whether the operation has memory side effects. }],
-        /*retType=*/"bool",
-        /*methodName=*/"hasSideEffect",
-        /*args=*/(ins),
-        /*methodBody=*/"",
-        /*defaultImplementation=*/"return true;"
-      >,
-    InterfaceMethod<
-        /*desc=*/[{ Returns PTX code. }],
-        /*retType=*/"std::string",
-        /*methodName=*/"getPtx"
-      >,
-    InterfaceMethod<
-        /*desc=*/[{Generate constant value.}],
-        /*retType=*/"::mlir::Value",
-        /*methodName=*/"makeConstantI32",
-        /*args=*/(ins "::mlir::RewriterBase &":$rewriter, "int" : $val),
-        /*methodBody=*/"",
-        /*defaultImpl=*/ [{
-            mlir::Operation* op = $_op;
-            return rewriter.create<LLVM::ConstantOp>(
-              op->getLoc(), rewriter.getIntegerType(32), val);
-        }]
-     >,
-     InterfaceMethod<
-         /*desc=*/[{ 
-            Returns arguments to pass PTX code.
-            The order of arguments is started from the results and they are 
-            used as write, followed by the operands and attributes.
-          }],
-         /*retType=*/"void",
-         /*methodName=*/"getAsmValues",
-         /*args=*/(ins "::mlir::RewriterBase &":$rewriter, 
-         "llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>>&" : $asmValues),
-         /*methodBody=*/"",
-         /*defaultImpl=*/ [{         
-           mlir::Operation* op = $_op;
-           for (auto val : op->getResults()) 
-            asmValues.push_back({val, mlir::NVVM::PTXRegisterMod::Write});
-           for (auto val : op->getOperands()) 
-            asmValues.push_back({val, mlir::NVVM::PTXRegisterMod::Read});
-           for (auto attr : op->getAttrs()) {
-            if (auto intAttr = dyn_cast<mlir::IntegerAttr>(attr.getValue())) {
-             Value val = makeConstantI32(rewriter, intAttr.getInt());
-             asmValues.push_back({val ,mlir::NVVM::PTXRegisterMod::Read});
-             }
-           }
-         }]
-       >
-  ];
-}
-
 //===----------------------------------------------------------------------===//
 // NVVM intrinsic operations
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp b/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp
index 2d7a441e950045c..26f710cbd1d3502 100644
--- a/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp
+++ b/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp
@@ -49,149 +49,16 @@ namespace mlir {
 using namespace mlir;
 using namespace NVVM;
 
-#include "mlir/Dialect/LLVMIR/NVVMOpsInterface.cpp.inc"
 namespace {
-
-class PtxBuilder {
-  NVVM::BasicPtxBuilderInterface op;
-  PatternRewriter &rewriter;
-  std::string asmStr;
-  SmallVector<Value> asmVals;
-  std::string asmConstraints;
-  bool sideEffects;
-  bool hasResult = false;
-
-  // https://docs.nvidia.com/cuda/inline-ptx-assembly/index.html#constraints
-  char getRegisterType(Type type) {
-    if (type.isInteger(16))
-      return 'h';
-    if (type.isInteger(32))
-      return 'r';
-    if (type.isInteger(64))
-      return 'l';
-    if (type.isF32())
-      return 'f';
-    if (type.isF64())
-      return 'd';
-    if (auto ptr = type.dyn_cast<LLVM::LLVMPointerType>()) {
-      // Shared address spaces is addressed with 32-bit pointers.
-      if (ptr.getAddressSpace() == NVVM::kSharedMemorySpace) {
-        return 'r';
-      }
-      return 'l';
-    }
-    op->emitError() << "Register type could not deduced from MLIR type: "
-                    << type;
-    return ' ';
-  }
-
-  char getRegisterType(Value v) {
-    if (v.getDefiningOp<LLVM::ConstantOp>())
-      return 'n';
-    return getRegisterType(v.getType());
-  }
-
-public:
-  PtxBuilder(Operation *op, PatternRewriter &rewriter, std::string ptxAsm,
-             bool sideEffects = false)
-      : op(op), rewriter(rewriter), asmStr(std::move(ptxAsm)),
-        sideEffects(sideEffects) {}
-
-  void insertValue(Value v, PTXRegisterMod itype = PTXRegisterMod::Read) {
-    LLVM_DEBUG(DBGS() << v << "\t Modifier : " << itype << "\n");
-    auto getModifier = [&]() -> const char * {
-      if (itype == PTXRegisterMod::ReadWrite) {
-        assert(false && "Read-Write modifier is not supported. Try setting the "
-                        "same value as Write and Read seperately.");
-        return "+";
-      }
-      if (itype == PTXRegisterMod::Write) {
-        return "=";
-      }
-      return "";
-    };
-    auto addValue = [&](Value v) {
-      if (itype == PTXRegisterMod::Read) {
-        asmVals.push_back(v);
-        return;
-      }
-      if (itype == PTXRegisterMod::ReadWrite)
-        asmVals.push_back(v);
-      hasResult = true;
-    };
-
-    llvm::raw_string_ostream ss(asmConstraints);
-    // Handle Structs
-    if (auto stype = dyn_cast<LLVM::LLVMStructType>(v.getType())) {
-      if (itype == PTXRegisterMod::Write) {
-        addValue(v);
-      }
-      for (auto [idx, t] : llvm::enumerate(stype.getBody())) {
-        if (itype != PTXRegisterMod::Write) {
-          Value extractValue =
-              rewriter.create<LLVM::ExtractValueOp>(op->getLoc(), v, idx);
-          addValue(extractValue);
-        }
-        if (itype == PTXRegisterMod::ReadWrite) {
-          ss << idx << ",";
-        } else {
-          ss << getModifier() << getRegisterType(t) << ",";
-        }
-        ss.flush();
-      }
-      return;
-    }
-    // Handle Scalars
-    addValue(v);
-    ss << getModifier() << getRegisterType(v) << ",";
-    ss.flush();
-  }
-
-  LLVM::InlineAsmOp build() {
-    auto asmDialectAttr =
-        LLVM::AsmDialectAttr::get(op->getContext(), LLVM::AsmDialect::AD_ATT);
-
-    auto resultTypes = op->getResultTypes();
-
-    // Remove the last comma from the constraints string.
-    if (!asmConstraints.empty() &&
-        asmConstraints[asmConstraints.size() - 1] == ',')
-      asmConstraints.pop_back();
-
-    // asm keywords expects %, but inline assembly uses $. Replace all % with $
-    std::replace(asmStr.begin(), asmStr.end(), '%', '$');
-
-    return rewriter.create<LLVM::InlineAsmOp>(
-        op->getLoc(),
-        /*result types=*/resultTypes,
-        /*operands=*/asmVals,
-        /*asm_string=*/llvm::StringRef(asmStr),
-        /*constraints=*/asmConstraints.data(),
-        /*has_side_effects=*/sideEffects,
-        /*is_align_stack=*/false,
-        /*asm_dialect=*/asmDialectAttr,
-        /*operand_attrs=*/ArrayAttr());
-  }
-
-  void buildAndReplaceOp() {
-    LLVM::InlineAsmOp inlineAsmOp = build();
-    LLVM_DEBUG(DBGS() << "\n Generated PTX \n\t" << inlineAsmOp << "\n");
-    if (inlineAsmOp->getNumResults() == op->getNumResults())
-      rewriter.replaceOp(op, inlineAsmOp);
-    else
-      rewriter.eraseOp(op);
-  }
-};
-
 struct PtxLowering
-    : public OpInterfaceRewritePattern<NVVM::BasicPtxBuilderInterface> {
+    : public OpInterfaceRewritePattern<BasicPtxBuilderInterface> {
   using OpInterfaceRewritePattern<
-      NVVM::BasicPtxBuilderInterface>::OpInterfaceRewritePattern;
+      BasicPtxBuilderInterface>::OpInterfaceRewritePattern;
 
   PtxLowering(MLIRContext *context, PatternBenefit benefit = 2)
       : OpInterfaceRewritePattern(context, benefit) {}
 
-  LogicalResult matchAndRewrite(NVVM::BasicPtxBuilderInterface op,
+  LogicalResult matchAndRewrite(BasicPtxBuilderInterface op,
                                 PatternRewriter &rewriter) const override {
     if (op.hasIntrinsic()) {
       LLVM_DEBUG(DBGS() << "Ptx Builder does not lower \n\t" << op << "\n");
@@ -200,11 +67,11 @@ struct PtxLowering
 
     SmallVector<std::pair<Value, PTXRegisterMod>> asmValues;
     LLVM_DEBUG(DBGS() << op.getPtx() << "\n");
-    PtxBuilder generator(op, rewriter, op.getPtx(), op.hasSideEffect());
+    PtxBuilder generator(op, rewriter);
 
     op.getAsmValues(rewriter, asmValues);
     for (auto &[asmValue, modifier] : asmValues) {
-      LLVM_DEBUG(DBGSNL() << asmValue << "\t Modifier : " << modifier);
+      LLVM_DEBUG(DBGSNL() << asmValue << "\t Modifier : " << &modifier);
       generator.insertValue(asmValue, modifier);
     }
 
@@ -232,7 +99,7 @@ struct ConvertNVVMToLLVMPass
   }
 };
 
-/// Implement the interface to convert NNVM to LLVM.
+/// Implement the interface to convert NVVM to LLVM.
 struct NVVMToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
   using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface;
   void loadDependentDialects(MLIRContext *context) const final {
diff --git a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt
index 230ffec900bb984..b00259677697a50 100644
--- a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt
+++ b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt
@@ -42,6 +42,7 @@ add_mlir_dialect_library(MLIRLLVMDialect
 
 add_mlir_dialect_library(MLIRNVVMDialect
   IR/NVVMDialect.cpp
+  IR/BasicPtxBuilderInterface.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/LLVMIR
@@ -50,6 +51,7 @@ add_mlir_dialect_library(MLIRNVVMDialect
   MLIRGPUCompilationAttrInterfacesIncGen
   MLIRNVVMOpsIncGen
   MLIRNVVMConversionsIncGen
+  MLIRBasicPtxBuilderInterfaceIncGen
   intrinsics_gen
 
   LINK_COMPONENTS
diff --git a/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp b/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp
new file mode 100644
index 000000000000000..8bcd091f830e2ad
--- /dev/null
+++ b/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp
@@ -0,0 +1,146 @@
+//===- BasicPtxBuilderInterface.td - PTX builder interface -*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Defines the interface to build PTX (Parallel Thread Execution) from NVVM Ops
+// automatically. It is used by NVVM to LLVM pass.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h"
+#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
+#include "mlir/Support/LogicalResult.h"
+
+#define DEBUG_TYPE "ptx-builder"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
+#define DBGSNL() (llvm::dbgs() << "\n")
+
+//===----------------------------------------------------------------------===//
+// BasicPtxBuilderInterface
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.cpp.inc"
+
+using namespace mlir;
+using namespace NVVM;
+
+static char getRegisterType(Type type) {
+  if (type.isInteger(16))
+    return 'h';
+  if (type.isInteger(32))
+    return 'r';
+  if (type.isInteger(64))
+    return 'l';
+  if (type.isF32())
+    return 'f';
+  if (type.isF64())
+    return 'd';
+  if (auto ptr = type.dyn_cast<LLVM::LLVMPointerType>()) {
+    // Shared address spaces is addressed with 32-bit pointers.
+    if (ptr.getAddressSpace() == NVVM::kSharedMemorySpace) {
+      return 'r';
+    }
+    return 'l';
+  }
+  // register type for struct is not supported.
+  return '?';
+}
+
+static char getRegisterType(Value v) {
+  if (v.getDefiningOp<LLVM::ConstantOp>())
+    return 'n';
+  return getRegisterType(v.getType());
+}
+
+void PtxBuilder::insertValue(Value v, PTXRegisterMod itype) {
+  LLVM_DEBUG(DBGS() << v << "\t Modifier : " << &itype << "\n");
+  auto getModifier = [&]() -> const char * {
+    if (itype == PTXRegisterMod::ReadWrite) {
+      assert(false && "Read-Write modifier is not supported. Try setting the "
+                      "same value as Write and Read seperately.");
+      return "+";
+    }
+    if (itype == PTXRegisterMod::Write) {
+      return "=";
+    }
+    return "";
+  };
+  auto addValue = [&](Value v) {
+    if (itype == PTXRegisterMod::Read) {
+      ptxOperands.push_back(v);
+      return;
+    }
+    if (itype == PTXRegisterMod::ReadWrite)
+      ptxOperands.push_back(v);
+    hasResult = true;
+  };
+
+  llvm::raw_string_ostream ss(registerConstraints);
+  // Handle Structs
+  if (auto stype = dyn_cast<LLVM::LLVMStructType>(v.getType())) {
+    if (itype == PTXRegisterMod::Write) {
+      addValue(v);
+    }
+    for (auto [idx, t] : llvm::enumerate(stype.getBody())) {
+      if (itype != PTXRegisterMod::Write) {
+        Value extractValue = rewriter.create<LLVM::ExtractValueOp>(
+            interfaceOp->getLoc(), v, idx);
+        addValue(extractValue);
+      }
+      if (itype == PTXRegisterMod::ReadWrite) {
+        ss << idx << ",";
+      } else {
+        ss << getModifier() << getRegisterType(t) << ",";
+      }
+      ss.flush();
+    }
+    return;
+  }
+  // Handle Scalars
+  addValue(v);
+  ss << getModifier() << getRegisterType(v) << ",";
+  ss.flush();
+}
+
+LLVM::InlineAsmOp PtxBuilder::build() {
+  auto asmDialectAttr = LLVM::AsmDialectAttr::get(interfaceOp->getContext(),
+                                                  LLVM::AsmDialect::AD_ATT);
+
+  auto resultTypes = interfaceOp->getResultTypes();
+
+  // Remove the last comma from the constraints string.
+  if (!registerConstraints.empty() &&
+      registerConstraints[registerConstraints.size() - 1] == ',')
+    registerConstraints.pop_back();
+
+  std::string ptxInstruction = interfaceOp.getPtx();
+
+  // Tablegen doesn't accept $, so we use %, but inline assembly uses $.
+  // Replace all % with $
+  std::replace(ptxInstruction.begin(), ptxInstruction.end(), '%', '$');
+
+  return rewriter.create<LLVM::InlineAsmOp>(
+      interfaceOp->getLoc(),
+      /*result types=*/resultTypes,
+      /*operands=*/ptxOperands,
+      /*asm_string=*/llvm::StringRef(ptxInstruction),
+      /*constraints=*/registerConstraints.data(),
+      /*has_side_effects=*/interfaceOp.hasSideEffect(),
+      /*is_align_stack=*/false,
+      /*asm_dialect=*/asmDialectAttr,
+      /*operand_attrs=*/ArrayAttr());
+}
+
+void PtxBuilder::buildAndReplaceOp() {
+  LLVM::InlineAsmOp inlineAsmOp = build();
+  LLVM_DEBUG(DBGS() << "\n Generated PTX \n\t" << inlineAsmOp << "\n");
+  if (inlineAsmOp->getNumResults() == interfaceOp->getNumResults()) {
+    rewriter.replaceOp(interfaceOp, inlineAsmOp);
+  } else {
+    rewriter.eraseOp(interfaceOp);
+  }
+}



More information about the Mlir-commits mailing list