[Mlir-commits] [mlir] [mlir][memref][transform] Add new alloca_to_global op. (PR #66511)

Oleksandr Alex Zinenko llvmlistbot at llvm.org
Wed Sep 20 10:32:39 PDT 2023


================
@@ -126,6 +126,96 @@ void transform::ApplyResolveRankedShapedTypeResultDimsPatternsOp::
   memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns);
 }
 
+//===----------------------------------------------------------------------===//
+// AllocaToGlobalOp
+//===----------------------------------------------------------------------===//
+
+namespace {
+static llvm::SmallString<64> getUniqueSymbol(llvm::StringRef prefix,
+                                             ModuleOp module) {
+  llvm::SmallString<64> candidateNameStorage;
+  StringRef candidateName(prefix);
+  int uniqueNumber = 0;
+  while (true) {
+    if (!module.lookupSymbol(candidateName)) {
+      break;
+    }
+    candidateNameStorage.clear();
+    candidateName = (prefix + Twine("_") + Twine(uniqueNumber))
+                        .toStringRef(candidateNameStorage);
+    uniqueNumber++;
+  }
+  return candidateName;
+}
+} // namespace
+
+DiagnosedSilenceableFailure
+transform::MemRefAllocaToGlobalOp::apply(transform::TransformRewriter &rewriter,
+                                         transform::TransformResults &results,
+                                         transform::TransformState &state) {
+  auto allocaOps = state.getPayloadOps(getAlloca());
+
+  SmallVector<memref::GlobalOp> globalOps;
+  SmallVector<memref::GetGlobalOp> getGlobalOps;
+
+  // Get `builtin.module`.
+  auto moduleOps = state.getPayloadOps(getModule());
+  if (!llvm::hasSingleElement(moduleOps)) {
+    return emitDefiniteFailure()
+           << Twine("expected exactly one 'module' payload, but found ") +
+                  std::to_string(llvm::range_size(moduleOps));
+  }
+  ModuleOp module = cast<ModuleOp>(*moduleOps.begin());
+
+  // Transform `memref.alloca`s.
+  for (auto *op : allocaOps) {
+    auto alloca = cast<memref::AllocaOp>(op);
+    MLIRContext *ctx = rewriter.getContext();
+    Location loc = alloca->getLoc();
+
+    memref::GlobalOp globalOp;
+    {
+      // Insert a `memref.global` at the beginning of the module.
+      if (module != alloca->getParentOfType<ModuleOp>()) {
+        return emitDefiniteFailure()
+               << "expected 'alloca' payload to be inside 'module' payload";
+      }
+      IRRewriter::InsertionGuard guard(rewriter);
+      rewriter.setInsertionPointToStart(&module.getBodyRegion().front());
+      Type resultType = alloca.getResult().getType();
+      llvm::SmallString<64> symName = getUniqueSymbol("alloca", module);
+      // XXX: Add a better builder for this.
----------------
ftynse wrote:

Nit: XXX->TODO

https://github.com/llvm/llvm-project/pull/66511


More information about the Mlir-commits mailing list