[Mlir-commits] [mlir] cbd4750 - [mlir][mlprogram] Add `mlprogram-pipeline-globals` optimization pass

Rob Suderman llvmlistbot at llvm.org
Mon Sep 18 17:11:46 PDT 2023


Author: Rob Suderman
Date: 2023-09-18T17:11:29-07:00
New Revision: cbd475040f8952cfc55b9e13dd5ce6c4f6434cd3

URL: https://github.com/llvm/llvm-project/commit/cbd475040f8952cfc55b9e13dd5ce6c4f6434cd3
DIFF: https://github.com/llvm/llvm-project/commit/cbd475040f8952cfc55b9e13dd5ce6c4f6434cd3.diff

LOG: [mlir][mlprogram] Add `mlprogram-pipeline-globals` optimization pass

Added pass optimizes MLProgram global operations by reducing to only
the minimal load/store operations for global tensors. This avoids
unnecessary global operations throughout a program and potentially
improves operation gusion.

Reviewed By: jpienaar

Differential Revision: https://reviews.llvm.org/D159228

Added: 
    mlir/include/mlir/Dialect/MLProgram/Transforms/CMakeLists.txt
    mlir/include/mlir/Dialect/MLProgram/Transforms/Passes.h
    mlir/include/mlir/Dialect/MLProgram/Transforms/Passes.td
    mlir/lib/Dialect/MLProgram/Transforms/CMakeLists.txt
    mlir/lib/Dialect/MLProgram/Transforms/PipelineGlobalOps.cpp
    mlir/test/Dialect/MLProgram/pipeline-globals.mlir

Modified: 
    mlir/include/mlir/Dialect/MLProgram/CMakeLists.txt
    mlir/include/mlir/InitAllPasses.h
    mlir/lib/Dialect/MLProgram/CMakeLists.txt
    mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/MLProgram/CMakeLists.txt b/mlir/include/mlir/Dialect/MLProgram/CMakeLists.txt
index f33061b2d87cffc..9f57627c321fb0c 100644
--- a/mlir/include/mlir/Dialect/MLProgram/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/MLProgram/CMakeLists.txt
@@ -1 +1,2 @@
 add_subdirectory(IR)
+add_subdirectory(Transforms)

diff  --git a/mlir/include/mlir/Dialect/MLProgram/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/MLProgram/Transforms/CMakeLists.txt
new file mode 100644
index 000000000000000..c5c11f17a9fa975
--- /dev/null
+++ b/mlir/include/mlir/Dialect/MLProgram/Transforms/CMakeLists.txt
@@ -0,0 +1,6 @@
+set(LLVM_TARGET_DEFINITIONS Passes.td)
+mlir_tablegen(Passes.h.inc -gen-pass-decls -name MLProgram)
+add_public_tablegen_target(MLIRMLProgramPassIncGen)
+add_dependencies(mlir-headers MLIRMLProgramPassIncGen)
+
+add_mlir_doc(Passes MLProgramPasses ./ -gen-pass-doc)

diff  --git a/mlir/include/mlir/Dialect/MLProgram/Transforms/Passes.h b/mlir/include/mlir/Dialect/MLProgram/Transforms/Passes.h
new file mode 100644
index 000000000000000..894e35e52724e90
--- /dev/null
+++ b/mlir/include/mlir/Dialect/MLProgram/Transforms/Passes.h
@@ -0,0 +1,35 @@
+//===- Passes.h - Pass Entrypoints ------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_MLPROGRAM_TRANSFORMS_PASSES_H_
+#define MLIR_DIALECT_MLPROGRAM_TRANSFORMS_PASSES_H_
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+namespace ml_program {
+
+#define GEN_PASS_DECL
+#include "mlir/Dialect/MLProgram/Transforms/Passes.h.inc"
+
+//===----------------------------------------------------------------------===//
+// Registration
+//===----------------------------------------------------------------------===//
+
+std::unique_ptr<OperationPass<ModuleOp>> createMLProgramPipelineGlobalsPass();
+
+/// Generate the code for registering passes.
+#define GEN_PASS_REGISTRATION
+#include "mlir/Dialect/MLProgram/Transforms/Passes.h.inc"
+
+} // namespace ml_program
+} // namespace mlir
+
+#endif // MLIR_DIALECT_MLPROGRAM_TRANSFORMS_PASSES_H_

diff  --git a/mlir/include/mlir/Dialect/MLProgram/Transforms/Passes.td b/mlir/include/mlir/Dialect/MLProgram/Transforms/Passes.td
new file mode 100644
index 000000000000000..defe8191cb905df
--- /dev/null
+++ b/mlir/include/mlir/Dialect/MLProgram/Transforms/Passes.td
@@ -0,0 +1,27 @@
+//===-- Passes.td - pass definition file -------------------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_MLPROGRAM_TRANSFORMS_PASSES
+#define MLIR_DIALECT_MLPROGRAM_TRANSFORMS_PASSES
+
+include "mlir/Pass/PassBase.td"
+
+def MLProgramPipelineGlobals : Pass<"mlprogram-pipeline-globals", "ModuleOp"> {
+  let summary = "Optimize `ml_program` global operations for read and store";
+  let description = [{
+    `ml_program`'s load and store operations can be optimized for
+    write-write or write-read sets of operations. This allows known
+    tensors to not be re-read when the value is already known in IR.
+
+    The pass is designed to handle both nested regions and function calls
+    safely.
+  }];
+  let constructor = "mlir::ml_program::createMLProgramPipelineGlobalsPass()";
+}
+
+#endif // MLIR_DIALECT_MLPROGRAM_TRANSFORMS_PASSES

diff  --git a/mlir/include/mlir/InitAllPasses.h b/mlir/include/mlir/InitAllPasses.h
index 8a45da7d1b982f1..5489a13a8040bdb 100644
--- a/mlir/include/mlir/InitAllPasses.h
+++ b/mlir/include/mlir/InitAllPasses.h
@@ -26,6 +26,7 @@
 #include "mlir/Dialect/GPU/Transforms/Passes.h"
 #include "mlir/Dialect/LLVMIR/Transforms/Passes.h"
 #include "mlir/Dialect/Linalg/Passes.h"
+#include "mlir/Dialect/MLProgram/Transforms/Passes.h"
 #include "mlir/Dialect/Math/Transforms/Passes.h"
 #include "mlir/Dialect/MemRef/Transforms/Passes.h"
 #include "mlir/Dialect/NVGPU/Transforms/Passes.h"
@@ -72,6 +73,7 @@ inline void registerAllPasses() {
   LLVM::registerLLVMPasses();
   math::registerMathPasses();
   memref::registerMemRefPasses();
+  ml_program::registerMLProgramPasses();
   registerSCFPasses();
   registerShapePasses();
   spirv::registerSPIRVPasses();

diff  --git a/mlir/lib/Dialect/MLProgram/CMakeLists.txt b/mlir/lib/Dialect/MLProgram/CMakeLists.txt
index f33061b2d87cffc..9f57627c321fb0c 100644
--- a/mlir/lib/Dialect/MLProgram/CMakeLists.txt
+++ b/mlir/lib/Dialect/MLProgram/CMakeLists.txt
@@ -1 +1,2 @@
 add_subdirectory(IR)
+add_subdirectory(Transforms)

diff  --git a/mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp b/mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp
index f8f754956603959..5352b7b0454fd19 100644
--- a/mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp
+++ b/mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp
@@ -178,8 +178,14 @@ LogicalResult GlobalOp::verify() {
 //===----------------------------------------------------------------------===//
 
 GlobalOp GlobalLoadOp::getGlobalOp(SymbolTableCollection &symbolTable) {
-  return symbolTable.lookupNearestSymbolFrom<GlobalOp>(
-      getOperation()->getParentOp(), getGlobalAttr());
+  for (auto parent = getOperation()->getParentOp(); parent;
+       parent = parent->getParentOp()) {
+    if (auto nearest = symbolTable.lookupNearestSymbolFrom<GlobalOp>(
+            parent, getGlobalAttr())) {
+      return nearest;
+    }
+  }
+  return {};
 }
 
 LogicalResult
@@ -253,8 +259,14 @@ GlobalLoadGraphOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
 //===----------------------------------------------------------------------===//
 
 GlobalOp GlobalStoreOp::getGlobalOp(SymbolTableCollection &symbolTable) {
-  return symbolTable.lookupNearestSymbolFrom<GlobalOp>(
-      getOperation()->getParentOp(), getGlobalAttr());
+  for (auto parent = getOperation()->getParentOp(); parent;) {
+    if (auto nearest = symbolTable.lookupNearestSymbolFrom<GlobalOp>(
+            parent, getGlobalAttr())) {
+      return nearest;
+    }
+    parent = parent->getParentOp();
+  }
+  return {};
 }
 
 LogicalResult

diff  --git a/mlir/lib/Dialect/MLProgram/Transforms/CMakeLists.txt b/mlir/lib/Dialect/MLProgram/Transforms/CMakeLists.txt
new file mode 100644
index 000000000000000..db567b62e0e747c
--- /dev/null
+++ b/mlir/lib/Dialect/MLProgram/Transforms/CMakeLists.txt
@@ -0,0 +1,14 @@
+add_mlir_dialect_library(MLIRMLProgramTransforms
+  PipelineGlobalOps.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/MLProgram/Transforms
+
+  DEPENDS
+  MLIRMLProgramPassIncGen
+
+  LINK_LIBS PUBLIC
+  MLIRIR
+  MLIRMLProgramDialect
+  MLIRPass
+)

diff  --git a/mlir/lib/Dialect/MLProgram/Transforms/PipelineGlobalOps.cpp b/mlir/lib/Dialect/MLProgram/Transforms/PipelineGlobalOps.cpp
new file mode 100644
index 000000000000000..7e00a731f6d731e
--- /dev/null
+++ b/mlir/lib/Dialect/MLProgram/Transforms/PipelineGlobalOps.cpp
@@ -0,0 +1,234 @@
+//===- PipelineGlobalOpsPass.cpp - Pipeline Global Ops Pass ---------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/MLProgram/Transforms/Passes.h"
+
+#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
+#include "mlir/Dialect/MLProgram/Transforms/Passes.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir {
+namespace ml_program {
+#define GEN_PASS_DEF_MLPROGRAMPIPELINEGLOBALS
+#include "mlir/Dialect/MLProgram/Transforms/Passes.h.inc"
+
+namespace {
+
+class MLProgramPipelineGlobals
+    : public impl::MLProgramPipelineGlobalsBase<MLProgramPipelineGlobals> {
+public:
+  void runOnOperation() override;
+
+private:
+  LogicalResult buildGlobalMap(ModuleOp op);
+
+  void ProcessBlock(Block &block, llvm::DenseSet<SymbolRefAttr> &symbolLoad,
+                    llvm::DenseSet<SymbolRefAttr> &symbolStore);
+
+  llvm::DenseMap<SymbolRefAttr, llvm::DenseSet<SymbolRefAttr>> loadSymbolsMap;
+  llvm::DenseMap<SymbolRefAttr, llvm::DenseSet<SymbolRefAttr>> storeSymbolsMap;
+};
+
+// Traverses upwards searchign for the operation mapped by the symbol.
+static Operation *getFromSymbol(Operation *baseOp, SymbolRefAttr symbol) {
+  for (auto op = baseOp; op; op = op->getParentOp()) {
+    auto lookup = SymbolTable::lookupNearestSymbolFrom(op, symbol);
+    if (lookup)
+      return lookup;
+  }
+  return nullptr;
+}
+
+// Builds map from a symbol to MLProgram global symbols loaded or stored
+// during processing.
+LogicalResult MLProgramPipelineGlobals::buildGlobalMap(ModuleOp module) {
+  llvm::DenseMap<SymbolRefAttr, Operation *> callableMap;
+  auto res = module->walk([&](Operation *op) {
+    if (auto caller = mlir::dyn_cast<CallOpInterface>(op)) {
+      auto callable = caller.getCallableForCallee();
+      // For now we do not know how to handle Value based tracing, so fail.
+      if (mlir::isa<Value>(callable)) {
+        return WalkResult::interrupt();
+      }
+
+      auto symbol = mlir::dyn_cast<SymbolRefAttr>(callable);
+      auto func = getFromSymbol(op, symbol);
+      callableMap[symbol] = func;
+    }
+    return WalkResult::advance();
+  });
+
+  if (res.wasInterrupted()) {
+    return failure();
+  }
+
+  // First grab all symbols loaded or stored by each function. This
+  // will not handle calls initially.
+  llvm::DenseMap<SymbolRefAttr, llvm::DenseSet<SymbolRefAttr>> opLoadSymbols;
+  llvm::DenseMap<SymbolRefAttr, llvm::DenseSet<SymbolRefAttr>> opStoreSymbols;
+  for (auto callable : callableMap) {
+    llvm::DenseSet<SymbolRefAttr> loadSymbols;
+    llvm::DenseSet<SymbolRefAttr> storeSymbols;
+
+    callable.getSecond()->walk(
+        [&](GlobalLoadOp op) { loadSymbols.insert(op.getGlobal()); });
+
+    callable.getSecond()->walk(
+        [&](GlobalStoreOp op) { storeSymbols.insert(op.getGlobal()); });
+
+    opLoadSymbols[callable.getFirst()] = std::move(loadSymbols);
+    opStoreSymbols[callable.getFirst()] = std::move(storeSymbols);
+  }
+
+  // For each callable function we find each global loaded/stored within the
+  // function or a nested called function. This includes recursion checking to
+  // avoid infinitely recursing.
+  for (auto callable : callableMap) {
+    SymbolRefAttr thisSymbol = llvm::dyn_cast<SymbolRefAttr>(callable.first);
+    llvm::SmallVector<SymbolRefAttr> work = {thisSymbol};
+    llvm::DenseSet<SymbolRefAttr> visited = {thisSymbol};
+    llvm::DenseSet<SymbolRefAttr> loadSymbols;
+    llvm::DenseSet<SymbolRefAttr> storeSymbols;
+
+    for (size_t i = 0; i < work.size(); ++i) {
+      callableMap[work[i]]->walk([&](CallOpInterface call) {
+        auto symbol = dyn_cast<SymbolRefAttr>(call.getCallableForCallee());
+        if (!visited.contains(symbol)) {
+          visited.insert(symbol);
+          work.push_back(symbol);
+        }
+      });
+
+      for (auto load : opLoadSymbols[work[i]])
+        loadSymbols.insert(load);
+
+      for (auto store : opStoreSymbols[work[i]])
+        storeSymbols.insert(store);
+    }
+
+    loadSymbolsMap[thisSymbol] = std::move(loadSymbols);
+    storeSymbolsMap[thisSymbol] = std::move(storeSymbols);
+  }
+
+  return success();
+}
+
+// Process each operation in the block deleting unneeded loads / stores,
+// recursing on subblocks and checking function calls.
+void MLProgramPipelineGlobals::ProcessBlock(
+    Block &block, llvm::DenseSet<SymbolRefAttr> &symbolLoad,
+    llvm::DenseSet<SymbolRefAttr> &symbolStore) {
+
+  llvm::DenseMap<SymbolRefAttr, Value> previousLoads;
+  llvm::DenseMap<SymbolRefAttr, Operation *> previousStores;
+  llvm::SmallVector<Operation *> toDelete;
+  for (auto &op : block) {
+    // If this is a global load, remap to a previous value if known
+    // and delete this load. Remember that this value is the currently
+    // known load.
+    if (auto load = mlir::dyn_cast<GlobalLoadOp>(op)) {
+      auto ref = load.getGlobal();
+      symbolLoad.insert(ref);
+      if (previousLoads.contains(ref)) {
+        toDelete.push_back(&op);
+        load.getResult().replaceAllUsesWith(previousLoads[ref]);
+      } else {
+        previousLoads[ref] = load.getResult();
+      }
+      continue;
+    }
+
+    // Delete a previous store if it exists and is not needed, update
+    // the most recent known value for this global ref.
+    if (auto store = mlir::dyn_cast<GlobalStoreOp>(op)) {
+      auto ref = store.getGlobal();
+      symbolStore.insert(ref);
+      if (previousStores.contains(ref)) {
+        toDelete.push_back(previousStores.find(ref)->getSecond());
+      }
+
+      previousLoads[ref] = store.getValue();
+      previousStores[ref] = &op;
+      continue;
+    }
+
+    // If a function is called, clear known values for loads/stores used by
+    // the function or its sub-functions.
+    if (auto call = mlir::dyn_cast<CallOpInterface>(op)) {
+      auto loadSymbols =
+          loadSymbolsMap[dyn_cast<SymbolRefAttr>(call.getCallableForCallee())];
+      auto storeSymbols =
+          storeSymbolsMap[dyn_cast<SymbolRefAttr>(call.getCallableForCallee())];
+
+      for (auto sym : loadSymbols) {
+        previousStores.erase(sym);
+      }
+
+      for (auto sym : storeSymbols) {
+        previousLoads.erase(sym);
+        previousStores.erase(sym);
+      }
+      continue;
+    }
+
+    // If the op has sub-regions, recurse inside. We make no guarantees whether
+    // the recursion occurs.
+    llvm::DenseSet<SymbolRefAttr> opSymbolLoad;
+    llvm::DenseSet<SymbolRefAttr> opSymbolStore;
+    for (auto &region : op.getRegions()) {
+      for (auto &block : region) {
+        ProcessBlock(block, opSymbolLoad, opSymbolStore);
+      }
+    }
+
+    // Update current state from the subblock.
+    for (auto change : opSymbolLoad) {
+      symbolLoad.insert(change);
+      previousStores.erase(change);
+    }
+
+    for (auto change : opSymbolStore) {
+      symbolStore.insert(change);
+      previousLoads.erase(change);
+      previousStores.erase(change);
+    }
+  }
+
+  for (auto op : toDelete) {
+    op->erase();
+  }
+}
+
+void MLProgramPipelineGlobals::runOnOperation() {
+  auto targetOp = getOperation();
+  if (failed(buildGlobalMap(targetOp))) {
+    return;
+  }
+
+  for (auto &funcOp : *targetOp.getBody()) {
+    for (auto &region : funcOp.getRegions()) {
+      for (auto &block : region.getBlocks()) {
+        llvm::DenseSet<SymbolRefAttr> symbolsLoaded;
+        llvm::DenseSet<SymbolRefAttr> symbolsStored;
+        ProcessBlock(block, symbolsLoaded, symbolsStored);
+      }
+    }
+  }
+}
+
+} // namespace
+
+std::unique_ptr<OperationPass<mlir::ModuleOp>>
+createMLProgramPipelineGlobalsPass() {
+  return std::make_unique<MLProgramPipelineGlobals>();
+}
+
+} // namespace ml_program
+} // namespace mlir

diff  --git a/mlir/test/Dialect/MLProgram/pipeline-globals.mlir b/mlir/test/Dialect/MLProgram/pipeline-globals.mlir
new file mode 100644
index 000000000000000..a5c9b3e890558ea
--- /dev/null
+++ b/mlir/test/Dialect/MLProgram/pipeline-globals.mlir
@@ -0,0 +1,246 @@
+// RUN: mlir-opt -split-input-file -pass-pipeline="builtin.module(mlprogram-pipeline-globals)" --allow-unregistered-dialect %s
+
+// CHECK-LABEL: @global_variable
+ml_program.global private mutable @global_variable(dense<4> : tensor<4xi32>) : tensor<4xi32>
+
+// CHECK-LABEL: @global_double_load
+func.func @global_double_load() {
+  // CHECK: %[[LOAD:.+]] = ml_program.global_load @global_variable
+  // CHECK-NOT: ml_program.global_load @global_variable
+  %0 = ml_program.global_load @global_variable : tensor<4xi32>
+  %1 = ml_program.global_load @global_variable : tensor<4xi32>
+
+  // CHECK: %[[DUMMY:.+]] = "unregistered.dummy"(%[[LOAD]], %[[LOAD]])
+  %2 = "unregistered.dummy"(%0, %1) : (tensor<4xi32>, tensor<4xi32>) -> (tensor<4xi32>)
+
+  // CHECK: ml_program.global_store @global_variable %[[DUMMY]]
+  ml_program.global_store @global_variable = %2 : tensor<4xi32>
+  func.return
+}
+
+// -----
+
+// CHECK-LABEL: @global_variable
+ml_program.global private mutable @global_variable(dense<4> : tensor<4xi32>) : tensor<4xi32>
+
+// CHECK-LABEL: @global_double_store
+func.func @global_double_store() {
+  // CHECK: %[[LOAD:.+]] = ml_program.global_load @global_variable
+  %0 = ml_program.global_load @global_variable : tensor<4xi32>
+
+  // CHECK: %[[DUMMY:.+]] = "unregistered.dummy"(%[[LOAD]])
+  %1 = "unregistered.dummy"(%0) : (tensor<4xi32>) -> (tensor<4xi32>)
+
+  // CHECK: ml_program.global_store @global_variable %[[DUMMY]]
+  ml_program.global_store @global_variable = %1 : tensor<4xi32>
+
+  // CHECK-NOT: ml_program.global_store
+  ml_program.global_store @global_variable = %1 : tensor<4xi32>
+  func.return
+}
+
+// -----
+
+// CHECK-LABEL: @global_variable
+ml_program.global private mutable @global_variable(dense<4> : tensor<4xi32>) : tensor<4xi32>
+
+// CHECK-LABEL: @global_store_load
+func.func @global_store_load() {
+  // CHECK: %[[LOAD:.+]] = ml_program.global_load @global_variable
+  %0 = ml_program.global_load @global_variable : tensor<4xi32>
+
+  // CHECK: %[[DUMMY:.+]] = "unregistered.dummy"(%[[LOAD]])
+  // CHECK: %[[DUMMY2:.+]] = "unregistered.dummy"(%[[DUMMY2]])
+  %1 = "unregistered.dummy"(%0) : (tensor<4xi32>) -> (tensor<4xi32>)
+  ml_program.global_store @global_variable = %1 : tensor<4xi32>
+  %2 = ml_program.global_load @global_variable : tensor<4xi32>
+  %3 = "unregistered.dummy"(%2) : (tensor<4xi32>) -> (tensor<4xi32>)
+
+  // CHECK: ml_program.global_store @global_variable %[[DUMMY2]]
+  ml_program.global_store @global_variable = %3 : tensor<4xi32>
+  func.return
+}
+
+// -----
+
+// CHECK-LABEL: @global_variable
+ml_program.global private mutable @global_variable(dense<4> : tensor<4xi32>) : tensor<4xi32>
+
+// CHECK-LABEL: @global_store_load_region
+func.func @global_store_load_region() {
+  // CHECK: %[[LOAD:.+]] = ml_program.global_load @global_variable
+  %0 = ml_program.global_load @global_variable : tensor<4xi32>
+
+  // CHECK: %[[DUMMY:.+]] = "unregistered.dummy"(%[[LOAD]])
+  %1 = "unregistered.dummy"(%0) : (tensor<4xi32>) -> (tensor<4xi32>)
+
+  // CHECK: ml_program.global_store @global_variable %[[DUMMY]]
+  ml_program.global_store @global_variable = %1 : tensor<4xi32>
+
+  // CHECK: "unregistered.dummy2"
+  "unregistered.dummy2"() ({
+    ^bb():
+    %cst = arith.constant dense<0> : tensor<4xi32>
+    // CHECK: ml_program.global_store @global_variable
+    ml_program.global_store @global_variable = %cst : tensor<4xi32>
+    "unregistered.terminator"() : () -> ()
+  }) : () -> ()
+
+  // CHECK: %[[LOAD:.+]] ml_program.global_load @global_variable
+  %2 = ml_program.global_load @global_variable : tensor<4xi32>
+
+  // CHECK: %[[DUMMY2:.+]] = "unregistered.dummy"(%[[LOAD]])
+  %3 = "unregistered.dummy"(%2) : (tensor<4xi32>) -> (tensor<4xi32>)
+
+  // CHECK: ml_program.global_store @global_variable %[[DUMMY2]]
+  ml_program.global_store @global_variable = %3 : tensor<4xi32>
+  func.return
+}
+
+// -----
+
+// CHECK-LABEL: @global_variable
+ml_program.global private mutable @global_variable(dense<4> : tensor<4xi32>) : tensor<4xi32>
+
+// CHECK-LABEL: @interrupt
+func.func @interrupt() {
+  %cst = arith.constant dense<0> : tensor<4xi32>
+  // CHECK: ml_program.global_store
+  ml_program.global_store @global_variable = %cst : tensor<4xi32>
+  func.return
+}
+
+// CHECK-LABEL: @call_global_store
+func.func @call_global_store() {
+  // CHECK: %[[LOAD:.+]] = ml_program.global_load @global_variable
+  %0 = ml_program.global_load @global_variable : tensor<4xi32>
+
+  // CHECK: %[[DUMMY:.+]] = "unregistered.dummy"(%[[LOAD]])
+  %1 = "unregistered.dummy"(%0) : (tensor<4xi32>) -> (tensor<4xi32>)
+
+  // CHECK: ml_program.global_store @global_variable %[[DUMMY]]
+  ml_program.global_store @global_variable = %1 : tensor<4xi32>
+  call @interrupt() : () -> ()
+
+  // CHECK: %[[LOAD:.+]] ml_program.global_load @global_variable
+  %2 = ml_program.global_load @global_variable : tensor<4xi32>
+
+  // CHECK: %[[DUMMY:.+]] = "unregistered.dummy"(%[[LOAD]])
+  %3 = "unregistered.dummy"(%2) : (tensor<4xi32>) -> (tensor<4xi32>)
+
+  // CHECK: ml_program.global_store @global_variable %[[DUMMY]]
+  ml_program.global_store @global_variable = %3 : tensor<4xi32>
+  func.return
+}
+
+
+// -----
+
+// CHECK-LABEL: @global_variable
+ml_program.global private mutable @global_variable(dense<4> : tensor<4xi32>) : tensor<4xi32>
+
+// CHECK-LABEL: @interrupt_indirect
+func.func @interrupt_indirect() {
+  %cst = arith.constant dense<0> : tensor<4xi32>
+  // CHECK: ml_program.global_store
+  ml_program.global_store @global_variable = %cst : tensor<4xi32>
+  func.return
+}
+
+// CHECK-LABEL: @interrupt
+func.func @interrupt() {
+  call @interrupt_indirect() : () -> ()
+  func.return
+}
+
+// CHECK-LABEL: @call_indirect_store
+func.func @call_indirect_store() {
+  // CHECK: %[[LOAD:.+]] = ml_program.global_load @global_variable
+  %0 = ml_program.global_load @global_variable : tensor<4xi32>
+
+  // CHECK: %[[DUMMY:.+]] = "unregistered.dummy"(%[[LOAD]])
+  %1 = "unregistered.dummy"(%0) : (tensor<4xi32>) -> (tensor<4xi32>)
+
+  // CHECK: ml_program.global_store @global_variable %[[DUMMY]]
+  ml_program.global_store @global_variable = %1 : tensor<4xi32>
+  call @interrupt() : () -> ()
+
+  // CHECK: %[[LOAD:.+]] ml_program.global_load @global_variable
+  %2 = ml_program.global_load @global_variable : tensor<4xi32>
+
+  // CHECK: %[[DUMMY:.+]] = "unregistered.dummy"(%[[LOAD]])
+  %3 = "unregistered.dummy"(%2) : (tensor<4xi32>) -> (tensor<4xi32>)
+
+  // CHECK: ml_program.global_store @global_variable %[[DUMMY]]
+  ml_program.global_store @global_variable = %3 : tensor<4xi32>
+  func.return
+}
+
+
+// -----
+
+// CHECK-LABEL: @global_variable
+ml_program.global private mutable @global_variable(dense<4> : tensor<4xi32>) : tensor<4xi32>
+
+// CHECK-LABEL: @interrupt_indirect
+func.func @interrupt_indirect() -> tensor<4xi32> {
+  // CHECK: ml_program.global_load
+  %0 = ml_program.global_load @global_variable : tensor<4xi32>
+  func.return %0 : tensor<4xi32>
+}
+
+// CHECK-LABEL: @interrupt
+func.func @interrupt() {
+  %0 = call @interrupt_indirect() : () -> (tensor<4xi32>)
+  "unregistered.dummy"(%0) : (tensor<4xi32>) -> ()
+  func.return
+}
+
+// CHECK-LABEL: @call_indirect_load
+func.func @call_indirect_load() {
+  // CHECK: %[[LOAD:.+]] = ml_program.global_load @global_variable
+  %0 = ml_program.global_load @global_variable : tensor<4xi32>
+
+  // CHECK: %[[DUMMY:.+]] = "unregistered.dummy"(%[[LOAD]])
+  %1 = "unregistered.dummy"(%0) : (tensor<4xi32>) -> (tensor<4xi32>)
+
+  // CHECK: ml_program.global_store @global_variable %[[DUMMY]]
+  ml_program.global_store @global_variable = %1 : tensor<4xi32>
+  call @interrupt() : () -> ()
+
+  // CHECK: %[[DUMMY:.+]] = "unregistered.dummy"(%[[LOAD]])
+  %2 = ml_program.global_load @global_variable : tensor<4xi32>
+  %3 = "unregistered.dummy"(%2) : (tensor<4xi32>) -> (tensor<4xi32>)
+
+  // CHECK: ml_program.global_store @global_variable %[[DUMMY]]
+  ml_program.global_store @global_variable = %3 : tensor<4xi32>
+  func.return
+}
+
+// -----
+
+// CHECK-LABEL: @global_variable
+ml_program.global private mutable @global_variable(dense<4> : tensor<4xi32>) : tensor<4xi32>
+
+// CHECK-LABEL: @call_recursive
+func.func @call_recursive() {
+  // CHECK: %[[LOAD:.+]] = ml_program.global_load @global_variable
+  %0 = ml_program.global_load @global_variable : tensor<4xi32>
+
+  // CHECK: %[[DUMMY:.+]] = "unregistered.dummy"(%[[LOAD]])
+  %1 = "unregistered.dummy"(%0) : (tensor<4xi32>) -> (tensor<4xi32>)
+
+  // CHECK: ml_program.global_store @global_variable %[[DUMMY]]
+  ml_program.global_store @global_variable = %1 : tensor<4xi32>
+  call @call_recursive() : () -> ()
+
+  // CHECK: %[[LOAD:.+]] ml_program.global_load @global_variable
+  %2 = ml_program.global_load @global_variable : tensor<4xi32>
+
+  // CHECK: %[[DUMMY:.+]] = "unregistered.dummy"(%[[LOAD]])
+  %3 = "unregistered.dummy"(%2) : (tensor<4xi32>) -> (tensor<4xi32>)
+
+  // CHECK: ml_program.global_store @global_variable %[[DUMMY]]
+  ml_program.global_store @global_variable = %3 : tensor<4xi32>
+  func.return
+}


        


More information about the Mlir-commits mailing list