[Mlir-commits] [mlir] [mlir][MLProgram] Add MLProgram to MemRef bufferization pass (PR #75103)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Dec 11 13:32:08 PST 2023


https://github.com/ryan-holt-1 created https://github.com/llvm/llvm-project/pull/75103

There is currently no lowering out of `ml_program` in the LLVM repository. This change adds a lowering to `memref` so that it can be lowered all the way to LLVM. This lowering was taken from the [reference backend in torch-mlir](https://github.com/llvm/torch-mlir/commit/f41695360019bde71d52ca7548944d5488779e12 ).

I had tried implementing the `BufferizableOpInterface` for `ml_program` instead of adding a new pass but that did not work because `OneShotBufferize` does not visit global ops outside of a function.

>From 922ed871e1c2054137720a972c1a56e4d7c9a328 Mon Sep 17 00:00:00 2001
From: ryanholt <ryanholt at mathworks.com>
Date: Mon, 11 Dec 2023 15:46:28 -0500
Subject: [PATCH] [mlir][MLProgram] Add MLProgram to MemRef bufferization pass

There is currently no lowering out of MLProgram in the LLVM repository.
This change adds a lowering to MemRef so that it can be lowered all the
way to LLVM. This lowering was taken from the reference backend in
torch-mlir:
https://github.com/llvm/torch-mlir/commit/f41695360019bde71d52ca7548944d5488779e12
I had tried implementing the BufferizableOpInterface instead of adding
a new pass but that did not work because OneShotBufferize does not
visit global ops outside of a function.
---
 .../Dialect/MLProgram/Transforms/Passes.h     |   2 +
 .../Dialect/MLProgram/Transforms/Passes.td    |   8 +
 .../MLProgram/Transforms/Bufferize.cpp        | 146 ++++++++++++++++++
 .../MLProgram/Transforms/CMakeLists.txt       |   1 +
 mlir/test/Dialect/MLProgram/bufferize.mlir    |  81 ++++++++++
 5 files changed, 238 insertions(+)
 create mode 100644 mlir/lib/Dialect/MLProgram/Transforms/Bufferize.cpp
 create mode 100644 mlir/test/Dialect/MLProgram/bufferize.mlir

diff --git a/mlir/include/mlir/Dialect/MLProgram/Transforms/Passes.h b/mlir/include/mlir/Dialect/MLProgram/Transforms/Passes.h
index 894e35e52724e..75c107d917188 100644
--- a/mlir/include/mlir/Dialect/MLProgram/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/MLProgram/Transforms/Passes.h
@@ -23,6 +23,8 @@ namespace ml_program {
 // Registration
 //===----------------------------------------------------------------------===//
 
+std::unique_ptr<OperationPass<ModuleOp>> createMLProgramBufferizePass();
+
 std::unique_ptr<OperationPass<ModuleOp>> createMLProgramPipelineGlobalsPass();
 
 /// Generate the code for registering passes.
diff --git a/mlir/include/mlir/Dialect/MLProgram/Transforms/Passes.td b/mlir/include/mlir/Dialect/MLProgram/Transforms/Passes.td
index defe8191cb905..617c24a4d8641 100644
--- a/mlir/include/mlir/Dialect/MLProgram/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/MLProgram/Transforms/Passes.td
@@ -11,6 +11,14 @@
 
 include "mlir/Pass/PassBase.td"
 
+def MLProgramBufferize: Pass<"mlprogram-bufferize", "ModuleOp"> {
+  let summary = "Bufferize the MLProgram dialect ops";
+  let constructor = "mlir::ml_program::createMLProgramBufferizePass()";
+  let dependentDialects = [
+    "bufferization::BufferizationDialect", "memref::MemRefDialect", 
+  ];
+}
+
 def MLProgramPipelineGlobals : Pass<"mlprogram-pipeline-globals", "ModuleOp"> {
   let summary = "Optimize `ml_program` global operations for read and store";
   let description = [{
diff --git a/mlir/lib/Dialect/MLProgram/Transforms/Bufferize.cpp b/mlir/lib/Dialect/MLProgram/Transforms/Bufferize.cpp
new file mode 100644
index 0000000000000..c462550c706e2
--- /dev/null
+++ b/mlir/lib/Dialect/MLProgram/Transforms/Bufferize.cpp
@@ -0,0 +1,146 @@
+//===- Bufferize.cpp - MLProgram bufferize pass ---------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a bufferization pass for the MLProgram dialect
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/MLProgram/Transforms/Passes.h"
+
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/IR/BuiltinTypes.h"
+
+namespace mlir {
+namespace ml_program {
+#define GEN_PASS_DEF_MLPROGRAMBUFFERIZE
+#include "mlir/Dialect/MLProgram/Transforms/Passes.h.inc"
+
+static LogicalResult bufferizeMLProgramGlobalOp(GlobalOp globalOp,
+                                                OpBuilder &builder) {
+  if (!globalOp.getValue().has_value())
+    return globalOp.emitError("global op must have a value");
+
+  auto tensorType = cast<RankedTensorType>(globalOp.getType());
+  auto memrefType =
+      MemRefType::get(tensorType.getShape(), tensorType.getElementType());
+
+  builder.setInsertionPointToStart(
+      globalOp->getParentOfType<ModuleOp>().getBody());
+  builder.create<memref::GlobalOp>(
+      globalOp.getLoc(), globalOp.getSymName(),
+      /*sym_visibility=*/globalOp.getSymVisibilityAttr(),
+      /*type=*/memrefType,
+      /*initial_value=*/globalOp.getValue().value(),
+      /*constant=*/!globalOp.getIsMutable(),
+      /*alignment=*/nullptr);
+  return success();
+}
+
+static LogicalResult bufferizeMLProgramGlobalLoadOp(GlobalLoadOp globalLoadOp,
+                                                    OpBuilder &builder) {
+  auto loc = globalLoadOp.getLoc();
+  auto tensorType = cast<RankedTensorType>(globalLoadOp.getType());
+  auto memrefType =
+      MemRefType::get(tensorType.getShape(), tensorType.getElementType());
+
+  builder.setInsertionPoint(globalLoadOp);
+  Value globalVal = builder.create<memref::GetGlobalOp>(
+      loc, memrefType, globalLoadOp.getGlobalAttr().getLeafReference());
+
+  // We need a copy to guarantee that the produced tensor does not alias with
+  // any other buffer.
+  Value alloc = builder.create<memref::AllocOp>(loc, memrefType, ValueRange{});
+  builder.create<memref::CopyOp>(globalLoadOp->getLoc(), globalVal, alloc);
+
+  globalVal = builder.create<bufferization::ToTensorOp>(loc, tensorType, alloc,
+                                                        /*restrict=*/true);
+  globalLoadOp->getResult(0).replaceAllUsesWith(globalVal);
+  return success();
+}
+
+static LogicalResult
+bufferizeMLProgramGlobalStoreOp(GlobalStoreOp globalStoreOp,
+                                OpBuilder &builder) {
+  auto loc = globalStoreOp.getLoc();
+  auto tensorType = cast<RankedTensorType>(globalStoreOp.getValue().getType());
+  auto memrefType =
+      MemRefType::get(tensorType.getShape(), tensorType.getElementType());
+
+  builder.setInsertionPoint(globalStoreOp);
+  Value memref = builder.create<memref::GetGlobalOp>(
+      loc, memrefType, globalStoreOp.getGlobalAttr().getLeafReference());
+  Value copyValue = builder.create<bufferization::ToMemrefOp>(
+      loc, memrefType, globalStoreOp.getValue());
+  builder.create<memref::CopyOp>(loc, copyValue, memref);
+  return success();
+}
+
+namespace {
+/// Converts MLProgram operations that work on tensor-type operands or results
+/// to work on buffers.
+class MLProgramBufferize
+    : public impl::MLProgramBufferizeBase<MLProgramBufferize> {
+  void runOnOperation() override {
+    auto module = getOperation();
+    OpBuilder builder(module.getBodyRegion());
+    SmallVector<Operation *> toErase;
+
+    auto walkResult = module.walk([&](GlobalOp op) {
+      if (auto type = dyn_cast<RankedTensorType>(op.getType())) {
+        if (!type.hasStaticShape()) {
+          // If the ml_program.global has dynamically shaped tensor.
+          op.emitError(
+              "unimplemented: global op bufferization with dynamic shape");
+          return WalkResult::interrupt();
+        }
+      } else {
+        // If the ml_program.global is of non-tensor type.
+        op.emitError("unsupported global op type");
+        return WalkResult::interrupt();
+      }
+
+      if (failed(bufferizeMLProgramGlobalOp(op, builder))) {
+        op.emitError("bufferization for this op failed");
+        return WalkResult::interrupt();
+      }
+      toErase.push_back(op);
+      return WalkResult::advance();
+    });
+
+    if (walkResult.wasInterrupted())
+      return signalPassFailure();
+
+    module.walk([&](GlobalLoadOp op) {
+      if (failed(bufferizeMLProgramGlobalLoadOp(op, builder))) {
+        op.emitError("bufferization for this op failed");
+        return;
+      }
+      toErase.push_back(op);
+    });
+
+    module.walk([&](GlobalStoreOp op) {
+      if (failed(bufferizeMLProgramGlobalStoreOp(op, builder))) {
+        op.emitError("bufferization for this op failed");
+        return;
+      }
+      toErase.push_back(op);
+    });
+
+    for (auto *op : llvm::reverse(toErase))
+      op->erase();
+  }
+};
+} // namespace
+
+std::unique_ptr<OperationPass<ModuleOp>> createMLProgramBufferizePass() {
+  return std::make_unique<MLProgramBufferize>();
+}
+} // namespace ml_program
+} // namespace mlir
diff --git a/mlir/lib/Dialect/MLProgram/Transforms/CMakeLists.txt b/mlir/lib/Dialect/MLProgram/Transforms/CMakeLists.txt
index db567b62e0e74..dc14bf212434f 100644
--- a/mlir/lib/Dialect/MLProgram/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/MLProgram/Transforms/CMakeLists.txt
@@ -1,4 +1,5 @@
 add_mlir_dialect_library(MLIRMLProgramTransforms
+  Bufferize.cpp
   PipelineGlobalOps.cpp
 
   ADDITIONAL_HEADER_DIRS
diff --git a/mlir/test/Dialect/MLProgram/bufferize.mlir b/mlir/test/Dialect/MLProgram/bufferize.mlir
new file mode 100644
index 0000000000000..5dc71803dc0cf
--- /dev/null
+++ b/mlir/test/Dialect/MLProgram/bufferize.mlir
@@ -0,0 +1,81 @@
+// RUN: mlir-opt %s --mlprogram-bufferize -split-input-file -verify-diagnostics | FileCheck %s
+
+// CHECK-LABEL: @global 
+ml_program.global private mutable @global(dense<0> : tensor<i64>) : tensor<i64>
+
+// CHECK-LABEL: @global_load_store
+func.func @global_load_store() -> i64 {
+  // CHECK-DAG: %[[CST127:.+]] = arith.constant 127
+  // CHECK-DAG: %[[GLOBAL_1:.+]] = memref.get_global @global
+  // CHECK-DAG: %[[NEW_ALLOC:.+]] = memref.alloc
+  // CHECK:     memref.copy %[[GLOBAL_1]], %[[NEW_ALLOC]]
+  // CHECK:     %[[TENSOR:.+]] = bufferization.to_tensor %[[NEW_ALLOC]]
+  // CHECK:     %[[EXTRACTED:.+]] = tensor.extract %[[TENSOR]][]
+  // CHECK:     %[[NEW_VALUE:.+]] = arith.muli %[[EXTRACTED]], %[[CST127]]
+  // CHECK:     %[[INSERTED:.+]] = tensor.insert %[[NEW_VALUE]] into %[[TENSOR]][]
+  // CHECK:     %[[GLOBAL_2:.+]] = memref.get_global @global
+  // CHECK:     %[[MEMREF:.+]] = bufferization.to_memref %[[INSERTED]]
+  // CHECK:     memref.copy %[[MEMREF]], %[[GLOBAL_2]]
+  // CHECK:     return %[[NEW_VALUE]]
+  %c127_i64 = arith.constant 127 : i64
+  %0 = ml_program.global_load @global : tensor<i64>
+  %extracted = tensor.extract %0[] : tensor<i64>
+  %1 = arith.muli %extracted, %c127_i64 : i64
+  %inserted = tensor.insert %1 into %0[] : tensor<i64>
+  ml_program.global_store @global = %inserted : tensor<i64>
+  return %1 : i64
+}
+
+// -----
+
+// expected-error @below {{unsupported global op type}}
+ml_program.global private mutable @global(0 : i64) : i64
+
+func.func @global_scalar() -> i64 {
+  %c127_i64 = arith.constant 127 : i64
+  %0 = ml_program.global_load @global : i64
+  %1 = arith.muli %0, %c127_i64 : i64
+  ml_program.global_store @global = %1 : i64
+  return %1 : i64
+}
+
+// -----
+
+// expected-error @below {{unsupported global op type}}
+ml_program.global private mutable @global(dense<0> : memref<i64>) : memref<i64>
+
+func.func @global_memref() -> i64 {
+  %c127_i64 = arith.constant 127 : i64
+  %0 = ml_program.global_load @global : memref<i64>
+  %extracted = memref.load %0[] : memref<i64>
+  %1 = arith.muli %extracted, %c127_i64 : i64
+  memref.store %1, %0[] : memref<i64>
+  ml_program.global_store @global = %0 : memref<i64>
+  return %1 : i64
+}
+
+// -----
+
+// expected-error @below {{invalid tensor element type}}
+ml_program.global private mutable @global(dense<0> : tensor<memref<i64>>) : tensor<memref<i64>>
+
+func.func @global_tensor_of_memref() -> i64 {
+  %c127_i64 = arith.constant 127 : i64
+  return %c127_i64 : i64
+}
+
+// -----
+
+// expected-error @below {{unimplemented: global op bufferization with dynamic shape}}
+ml_program.global private mutable @global(dense<0> : tensor<1xi64>) : tensor<?xi64>
+
+func.func @global_dynamic_shape() -> i64 {
+  %c127_i64 = arith.constant 127 : i64
+  %c0 = arith.constant 0 : index
+  %0 = ml_program.global_load @global : tensor<?xi64>
+  %extracted = tensor.extract %0[%c0] : tensor<?xi64>
+  %1 = arith.muli %extracted, %c127_i64 : i64
+  %inserted = tensor.insert %1 into %0[%c0] : tensor<?xi64>
+  ml_program.global_store @global = %inserted : tensor<?xi64>
+  return %1 : i64
+}



More information about the Mlir-commits mailing list