[Mlir-commits] [mlir] ffbca7e - [mlir][nvvm] Change return type of std::string of getPtx of PtxBuilder

Guray Ozen llvmlistbot at llvm.org
Wed Jul 12 06:00:00 PDT 2023


Author: Guray Ozen
Date: 2023-07-12T14:59:54+02:00
New Revision: ffbca7e9f305949c6620d3e77371cc0463ed48b1

URL: https://github.com/llvm/llvm-project/commit/ffbca7e9f305949c6620d3e77371cc0463ed48b1
DIFF: https://github.com/llvm/llvm-project/commit/ffbca7e9f305949c6620d3e77371cc0463ed48b1.diff

LOG: [mlir][nvvm] Change return type of std::string of getPtx of PtxBuilder

getPtx used to return `const char*`. It is not flexible when one needs to build string in the function. This work changes return type.

Reviewed By: springerm

Differential Revision: https://reviews.llvm.org/D155056

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
    mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp
    mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index ad17d1874ef879..e867114527d5c0 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -121,7 +121,7 @@ def BasicPtxBuilderOpInterface : OpInterface<"BasicPtxBuilderInterface"> {
       Results<(outs LLVM_Type:$res)>, Arguments<(ins LLVM_i64ptr_any:$addr, I32:$txcount)> {
       ...
       let extraClassDefinition = [{
-        const char* $cppClass::getPtx() { return \"mbarrier.arrive.expect_tx.b64 %0, [%1], %2;\"; }
+        std::string $cppClass::getPtx() { return std::string(\"mbarrier.arrive.expect_tx.b64 %0, [%1], %2;\"); }
       }\];
     }
     ```
@@ -160,7 +160,7 @@ def BasicPtxBuilderOpInterface : OpInterface<"BasicPtxBuilderInterface"> {
       >,
     InterfaceMethod<
         /*desc=*/[{ Returns PTX code. }],
-        /*retType=*/"const char*",
+        /*retType=*/"std::string",
         /*methodName=*/"getPtx"
       >,
     InterfaceMethod<
@@ -377,7 +377,7 @@ def NVVM_MBarrierArriveExpectTxOp : NVVM_Op<"mbarrier.arrive.expect_tx",
   Arguments<(ins LLVM_i64ptr_any:$addr, I32:$txcount)> {
   let assemblyFormat = "$addr `,` $txcount attr-dict `:` type(operands) `->` type($res)";
   let extraClassDefinition = [{
-    const char* $cppClass::getPtx() { return "mbarrier.arrive.expect_tx.b64 %0, [%1], %2;"; }
+    std::string $cppClass::getPtx() { return std::string("mbarrier.arrive.expect_tx.b64 %0, [%1], %2;"); }
   }];
 }
 
@@ -387,7 +387,7 @@ def NVVM_MBarrierArriveExpectTxSharedOp : NVVM_Op<"mbarrier.arrive.expect_tx.sha
   Arguments<(ins LLVM_i64ptr_shared:$addr, I32:$txcount)> {    
   let assemblyFormat = "$addr `,` $txcount attr-dict `:` type(operands) `->` type($res)";
   let extraClassDefinition = [{
-    const char* $cppClass::getPtx() { return "mbarrier.arrive.expect_tx.shared.b64 %0, [%1], %2;"; }
+    std::string $cppClass::getPtx() { return std::string("mbarrier.arrive.expect_tx.shared.b64 %0, [%1], %2;"); }
   }];
 }
 
@@ -397,12 +397,12 @@ def NVVM_MBarrierTryWaitParityOp : NVVM_Op<"mbarrier.try_wait.parity",
   Arguments<(ins LLVM_i64ptr_any:$addr, LLVM_Type:$token)> {
   let assemblyFormat = "$addr `,` $token attr-dict `:` type(operands) `->` type($res)";
   let extraClassDefinition = [{
-    const char* $cppClass::getPtx() {
-      return "{\n\t"
+    std::string $cppClass::getPtx() {
+      return std::string("{\n\t"
               ".reg .pred P1; \n\t"
               "mbarrier.try_wait.parity.b64 P1, [%1], %2; \n\t"
               "selp.b32 %0, 1, 0, P1; \n\t"
-              "}"; 
+              "}"); 
     }
   }];
 }
@@ -413,12 +413,12 @@ def NVVM_MBarrierTryWaitParitySharedOp : NVVM_Op<"mbarrier.try_wait.parity.share
   Arguments<(ins LLVM_i64ptr_shared:$addr, LLVM_Type:$token)> {  
   let assemblyFormat = "$addr `,` $token attr-dict `:` type(operands) `->` type($res)";
   let extraClassDefinition = [{
-    const char* $cppClass::getPtx() {
-      return "{\n\t"
+    std::string $cppClass::getPtx() {
+      return std::string("{\n\t"
               ".reg .pred P1; \n\t"
               "mbarrier.try_wait.parity.shared.b64 P1, [%1], %2; \n\t"
               "selp.b32 %0, 1, 0, P1; \n\t"
-              "}"; 
+              "}"); 
     }
   }];
 }
@@ -567,11 +567,11 @@ def NVVM_CpAsyncOp : NVVM_Op<"cp.async.shared.global", [DeclareOpInterfaceMethod
     }        
   }];
   let extraClassDefinition = [{        
-    const char* $cppClass::getPtx() { 
+    std::string $cppClass::getPtx() { 
       if(getModifier() == NVVM::LoadCacheModifierKind::CG)
-        return "cp.async.cg.shared.global [%0], [%1], %2, %3;\n"; 
+        return std::string("cp.async.cg.shared.global [%0], [%1], %2, %3;\n");
       if(getModifier() == NVVM::LoadCacheModifierKind::CA)
-        return "cp.async.ca.shared.global [%0], [%1], %2, %3;\n";        
+        return std::string("cp.async.ca.shared.global [%0], [%1], %2, %3;\n");
       llvm_unreachable("unsupported cache modifier");      
     }
   }];

diff  --git a/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp b/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp
index 62d04e85f83a16..36c2f3ab2cfb19 100644
--- a/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp
+++ b/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp
@@ -33,6 +33,7 @@
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/StringRef.h"
 #include "llvm/Support/raw_ostream.h"
+#include <regex>
 #include <string>
 
 #define DEBUG_TYPE "nvvm-to-llvm"
@@ -53,7 +54,7 @@ namespace {
 class PtxBuilder {
   Operation *op;
   PatternRewriter &rewriter;
-  const char *asmStr;
+  std::string asmStr;
   SmallVector<Value> asmVals;
   std::string asmConstraints;
   bool sideEffects;
@@ -85,9 +86,10 @@ class PtxBuilder {
   }
 
 public:
-  PtxBuilder(Operation *op, PatternRewriter &rewriter, const char *ptxAsm,
+  PtxBuilder(Operation *op, PatternRewriter &rewriter, std::string ptxAsm,
              bool sideEffects = false)
-      : op(op), rewriter(rewriter), asmStr(ptxAsm), sideEffects(sideEffects) {}
+      : op(op), rewriter(rewriter), asmStr(std::move(ptxAsm)),
+        sideEffects(sideEffects) {}
 
   void insertValue(Value v, PTXRegisterMod itype = PTXRegisterMod::Read) {
     llvm::raw_string_ostream ss(asmConstraints);
@@ -116,10 +118,13 @@ class PtxBuilder {
         asmConstraints[asmConstraints.size() - 1] == ',')
       asmConstraints.pop_back();
 
+    // asm keywords expects %, but inline assembly uses $. Replace all % with $
+    std::replace(asmStr.begin(), asmStr.end(), '%', '$');
+
     return rewriter.create<LLVM::InlineAsmOp>(
         op->getLoc(), resultType,
         /*operands=*/asmVals,
-        /*asm_string=*/asmStr,
+        /*asm_string=*/llvm::StringRef(asmStr),
         /*constraints=*/asmConstraints.data(),
         /*has_side_effects=*/sideEffects,
         /*is_align_stack=*/false,

diff  --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
index 5e00cea3c13d90..ceb59b9632533d 100644
--- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
+++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
@@ -2,28 +2,28 @@
 
 // CHECK-LABEL : @init_mbarrier_arrive_expect_tx
 llvm.func @init_mbarrier_arrive_expect_tx(%barrier : !llvm.ptr<3>, %txcount : i32) -> i64 {
-  //CHECK : llvm.inline_asm has_side_effects asm_dialect = att "mbarrier.arrive.expect_tx.shared.b64 %0, [%1], %2;", "=l,r,r" %{{.*}}, %{{.*}} : (!llvm.ptr<3>, i32) -> i64            
+  //CHECK : llvm.inline_asm has_side_effects asm_dialect = att "mbarrier.arrive.expect_tx.shared.b64 $0, [$1], $2;", "=l,r,r" %{{.*}}, %{{.*}} : (!llvm.ptr<3>, i32) -> i64            
   %res = nvvm.mbarrier.arrive.expect_tx.shared %barrier, %txcount : !llvm.ptr<3>, i32 -> i64
   llvm.return %res : i64
 }
 
 // CHECK-LABEL : @init_mbarrier_arrive_expect_tx_generic
 llvm.func @init_mbarrier_arrive_expect_tx_generic(%barrier : !llvm.ptr, %txcount : i32)-> i64 {
-  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "mbarrier.arrive.expect_tx.b64 %0, [%1], %2;", "=l,l,r" %{{.*}}, %{{.*}} : (!llvm.ptr, i32) -> i64
+  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "mbarrier.arrive.expect_tx.b64 $0, [$1], $2;", "=l,l,r" %{{.*}}, %{{.*}} : (!llvm.ptr, i32) -> i64
   %res = nvvm.mbarrier.arrive.expect_tx %barrier, %txcount : !llvm.ptr, i32 -> i64
   llvm.return %res : i64
 }
 
 // CHECK-LABEL : @init_mbarrier_try_wait.parity.shared
 llvm.func @init_mbarrier_try_wait_shared(%barrier : !llvm.ptr<3>, %token : i32) -> i32 {
-  // CHECK : llvm.inline_asm has_side_effects asm_dialect = att "{\0A\09.reg .pred P1; \0A\09mbarrier.try_wait.parity.shared.b64 P1, [%1], %2; \0A\09selp.b32 %0, 1, 0, P1; \0A\09}", "=r,r,r" %{{.*}}, %{{.*}} : (!llvm.ptr<3>, i32) -> i32
+  // CHECK : llvm.inline_asm has_side_effects asm_dialect = att "{\0A\09.reg .pred P1; \0A\09mbarrier.try_wait.parity.shared.b64 P1, [$1], $2; \0A\09selp.b32 $0, 1, 0, P1; \0A\09}", "=r,r,r" %{{.*}}, %{{.*}} : (!llvm.ptr<3>, i32) -> i32
   %res = nvvm.mbarrier.try_wait.parity.shared %barrier, %token : !llvm.ptr<3>, i32 -> i32
   llvm.return %res : i32
 }
 
 // CHECK-LABEL : @init_mbarrier_try_wait.parity
 llvm.func @init_mbarrier_try_wait(%barrier : !llvm.ptr, %token : i32) -> i32{
-  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "{\0A\09.reg .pred P1; \0A\09mbarrier.try_wait.parity.b64 P1, [%1], %2; \0A\09selp.b32 %0, 1, 0, P1; \0A\09}", "=r,l,r" %{{.*}}, %{{.*}} : (!llvm.ptr, i32) -> i32
+  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "{\0A\09.reg .pred P1; \0A\09mbarrier.try_wait.parity.b64 P1, [$1], $2; \0A\09selp.b32 $0, 1, 0, P1; \0A\09}", "=r,l,r" %{{.*}}, %{{.*}} : (!llvm.ptr, i32) -> i32
   %res = nvvm.mbarrier.try_wait.parity %barrier, %token : !llvm.ptr, i32 -> i32
   llvm.return %res : i32
 }
@@ -39,9 +39,9 @@ func.func @async_cp(%dst: !llvm.ptr<3>, %src: !llvm.ptr<1>) {
 
 // CHECK-LABEL : @async_cp_zfill
 func.func @async_cp_zfill(%dst: !llvm.ptr<3>, %src: !llvm.ptr<1>, %cpSize: i32) {
-  // CHECK : llvm.inline_asm has_side_effects asm_dialect = att "cp.async.cg.shared.global [%0], [%1], %2, %3;\0A", "r,l,r" %{{.*}}, %{{.*}}, %{{.*}} : (!llvm.ptr<3>, !llvm.ptr<1>, i32) -> !llvm.void
+  // CHECK : llvm.inline_asm has_side_effects asm_dialect = att "cp.async.cg.shared.global [$0], [$1], $2, $3;\0A", "r,l,r" %{{.*}}, %{{.*}}, %{{.*}} : (!llvm.ptr<3>, !llvm.ptr<1>, i32) -> !llvm.void
   nvvm.cp.async.shared.global %dst, %src, 16, cache =  cg, %cpSize : !llvm.ptr<3>, !llvm.ptr<1>, i32
-  // CHECK : llvm.inline_asm has_side_effects asm_dialect = att "cp.async.ca.shared.global [%0], [%1], %2, %3;\0A", "r,l,r" %{{.*}}, %{{.*}}, %{{.*}} : (!llvm.ptr<3>, !llvm.ptr<1>, i32) -> !llvm.void
+  // CHECK : llvm.inline_asm has_side_effects asm_dialect = att "cp.async.ca.shared.global [$0], [$1], $2, $3;\0A", "r,l,r" %{{.*}}, %{{.*}}, %{{.*}} : (!llvm.ptr<3>, !llvm.ptr<1>, i32) -> !llvm.void
   nvvm.cp.async.shared.global %dst, %src, 4, cache =  ca, %cpSize : !llvm.ptr<3>, !llvm.ptr<1>, i32
   return
 }


        


More information about the Mlir-commits mailing list