[Mlir-commits] [mlir] [mlir][memref][transform] Add new alloca_to_global op. (PR #66511)
Oleksandr Alex Zinenko
llvmlistbot at llvm.org
Wed Sep 20 10:32:39 PDT 2023
================
@@ -126,6 +126,96 @@ void transform::ApplyResolveRankedShapedTypeResultDimsPatternsOp::
memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns);
}
+//===----------------------------------------------------------------------===//
+// AllocaToGlobalOp
+//===----------------------------------------------------------------------===//
+
+namespace {
+static llvm::SmallString<64> getUniqueSymbol(llvm::StringRef prefix,
+ ModuleOp module) {
+ llvm::SmallString<64> candidateNameStorage;
+ StringRef candidateName(prefix);
+ int uniqueNumber = 0;
+ while (true) {
+ if (!module.lookupSymbol(candidateName)) {
+ break;
+ }
+ candidateNameStorage.clear();
+ candidateName = (prefix + Twine("_") + Twine(uniqueNumber))
+ .toStringRef(candidateNameStorage);
+ uniqueNumber++;
+ }
+ return candidateName;
+}
+} // namespace
+
+DiagnosedSilenceableFailure
+transform::MemRefAllocaToGlobalOp::apply(transform::TransformRewriter &rewriter,
+ transform::TransformResults &results,
+ transform::TransformState &state) {
+ auto allocaOps = state.getPayloadOps(getAlloca());
+
+ SmallVector<memref::GlobalOp> globalOps;
+ SmallVector<memref::GetGlobalOp> getGlobalOps;
+
+ // Get `builtin.module`.
+ auto moduleOps = state.getPayloadOps(getModule());
+ if (!llvm::hasSingleElement(moduleOps)) {
+ return emitDefiniteFailure()
+ << Twine("expected exactly one 'module' payload, but found ") +
+ std::to_string(llvm::range_size(moduleOps));
+ }
+ ModuleOp module = cast<ModuleOp>(*moduleOps.begin());
+
+ // Transform `memref.alloca`s.
+ for (auto *op : allocaOps) {
+ auto alloca = cast<memref::AllocaOp>(op);
+ MLIRContext *ctx = rewriter.getContext();
+ Location loc = alloca->getLoc();
+
+ memref::GlobalOp globalOp;
+ {
+ // Insert a `memref.global` at the beginning of the module.
+ if (module != alloca->getParentOfType<ModuleOp>()) {
+ return emitDefiniteFailure()
+ << "expected 'alloca' payload to be inside 'module' payload";
+ }
+ IRRewriter::InsertionGuard guard(rewriter);
+ rewriter.setInsertionPointToStart(&module.getBodyRegion().front());
+ Type resultType = alloca.getResult().getType();
+ llvm::SmallString<64> symName = getUniqueSymbol("alloca", module);
+ // XXX: Add a better builder for this.
----------------
ftynse wrote:
Nit: XXX->TODO
https://github.com/llvm/llvm-project/pull/66511
More information about the Mlir-commits
mailing list