[llvm] [WebAssembly] Allocate MCSymbolWasm data on MCContext (PR #85866)

Tim Neumann via llvm-commits llvm-commits at lists.llvm.org
Mon Apr 1 23:21:02 PDT 2024


https://github.com/TimNN updated https://github.com/llvm/llvm-project/pull/85866

>From ddbd7f74be7c580fac8672cf5950c5e5edbf9ed1 Mon Sep 17 00:00:00 2001
From: Tim Neumann <timnn at google.com>
Date: Tue, 19 Mar 2024 21:23:02 +0100
Subject: [PATCH] [WebAssembly] Allocate MCSymbolWasm data on MCContext

Fixes #85578, a use-after-free caused by some `MCSymbolWasm` data being freed too early.
---
 llvm/include/llvm/MC/MCContext.h              | 17 +++++++
 llvm/lib/MC/MCContext.cpp                     |  5 +++
 .../AsmParser/WebAssemblyAsmParser.cpp        | 45 ++++++-------------
 .../WebAssembly/WebAssemblyAsmPrinter.cpp     | 29 +++++-------
 .../WebAssembly/WebAssemblyAsmPrinter.h       | 12 -----
 .../WebAssembly/WebAssemblyMCInstLower.cpp    | 15 +++----
 .../WebAssemblyMachineFunctionInfo.cpp        |  6 +--
 .../WebAssemblyMachineFunctionInfo.h          |  6 +--
 llvm/test/MC/WebAssembly/module-asm.ll        | 25 +++++++++++
 9 files changed, 86 insertions(+), 74 deletions(-)
 create mode 100644 llvm/test/MC/WebAssembly/module-asm.ll

diff --git a/llvm/include/llvm/MC/MCContext.h b/llvm/include/llvm/MC/MCContext.h
index 3f585d4d2efaf5..ef9833cdf2b070 100644
--- a/llvm/include/llvm/MC/MCContext.h
+++ b/llvm/include/llvm/MC/MCContext.h
@@ -26,6 +26,7 @@
 #include "llvm/Support/Compiler.h"
 #include "llvm/Support/Error.h"
 #include "llvm/Support/MD5.h"
+#include "llvm/Support/StringSaver.h"
 #include "llvm/Support/raw_ostream.h"
 #include <algorithm>
 #include <cassert>
@@ -70,6 +71,10 @@ class SMLoc;
 class SourceMgr;
 enum class EmitDwarfUnwindType;
 
+namespace wasm {
+struct WasmSignature;
+}
+
 /// Context object for machine code objects.  This class owns all of the
 /// sections that it creates.
 ///
@@ -139,6 +144,8 @@ class MCContext {
   SpecificBumpPtrAllocator<MCSectionXCOFF> XCOFFAllocator;
   SpecificBumpPtrAllocator<MCInst> MCInstAllocator;
 
+  SpecificBumpPtrAllocator<wasm::WasmSignature> WasmSignatureAllocator;
+
   /// Bindings of names to symbols.
   SymbolTable Symbols;
 
@@ -538,6 +545,10 @@ class MCContext {
   /// inline assembly.
   void registerInlineAsmLabel(MCSymbol *Sym);
 
+  /// Allocates and returns a new `WasmSignature` instance (with empty parameter
+  /// and return type lists).
+  wasm::WasmSignature *createWasmSignature();
+
   /// @}
 
   /// \name Section Management
@@ -850,6 +861,12 @@ class MCContext {
 
   void deallocate(void *Ptr) {}
 
+  /// Allocates a copy of the given string on the allocator managed by this
+  /// context and returns the result.
+  StringRef allocateString(StringRef s) {
+    return StringSaver(Allocator).save(s);
+  }
+
   bool hadError() { return HadError; }
   void diagnose(const SMDiagnostic &SMD);
   void reportError(SMLoc L, const Twine &Msg);
diff --git a/llvm/lib/MC/MCContext.cpp b/llvm/lib/MC/MCContext.cpp
index ba5cefaf18c1fd..3aee96fdf57fc2 100644
--- a/llvm/lib/MC/MCContext.cpp
+++ b/llvm/lib/MC/MCContext.cpp
@@ -147,6 +147,7 @@ void MCContext::reset() {
   XCOFFAllocator.DestroyAll();
   MCInstAllocator.DestroyAll();
   SPIRVAllocator.DestroyAll();
+  WasmSignatureAllocator.DestroyAll();
 
   MCSubtargetAllocator.DestroyAll();
   InlineAsmUsedLabelNames.clear();
@@ -375,6 +376,10 @@ void MCContext::registerInlineAsmLabel(MCSymbol *Sym) {
   InlineAsmUsedLabelNames[Sym->getName()] = Sym;
 }
 
+wasm::WasmSignature *MCContext::createWasmSignature() {
+  return new (WasmSignatureAllocator.Allocate()) wasm::WasmSignature;
+}
+
 MCSymbolXCOFF *
 MCContext::createXCOFFSymbolImpl(const StringMapEntry<bool> *Name,
                                  bool IsTemporary) {
diff --git a/llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmParser.cpp b/llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmParser.cpp
index 020c0d6229d22d..8e2063121e00b1 100644
--- a/llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmParser.cpp
+++ b/llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmParser.cpp
@@ -196,10 +196,6 @@ class WebAssemblyAsmParser final : public MCTargetAsmParser {
   MCAsmParser &Parser;
   MCAsmLexer &Lexer;
 
-  // Much like WebAssemblyAsmPrinter in the backend, we have to own these.
-  std::vector<std::unique_ptr<wasm::WasmSignature>> Signatures;
-  std::vector<std::unique_ptr<std::string>> Names;
-
   // Order of labels, directives and instructions in a .s file have no
   // syntactical enforcement. This class is a callback from the actual parser,
   // and yet we have to be feeding data to the streamer in a very particular
@@ -287,16 +283,6 @@ class WebAssemblyAsmParser final : public MCTargetAsmParser {
     return Parser.Error(Loc.isValid() ? Loc : Lexer.getTok().getLoc(), Msg);
   }
 
-  void addSignature(std::unique_ptr<wasm::WasmSignature> &&Sig) {
-    Signatures.push_back(std::move(Sig));
-  }
-
-  StringRef storeName(StringRef Name) {
-    std::unique_ptr<std::string> N = std::make_unique<std::string>(Name);
-    Names.push_back(std::move(N));
-    return *Names.back();
-  }
-
   std::pair<StringRef, StringRef> nestingString(NestingType NT) {
     switch (NT) {
     case Function:
@@ -640,21 +626,20 @@ class WebAssemblyAsmParser final : public MCTargetAsmParser {
       // represent as a signature, such that we can re-build this signature,
       // attach it to an anonymous symbol, which is what WasmObjectWriter
       // expects to be able to recreate the actual unique-ified type indices.
+      auto &Ctx = getContext();
       auto Loc = Parser.getTok();
-      auto Signature = std::make_unique<wasm::WasmSignature>();
-      if (parseSignature(Signature.get()))
+      auto Signature = Ctx.createWasmSignature();
+      if (parseSignature(Signature))
         return true;
       // Got signature as block type, don't need more
-      TC.setLastSig(*Signature.get());
+      TC.setLastSig(*Signature);
       if (ExpectBlockType)
-        NestingStack.back().Sig = *Signature.get();
+        NestingStack.back().Sig = *Signature;
       ExpectBlockType = false;
-      auto &Ctx = getContext();
       // The "true" here will cause this to be a nameless symbol.
       MCSymbol *Sym = Ctx.createTempSymbol("typeindex", true);
       auto *WasmSym = cast<MCSymbolWasm>(Sym);
-      WasmSym->setSignature(Signature.get());
-      addSignature(std::move(Signature));
+      WasmSym->setSignature(Signature);
       WasmSym->setType(wasm::WASM_SYMBOL_TYPE_FUNCTION);
       const MCExpr *Expr = MCSymbolRefExpr::create(
           WasmSym, MCSymbolRefExpr::VK_WASM_TYPEINDEX, Ctx);
@@ -887,12 +872,11 @@ class WebAssemblyAsmParser final : public MCTargetAsmParser {
         CurrentState = FunctionStart;
         LastFunctionLabel = WasmSym;
       }
-      auto Signature = std::make_unique<wasm::WasmSignature>();
-      if (parseSignature(Signature.get()))
+      auto Signature = Ctx.createWasmSignature();
+      if (parseSignature(Signature))
         return ParseStatus::Failure;
       TC.funcDecl(*Signature);
-      WasmSym->setSignature(Signature.get());
-      addSignature(std::move(Signature));
+      WasmSym->setSignature(Signature);
       WasmSym->setType(wasm::WASM_SYMBOL_TYPE_FUNCTION);
       TOut.emitFunctionType(WasmSym);
       // TODO: backend also calls TOut.emitIndIdx, but that is not implemented.
@@ -909,7 +893,7 @@ class WebAssemblyAsmParser final : public MCTargetAsmParser {
       if (ExportName.empty())
         return ParseStatus::Failure;
       auto WasmSym = cast<MCSymbolWasm>(Ctx.getOrCreateSymbol(SymName));
-      WasmSym->setExportName(storeName(ExportName));
+      WasmSym->setExportName(Ctx.allocateString(ExportName));
       TOut.emitExportName(WasmSym, ExportName);
       return expect(AsmToken::EndOfStatement, "EOL");
     }
@@ -924,7 +908,7 @@ class WebAssemblyAsmParser final : public MCTargetAsmParser {
       if (ImportModule.empty())
         return ParseStatus::Failure;
       auto WasmSym = cast<MCSymbolWasm>(Ctx.getOrCreateSymbol(SymName));
-      WasmSym->setImportModule(storeName(ImportModule));
+      WasmSym->setImportModule(Ctx.allocateString(ImportModule));
       TOut.emitImportModule(WasmSym, ImportModule);
       return expect(AsmToken::EndOfStatement, "EOL");
     }
@@ -939,7 +923,7 @@ class WebAssemblyAsmParser final : public MCTargetAsmParser {
       if (ImportName.empty())
         return ParseStatus::Failure;
       auto WasmSym = cast<MCSymbolWasm>(Ctx.getOrCreateSymbol(SymName));
-      WasmSym->setImportName(storeName(ImportName));
+      WasmSym->setImportName(Ctx.allocateString(ImportName));
       TOut.emitImportName(WasmSym, ImportName);
       return expect(AsmToken::EndOfStatement, "EOL");
     }
@@ -949,11 +933,10 @@ class WebAssemblyAsmParser final : public MCTargetAsmParser {
       if (SymName.empty())
         return ParseStatus::Failure;
       auto WasmSym = cast<MCSymbolWasm>(Ctx.getOrCreateSymbol(SymName));
-      auto Signature = std::make_unique<wasm::WasmSignature>();
+      auto Signature = Ctx.createWasmSignature();
       if (parseRegTypeList(Signature->Params))
         return ParseStatus::Failure;
-      WasmSym->setSignature(Signature.get());
-      addSignature(std::move(Signature));
+      WasmSym->setSignature(Signature);
       WasmSym->setType(wasm::WASM_SYMBOL_TYPE_TAG);
       TOut.emitTagType(WasmSym);
       // TODO: backend also calls TOut.emitIndIdx, but that is not implemented.
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyAsmPrinter.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyAsmPrinter.cpp
index 03897b551ecaee..3524abba8990aa 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyAsmPrinter.cpp
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyAsmPrinter.cpp
@@ -266,10 +266,10 @@ MCSymbol *WebAssemblyAsmPrinter::getOrCreateWasmSymbol(StringRef Name) {
     WasmSym->setType(wasm::WASM_SYMBOL_TYPE_FUNCTION);
     WebAssembly::getLibcallSignature(Subtarget, Name, Returns, Params);
   }
-  auto Signature = std::make_unique<wasm::WasmSignature>(std::move(Returns),
-                                                         std::move(Params));
-  WasmSym->setSignature(Signature.get());
-  addSignature(std::move(Signature));
+  auto Signature = OutContext.createWasmSignature();
+  Signature->Returns = std::move(Returns);
+  Signature->Params = std::move(Params);
+  WasmSym->setSignature(Signature);
 
   return WasmSym;
 }
@@ -338,11 +338,11 @@ void WebAssemblyAsmPrinter::emitDecls(const Module &M) {
     // and thus also contain a signature, but we need to get the signature
     // anyway here in case it is an invoke that has not yet been created. We
     // will discard it later if it turns out not to be necessary.
-    auto Signature = signatureFromMVTs(Results, Params);
+    auto Signature = signatureFromMVTs(OutContext, Results, Params);
     bool InvokeDetected = false;
     auto *Sym = getMCSymbolForFunction(
         &F, WebAssembly::WasmEnableEmEH || WebAssembly::WasmEnableEmSjLj,
-        Signature.get(), InvokeDetected);
+        Signature, InvokeDetected);
 
     // Multiple functions can be mapped to the same invoke symbol. For
     // example, two IR functions '__invoke_void_i8*' and '__invoke_void_i32'
@@ -353,11 +353,7 @@ void WebAssemblyAsmPrinter::emitDecls(const Module &M) {
 
     Sym->setType(wasm::WASM_SYMBOL_TYPE_FUNCTION);
     if (!Sym->getSignature()) {
-      Sym->setSignature(Signature.get());
-      addSignature(std::move(Signature));
-    } else {
-      // This symbol has already been created and had a signature. Discard it.
-      Signature.reset();
+      Sym->setSignature(Signature);
     }
 
     getTargetStreamer()->emitFunctionType(Sym);
@@ -365,7 +361,7 @@ void WebAssemblyAsmPrinter::emitDecls(const Module &M) {
     if (F.hasFnAttribute("wasm-import-module")) {
       StringRef Name =
           F.getFnAttribute("wasm-import-module").getValueAsString();
-      Sym->setImportModule(storeName(Name));
+      Sym->setImportModule(OutContext.allocateString(Name));
       getTargetStreamer()->emitImportModule(Sym, Name);
     }
     if (F.hasFnAttribute("wasm-import-name")) {
@@ -375,14 +371,14 @@ void WebAssemblyAsmPrinter::emitDecls(const Module &M) {
           InvokeDetected
               ? Sym->getName()
               : F.getFnAttribute("wasm-import-name").getValueAsString();
-      Sym->setImportName(storeName(Name));
+      Sym->setImportName(OutContext.allocateString(Name));
       getTargetStreamer()->emitImportName(Sym, Name);
     }
 
     if (F.hasFnAttribute("wasm-export-name")) {
       auto *Sym = cast<MCSymbolWasm>(getSymbol(&F));
       StringRef Name = F.getFnAttribute("wasm-export-name").getValueAsString();
-      Sym->setExportName(storeName(Name));
+      Sym->setExportName(OutContext.allocateString(Name));
       getTargetStreamer()->emitExportName(Sym, Name);
     }
   }
@@ -618,10 +614,9 @@ void WebAssemblyAsmPrinter::emitFunctionBodyStart() {
   SmallVector<MVT, 4> ParamVTs;
   computeSignatureVTs(F.getFunctionType(), &F, F, TM, ParamVTs, ResultVTs);
 
-  auto Signature = signatureFromMVTs(ResultVTs, ParamVTs);
+  auto Signature = signatureFromMVTs(OutContext, ResultVTs, ParamVTs);
   auto *WasmSym = cast<MCSymbolWasm>(CurrentFnSym);
-  WasmSym->setSignature(Signature.get());
-  addSignature(std::move(Signature));
+  WasmSym->setSignature(Signature);
   WasmSym->setType(wasm::WASM_SYMBOL_TYPE_FUNCTION);
 
   getTargetStreamer()->emitFunctionType(WasmSym);
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyAsmPrinter.h b/llvm/lib/Target/WebAssembly/WebAssemblyAsmPrinter.h
index c30e0155c81e72..6a544abe6ce830 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyAsmPrinter.h
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyAsmPrinter.h
@@ -22,17 +22,8 @@ class LLVM_LIBRARY_VISIBILITY WebAssemblyAsmPrinter final : public AsmPrinter {
   const WebAssemblySubtarget *Subtarget;
   const MachineRegisterInfo *MRI;
   WebAssemblyFunctionInfo *MFI;
-  // TODO: Do the uniquing of Signatures here instead of ObjectFileWriter?
-  std::vector<std::unique_ptr<wasm::WasmSignature>> Signatures;
-  std::vector<std::unique_ptr<std::string>> Names;
   bool signaturesEmitted = false;
 
-  StringRef storeName(StringRef Name) {
-    std::unique_ptr<std::string> N = std::make_unique<std::string>(Name);
-    Names.push_back(std::move(N));
-    return *Names.back();
-  }
-
 public:
   explicit WebAssemblyAsmPrinter(TargetMachine &TM,
                                  std::unique_ptr<MCStreamer> Streamer)
@@ -44,9 +35,6 @@ class LLVM_LIBRARY_VISIBILITY WebAssemblyAsmPrinter final : public AsmPrinter {
   }
 
   const WebAssemblySubtarget &getSubtarget() const { return *Subtarget; }
-  void addSignature(std::unique_ptr<wasm::WasmSignature> &&Sig) {
-    Signatures.push_back(std::move(Sig));
-  }
 
   //===------------------------------------------------------------------===//
   // MachineFunctionPass Implementation.
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyMCInstLower.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyMCInstLower.cpp
index f6e24f7aaece01..431dc7f33ac89f 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyMCInstLower.cpp
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyMCInstLower.cpp
@@ -73,14 +73,13 @@ WebAssemblyMCInstLower::GetGlobalAddressSymbol(const MachineOperand &MO) const {
   SmallVector<MVT, 4> ParamMVTs;
   const auto *const F = dyn_cast<Function>(Global);
   computeSignatureVTs(FuncTy, F, CurrentFunc, TM, ParamMVTs, ResultMVTs);
-  auto Signature = signatureFromMVTs(ResultMVTs, ParamMVTs);
+  auto Signature = signatureFromMVTs(Ctx, ResultMVTs, ParamMVTs);
 
   bool InvokeDetected = false;
   auto *WasmSym = Printer.getMCSymbolForFunction(
       F, WebAssembly::WasmEnableEmEH || WebAssembly::WasmEnableEmSjLj,
-      Signature.get(), InvokeDetected);
-  WasmSym->setSignature(Signature.get());
-  Printer.addSignature(std::move(Signature));
+      Signature, InvokeDetected);
+  WasmSym->setSignature(Signature);
   WasmSym->setType(wasm::WASM_SYMBOL_TYPE_FUNCTION);
   return WasmSym;
 }
@@ -142,12 +141,12 @@ MCOperand WebAssemblyMCInstLower::lowerSymbolOperand(const MachineOperand &MO,
 MCOperand WebAssemblyMCInstLower::lowerTypeIndexOperand(
     SmallVectorImpl<wasm::ValType> &&Returns,
     SmallVectorImpl<wasm::ValType> &&Params) const {
-  auto Signature = std::make_unique<wasm::WasmSignature>(std::move(Returns),
-                                                         std::move(Params));
+  auto Signature = Ctx.createWasmSignature();
+  Signature->Returns = std::move(Returns);
+  Signature->Params = std::move(Params);
   MCSymbol *Sym = Printer.createTempSymbol("typeindex");
   auto *WasmSym = cast<MCSymbolWasm>(Sym);
-  WasmSym->setSignature(Signature.get());
-  Printer.addSignature(std::move(Signature));
+  WasmSym->setSignature(Signature);
   WasmSym->setType(wasm::WASM_SYMBOL_TYPE_FUNCTION);
   const MCExpr *Expr =
       MCSymbolRefExpr::create(WasmSym, MCSymbolRefExpr::VK_WASM_TYPEINDEX, Ctx);
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyMachineFunctionInfo.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyMachineFunctionInfo.cpp
index d17394eede7725..6f4e7d876c693e 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyMachineFunctionInfo.cpp
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyMachineFunctionInfo.cpp
@@ -111,10 +111,10 @@ void llvm::valTypesFromMVTs(ArrayRef<MVT> In,
     Out.push_back(WebAssembly::toValType(Ty));
 }
 
-std::unique_ptr<wasm::WasmSignature>
-llvm::signatureFromMVTs(const SmallVectorImpl<MVT> &Results,
+wasm::WasmSignature *
+llvm::signatureFromMVTs(MCContext &Ctx, const SmallVectorImpl<MVT> &Results,
                         const SmallVectorImpl<MVT> &Params) {
-  auto Sig = std::make_unique<wasm::WasmSignature>();
+  auto Sig = Ctx.createWasmSignature();
   valTypesFromMVTs(Results, Sig->Returns);
   valTypesFromMVTs(Params, Sig->Params);
   return Sig;
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyMachineFunctionInfo.h b/llvm/lib/Target/WebAssembly/WebAssemblyMachineFunctionInfo.h
index 37059188a7614e..6c9824bbd5d917 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyMachineFunctionInfo.h
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyMachineFunctionInfo.h
@@ -171,9 +171,9 @@ void computeSignatureVTs(const FunctionType *Ty, const Function *TargetFunc,
 
 void valTypesFromMVTs(ArrayRef<MVT> In, SmallVectorImpl<wasm::ValType> &Out);
 
-std::unique_ptr<wasm::WasmSignature>
-signatureFromMVTs(const SmallVectorImpl<MVT> &Results,
-                  const SmallVectorImpl<MVT> &Params);
+wasm::WasmSignature *signatureFromMVTs(MCContext &Ctx,
+                                       const SmallVectorImpl<MVT> &Results,
+                                       const SmallVectorImpl<MVT> &Params);
 
 namespace yaml {
 
diff --git a/llvm/test/MC/WebAssembly/module-asm.ll b/llvm/test/MC/WebAssembly/module-asm.ll
new file mode 100644
index 00000000000000..d451bec0020900
--- /dev/null
+++ b/llvm/test/MC/WebAssembly/module-asm.ll
@@ -0,0 +1,25 @@
+; Ensure that symbols from module ASM are properly exported.
+;
+; Regression test for https://github.com/llvm/llvm-project/issues/85578.
+
+; RUN: llc -mtriple=wasm32-unknown-unknown -filetype=obj %s -o - | obj2yaml | FileCheck %s
+
+module asm "test_func:"
+module asm "    .globl test_func"
+module asm "    .functype test_func (i32) -> (i32)"
+module asm "    .export_name test_func, test_export"
+module asm "    end_function"
+
+; CHECK:       - Type:            TYPE
+; CHECK-NEXT:      Signatures:
+; CHECK-NEXT:        - Index:           0
+; CHECK-NEXT:          ParamTypes:
+; CHECK-NEXT:            - I32
+; CHECK-NEXT:          ReturnTypes:
+; CHECK-NEXT:            - I32
+
+; CHECK:        - Type:            EXPORT
+; CHECK-NEXT:     Exports:
+; CHECK-NEXT:       - Name:            test_export
+; CHECK-NEXT:         Kind:            FUNCTION
+; CHECK-NEXT:         Index:           0



More information about the llvm-commits mailing list