[Mlir-commits] [mlir] [mlir][memref][transform] Add new alloca_to_global op. (PR #66511)
Ingo Müller
llvmlistbot at llvm.org
Thu Sep 21 03:13:49 PDT 2023
================
@@ -126,6 +126,96 @@ void transform::ApplyResolveRankedShapedTypeResultDimsPatternsOp::
memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns);
}
+//===----------------------------------------------------------------------===//
+// AllocaToGlobalOp
+//===----------------------------------------------------------------------===//
+
+namespace {
+static llvm::SmallString<64> getUniqueSymbol(llvm::StringRef prefix,
+ ModuleOp module) {
+ llvm::SmallString<64> candidateNameStorage;
+ StringRef candidateName(prefix);
+ int uniqueNumber = 0;
+ while (true) {
+ if (!module.lookupSymbol(candidateName)) {
+ break;
+ }
+ candidateNameStorage.clear();
+ candidateName = (prefix + Twine("_") + Twine(uniqueNumber))
+ .toStringRef(candidateNameStorage);
+ uniqueNumber++;
+ }
+ return candidateName;
+}
+} // namespace
+
+DiagnosedSilenceableFailure
+transform::MemRefAllocaToGlobalOp::apply(transform::TransformRewriter &rewriter,
+ transform::TransformResults &results,
+ transform::TransformState &state) {
+ auto allocaOps = state.getPayloadOps(getAlloca());
+
+ SmallVector<memref::GlobalOp> globalOps;
+ SmallVector<memref::GetGlobalOp> getGlobalOps;
+
+ // Get `builtin.module`.
+ auto moduleOps = state.getPayloadOps(getModule());
+ if (!llvm::hasSingleElement(moduleOps)) {
+ return emitDefiniteFailure()
+ << Twine("expected exactly one 'module' payload, but found ") +
+ std::to_string(llvm::range_size(moduleOps));
+ }
+ ModuleOp module = cast<ModuleOp>(*moduleOps.begin());
+
+ // Transform `memref.alloca`s.
+ for (auto *op : allocaOps) {
+ auto alloca = cast<memref::AllocaOp>(op);
+ MLIRContext *ctx = rewriter.getContext();
+ Location loc = alloca->getLoc();
+
+ memref::GlobalOp globalOp;
+ {
+ // Insert a `memref.global` at the beginning of the module.
+ if (module != alloca->getParentOfType<ModuleOp>()) {
+ return emitDefiniteFailure()
+ << "expected 'alloca' payload to be inside 'module' payload";
+ }
+ IRRewriter::InsertionGuard guard(rewriter);
+ rewriter.setInsertionPointToStart(&module.getBodyRegion().front());
+ Type resultType = alloca.getResult().getType();
+ llvm::SmallString<64> symName = getUniqueSymbol("alloca", module);
+ // XXX: Add a better builder for this.
+ globalOp = rewriter.create<memref::GlobalOp>(
+ loc, StringAttr::get(ctx, symName), StringAttr::get(ctx, "private"),
+ TypeAttr::get(resultType), Attribute{}, UnitAttr{}, IntegerAttr{});
+ }
+
+ // Replace the `memref.alloca` with a `memref.get_global` accessing the
+ // global symbol inserted above.
+ rewriter.setInsertionPoint(alloca);
+ auto getGlobalOp = rewriter.replaceOpWithNewOp<memref::GetGlobalOp>(
+ alloca, globalOp.getType(), globalOp.getName());
+
+ globalOps.push_back(globalOp);
+ getGlobalOps.push_back(getGlobalOp);
----------------
ingomueller-net wrote:
Done in the latest revision. Note, though, that looking up the number of alloca ops isn't a constant-time operation and calling `llvm::range_size(allocaOps)` could theoretically be more expensive than the `push_back`s.
https://github.com/llvm/llvm-project/pull/66511
More information about the Mlir-commits
mailing list