[llvm] e27f921 - [NVPTX][NFC] Simplify printing initialization of aggregates

Igor Kudrin via llvm-commits llvm-commits at lists.llvm.org
Mon Jul 18 04:09:50 PDT 2022


Author: Igor Kudrin
Date: 2022-07-18T04:08:59-07:00
New Revision: e27f9214c0465fab2d5efa3cf462019836f5eb08

URL: https://github.com/llvm/llvm-project/commit/e27f9214c0465fab2d5efa3cf462019836f5eb08
DIFF: https://github.com/llvm/llvm-project/commit/e27f9214c0465fab2d5efa3cf462019836f5eb08.diff

LOG: [NVPTX][NFC] Simplify printing initialization of aggregates

This simplifies NVPTXAsmPrinter::AggBuffer and its usage.
It is also a preparation for D127504.

Differential Revision: https://reviews.llvm.org/D129773

Added: 
    

Modified: 
    llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
    llvm/lib/Target/NVPTX/NVPTXAsmPrinter.h

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
index 9977d8ba03009..422e47b3d97d7 100644
--- a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
@@ -1168,31 +1168,22 @@ void NVPTXAsmPrinter::printModuleLevelGV(const GlobalVariable *GVar,
           GVar->hasInitializer()) {
         const Constant *Initializer = GVar->getInitializer();
         if (!isa<UndefValue>(Initializer) && !Initializer->isNullValue()) {
-          AggBuffer aggBuffer(ElementSize, O, *this);
+          AggBuffer aggBuffer(ElementSize, *this);
           bufferAggregateConstant(Initializer, &aggBuffer);
-          if (aggBuffer.numSymbols) {
-            if (static_cast<const NVPTXTargetMachine &>(TM).is64Bit()) {
-              O << " .u64 ";
-              getSymbol(GVar)->print(O, MAI);
-              O << "[";
-              O << ElementSize / 8;
-            } else {
-              O << " .u32 ";
-              getSymbol(GVar)->print(O, MAI);
-              O << "[";
-              O << ElementSize / 4;
-            }
-            O << "]";
+          if (aggBuffer.numSymbols()) {
+            unsigned int ptrSize = MAI->getCodePointerSize();
+            O << " .u" << ptrSize * 8 << " ";
+            getSymbol(GVar)->print(O, MAI);
+            O << "[" << ElementSize / ptrSize << "] = {";
+            aggBuffer.printWords(O);
+            O << "}";
           } else {
             O << " .b8 ";
             getSymbol(GVar)->print(O, MAI);
-            O << "[";
-            O << ElementSize;
-            O << "]";
+            O << "[" << ElementSize << "] = {";
+            aggBuffer.printBytes(O);
+            O << "}";
           }
-          O << " = {";
-          aggBuffer.print();
-          O << "}";
         } else {
           O << " .b8 ";
           getSymbol(GVar)->print(O, MAI);
@@ -1219,6 +1210,57 @@ void NVPTXAsmPrinter::printModuleLevelGV(const GlobalVariable *GVar,
   O << ";\n";
 }
 
+void NVPTXAsmPrinter::AggBuffer::printSymbol(unsigned nSym, raw_ostream &os) {
+  const Value *v = Symbols[nSym];
+  const Value *v0 = SymbolsBeforeStripping[nSym];
+  if (const GlobalValue *GVar = dyn_cast<GlobalValue>(v)) {
+    MCSymbol *Name = AP.getSymbol(GVar);
+    PointerType *PTy = dyn_cast<PointerType>(v0->getType());
+    // Is v0 a generic pointer?
+    bool isGenericPointer = PTy && PTy->getAddressSpace() == 0;
+    if (EmitGeneric && isGenericPointer && !isa<Function>(v)) {
+      os << "generic(";
+      Name->print(os, AP.MAI);
+      os << ")";
+    } else {
+      Name->print(os, AP.MAI);
+    }
+  } else if (const ConstantExpr *CExpr = dyn_cast<ConstantExpr>(v0)) {
+    const MCExpr *Expr = AP.lowerConstantForGV(cast<Constant>(CExpr), false);
+    AP.printMCExpr(*Expr, os);
+  } else
+    llvm_unreachable("symbol type unknown");
+}
+
+void NVPTXAsmPrinter::AggBuffer::printBytes(raw_ostream &os) {
+  for (unsigned int pos = 0; pos < size; ++pos) {
+    if (pos)
+      os << ", ";
+    os << (unsigned int)buffer[pos];
+  }
+}
+
+void NVPTXAsmPrinter::AggBuffer::printWords(raw_ostream &os) {
+  unsigned int ptrSize = AP.MAI->getCodePointerSize();
+  symbolPosInBuffer.push_back(size);
+  unsigned int nSym = 0;
+  unsigned int nextSymbolPos = symbolPosInBuffer[nSym];
+  assert(nextSymbolPos % ptrSize == 0);
+  for (unsigned int pos = 0; pos < size; pos += ptrSize) {
+    if (pos)
+      os << ", ";
+    if (pos == nextSymbolPos) {
+      printSymbol(nSym, os);
+      nextSymbolPos = symbolPosInBuffer[++nSym];
+      assert(nextSymbolPos % ptrSize == 0);
+      assert(nextSymbolPos >= pos + ptrSize);
+    } else if (ptrSize == 4)
+      os << *(uint32_t *)(&buffer[pos]);
+    else
+      os << *(uint64_t *)(&buffer[pos]);
+  }
+}
+
 void NVPTXAsmPrinter::emitDemotedVars(const Function *f, raw_ostream &O) {
   if (localDecls.find(f) == localDecls.end())
     return;

diff  --git a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.h b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.h
index cd61e99a103a6..ebf9e29b5a627 100644
--- a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.h
+++ b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.h
@@ -78,7 +78,8 @@ class LLVM_LIBRARY_VISIBILITY NVPTXAsmPrinter : public AsmPrinter {
     // Once we have this AggBuffer setup, we can choose how to print
     // it out.
   public:
-    unsigned numSymbols;   // number of symbol addresses
+    // number of symbol addresses
+    unsigned numSymbols() const { return Symbols.size(); }
 
   private:
     const unsigned size;   // size of the buffer in bytes
@@ -94,15 +95,13 @@ class LLVM_LIBRARY_VISIBILITY NVPTXAsmPrinter : public AsmPrinter {
     // SymbolsBeforeStripping[i].
     SmallVector<const Value *, 4> SymbolsBeforeStripping;
     unsigned curpos;
-    raw_ostream &O;
     NVPTXAsmPrinter &AP;
     bool EmitGeneric;
 
   public:
-    AggBuffer(unsigned size, raw_ostream &O, NVPTXAsmPrinter &AP)
-        : size(size), buffer(size), O(O), AP(AP) {
+    AggBuffer(unsigned size, NVPTXAsmPrinter &AP)
+        : size(size), buffer(size), AP(AP) {
       curpos = 0;
-      numSymbols = 0;
       EmitGeneric = AP.EmitGeneric;
     }
 
@@ -135,63 +134,13 @@ class LLVM_LIBRARY_VISIBILITY NVPTXAsmPrinter : public AsmPrinter {
       symbolPosInBuffer.push_back(curpos);
       Symbols.push_back(GVar);
       SymbolsBeforeStripping.push_back(GVarBeforeStripping);
-      numSymbols++;
     }
 
-    void print() {
-      if (numSymbols == 0) {
-        // print out in bytes
-        for (unsigned i = 0; i < size; i++) {
-          if (i)
-            O << ", ";
-          O << (unsigned int) buffer[i];
-        }
-      } else {
-        // print out in 4-bytes or 8-bytes
-        unsigned int pos = 0;
-        unsigned int nSym = 0;
-        unsigned int nextSymbolPos = symbolPosInBuffer[nSym];
-        unsigned int nBytes = 4;
-        if (static_cast<const NVPTXTargetMachine &>(AP.TM).is64Bit())
-          nBytes = 8;
-        for (pos = 0; pos < size; pos += nBytes) {
-          if (pos)
-            O << ", ";
-          if (pos == nextSymbolPos) {
-            const Value *v = Symbols[nSym];
-            const Value *v0 = SymbolsBeforeStripping[nSym];
-            if (const GlobalValue *GVar = dyn_cast<GlobalValue>(v)) {
-              MCSymbol *Name = AP.getSymbol(GVar);
-              PointerType *PTy = dyn_cast<PointerType>(v0->getType());
-              bool IsNonGenericPointer = false; // Is v0 a non-generic pointer?
-              if (PTy && PTy->getAddressSpace() != 0) {
-                IsNonGenericPointer = true;
-              }
-              if (EmitGeneric && !isa<Function>(v) && !IsNonGenericPointer) {
-                O << "generic(";
-                Name->print(O, AP.MAI);
-                O << ")";
-              } else {
-                Name->print(O, AP.MAI);
-              }
-            } else if (const ConstantExpr *CExpr = dyn_cast<ConstantExpr>(v0)) {
-              const MCExpr *Expr =
-                AP.lowerConstantForGV(cast<Constant>(CExpr), false);
-              AP.printMCExpr(*Expr, O);
-            } else
-              llvm_unreachable("symbol type unknown");
-            nSym++;
-            if (nSym >= numSymbols)
-              nextSymbolPos = size + 1;
-            else
-              nextSymbolPos = symbolPosInBuffer[nSym];
-          } else if (nBytes == 4)
-            O << *(unsigned int *)(&buffer[pos]);
-          else
-            O << *(unsigned long long *)(&buffer[pos]);
-        }
-      }
-    }
+    void printBytes(raw_ostream &os);
+    void printWords(raw_ostream &os);
+
+  private:
+    void printSymbol(unsigned nSym, raw_ostream &os);
   };
 
   friend class AggBuffer;


        


More information about the llvm-commits mailing list