[Mlir-commits] [mlir] [mlir][memref][transform] Add new alloca_to_global op. (PR #66511)

Oleksandr Alex Zinenko llvmlistbot at llvm.org
Thu Sep 21 06:13:10 PDT 2023


Ingo =?utf-8?q?Müller?= <ingomueller at google.com>
Message-ID:
In-Reply-To: <llvm/llvm-project/pull/66511/mlir at github.com>


================
@@ -126,6 +126,96 @@ void transform::ApplyResolveRankedShapedTypeResultDimsPatternsOp::
   memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns);
 }
 
+//===----------------------------------------------------------------------===//
+// AllocaToGlobalOp
+//===----------------------------------------------------------------------===//
+
+namespace {
+static llvm::SmallString<64> getUniqueSymbol(llvm::StringRef prefix,
+                                             ModuleOp module) {
+  llvm::SmallString<64> candidateNameStorage;
+  StringRef candidateName(prefix);
+  int uniqueNumber = 0;
+  while (true) {
+    if (!module.lookupSymbol(candidateName)) {
+      break;
+    }
+    candidateNameStorage.clear();
+    candidateName = (prefix + Twine("_") + Twine(uniqueNumber))
+                        .toStringRef(candidateNameStorage);
+    uniqueNumber++;
+  }
+  return candidateName;
+}
+} // namespace
+
+DiagnosedSilenceableFailure
+transform::MemRefAllocaToGlobalOp::apply(transform::TransformRewriter &rewriter,
+                                         transform::TransformResults &results,
+                                         transform::TransformState &state) {
+  auto allocaOps = state.getPayloadOps(getAlloca());
+
+  SmallVector<memref::GlobalOp> globalOps;
+  SmallVector<memref::GetGlobalOp> getGlobalOps;
+
+  // Get `builtin.module`.
+  auto moduleOps = state.getPayloadOps(getModule());
+  if (!llvm::hasSingleElement(moduleOps)) {
+    return emitDefiniteFailure()
+           << Twine("expected exactly one 'module' payload, but found ") +
+                  std::to_string(llvm::range_size(moduleOps));
+  }
+  ModuleOp module = cast<ModuleOp>(*moduleOps.begin());
+
+  // Transform `memref.alloca`s.
+  for (auto *op : allocaOps) {
+    auto alloca = cast<memref::AllocaOp>(op);
+    MLIRContext *ctx = rewriter.getContext();
+    Location loc = alloca->getLoc();
+
+    memref::GlobalOp globalOp;
+    {
+      // Insert a `memref.global` at the beginning of the module.
+      if (module != alloca->getParentOfType<ModuleOp>()) {
+        return emitDefiniteFailure()
+               << "expected 'alloca' payload to be inside 'module' payload";
+      }
+      IRRewriter::InsertionGuard guard(rewriter);
+      rewriter.setInsertionPointToStart(&module.getBodyRegion().front());
+      Type resultType = alloca.getResult().getType();
+      llvm::SmallString<64> symName = getUniqueSymbol("alloca", module);
+      // XXX: Add a better builder for this.
+      globalOp = rewriter.create<memref::GlobalOp>(
+          loc, StringAttr::get(ctx, symName), StringAttr::get(ctx, "private"),
+          TypeAttr::get(resultType), Attribute{}, UnitAttr{}, IntegerAttr{});
+    }
+
+    // Replace the `memref.alloca` with a `memref.get_global` accessing the
+    // global symbol inserted above.
+    rewriter.setInsertionPoint(alloca);
+    auto getGlobalOp = rewriter.replaceOpWithNewOp<memref::GetGlobalOp>(
+        alloca, globalOp.getType(), globalOp.getName());
+
+    globalOps.push_back(globalOp);
+    getGlobalOps.push_back(getGlobalOp);
----------------
ftynse wrote:

The previous version was doing it anyway IIRC. If we only need it for reserve, then let's drop the reserve.

https://github.com/llvm/llvm-project/pull/66511


More information about the Mlir-commits mailing list