[Mlir-commits] [mlir] [mlir][MLProgram] Add MLProgram to MemRef bufferization pass (PR #75103)

Mehdi Amini llvmlistbot at llvm.org
Mon Dec 11 21:43:39 PST 2023


================
@@ -0,0 +1,146 @@
+//===- Bufferize.cpp - MLProgram bufferize pass ---------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a bufferization pass for the MLProgram dialect
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/MLProgram/Transforms/Passes.h"
+
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/IR/BuiltinTypes.h"
+
+namespace mlir {
+namespace ml_program {
+#define GEN_PASS_DEF_MLPROGRAMBUFFERIZE
+#include "mlir/Dialect/MLProgram/Transforms/Passes.h.inc"
+
+static LogicalResult bufferizeMLProgramGlobalOp(GlobalOp globalOp,
+                                                OpBuilder &builder) {
+  if (!globalOp.getValue().has_value())
+    return globalOp.emitError("global op must have a value");
+
+  auto tensorType = cast<RankedTensorType>(globalOp.getType());
+  auto memrefType =
+      MemRefType::get(tensorType.getShape(), tensorType.getElementType());
+
+  builder.setInsertionPointToStart(
+      globalOp->getParentOfType<ModuleOp>().getBody());
+  builder.create<memref::GlobalOp>(
+      globalOp.getLoc(), globalOp.getSymName(),
+      /*sym_visibility=*/globalOp.getSymVisibilityAttr(),
+      /*type=*/memrefType,
+      /*initial_value=*/globalOp.getValue().value(),
+      /*constant=*/!globalOp.getIsMutable(),
+      /*alignment=*/nullptr);
+  return success();
+}
+
+static LogicalResult bufferizeMLProgramGlobalLoadOp(GlobalLoadOp globalLoadOp,
+                                                    OpBuilder &builder) {
+  auto loc = globalLoadOp.getLoc();
+  auto tensorType = cast<RankedTensorType>(globalLoadOp.getType());
+  auto memrefType =
+      MemRefType::get(tensorType.getShape(), tensorType.getElementType());
+
+  builder.setInsertionPoint(globalLoadOp);
+  Value globalVal = builder.create<memref::GetGlobalOp>(
+      loc, memrefType, globalLoadOp.getGlobalAttr().getLeafReference());
+
+  // We need a copy to guarantee that the produced tensor does not alias with
+  // any other buffer.
+  Value alloc = builder.create<memref::AllocOp>(loc, memrefType, ValueRange{});
+  builder.create<memref::CopyOp>(globalLoadOp->getLoc(), globalVal, alloc);
+
+  globalVal = builder.create<bufferization::ToTensorOp>(loc, tensorType, alloc,
+                                                        /*restrict=*/true);
+  globalLoadOp->getResult(0).replaceAllUsesWith(globalVal);
+  return success();
+}
+
+static LogicalResult
+bufferizeMLProgramGlobalStoreOp(GlobalStoreOp globalStoreOp,
+                                OpBuilder &builder) {
+  auto loc = globalStoreOp.getLoc();
+  auto tensorType = cast<RankedTensorType>(globalStoreOp.getValue().getType());
+  auto memrefType =
+      MemRefType::get(tensorType.getShape(), tensorType.getElementType());
+
+  builder.setInsertionPoint(globalStoreOp);
+  Value memref = builder.create<memref::GetGlobalOp>(
+      loc, memrefType, globalStoreOp.getGlobalAttr().getLeafReference());
+  Value copyValue = builder.create<bufferization::ToMemrefOp>(
+      loc, memrefType, globalStoreOp.getValue());
+  builder.create<memref::CopyOp>(loc, copyValue, memref);
+  return success();
+}
+
+namespace {
+/// Converts MLProgram operations that work on tensor-type operands or results
+/// to work on buffers.
+class MLProgramBufferize
+    : public impl::MLProgramBufferizeBase<MLProgramBufferize> {
+  void runOnOperation() override {
+    auto module = getOperation();
+    OpBuilder builder(module.getBodyRegion());
+    SmallVector<Operation *> toErase;
+
+    auto walkResult = module.walk([&](GlobalOp op) {
+      if (auto type = dyn_cast<RankedTensorType>(op.getType())) {
+        if (!type.hasStaticShape()) {
+          // If the ml_program.global has dynamically shaped tensor.
+          op.emitError(
+              "unimplemented: global op bufferization with dynamic shape");
+          return WalkResult::interrupt();
+        }
+      } else {
+        // If the ml_program.global is of non-tensor type.
+        op.emitError("unsupported global op type");
+        return WalkResult::interrupt();
+      }
+
+      if (failed(bufferizeMLProgramGlobalOp(op, builder))) {
+        op.emitError("bufferization for this op failed");
+        return WalkResult::interrupt();
+      }
+      toErase.push_back(op);
+      return WalkResult::advance();
+    });
+
+    if (walkResult.wasInterrupted())
+      return signalPassFailure();
+
+    module.walk([&](GlobalLoadOp op) {
+      if (failed(bufferizeMLProgramGlobalLoadOp(op, builder))) {
+        op.emitError("bufferization for this op failed");
+        return;
+      }
+      toErase.push_back(op);
+    });
+
+    module.walk([&](GlobalStoreOp op) {
+      if (failed(bufferizeMLProgramGlobalStoreOp(op, builder))) {
+        op.emitError("bufferization for this op failed");
+        return;
+      }
+      toErase.push_back(op);
+    });
+
+    for (auto *op : llvm::reverse(toErase))
+      op->erase();
+  }
+};
+} // namespace
+
+std::unique_ptr<OperationPass<ModuleOp>> createMLProgramBufferizePass() {
+  return std::make_unique<MLProgramBufferize>();
+}
----------------
joker-eph wrote:

(similarly: to be removed)

https://github.com/llvm/llvm-project/pull/75103


More information about the Mlir-commits mailing list