[Mlir-commits] [mlir] [mlir][MLProgram] Add MLProgram to MemRef bufferization pass (PR #75103)

Ryan Holt llvmlistbot at llvm.org
Mon Jan 29 20:07:16 PST 2024


https://github.com/ryan-holt-1 updated https://github.com/llvm/llvm-project/pull/75103

>From 38227add6ed579190817826d1416f04936123761 Mon Sep 17 00:00:00 2001
From: ryanholt <ryanholt at mathworks.com>
Date: Mon, 11 Dec 2023 15:46:28 -0500
Subject: [PATCH] [mlir][MLProgram] Implement BufferizableOpInterface

This commit implements the `BufferizableOpInterface` for
`ml_program.global`, `ml_program.global_load` and
`ml_program.global_store` so that these ops can be lowered all the way
to LLVM.
---
 .../Transforms/BufferizableOpInterfaceImpl.h  |  20 +++
 mlir/include/mlir/InitAllDialects.h           |   2 +
 .../Bufferization/Transforms/Bufferize.cpp    |   4 +
 .../Transforms/OneShotModuleBufferize.cpp     |   2 +-
 .../BufferizableOpInterfaceImpl.cpp           | 159 ++++++++++++++++++
 .../MLProgram/Transforms/CMakeLists.txt       |   1 +
 .../Dialect/MLProgram/one-shot-bufferize.mlir |  52 ++++++
 7 files changed, 239 insertions(+), 1 deletion(-)
 create mode 100644 mlir/include/mlir/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.h
 create mode 100644 mlir/lib/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.cpp
 create mode 100644 mlir/test/Dialect/MLProgram/one-shot-bufferize.mlir

diff --git a/mlir/include/mlir/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.h b/mlir/include/mlir/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.h
new file mode 100644
index 0000000000000..ca541238cf63b
--- /dev/null
+++ b/mlir/include/mlir/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.h
@@ -0,0 +1,20 @@
+//===- BufferizableOpInterfaceImpl.h - Impl. of BufferizableOpInterface ---===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_MLPROGRAM_BUFFERIZABLEOPINTERFACEIMPL_H
+#define MLIR_DIALECT_MLPROGRAM_BUFFERIZABLEOPINTERFACEIMPL_H
+
+namespace mlir {
+class DialectRegistry;
+
+namespace ml_program {
+void registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry);
+} // namespace ml_program
+} // namespace mlir
+
+#endif // MLIR_DIALECT_MLPROGRAM_BUFFERIZABLEOPINTERFACEIMPL_H
diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index 19a62cadaa2e0..0d21ecb4ebb44 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -48,6 +48,7 @@
 #include "mlir/Dialect/Linalg/Transforms/SubsetInsertionOpInterfaceImpl.h"
 #include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h"
 #include "mlir/Dialect/MLProgram/IR/MLProgram.h"
+#include "mlir/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.h"
 #include "mlir/Dialect/Math/IR/Math.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/MemRef/IR/MemRefMemorySlot.h"
@@ -160,6 +161,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
   memref::registerRuntimeVerifiableOpInterfaceExternalModels(registry);
   memref::registerValueBoundsOpInterfaceExternalModels(registry);
   memref::registerMemorySlotExternalModels(registry);
+  ml_program::registerBufferizableOpInterfaceExternalModels(registry);
   scf::registerBufferDeallocationOpInterfaceExternalModels(registry);
   scf::registerBufferizableOpInterfaceExternalModels(registry);
   scf::registerValueBoundsOpInterfaceExternalModels(registry);
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
index 3f1626a6af34d..a151ba94ff910 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
@@ -494,6 +494,10 @@ LogicalResult bufferization::bufferizeOp(Operation *op,
                << "\n//===-------------------------------------------===//\n");
   }
 
+  // Return early if the top-level op is entirely gone.
+  if (erasedOps.contains(op))
+    return success();
+
   // Fold all to_memref(to_tensor(x)) pairs.
   for (Operation *op : toMemrefOps) {
     rewriter.setInsertionPoint(op);
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
index aeda995fd585a..33feea0b956ca 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
@@ -459,7 +459,7 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
   }
 
   // Bufferize all other ops.
-  for (Operation &op : moduleOp.getOps()) {
+  for (Operation &op : llvm::make_early_inc_range(moduleOp.getOps())) {
     // Functions were already bufferized.
     if (isa<func::FuncOp>(&op))
       continue;
diff --git a/mlir/lib/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.cpp
new file mode 100644
index 0000000000000..54759099fc4aa
--- /dev/null
+++ b/mlir/lib/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -0,0 +1,159 @@
+//===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"
+
+#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
+#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+
+using namespace mlir;
+using namespace mlir::bufferization;
+using namespace mlir::ml_program;
+
+namespace mlir {
+namespace ml_program {
+namespace {
+
+template <typename Interface, typename Op>
+struct ExternalModelBase
+    : public BufferizableOpInterface::ExternalModel<Interface, Op> {
+
+  AliasingValueList getAliasingValues(Operation *, OpOperand &,
+                                      const AnalysisState &) const {
+    return {};
+  }
+
+  BufferRelation bufferRelation(Operation *, OpResult,
+                                const AnalysisState &) const {
+    return BufferRelation::Unknown;
+  }
+};
+
+/// Bufferization of ml_program.global into a memref.global
+struct GlobalOpInterface
+    : public ExternalModelBase<GlobalOpInterface, GlobalOp> {
+
+  bool bufferizesToMemoryRead(Operation *, OpOperand &,
+                              const AnalysisState &) const {
+    return false;
+  }
+
+  bool bufferizesToMemoryWrite(Operation *, OpOperand &,
+                               const AnalysisState &) const {
+    return false;
+  }
+
+  bool hasTensorSemantics(Operation *) const { return true; }
+
+  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
+                          const BufferizationOptions &) const {
+    auto globalOp = cast<GlobalOp>(op);
+    if (!globalOp.getValue().has_value())
+      return globalOp.emitError("global op must have a value");
+
+    auto tensorType = cast<TensorType>(globalOp.getType());
+    auto memrefType = getMemRefTypeWithStaticIdentityLayout(tensorType);
+
+    replaceOpWithNewBufferizedOp<memref::GlobalOp>(
+        rewriter, globalOp, globalOp.getSymName(),
+        /*sym_visibility=*/globalOp.getSymVisibilityAttr(),
+        /*type=*/cast<MemRefType>(memrefType),
+        /*initial_value=*/globalOp.getValue().value(),
+        /*constant=*/!globalOp.getIsMutable(),
+        /*alignment=*/nullptr);
+    return success();
+  }
+};
+
+/// Bufferization of ml_program.global_load into a memref.get_global
+struct GlobalLoadOpInterface
+    : public ExternalModelBase<GlobalLoadOpInterface, GlobalLoadOp> {
+
+  bool bufferizesToMemoryRead(Operation *, OpOperand &,
+                              const AnalysisState &) const {
+    return false;
+  }
+
+  bool bufferizesToMemoryWrite(Operation *, OpOperand &,
+                               const AnalysisState &) const {
+    return false;
+  }
+
+  bool isWritable(Operation *, Value, const AnalysisState &) const {
+    return false;
+  }
+
+  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
+                          const BufferizationOptions &) const {
+    auto globalLoadOp = cast<GlobalLoadOp>(op);
+
+    auto tensorType = cast<TensorType>(globalLoadOp.getType());
+    auto memrefType = getMemRefTypeWithStaticIdentityLayout(tensorType);
+
+    replaceOpWithNewBufferizedOp<memref::GetGlobalOp>(
+        rewriter, globalLoadOp, memrefType,
+        globalLoadOp.getGlobalAttr().getLeafReference());
+
+    return success();
+  }
+};
+
+/// Bufferization of ml_program.global_store into a memref.get_global and
+/// memcpy
+struct GlobalStoreOpInterface
+    : public ExternalModelBase<GlobalStoreOpInterface, GlobalStoreOp> {
+
+  bool bufferizesToMemoryRead(Operation *, OpOperand &,
+                              const AnalysisState &) const {
+    return false;
+  }
+
+  bool bufferizesToMemoryWrite(Operation *, OpOperand &,
+                               const AnalysisState &) const {
+    return true;
+  }
+
+  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
+                          const BufferizationOptions &options) const {
+    auto globalStoreOp = cast<GlobalStoreOp>(op);
+
+    auto tensorType = cast<TensorType>(globalStoreOp.getValue().getType());
+    auto memrefType = getMemRefTypeWithStaticIdentityLayout(tensorType);
+
+    auto loc = globalStoreOp.getLoc();
+    auto targetMemref = rewriter.create<memref::GetGlobalOp>(
+        loc, memrefType, globalStoreOp.getGlobalAttr().getLeafReference());
+
+    auto sourceMemref = getBuffer(rewriter, globalStoreOp.getValue(), options);
+    if (failed(sourceMemref)) {
+      return failure();
+    }
+
+    auto memcpy =
+        options.createMemCpy(rewriter, loc, sourceMemref.value(), targetMemref);
+    if (failed(memcpy)) {
+      return failure();
+    }
+    rewriter.eraseOp(globalStoreOp);
+
+    return success();
+  }
+};
+
+void registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry) {
+  registry.addExtension(+[](MLIRContext *ctx, MLProgramDialect *) {
+    GlobalOp::attachInterface<GlobalOpInterface>(*ctx);
+    GlobalLoadOp::attachInterface<GlobalLoadOpInterface>(*ctx);
+    GlobalStoreOp::attachInterface<GlobalStoreOpInterface>(*ctx);
+  });
+}
+
+} // namespace
+} // namespace ml_program
+} // namespace mlir
diff --git a/mlir/lib/Dialect/MLProgram/Transforms/CMakeLists.txt b/mlir/lib/Dialect/MLProgram/Transforms/CMakeLists.txt
index db567b62e0e74..53ca492339f50 100644
--- a/mlir/lib/Dialect/MLProgram/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/MLProgram/Transforms/CMakeLists.txt
@@ -1,4 +1,5 @@
 add_mlir_dialect_library(MLIRMLProgramTransforms
+  BufferizableOpInterfaceImpl.cpp
   PipelineGlobalOps.cpp
 
   ADDITIONAL_HEADER_DIRS
diff --git a/mlir/test/Dialect/MLProgram/one-shot-bufferize.mlir b/mlir/test/Dialect/MLProgram/one-shot-bufferize.mlir
new file mode 100644
index 0000000000000..442c0854f7af8
--- /dev/null
+++ b/mlir/test/Dialect/MLProgram/one-shot-bufferize.mlir
@@ -0,0 +1,52 @@
+// RUN: mlir-opt %s -one-shot-bufferize -split-input-file | FileCheck %s
+
+// CHECK-LABEL: memref.global "private" @global 
+ml_program.global private mutable @global(dense<0> : tensor<i64>) : tensor<i64>
+
+// CHECK-LABEL: func.func @global_load_store
+func.func @global_load_store() -> i64 {
+// CHECK-DAG: %[[CST127:.*]] = arith.constant 127
+// CHECK-DAG: %[[GLOBAL_1:.*]] = memref.get_global @global
+// CHECK:     %[[VALUE:.*]] = memref.load %[[GLOBAL_1]][]
+// CHECK:     %[[NEW_VALUE:.*]] = arith.muli %[[VALUE]], %[[CST127]]
+// CHECK:     %[[ALLOC:.*]] = memref.alloc()
+// CHECK:     memref.copy %[[GLOBAL_1]], %[[ALLOC]]
+// CHECK:     memref.store %[[NEW_VALUE]], %[[ALLOC]][]
+// CHECK:     %[[GLOBAL_2:.*]] = memref.get_global @global
+// CHECK:     memref.copy %[[ALLOC]], %[[GLOBAL_2]]
+// CHECK:     return %[[NEW_VALUE]]
+  %c127 = arith.constant 127 : i64
+  %0 = ml_program.global_load @global : tensor<i64>
+  %extracted = tensor.extract %0[] : tensor<i64>
+  %1 = arith.muli %extracted, %c127 : i64
+  %inserted = tensor.insert %1 into %0[] : tensor<i64>
+  ml_program.global_store @global = %inserted : tensor<i64>
+  return %1 : i64
+}
+
+// -----
+
+// CHECK-LABEL: memref.global "private" @global 
+ml_program.global private mutable @global(dense<0> : tensor<i64>) : tensor<i64>
+
+// CHECK-LABEL: func.func @raw_hazard
+func.func @raw_hazard() -> i64 {
+// CHECK-DAG: %[[CST127:.*]] = arith.constant 127
+// CHECK-DAG: %[[GLOBAL_1:.*]] = memref.get_global @global
+// CHECK-DAG: %[[GLOBAL_2:.*]] = memref.get_global @global
+// CHECK-DAG: %[[ALLOC:.*]] = memref.alloc()
+// CHECK:     memref.copy %[[GLOBAL_1]], %[[ALLOC]]
+// CHECK:     memref.store %[[CST127]], %[[ALLOC]][]
+// CHECK:     %[[VAL:.*]] = memref.load %[[GLOBAL_2]][]
+// CHECK:     %[[GLOBAL_3:.*]] = memref.get_global @global
+// CHECK:     memref.copy %[[ALLOC]], %[[GLOBAL_3]]
+// CHECK:     return %[[VAL]]
+  %c127 = arith.constant 127 : i64
+  %0 = ml_program.global_load @global : tensor<i64>
+  %1 = ml_program.global_load @global : tensor<i64>
+  %inserted = tensor.insert %c127 into %0[] : tensor<i64>
+  %extracted = tensor.extract %1[] : tensor<i64>
+  ml_program.global_store @global = %inserted : tensor<i64>
+  return %extracted : i64
+}
+



More information about the Mlir-commits mailing list