[Mlir-commits] [mlir] ffbca7e - [mlir][nvvm] Change return type of std::string of getPtx of PtxBuilder
Guray Ozen
llvmlistbot at llvm.org
Wed Jul 12 06:00:00 PDT 2023
Author: Guray Ozen
Date: 2023-07-12T14:59:54+02:00
New Revision: ffbca7e9f305949c6620d3e77371cc0463ed48b1
URL: https://github.com/llvm/llvm-project/commit/ffbca7e9f305949c6620d3e77371cc0463ed48b1
DIFF: https://github.com/llvm/llvm-project/commit/ffbca7e9f305949c6620d3e77371cc0463ed48b1.diff
LOG: [mlir][nvvm] Change return type of std::string of getPtx of PtxBuilder
getPtx used to return `const char*`. It is not flexible when one needs to build string in the function. This work changes return type.
Reviewed By: springerm
Differential Revision: https://reviews.llvm.org/D155056
Added:
Modified:
mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp
mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index ad17d1874ef879..e867114527d5c0 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -121,7 +121,7 @@ def BasicPtxBuilderOpInterface : OpInterface<"BasicPtxBuilderInterface"> {
Results<(outs LLVM_Type:$res)>, Arguments<(ins LLVM_i64ptr_any:$addr, I32:$txcount)> {
...
let extraClassDefinition = [{
- const char* $cppClass::getPtx() { return \"mbarrier.arrive.expect_tx.b64 %0, [%1], %2;\"; }
+ std::string $cppClass::getPtx() { return std::string(\"mbarrier.arrive.expect_tx.b64 %0, [%1], %2;\"); }
}\];
}
```
@@ -160,7 +160,7 @@ def BasicPtxBuilderOpInterface : OpInterface<"BasicPtxBuilderInterface"> {
>,
InterfaceMethod<
/*desc=*/[{ Returns PTX code. }],
- /*retType=*/"const char*",
+ /*retType=*/"std::string",
/*methodName=*/"getPtx"
>,
InterfaceMethod<
@@ -377,7 +377,7 @@ def NVVM_MBarrierArriveExpectTxOp : NVVM_Op<"mbarrier.arrive.expect_tx",
Arguments<(ins LLVM_i64ptr_any:$addr, I32:$txcount)> {
let assemblyFormat = "$addr `,` $txcount attr-dict `:` type(operands) `->` type($res)";
let extraClassDefinition = [{
- const char* $cppClass::getPtx() { return "mbarrier.arrive.expect_tx.b64 %0, [%1], %2;"; }
+ std::string $cppClass::getPtx() { return std::string("mbarrier.arrive.expect_tx.b64 %0, [%1], %2;"); }
}];
}
@@ -387,7 +387,7 @@ def NVVM_MBarrierArriveExpectTxSharedOp : NVVM_Op<"mbarrier.arrive.expect_tx.sha
Arguments<(ins LLVM_i64ptr_shared:$addr, I32:$txcount)> {
let assemblyFormat = "$addr `,` $txcount attr-dict `:` type(operands) `->` type($res)";
let extraClassDefinition = [{
- const char* $cppClass::getPtx() { return "mbarrier.arrive.expect_tx.shared.b64 %0, [%1], %2;"; }
+ std::string $cppClass::getPtx() { return std::string("mbarrier.arrive.expect_tx.shared.b64 %0, [%1], %2;"); }
}];
}
@@ -397,12 +397,12 @@ def NVVM_MBarrierTryWaitParityOp : NVVM_Op<"mbarrier.try_wait.parity",
Arguments<(ins LLVM_i64ptr_any:$addr, LLVM_Type:$token)> {
let assemblyFormat = "$addr `,` $token attr-dict `:` type(operands) `->` type($res)";
let extraClassDefinition = [{
- const char* $cppClass::getPtx() {
- return "{\n\t"
+ std::string $cppClass::getPtx() {
+ return std::string("{\n\t"
".reg .pred P1; \n\t"
"mbarrier.try_wait.parity.b64 P1, [%1], %2; \n\t"
"selp.b32 %0, 1, 0, P1; \n\t"
- "}";
+ "}");
}
}];
}
@@ -413,12 +413,12 @@ def NVVM_MBarrierTryWaitParitySharedOp : NVVM_Op<"mbarrier.try_wait.parity.share
Arguments<(ins LLVM_i64ptr_shared:$addr, LLVM_Type:$token)> {
let assemblyFormat = "$addr `,` $token attr-dict `:` type(operands) `->` type($res)";
let extraClassDefinition = [{
- const char* $cppClass::getPtx() {
- return "{\n\t"
+ std::string $cppClass::getPtx() {
+ return std::string("{\n\t"
".reg .pred P1; \n\t"
"mbarrier.try_wait.parity.shared.b64 P1, [%1], %2; \n\t"
"selp.b32 %0, 1, 0, P1; \n\t"
- "}";
+ "}");
}
}];
}
@@ -567,11 +567,11 @@ def NVVM_CpAsyncOp : NVVM_Op<"cp.async.shared.global", [DeclareOpInterfaceMethod
}
}];
let extraClassDefinition = [{
- const char* $cppClass::getPtx() {
+ std::string $cppClass::getPtx() {
if(getModifier() == NVVM::LoadCacheModifierKind::CG)
- return "cp.async.cg.shared.global [%0], [%1], %2, %3;\n";
+ return std::string("cp.async.cg.shared.global [%0], [%1], %2, %3;\n");
if(getModifier() == NVVM::LoadCacheModifierKind::CA)
- return "cp.async.ca.shared.global [%0], [%1], %2, %3;\n";
+ return std::string("cp.async.ca.shared.global [%0], [%1], %2, %3;\n");
llvm_unreachable("unsupported cache modifier");
}
}];
diff --git a/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp b/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp
index 62d04e85f83a16..36c2f3ab2cfb19 100644
--- a/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp
+++ b/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp
@@ -33,6 +33,7 @@
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/raw_ostream.h"
+#include <regex>
#include <string>
#define DEBUG_TYPE "nvvm-to-llvm"
@@ -53,7 +54,7 @@ namespace {
class PtxBuilder {
Operation *op;
PatternRewriter &rewriter;
- const char *asmStr;
+ std::string asmStr;
SmallVector<Value> asmVals;
std::string asmConstraints;
bool sideEffects;
@@ -85,9 +86,10 @@ class PtxBuilder {
}
public:
- PtxBuilder(Operation *op, PatternRewriter &rewriter, const char *ptxAsm,
+ PtxBuilder(Operation *op, PatternRewriter &rewriter, std::string ptxAsm,
bool sideEffects = false)
- : op(op), rewriter(rewriter), asmStr(ptxAsm), sideEffects(sideEffects) {}
+ : op(op), rewriter(rewriter), asmStr(std::move(ptxAsm)),
+ sideEffects(sideEffects) {}
void insertValue(Value v, PTXRegisterMod itype = PTXRegisterMod::Read) {
llvm::raw_string_ostream ss(asmConstraints);
@@ -116,10 +118,13 @@ class PtxBuilder {
asmConstraints[asmConstraints.size() - 1] == ',')
asmConstraints.pop_back();
+ // asm keywords expects %, but inline assembly uses $. Replace all % with $
+ std::replace(asmStr.begin(), asmStr.end(), '%', '$');
+
return rewriter.create<LLVM::InlineAsmOp>(
op->getLoc(), resultType,
/*operands=*/asmVals,
- /*asm_string=*/asmStr,
+ /*asm_string=*/llvm::StringRef(asmStr),
/*constraints=*/asmConstraints.data(),
/*has_side_effects=*/sideEffects,
/*is_align_stack=*/false,
diff --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
index 5e00cea3c13d90..ceb59b9632533d 100644
--- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
+++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
@@ -2,28 +2,28 @@
// CHECK-LABEL : @init_mbarrier_arrive_expect_tx
llvm.func @init_mbarrier_arrive_expect_tx(%barrier : !llvm.ptr<3>, %txcount : i32) -> i64 {
- //CHECK : llvm.inline_asm has_side_effects asm_dialect = att "mbarrier.arrive.expect_tx.shared.b64 %0, [%1], %2;", "=l,r,r" %{{.*}}, %{{.*}} : (!llvm.ptr<3>, i32) -> i64
+ //CHECK : llvm.inline_asm has_side_effects asm_dialect = att "mbarrier.arrive.expect_tx.shared.b64 $0, [$1], $2;", "=l,r,r" %{{.*}}, %{{.*}} : (!llvm.ptr<3>, i32) -> i64
%res = nvvm.mbarrier.arrive.expect_tx.shared %barrier, %txcount : !llvm.ptr<3>, i32 -> i64
llvm.return %res : i64
}
// CHECK-LABEL : @init_mbarrier_arrive_expect_tx_generic
llvm.func @init_mbarrier_arrive_expect_tx_generic(%barrier : !llvm.ptr, %txcount : i32)-> i64 {
- // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "mbarrier.arrive.expect_tx.b64 %0, [%1], %2;", "=l,l,r" %{{.*}}, %{{.*}} : (!llvm.ptr, i32) -> i64
+ // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "mbarrier.arrive.expect_tx.b64 $0, [$1], $2;", "=l,l,r" %{{.*}}, %{{.*}} : (!llvm.ptr, i32) -> i64
%res = nvvm.mbarrier.arrive.expect_tx %barrier, %txcount : !llvm.ptr, i32 -> i64
llvm.return %res : i64
}
// CHECK-LABEL : @init_mbarrier_try_wait.parity.shared
llvm.func @init_mbarrier_try_wait_shared(%barrier : !llvm.ptr<3>, %token : i32) -> i32 {
- // CHECK : llvm.inline_asm has_side_effects asm_dialect = att "{\0A\09.reg .pred P1; \0A\09mbarrier.try_wait.parity.shared.b64 P1, [%1], %2; \0A\09selp.b32 %0, 1, 0, P1; \0A\09}", "=r,r,r" %{{.*}}, %{{.*}} : (!llvm.ptr<3>, i32) -> i32
+ // CHECK : llvm.inline_asm has_side_effects asm_dialect = att "{\0A\09.reg .pred P1; \0A\09mbarrier.try_wait.parity.shared.b64 P1, [$1], $2; \0A\09selp.b32 $0, 1, 0, P1; \0A\09}", "=r,r,r" %{{.*}}, %{{.*}} : (!llvm.ptr<3>, i32) -> i32
%res = nvvm.mbarrier.try_wait.parity.shared %barrier, %token : !llvm.ptr<3>, i32 -> i32
llvm.return %res : i32
}
// CHECK-LABEL : @init_mbarrier_try_wait.parity
llvm.func @init_mbarrier_try_wait(%barrier : !llvm.ptr, %token : i32) -> i32{
- // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "{\0A\09.reg .pred P1; \0A\09mbarrier.try_wait.parity.b64 P1, [%1], %2; \0A\09selp.b32 %0, 1, 0, P1; \0A\09}", "=r,l,r" %{{.*}}, %{{.*}} : (!llvm.ptr, i32) -> i32
+ // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "{\0A\09.reg .pred P1; \0A\09mbarrier.try_wait.parity.b64 P1, [$1], $2; \0A\09selp.b32 $0, 1, 0, P1; \0A\09}", "=r,l,r" %{{.*}}, %{{.*}} : (!llvm.ptr, i32) -> i32
%res = nvvm.mbarrier.try_wait.parity %barrier, %token : !llvm.ptr, i32 -> i32
llvm.return %res : i32
}
@@ -39,9 +39,9 @@ func.func @async_cp(%dst: !llvm.ptr<3>, %src: !llvm.ptr<1>) {
// CHECK-LABEL : @async_cp_zfill
func.func @async_cp_zfill(%dst: !llvm.ptr<3>, %src: !llvm.ptr<1>, %cpSize: i32) {
- // CHECK : llvm.inline_asm has_side_effects asm_dialect = att "cp.async.cg.shared.global [%0], [%1], %2, %3;\0A", "r,l,r" %{{.*}}, %{{.*}}, %{{.*}} : (!llvm.ptr<3>, !llvm.ptr<1>, i32) -> !llvm.void
+ // CHECK : llvm.inline_asm has_side_effects asm_dialect = att "cp.async.cg.shared.global [$0], [$1], $2, $3;\0A", "r,l,r" %{{.*}}, %{{.*}}, %{{.*}} : (!llvm.ptr<3>, !llvm.ptr<1>, i32) -> !llvm.void
nvvm.cp.async.shared.global %dst, %src, 16, cache = cg, %cpSize : !llvm.ptr<3>, !llvm.ptr<1>, i32
- // CHECK : llvm.inline_asm has_side_effects asm_dialect = att "cp.async.ca.shared.global [%0], [%1], %2, %3;\0A", "r,l,r" %{{.*}}, %{{.*}}, %{{.*}} : (!llvm.ptr<3>, !llvm.ptr<1>, i32) -> !llvm.void
+ // CHECK : llvm.inline_asm has_side_effects asm_dialect = att "cp.async.ca.shared.global [$0], [$1], $2, $3;\0A", "r,l,r" %{{.*}}, %{{.*}}, %{{.*}} : (!llvm.ptr<3>, !llvm.ptr<1>, i32) -> !llvm.void
nvvm.cp.async.shared.global %dst, %src, 4, cache = ca, %cpSize : !llvm.ptr<3>, !llvm.ptr<1>, i32
return
}
More information about the Mlir-commits
mailing list