[Mlir-commits] [mlir] [mlir][nvvm] Move `BasicPtxBuilder` Interface to Its Own File (NFC) (PR #68095)
Guray Ozen
llvmlistbot at llvm.org
Tue Oct 3 05:18:45 PDT 2023
https://github.com/grypp created https://github.com/llvm/llvm-project/pull/68095
The `BasicPtxBuilder` interface plays a crucial role in generating PTX assembly from NVVM Ops. Previously, it was situated within `NVVM.td` and `NVVMToLLVM.cpp`. For the sake of code readability, this PR moves it into its own dedicated file. Additionally, it includes comprehensive documentation for the classes and interface.
>From 1937f5746ce1cc0dc7cf79f7a84372c78a7d8237 Mon Sep 17 00:00:00 2001
From: Guray Ozen <guray.ozen at gmail.com>
Date: Tue, 3 Oct 2023 14:17:46 +0200
Subject: [PATCH] [mlir][nvvm] Move `BasicPtxBuilder` Interface to Its Own File
(NFC)
The `BasicPtxBuilder` interface plays a crucial role in generating PTX assembly from NVVM Ops. Previously, it was situated within `NVVM.td` and `NVVMToLLVM.cpp`. For the sake of code readability, this PR moves it into its own dedicated file. Additionally, it includes comprehensive documentation for the classes and interface.
---
mlir/include/mlir/Conversion/Passes.td | 6 +-
.../Dialect/LLVMIR/BasicPtxBuilderInterface.h | 84 ++++++++++
.../LLVMIR/BasicPtxBuilderInterface.td | 139 +++++++++++++++++
.../mlir/Dialect/LLVMIR/CMakeLists.txt | 8 +-
.../include/mlir/Dialect/LLVMIR/NVVMDialect.h | 1 +
mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td | 130 +---------------
mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp | 145 +----------------
mlir/lib/Dialect/LLVMIR/CMakeLists.txt | 2 +
.../LLVMIR/IR/BasicPtxBuilderInterface.cpp | 146 ++++++++++++++++++
9 files changed, 388 insertions(+), 273 deletions(-)
create mode 100644 mlir/include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h
create mode 100644 mlir/include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.td
create mode 100644 mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 38b05c792d405ad..afaeb16f4c2e27c 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -792,10 +792,10 @@ def ConvertMemRefToSPIRV : Pass<"convert-memref-to-spirv"> {
//===----------------------------------------------------------------------===//
def ConvertNVVMToLLVMPass : Pass<"convert-nvvm-to-llvm"> {
- let summary = "Convert NVVM dialect to LLVM dialect";
+ let summary = "Convert NVVM to PTX with Inline Assembly in LLVM dialect";
let description = [{
- This pass generates inline assembly for the NVVM ops which is not
- implemented in LLVM core.
+ This pass generates PTX instructions using inline assembly for NVVM
+ operations implements `BasicPtxBuilderInterface`.
}];
let dependentDialects = [
"NVVM::NVVMDialect",
diff --git a/mlir/include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h b/mlir/include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h
new file mode 100644
index 000000000000000..677cb802c05628b
--- /dev/null
+++ b/mlir/include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h
@@ -0,0 +1,84 @@
+//===- BasicPtxBuilderInterface.td - PTX builder interface -*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Defines the interface to build PTX (Parallel Thread Execution) from NVVM Ops
+// automatically. It is used by NVVM to LLVM pass.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef NVVM_DIALECT_NVVM_IR_BASICPTXBUILDERINTERFACE_H_
+#define NVVM_DIALECT_NVVM_IR_BASICPTXBUILDERINTERFACE_H_
+
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/Value.h"
+#include <sys/types.h>
+
+#include <utility>
+
+namespace mlir {
+namespace NVVM {
+/// Register read/write modifier to build constraint string for PTX inline
+/// https://docs.nvidia.com/cuda/inline-ptx-assembly/index.html#parameters
+enum class PTXRegisterMod : u_int32_t {
+ /// Read register with no modifier
+ Read = 0,
+ /// Read register with '+' modifier
+ Write = 2,
+ /// Read register with '=' modifier.
+ /// Note that, this is not natively supported by LLVM, but it is possible to
+ /// set read and write for the same operand.
+ ReadWrite = 1,
+};
+} // namespace NVVM
+} // namespace mlir
+
+/// Include the generated interface declarations.
+#include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h.inc"
+
+namespace mlir {
+
+namespace NVVM {
+
+/// A class to build PTX assembly automatically. It is used by
+/// BasicPtxBuilderInterface.
+class PtxBuilder {
+ // The interface op that is used to build the PTX.
+ BasicPtxBuilderInterface interfaceOp;
+ // Rewriter to create new operations.
+ PatternRewriter &rewriter;
+ // The operands for the PTX instruction
+ SmallVector<Value> ptxOperands;
+ // Register constraints (read, write, readwrite) and register data types
+ std::string registerConstraints;
+
+ bool hasResult = false;
+
+public:
+ /// Single constructor that only initializes members.
+ PtxBuilder(Operation *op, PatternRewriter &rewriter)
+ : interfaceOp(op), rewriter(rewriter) {}
+
+ /// Add an operand with the read/write input type.
+ void insertValue(Value v, PTXRegisterMod itype = PTXRegisterMod::Read);
+
+ /// Builds the inline assembly Op and returns it. The `insertValue` needs to
+ /// be called to pass operands before building the PTX.
+ LLVM::InlineAsmOp build();
+
+ /// Shortcut to build the inline assembly Op and replace or erase the original
+ /// op with
+ void buildAndReplaceOp();
+};
+
+} // namespace NVVM
+} // namespace mlir
+
+#endif // NVVM_DIALECT_NVVM_IR_BASICPTXBUILDERINTERFACE_H_
diff --git a/mlir/include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.td b/mlir/include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.td
new file mode 100644
index 000000000000000..6f27c8eb47175e6
--- /dev/null
+++ b/mlir/include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.td
@@ -0,0 +1,139 @@
+//===- BasicPtxBuilderInterface.td - PTX builder interface -*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Defines the interface to build PTX (Parallel Thread Execution) from NVVM Ops
+// automatically. It is used by NVVM to LLVM pass.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef BASICPTXBUILDER_OP_INTERFACE
+#define BASICPTXBUILDER_OP_INTERFACE
+
+include "mlir/IR/EnumAttr.td"
+include "mlir/Dialect/GPU/IR/CompilationAttrInterfaces.td"
+include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
+
+//===----------------------------------------------------------------------===//
+// Basic PTX Builder Interface
+//===----------------------------------------------------------------------===//
+
+def BasicPtxBuilderOpInterface : OpInterface<"BasicPtxBuilderInterface"> {
+ let description = [{
+ This interface is used to generate inline assembly with PTX for basic
+ operations. It's utilized in the `convert-nvvm-to-llvm pass` to lower
+ NVVM Ops that implement this interface to PTX (parallel thread execution)
+ using inline assembly Ops. Interface methods play a crucial role in this
+ lowering process.
+
+ Here's an example of an Op with the `BasicPtxBuilderOpInterface`:
+ ```tablegen
+ def NVVM_SpecialOp : NVVM_Op<"special.op",
+ [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,
+ Results<(outs LLVM_Type:$res)>,
+ Arguments<(ins LLVM_i64ptr_any:$op1, I32:$op2)> {
+ ...
+ let extraClassDefinition = [{
+ std::string $cppClass::getPtx() {
+ return std::string("special.op %0, %1, %2;");
+ }
+ } ];
+ ```
+
+ In the above NVVM Op example:
+ ```mlir
+ %0 = nvvm.special.op %1, %2 : !llvm.ptr, i32 -> i32
+ ```
+
+ The `convert-nvvm-to-llvm` pass generates the inline assembly like below.
+ The order of arguments is retained, and the read and write modifiers are
+ set based on the input and result types:
+ ```mlir
+ %0 = llvm.inline_asm
+ has_side_effects
+ asm_dialect =
+ att "special.op %0, %1, %2;", "=r,l,r" %arg0, %arg1
+ : (!llvm.ptr, i32) -> i32
+ ```
+ }];
+ let cppNamespace = "::mlir::NVVM";
+ let methods = [
+ InterfaceMethod<
+ /*desc=*/[{ Returns PTX assembly with operand number. }],
+ /*retType=*/"std::string",
+ /*methodName=*/"getPtx"
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ This function indicates whether the operation is supported by LLVM
+ intrinsics. It's particularly useful for operations that have
+ specific cases with LLVM intrinsic support.
+ }],
+ /*retType=*/"bool",
+ /*methodName=*/"hasIntrinsic",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/"return false;"
+ >,
+ InterfaceMethod<
+ /*desc=*/[{Return whether the operation has memory side effects.}],
+ /*retType=*/"bool",
+ /*methodName=*/"hasSideEffect",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/"return true;"
+ >,
+
+ InterfaceMethod<
+ /*desc=*/[{Helper function to generate i32 constant value.}],
+ /*retType=*/"::mlir::Value",
+ /*methodName=*/"makeConstantI32",
+ /*args=*/(ins "::mlir::RewriterBase &":$rewriter, "int" : $val),
+ /*methodBody=*/"",
+ /*defaultImpl=*/ [{
+ mlir::Operation* op = $_op;
+ return rewriter.create<LLVM::ConstantOp>(
+ op->getLoc(), rewriter.getIntegerType(32), val);
+ }]
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ This function supplies the necessary arguments for passing PTX code,
+ following this order:
+ 1) Adds results
+ 2) Adds operands
+ 3) Adds attributes
+ }],
+ /*retType=*/"void",
+ /*methodName=*/"getAsmValues",
+ /*args=*/(ins "::mlir::RewriterBase &":$rewriter,
+ "llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>>&" : $asmValues),
+ /*methodBody=*/"",
+ /*defaultImpl=*/ [{
+ mlir::Operation* op = $_op;
+
+ // Step 1. Add results
+ for (auto val : op->getResults())
+ asmValues.push_back({val, mlir::NVVM::PTXRegisterMod::Write});
+
+ // Step 2. Add operands
+ for (auto val : op->getOperands())
+ asmValues.push_back({val, mlir::NVVM::PTXRegisterMod::Read});
+
+ // Step 3. Add attributes
+ for (auto attr : op->getAttrs()) {
+ if (auto intAttr = dyn_cast<mlir::IntegerAttr>(attr.getValue())) {
+ ::mlir::Value val = makeConstantI32(rewriter, intAttr.getInt());
+ asmValues.push_back({val, mlir::NVVM::PTXRegisterMod::Read});
+ }
+ }
+ }]
+ >
+ ];
+}
+
+#endif // BASICPTXBUILDER_OP_INTERFACE
\ No newline at end of file
diff --git a/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt b/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt
index aaf64e63321f204..64de028c7fe4061 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt
@@ -46,14 +46,18 @@ mlir_tablegen(LLVMIntrinsicFromLLVMIRConversions.inc -gen-intr-from-llvmir-conve
mlir_tablegen(LLVMConvertibleLLVMIRIntrinsics.inc -gen-convertible-llvmir-intrinsics)
add_public_tablegen_target(MLIRLLVMIntrinsicConversionsIncGen)
+set(LLVM_TARGET_DEFINITIONS BasicPtxBuilderInterface.td)
+mlir_tablegen(BasicPtxBuilderInterface.h.inc -gen-op-interface-decls)
+mlir_tablegen(BasicPtxBuilderInterface.cpp.inc -gen-op-interface-defs)
+add_public_tablegen_target(MLIRBasicPtxBuilderInterfaceIncGen)
+add_dependencies(mlir-headers MLIRBasicPtxBuilderInterfaceIncGen)
+
add_mlir_dialect(NVVMOps nvvm)
add_mlir_doc(NVVMOps NVVMDialect Dialects/ -gen-dialect-doc -dialect=nvvm)
set(LLVM_TARGET_DEFINITIONS NVVMOps.td)
mlir_tablegen(NVVMConversions.inc -gen-llvmir-conversions)
mlir_tablegen(NVVMOpsEnums.h.inc -gen-enum-decls)
mlir_tablegen(NVVMOpsEnums.cpp.inc -gen-enum-defs)
-mlir_tablegen(NVVMOpsInterface.h.inc -gen-op-interface-decls)
-mlir_tablegen(NVVMOpsInterface.cpp.inc -gen-op-interface-defs)
mlir_tablegen(NVVMOpsAttributes.h.inc -gen-attrdef-decls -attrdefs-dialect=nvvm)
mlir_tablegen(NVVMOpsAttributes.cpp.inc -gen-attrdef-defs -attrdefs-dialect=nvvm)
add_public_tablegen_target(MLIRNVVMConversionsIncGen)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h
index 1644d0029380cec..d5ffa64fefa2609 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h
@@ -15,6 +15,7 @@
#define MLIR_DIALECT_LLVMIR_NVVMDIALECT_H_
#include "mlir/Bytecode/BytecodeOpInterface.h"
+#include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 0d4d734edd2b69b..5e7c168d45b2b45 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -17,6 +17,7 @@ include "mlir/IR/EnumAttr.td"
include "mlir/Dialect/GPU/IR/CompilationAttrInterfaces.td"
include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.td"
def LLVM_i8Ptr_global : LLVM_IntPtrBase<8, 1>;
def LLVM_i8Ptr_shared : LLVM_IntPtrBase<8, 3>;
@@ -82,135 +83,6 @@ class NVVM_Attr<string attrName, string attrMnemonic, list<Trait> traits = []>
let mnemonic = attrMnemonic;
}
-//===----------------------------------------------------------------------===//
-// Basic PTX Builder Interface
-//===----------------------------------------------------------------------===//
-
-// https://docs.nvidia.com/cuda/inline-ptx-assembly/index.html#parameters
-def Read : I32EnumAttrCase<"Read", 0, "read">;
-def Write : I32EnumAttrCase<"Write", 2, "write">;
-def ReadWrite : I32EnumAttrCase<"ReadWrite", 1, "readwrite">;
-
-def PTXRegisterMod : I32EnumAttr<"PTXRegisterMod",
- "Register read/write modifier to build cosntraint string for PTX inline",
- [Read, Write, ReadWrite]> {
- let cppNamespace = "::mlir::NVVM";
-}
-
-
-def BasicPtxBuilderOpInterface : OpInterface<"BasicPtxBuilderInterface"> {
- let description = [{
- Interface to generate inline assembly with PTX for basic operations.
-
- Interface is used in `convert-nvvm-to-llvm` pass that lowers Ops supports
- this interface to inline assembly Op. Interface has several methods and
- they are used for this lowering.
-
- `getPtx` method returns PTX code.
-
- `hasSideEffect` is used to set whether the op has any side effect on the
- memory.
-
- `hasIntrinsic` returns whether the operation has intrinsic support in LLVM.
- This is useful for the Ops that don't have intrinsic support for each case.
-
- `getAsmValues` returns arguments to pass PTX code. The order of arguments
- is started from the results and they are used as write, followed by the
- operands and attributes.
-
- Example:
- If we have following Op definition that returns PTX code by `getPtx`.
-
- ```tablegen
- def NVVM_MyOp : NVVM_Op<"myop",
- [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,
- Results<(outs LLVM_Type:$res)>,
- Arguments<(ins LLVM_i64ptr_any:$op1, I32:$op2)> {
- ...
- let extraClassDefinition = [{
- std::string $cppClass::getPtx() {
- return std::string("my.ptx.code %0, %1, %2;");
- }
- } ];
- ```
-
- The NVVM Op will look like below:
- ```mlir
- %0 = my.ptx.code %1, %2 : !llvm.ptr, i32 -> i32
- ```
-
- The `convert-nvvm-to-llvm` Pass generates the PTX code below. The order of
- arguments are kept the same. The read and write modifiers are set based on
- the input and result types.
- ```mlir
- %0 = llvm.inline_asm has_side_effects asm_dialect = att "my.ptx.code %0, %1, %2;", "=r,l,r" %arg0, %arg1 : (!llvm.ptr, i32) -> i32
- ```
-
- }];
- let methods = [
- InterfaceMethod<
- /*desc=*/[{
- Returns whether the operation has intrinsic support in LLVM.
- }],
- /*retType=*/"bool",
- /*methodName=*/"hasIntrinsic",
- /*args=*/(ins),
- /*methodBody=*/"",
- /*defaultImplementation=*/"return false;"
- >,
- InterfaceMethod<
- /*desc=*/[{ Return whether the operation has memory side effects. }],
- /*retType=*/"bool",
- /*methodName=*/"hasSideEffect",
- /*args=*/(ins),
- /*methodBody=*/"",
- /*defaultImplementation=*/"return true;"
- >,
- InterfaceMethod<
- /*desc=*/[{ Returns PTX code. }],
- /*retType=*/"std::string",
- /*methodName=*/"getPtx"
- >,
- InterfaceMethod<
- /*desc=*/[{Generate constant value.}],
- /*retType=*/"::mlir::Value",
- /*methodName=*/"makeConstantI32",
- /*args=*/(ins "::mlir::RewriterBase &":$rewriter, "int" : $val),
- /*methodBody=*/"",
- /*defaultImpl=*/ [{
- mlir::Operation* op = $_op;
- return rewriter.create<LLVM::ConstantOp>(
- op->getLoc(), rewriter.getIntegerType(32), val);
- }]
- >,
- InterfaceMethod<
- /*desc=*/[{
- Returns arguments to pass PTX code.
- The order of arguments is started from the results and they are
- used as write, followed by the operands and attributes.
- }],
- /*retType=*/"void",
- /*methodName=*/"getAsmValues",
- /*args=*/(ins "::mlir::RewriterBase &":$rewriter,
- "llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>>&" : $asmValues),
- /*methodBody=*/"",
- /*defaultImpl=*/ [{
- mlir::Operation* op = $_op;
- for (auto val : op->getResults())
- asmValues.push_back({val, mlir::NVVM::PTXRegisterMod::Write});
- for (auto val : op->getOperands())
- asmValues.push_back({val, mlir::NVVM::PTXRegisterMod::Read});
- for (auto attr : op->getAttrs()) {
- if (auto intAttr = dyn_cast<mlir::IntegerAttr>(attr.getValue())) {
- Value val = makeConstantI32(rewriter, intAttr.getInt());
- asmValues.push_back({val ,mlir::NVVM::PTXRegisterMod::Read});
- }
- }
- }]
- >
- ];
-}
-
//===----------------------------------------------------------------------===//
// NVVM intrinsic operations
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp b/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp
index 2d7a441e950045c..26f710cbd1d3502 100644
--- a/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp
+++ b/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp
@@ -49,149 +49,16 @@ namespace mlir {
using namespace mlir;
using namespace NVVM;
-#include "mlir/Dialect/LLVMIR/NVVMOpsInterface.cpp.inc"
namespace {
-
-class PtxBuilder {
- NVVM::BasicPtxBuilderInterface op;
- PatternRewriter &rewriter;
- std::string asmStr;
- SmallVector<Value> asmVals;
- std::string asmConstraints;
- bool sideEffects;
- bool hasResult = false;
-
- // https://docs.nvidia.com/cuda/inline-ptx-assembly/index.html#constraints
- char getRegisterType(Type type) {
- if (type.isInteger(16))
- return 'h';
- if (type.isInteger(32))
- return 'r';
- if (type.isInteger(64))
- return 'l';
- if (type.isF32())
- return 'f';
- if (type.isF64())
- return 'd';
- if (auto ptr = type.dyn_cast<LLVM::LLVMPointerType>()) {
- // Shared address spaces is addressed with 32-bit pointers.
- if (ptr.getAddressSpace() == NVVM::kSharedMemorySpace) {
- return 'r';
- }
- return 'l';
- }
- op->emitError() << "Register type could not deduced from MLIR type: "
- << type;
- return ' ';
- }
-
- char getRegisterType(Value v) {
- if (v.getDefiningOp<LLVM::ConstantOp>())
- return 'n';
- return getRegisterType(v.getType());
- }
-
-public:
- PtxBuilder(Operation *op, PatternRewriter &rewriter, std::string ptxAsm,
- bool sideEffects = false)
- : op(op), rewriter(rewriter), asmStr(std::move(ptxAsm)),
- sideEffects(sideEffects) {}
-
- void insertValue(Value v, PTXRegisterMod itype = PTXRegisterMod::Read) {
- LLVM_DEBUG(DBGS() << v << "\t Modifier : " << itype << "\n");
- auto getModifier = [&]() -> const char * {
- if (itype == PTXRegisterMod::ReadWrite) {
- assert(false && "Read-Write modifier is not supported. Try setting the "
- "same value as Write and Read seperately.");
- return "+";
- }
- if (itype == PTXRegisterMod::Write) {
- return "=";
- }
- return "";
- };
- auto addValue = [&](Value v) {
- if (itype == PTXRegisterMod::Read) {
- asmVals.push_back(v);
- return;
- }
- if (itype == PTXRegisterMod::ReadWrite)
- asmVals.push_back(v);
- hasResult = true;
- };
-
- llvm::raw_string_ostream ss(asmConstraints);
- // Handle Structs
- if (auto stype = dyn_cast<LLVM::LLVMStructType>(v.getType())) {
- if (itype == PTXRegisterMod::Write) {
- addValue(v);
- }
- for (auto [idx, t] : llvm::enumerate(stype.getBody())) {
- if (itype != PTXRegisterMod::Write) {
- Value extractValue =
- rewriter.create<LLVM::ExtractValueOp>(op->getLoc(), v, idx);
- addValue(extractValue);
- }
- if (itype == PTXRegisterMod::ReadWrite) {
- ss << idx << ",";
- } else {
- ss << getModifier() << getRegisterType(t) << ",";
- }
- ss.flush();
- }
- return;
- }
- // Handle Scalars
- addValue(v);
- ss << getModifier() << getRegisterType(v) << ",";
- ss.flush();
- }
-
- LLVM::InlineAsmOp build() {
- auto asmDialectAttr =
- LLVM::AsmDialectAttr::get(op->getContext(), LLVM::AsmDialect::AD_ATT);
-
- auto resultTypes = op->getResultTypes();
-
- // Remove the last comma from the constraints string.
- if (!asmConstraints.empty() &&
- asmConstraints[asmConstraints.size() - 1] == ',')
- asmConstraints.pop_back();
-
- // asm keywords expects %, but inline assembly uses $. Replace all % with $
- std::replace(asmStr.begin(), asmStr.end(), '%', '$');
-
- return rewriter.create<LLVM::InlineAsmOp>(
- op->getLoc(),
- /*result types=*/resultTypes,
- /*operands=*/asmVals,
- /*asm_string=*/llvm::StringRef(asmStr),
- /*constraints=*/asmConstraints.data(),
- /*has_side_effects=*/sideEffects,
- /*is_align_stack=*/false,
- /*asm_dialect=*/asmDialectAttr,
- /*operand_attrs=*/ArrayAttr());
- }
-
- void buildAndReplaceOp() {
- LLVM::InlineAsmOp inlineAsmOp = build();
- LLVM_DEBUG(DBGS() << "\n Generated PTX \n\t" << inlineAsmOp << "\n");
- if (inlineAsmOp->getNumResults() == op->getNumResults())
- rewriter.replaceOp(op, inlineAsmOp);
- else
- rewriter.eraseOp(op);
- }
-};
-
struct PtxLowering
- : public OpInterfaceRewritePattern<NVVM::BasicPtxBuilderInterface> {
+ : public OpInterfaceRewritePattern<BasicPtxBuilderInterface> {
using OpInterfaceRewritePattern<
- NVVM::BasicPtxBuilderInterface>::OpInterfaceRewritePattern;
+ BasicPtxBuilderInterface>::OpInterfaceRewritePattern;
PtxLowering(MLIRContext *context, PatternBenefit benefit = 2)
: OpInterfaceRewritePattern(context, benefit) {}
- LogicalResult matchAndRewrite(NVVM::BasicPtxBuilderInterface op,
+ LogicalResult matchAndRewrite(BasicPtxBuilderInterface op,
PatternRewriter &rewriter) const override {
if (op.hasIntrinsic()) {
LLVM_DEBUG(DBGS() << "Ptx Builder does not lower \n\t" << op << "\n");
@@ -200,11 +67,11 @@ struct PtxLowering
SmallVector<std::pair<Value, PTXRegisterMod>> asmValues;
LLVM_DEBUG(DBGS() << op.getPtx() << "\n");
- PtxBuilder generator(op, rewriter, op.getPtx(), op.hasSideEffect());
+ PtxBuilder generator(op, rewriter);
op.getAsmValues(rewriter, asmValues);
for (auto &[asmValue, modifier] : asmValues) {
- LLVM_DEBUG(DBGSNL() << asmValue << "\t Modifier : " << modifier);
+ LLVM_DEBUG(DBGSNL() << asmValue << "\t Modifier : " << &modifier);
generator.insertValue(asmValue, modifier);
}
@@ -232,7 +99,7 @@ struct ConvertNVVMToLLVMPass
}
};
-/// Implement the interface to convert NNVM to LLVM.
+/// Implement the interface to convert NVVM to LLVM.
struct NVVMToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface;
void loadDependentDialects(MLIRContext *context) const final {
diff --git a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt
index 230ffec900bb984..b00259677697a50 100644
--- a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt
+++ b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt
@@ -42,6 +42,7 @@ add_mlir_dialect_library(MLIRLLVMDialect
add_mlir_dialect_library(MLIRNVVMDialect
IR/NVVMDialect.cpp
+ IR/BasicPtxBuilderInterface.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/LLVMIR
@@ -50,6 +51,7 @@ add_mlir_dialect_library(MLIRNVVMDialect
MLIRGPUCompilationAttrInterfacesIncGen
MLIRNVVMOpsIncGen
MLIRNVVMConversionsIncGen
+ MLIRBasicPtxBuilderInterfaceIncGen
intrinsics_gen
LINK_COMPONENTS
diff --git a/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp b/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp
new file mode 100644
index 000000000000000..8bcd091f830e2ad
--- /dev/null
+++ b/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp
@@ -0,0 +1,146 @@
+//===- BasicPtxBuilderInterface.td - PTX builder interface -*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Defines the interface to build PTX (Parallel Thread Execution) from NVVM Ops
+// automatically. It is used by NVVM to LLVM pass.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h"
+#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
+#include "mlir/Support/LogicalResult.h"
+
+#define DEBUG_TYPE "ptx-builder"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
+#define DBGSNL() (llvm::dbgs() << "\n")
+
+//===----------------------------------------------------------------------===//
+// BasicPtxBuilderInterface
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.cpp.inc"
+
+using namespace mlir;
+using namespace NVVM;
+
+static char getRegisterType(Type type) {
+ if (type.isInteger(16))
+ return 'h';
+ if (type.isInteger(32))
+ return 'r';
+ if (type.isInteger(64))
+ return 'l';
+ if (type.isF32())
+ return 'f';
+ if (type.isF64())
+ return 'd';
+ if (auto ptr = type.dyn_cast<LLVM::LLVMPointerType>()) {
+ // Shared address spaces is addressed with 32-bit pointers.
+ if (ptr.getAddressSpace() == NVVM::kSharedMemorySpace) {
+ return 'r';
+ }
+ return 'l';
+ }
+ // register type for struct is not supported.
+ return '?';
+}
+
+static char getRegisterType(Value v) {
+ if (v.getDefiningOp<LLVM::ConstantOp>())
+ return 'n';
+ return getRegisterType(v.getType());
+}
+
+void PtxBuilder::insertValue(Value v, PTXRegisterMod itype) {
+ LLVM_DEBUG(DBGS() << v << "\t Modifier : " << &itype << "\n");
+ auto getModifier = [&]() -> const char * {
+ if (itype == PTXRegisterMod::ReadWrite) {
+ assert(false && "Read-Write modifier is not supported. Try setting the "
+ "same value as Write and Read seperately.");
+ return "+";
+ }
+ if (itype == PTXRegisterMod::Write) {
+ return "=";
+ }
+ return "";
+ };
+ auto addValue = [&](Value v) {
+ if (itype == PTXRegisterMod::Read) {
+ ptxOperands.push_back(v);
+ return;
+ }
+ if (itype == PTXRegisterMod::ReadWrite)
+ ptxOperands.push_back(v);
+ hasResult = true;
+ };
+
+ llvm::raw_string_ostream ss(registerConstraints);
+ // Handle Structs
+ if (auto stype = dyn_cast<LLVM::LLVMStructType>(v.getType())) {
+ if (itype == PTXRegisterMod::Write) {
+ addValue(v);
+ }
+ for (auto [idx, t] : llvm::enumerate(stype.getBody())) {
+ if (itype != PTXRegisterMod::Write) {
+ Value extractValue = rewriter.create<LLVM::ExtractValueOp>(
+ interfaceOp->getLoc(), v, idx);
+ addValue(extractValue);
+ }
+ if (itype == PTXRegisterMod::ReadWrite) {
+ ss << idx << ",";
+ } else {
+ ss << getModifier() << getRegisterType(t) << ",";
+ }
+ ss.flush();
+ }
+ return;
+ }
+ // Handle Scalars
+ addValue(v);
+ ss << getModifier() << getRegisterType(v) << ",";
+ ss.flush();
+}
+
+LLVM::InlineAsmOp PtxBuilder::build() {
+ auto asmDialectAttr = LLVM::AsmDialectAttr::get(interfaceOp->getContext(),
+ LLVM::AsmDialect::AD_ATT);
+
+ auto resultTypes = interfaceOp->getResultTypes();
+
+ // Remove the last comma from the constraints string.
+ if (!registerConstraints.empty() &&
+ registerConstraints[registerConstraints.size() - 1] == ',')
+ registerConstraints.pop_back();
+
+ std::string ptxInstruction = interfaceOp.getPtx();
+
+ // Tablegen doesn't accept $, so we use %, but inline assembly uses $.
+ // Replace all % with $
+ std::replace(ptxInstruction.begin(), ptxInstruction.end(), '%', '$');
+
+ return rewriter.create<LLVM::InlineAsmOp>(
+ interfaceOp->getLoc(),
+ /*result types=*/resultTypes,
+ /*operands=*/ptxOperands,
+ /*asm_string=*/llvm::StringRef(ptxInstruction),
+ /*constraints=*/registerConstraints.data(),
+ /*has_side_effects=*/interfaceOp.hasSideEffect(),
+ /*is_align_stack=*/false,
+ /*asm_dialect=*/asmDialectAttr,
+ /*operand_attrs=*/ArrayAttr());
+}
+
+void PtxBuilder::buildAndReplaceOp() {
+ LLVM::InlineAsmOp inlineAsmOp = build();
+ LLVM_DEBUG(DBGS() << "\n Generated PTX \n\t" << inlineAsmOp << "\n");
+ if (inlineAsmOp->getNumResults() == interfaceOp->getNumResults()) {
+ rewriter.replaceOp(interfaceOp, inlineAsmOp);
+ } else {
+ rewriter.eraseOp(interfaceOp);
+ }
+}
More information about the Mlir-commits
mailing list