[Mlir-commits] [mlir] cbd4750 - [mlir][mlprogram] Add `mlprogram-pipeline-globals` optimization pass
Rob Suderman
llvmlistbot at llvm.org
Mon Sep 18 17:11:46 PDT 2023
Author: Rob Suderman
Date: 2023-09-18T17:11:29-07:00
New Revision: cbd475040f8952cfc55b9e13dd5ce6c4f6434cd3
URL: https://github.com/llvm/llvm-project/commit/cbd475040f8952cfc55b9e13dd5ce6c4f6434cd3
DIFF: https://github.com/llvm/llvm-project/commit/cbd475040f8952cfc55b9e13dd5ce6c4f6434cd3.diff
LOG: [mlir][mlprogram] Add `mlprogram-pipeline-globals` optimization pass
Added pass optimizes MLProgram global operations by reducing to only
the minimal load/store operations for global tensors. This avoids
unnecessary global operations throughout a program and potentially
improves operation gusion.
Reviewed By: jpienaar
Differential Revision: https://reviews.llvm.org/D159228
Added:
mlir/include/mlir/Dialect/MLProgram/Transforms/CMakeLists.txt
mlir/include/mlir/Dialect/MLProgram/Transforms/Passes.h
mlir/include/mlir/Dialect/MLProgram/Transforms/Passes.td
mlir/lib/Dialect/MLProgram/Transforms/CMakeLists.txt
mlir/lib/Dialect/MLProgram/Transforms/PipelineGlobalOps.cpp
mlir/test/Dialect/MLProgram/pipeline-globals.mlir
Modified:
mlir/include/mlir/Dialect/MLProgram/CMakeLists.txt
mlir/include/mlir/InitAllPasses.h
mlir/lib/Dialect/MLProgram/CMakeLists.txt
mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/MLProgram/CMakeLists.txt b/mlir/include/mlir/Dialect/MLProgram/CMakeLists.txt
index f33061b2d87cffc..9f57627c321fb0c 100644
--- a/mlir/include/mlir/Dialect/MLProgram/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/MLProgram/CMakeLists.txt
@@ -1 +1,2 @@
add_subdirectory(IR)
+add_subdirectory(Transforms)
diff --git a/mlir/include/mlir/Dialect/MLProgram/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/MLProgram/Transforms/CMakeLists.txt
new file mode 100644
index 000000000000000..c5c11f17a9fa975
--- /dev/null
+++ b/mlir/include/mlir/Dialect/MLProgram/Transforms/CMakeLists.txt
@@ -0,0 +1,6 @@
+set(LLVM_TARGET_DEFINITIONS Passes.td)
+mlir_tablegen(Passes.h.inc -gen-pass-decls -name MLProgram)
+add_public_tablegen_target(MLIRMLProgramPassIncGen)
+add_dependencies(mlir-headers MLIRMLProgramPassIncGen)
+
+add_mlir_doc(Passes MLProgramPasses ./ -gen-pass-doc)
diff --git a/mlir/include/mlir/Dialect/MLProgram/Transforms/Passes.h b/mlir/include/mlir/Dialect/MLProgram/Transforms/Passes.h
new file mode 100644
index 000000000000000..894e35e52724e90
--- /dev/null
+++ b/mlir/include/mlir/Dialect/MLProgram/Transforms/Passes.h
@@ -0,0 +1,35 @@
+//===- Passes.h - Pass Entrypoints ------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_MLPROGRAM_TRANSFORMS_PASSES_H_
+#define MLIR_DIALECT_MLPROGRAM_TRANSFORMS_PASSES_H_
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+namespace ml_program {
+
+#define GEN_PASS_DECL
+#include "mlir/Dialect/MLProgram/Transforms/Passes.h.inc"
+
+//===----------------------------------------------------------------------===//
+// Registration
+//===----------------------------------------------------------------------===//
+
+std::unique_ptr<OperationPass<ModuleOp>> createMLProgramPipelineGlobalsPass();
+
+/// Generate the code for registering passes.
+#define GEN_PASS_REGISTRATION
+#include "mlir/Dialect/MLProgram/Transforms/Passes.h.inc"
+
+} // namespace ml_program
+} // namespace mlir
+
+#endif // MLIR_DIALECT_MLPROGRAM_TRANSFORMS_PASSES_H_
diff --git a/mlir/include/mlir/Dialect/MLProgram/Transforms/Passes.td b/mlir/include/mlir/Dialect/MLProgram/Transforms/Passes.td
new file mode 100644
index 000000000000000..defe8191cb905df
--- /dev/null
+++ b/mlir/include/mlir/Dialect/MLProgram/Transforms/Passes.td
@@ -0,0 +1,27 @@
+//===-- Passes.td - pass definition file -------------------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_MLPROGRAM_TRANSFORMS_PASSES
+#define MLIR_DIALECT_MLPROGRAM_TRANSFORMS_PASSES
+
+include "mlir/Pass/PassBase.td"
+
+def MLProgramPipelineGlobals : Pass<"mlprogram-pipeline-globals", "ModuleOp"> {
+ let summary = "Optimize `ml_program` global operations for read and store";
+ let description = [{
+ `ml_program`'s load and store operations can be optimized for
+ write-write or write-read sets of operations. This allows known
+ tensors to not be re-read when the value is already known in IR.
+
+ The pass is designed to handle both nested regions and function calls
+ safely.
+ }];
+ let constructor = "mlir::ml_program::createMLProgramPipelineGlobalsPass()";
+}
+
+#endif // MLIR_DIALECT_MLPROGRAM_TRANSFORMS_PASSES
diff --git a/mlir/include/mlir/InitAllPasses.h b/mlir/include/mlir/InitAllPasses.h
index 8a45da7d1b982f1..5489a13a8040bdb 100644
--- a/mlir/include/mlir/InitAllPasses.h
+++ b/mlir/include/mlir/InitAllPasses.h
@@ -26,6 +26,7 @@
#include "mlir/Dialect/GPU/Transforms/Passes.h"
#include "mlir/Dialect/LLVMIR/Transforms/Passes.h"
#include "mlir/Dialect/Linalg/Passes.h"
+#include "mlir/Dialect/MLProgram/Transforms/Passes.h"
#include "mlir/Dialect/Math/Transforms/Passes.h"
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
#include "mlir/Dialect/NVGPU/Transforms/Passes.h"
@@ -72,6 +73,7 @@ inline void registerAllPasses() {
LLVM::registerLLVMPasses();
math::registerMathPasses();
memref::registerMemRefPasses();
+ ml_program::registerMLProgramPasses();
registerSCFPasses();
registerShapePasses();
spirv::registerSPIRVPasses();
diff --git a/mlir/lib/Dialect/MLProgram/CMakeLists.txt b/mlir/lib/Dialect/MLProgram/CMakeLists.txt
index f33061b2d87cffc..9f57627c321fb0c 100644
--- a/mlir/lib/Dialect/MLProgram/CMakeLists.txt
+++ b/mlir/lib/Dialect/MLProgram/CMakeLists.txt
@@ -1 +1,2 @@
add_subdirectory(IR)
+add_subdirectory(Transforms)
diff --git a/mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp b/mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp
index f8f754956603959..5352b7b0454fd19 100644
--- a/mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp
+++ b/mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp
@@ -178,8 +178,14 @@ LogicalResult GlobalOp::verify() {
//===----------------------------------------------------------------------===//
GlobalOp GlobalLoadOp::getGlobalOp(SymbolTableCollection &symbolTable) {
- return symbolTable.lookupNearestSymbolFrom<GlobalOp>(
- getOperation()->getParentOp(), getGlobalAttr());
+ for (auto parent = getOperation()->getParentOp(); parent;
+ parent = parent->getParentOp()) {
+ if (auto nearest = symbolTable.lookupNearestSymbolFrom<GlobalOp>(
+ parent, getGlobalAttr())) {
+ return nearest;
+ }
+ }
+ return {};
}
LogicalResult
@@ -253,8 +259,14 @@ GlobalLoadGraphOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
//===----------------------------------------------------------------------===//
GlobalOp GlobalStoreOp::getGlobalOp(SymbolTableCollection &symbolTable) {
- return symbolTable.lookupNearestSymbolFrom<GlobalOp>(
- getOperation()->getParentOp(), getGlobalAttr());
+ for (auto parent = getOperation()->getParentOp(); parent;) {
+ if (auto nearest = symbolTable.lookupNearestSymbolFrom<GlobalOp>(
+ parent, getGlobalAttr())) {
+ return nearest;
+ }
+ parent = parent->getParentOp();
+ }
+ return {};
}
LogicalResult
diff --git a/mlir/lib/Dialect/MLProgram/Transforms/CMakeLists.txt b/mlir/lib/Dialect/MLProgram/Transforms/CMakeLists.txt
new file mode 100644
index 000000000000000..db567b62e0e747c
--- /dev/null
+++ b/mlir/lib/Dialect/MLProgram/Transforms/CMakeLists.txt
@@ -0,0 +1,14 @@
+add_mlir_dialect_library(MLIRMLProgramTransforms
+ PipelineGlobalOps.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/MLProgram/Transforms
+
+ DEPENDS
+ MLIRMLProgramPassIncGen
+
+ LINK_LIBS PUBLIC
+ MLIRIR
+ MLIRMLProgramDialect
+ MLIRPass
+)
diff --git a/mlir/lib/Dialect/MLProgram/Transforms/PipelineGlobalOps.cpp b/mlir/lib/Dialect/MLProgram/Transforms/PipelineGlobalOps.cpp
new file mode 100644
index 000000000000000..7e00a731f6d731e
--- /dev/null
+++ b/mlir/lib/Dialect/MLProgram/Transforms/PipelineGlobalOps.cpp
@@ -0,0 +1,234 @@
+//===- PipelineGlobalOpsPass.cpp - Pipeline Global Ops Pass ---------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/MLProgram/Transforms/Passes.h"
+
+#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
+#include "mlir/Dialect/MLProgram/Transforms/Passes.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir {
+namespace ml_program {
+#define GEN_PASS_DEF_MLPROGRAMPIPELINEGLOBALS
+#include "mlir/Dialect/MLProgram/Transforms/Passes.h.inc"
+
+namespace {
+
+class MLProgramPipelineGlobals
+ : public impl::MLProgramPipelineGlobalsBase<MLProgramPipelineGlobals> {
+public:
+ void runOnOperation() override;
+
+private:
+ LogicalResult buildGlobalMap(ModuleOp op);
+
+ void ProcessBlock(Block &block, llvm::DenseSet<SymbolRefAttr> &symbolLoad,
+ llvm::DenseSet<SymbolRefAttr> &symbolStore);
+
+ llvm::DenseMap<SymbolRefAttr, llvm::DenseSet<SymbolRefAttr>> loadSymbolsMap;
+ llvm::DenseMap<SymbolRefAttr, llvm::DenseSet<SymbolRefAttr>> storeSymbolsMap;
+};
+
+// Traverses upwards searchign for the operation mapped by the symbol.
+static Operation *getFromSymbol(Operation *baseOp, SymbolRefAttr symbol) {
+ for (auto op = baseOp; op; op = op->getParentOp()) {
+ auto lookup = SymbolTable::lookupNearestSymbolFrom(op, symbol);
+ if (lookup)
+ return lookup;
+ }
+ return nullptr;
+}
+
+// Builds map from a symbol to MLProgram global symbols loaded or stored
+// during processing.
+LogicalResult MLProgramPipelineGlobals::buildGlobalMap(ModuleOp module) {
+ llvm::DenseMap<SymbolRefAttr, Operation *> callableMap;
+ auto res = module->walk([&](Operation *op) {
+ if (auto caller = mlir::dyn_cast<CallOpInterface>(op)) {
+ auto callable = caller.getCallableForCallee();
+ // For now we do not know how to handle Value based tracing, so fail.
+ if (mlir::isa<Value>(callable)) {
+ return WalkResult::interrupt();
+ }
+
+ auto symbol = mlir::dyn_cast<SymbolRefAttr>(callable);
+ auto func = getFromSymbol(op, symbol);
+ callableMap[symbol] = func;
+ }
+ return WalkResult::advance();
+ });
+
+ if (res.wasInterrupted()) {
+ return failure();
+ }
+
+ // First grab all symbols loaded or stored by each function. This
+ // will not handle calls initially.
+ llvm::DenseMap<SymbolRefAttr, llvm::DenseSet<SymbolRefAttr>> opLoadSymbols;
+ llvm::DenseMap<SymbolRefAttr, llvm::DenseSet<SymbolRefAttr>> opStoreSymbols;
+ for (auto callable : callableMap) {
+ llvm::DenseSet<SymbolRefAttr> loadSymbols;
+ llvm::DenseSet<SymbolRefAttr> storeSymbols;
+
+ callable.getSecond()->walk(
+ [&](GlobalLoadOp op) { loadSymbols.insert(op.getGlobal()); });
+
+ callable.getSecond()->walk(
+ [&](GlobalStoreOp op) { storeSymbols.insert(op.getGlobal()); });
+
+ opLoadSymbols[callable.getFirst()] = std::move(loadSymbols);
+ opStoreSymbols[callable.getFirst()] = std::move(storeSymbols);
+ }
+
+ // For each callable function we find each global loaded/stored within the
+ // function or a nested called function. This includes recursion checking to
+ // avoid infinitely recursing.
+ for (auto callable : callableMap) {
+ SymbolRefAttr thisSymbol = llvm::dyn_cast<SymbolRefAttr>(callable.first);
+ llvm::SmallVector<SymbolRefAttr> work = {thisSymbol};
+ llvm::DenseSet<SymbolRefAttr> visited = {thisSymbol};
+ llvm::DenseSet<SymbolRefAttr> loadSymbols;
+ llvm::DenseSet<SymbolRefAttr> storeSymbols;
+
+ for (size_t i = 0; i < work.size(); ++i) {
+ callableMap[work[i]]->walk([&](CallOpInterface call) {
+ auto symbol = dyn_cast<SymbolRefAttr>(call.getCallableForCallee());
+ if (!visited.contains(symbol)) {
+ visited.insert(symbol);
+ work.push_back(symbol);
+ }
+ });
+
+ for (auto load : opLoadSymbols[work[i]])
+ loadSymbols.insert(load);
+
+ for (auto store : opStoreSymbols[work[i]])
+ storeSymbols.insert(store);
+ }
+
+ loadSymbolsMap[thisSymbol] = std::move(loadSymbols);
+ storeSymbolsMap[thisSymbol] = std::move(storeSymbols);
+ }
+
+ return success();
+}
+
+// Process each operation in the block deleting unneeded loads / stores,
+// recursing on subblocks and checking function calls.
+void MLProgramPipelineGlobals::ProcessBlock(
+ Block &block, llvm::DenseSet<SymbolRefAttr> &symbolLoad,
+ llvm::DenseSet<SymbolRefAttr> &symbolStore) {
+
+ llvm::DenseMap<SymbolRefAttr, Value> previousLoads;
+ llvm::DenseMap<SymbolRefAttr, Operation *> previousStores;
+ llvm::SmallVector<Operation *> toDelete;
+ for (auto &op : block) {
+ // If this is a global load, remap to a previous value if known
+ // and delete this load. Remember that this value is the currently
+ // known load.
+ if (auto load = mlir::dyn_cast<GlobalLoadOp>(op)) {
+ auto ref = load.getGlobal();
+ symbolLoad.insert(ref);
+ if (previousLoads.contains(ref)) {
+ toDelete.push_back(&op);
+ load.getResult().replaceAllUsesWith(previousLoads[ref]);
+ } else {
+ previousLoads[ref] = load.getResult();
+ }
+ continue;
+ }
+
+ // Delete a previous store if it exists and is not needed, update
+ // the most recent known value for this global ref.
+ if (auto store = mlir::dyn_cast<GlobalStoreOp>(op)) {
+ auto ref = store.getGlobal();
+ symbolStore.insert(ref);
+ if (previousStores.contains(ref)) {
+ toDelete.push_back(previousStores.find(ref)->getSecond());
+ }
+
+ previousLoads[ref] = store.getValue();
+ previousStores[ref] = &op;
+ continue;
+ }
+
+ // If a function is called, clear known values for loads/stores used by
+ // the function or its sub-functions.
+ if (auto call = mlir::dyn_cast<CallOpInterface>(op)) {
+ auto loadSymbols =
+ loadSymbolsMap[dyn_cast<SymbolRefAttr>(call.getCallableForCallee())];
+ auto storeSymbols =
+ storeSymbolsMap[dyn_cast<SymbolRefAttr>(call.getCallableForCallee())];
+
+ for (auto sym : loadSymbols) {
+ previousStores.erase(sym);
+ }
+
+ for (auto sym : storeSymbols) {
+ previousLoads.erase(sym);
+ previousStores.erase(sym);
+ }
+ continue;
+ }
+
+ // If the op has sub-regions, recurse inside. We make no guarantees whether
+ // the recursion occurs.
+ llvm::DenseSet<SymbolRefAttr> opSymbolLoad;
+ llvm::DenseSet<SymbolRefAttr> opSymbolStore;
+ for (auto ®ion : op.getRegions()) {
+ for (auto &block : region) {
+ ProcessBlock(block, opSymbolLoad, opSymbolStore);
+ }
+ }
+
+ // Update current state from the subblock.
+ for (auto change : opSymbolLoad) {
+ symbolLoad.insert(change);
+ previousStores.erase(change);
+ }
+
+ for (auto change : opSymbolStore) {
+ symbolStore.insert(change);
+ previousLoads.erase(change);
+ previousStores.erase(change);
+ }
+ }
+
+ for (auto op : toDelete) {
+ op->erase();
+ }
+}
+
+void MLProgramPipelineGlobals::runOnOperation() {
+ auto targetOp = getOperation();
+ if (failed(buildGlobalMap(targetOp))) {
+ return;
+ }
+
+ for (auto &funcOp : *targetOp.getBody()) {
+ for (auto ®ion : funcOp.getRegions()) {
+ for (auto &block : region.getBlocks()) {
+ llvm::DenseSet<SymbolRefAttr> symbolsLoaded;
+ llvm::DenseSet<SymbolRefAttr> symbolsStored;
+ ProcessBlock(block, symbolsLoaded, symbolsStored);
+ }
+ }
+ }
+}
+
+} // namespace
+
+std::unique_ptr<OperationPass<mlir::ModuleOp>>
+createMLProgramPipelineGlobalsPass() {
+ return std::make_unique<MLProgramPipelineGlobals>();
+}
+
+} // namespace ml_program
+} // namespace mlir
diff --git a/mlir/test/Dialect/MLProgram/pipeline-globals.mlir b/mlir/test/Dialect/MLProgram/pipeline-globals.mlir
new file mode 100644
index 000000000000000..a5c9b3e890558ea
--- /dev/null
+++ b/mlir/test/Dialect/MLProgram/pipeline-globals.mlir
@@ -0,0 +1,246 @@
+// RUN: mlir-opt -split-input-file -pass-pipeline="builtin.module(mlprogram-pipeline-globals)" --allow-unregistered-dialect %s
+
+// CHECK-LABEL: @global_variable
+ml_program.global private mutable @global_variable(dense<4> : tensor<4xi32>) : tensor<4xi32>
+
+// CHECK-LABEL: @global_double_load
+func.func @global_double_load() {
+ // CHECK: %[[LOAD:.+]] = ml_program.global_load @global_variable
+ // CHECK-NOT: ml_program.global_load @global_variable
+ %0 = ml_program.global_load @global_variable : tensor<4xi32>
+ %1 = ml_program.global_load @global_variable : tensor<4xi32>
+
+ // CHECK: %[[DUMMY:.+]] = "unregistered.dummy"(%[[LOAD]], %[[LOAD]])
+ %2 = "unregistered.dummy"(%0, %1) : (tensor<4xi32>, tensor<4xi32>) -> (tensor<4xi32>)
+
+ // CHECK: ml_program.global_store @global_variable %[[DUMMY]]
+ ml_program.global_store @global_variable = %2 : tensor<4xi32>
+ func.return
+}
+
+// -----
+
+// CHECK-LABEL: @global_variable
+ml_program.global private mutable @global_variable(dense<4> : tensor<4xi32>) : tensor<4xi32>
+
+// CHECK-LABEL: @global_double_store
+func.func @global_double_store() {
+ // CHECK: %[[LOAD:.+]] = ml_program.global_load @global_variable
+ %0 = ml_program.global_load @global_variable : tensor<4xi32>
+
+ // CHECK: %[[DUMMY:.+]] = "unregistered.dummy"(%[[LOAD]])
+ %1 = "unregistered.dummy"(%0) : (tensor<4xi32>) -> (tensor<4xi32>)
+
+ // CHECK: ml_program.global_store @global_variable %[[DUMMY]]
+ ml_program.global_store @global_variable = %1 : tensor<4xi32>
+
+ // CHECK-NOT: ml_program.global_store
+ ml_program.global_store @global_variable = %1 : tensor<4xi32>
+ func.return
+}
+
+// -----
+
+// CHECK-LABEL: @global_variable
+ml_program.global private mutable @global_variable(dense<4> : tensor<4xi32>) : tensor<4xi32>
+
+// CHECK-LABEL: @global_store_load
+func.func @global_store_load() {
+ // CHECK: %[[LOAD:.+]] = ml_program.global_load @global_variable
+ %0 = ml_program.global_load @global_variable : tensor<4xi32>
+
+ // CHECK: %[[DUMMY:.+]] = "unregistered.dummy"(%[[LOAD]])
+ // CHECK: %[[DUMMY2:.+]] = "unregistered.dummy"(%[[DUMMY2]])
+ %1 = "unregistered.dummy"(%0) : (tensor<4xi32>) -> (tensor<4xi32>)
+ ml_program.global_store @global_variable = %1 : tensor<4xi32>
+ %2 = ml_program.global_load @global_variable : tensor<4xi32>
+ %3 = "unregistered.dummy"(%2) : (tensor<4xi32>) -> (tensor<4xi32>)
+
+ // CHECK: ml_program.global_store @global_variable %[[DUMMY2]]
+ ml_program.global_store @global_variable = %3 : tensor<4xi32>
+ func.return
+}
+
+// -----
+
+// CHECK-LABEL: @global_variable
+ml_program.global private mutable @global_variable(dense<4> : tensor<4xi32>) : tensor<4xi32>
+
+// CHECK-LABEL: @global_store_load_region
+func.func @global_store_load_region() {
+ // CHECK: %[[LOAD:.+]] = ml_program.global_load @global_variable
+ %0 = ml_program.global_load @global_variable : tensor<4xi32>
+
+ // CHECK: %[[DUMMY:.+]] = "unregistered.dummy"(%[[LOAD]])
+ %1 = "unregistered.dummy"(%0) : (tensor<4xi32>) -> (tensor<4xi32>)
+
+ // CHECK: ml_program.global_store @global_variable %[[DUMMY]]
+ ml_program.global_store @global_variable = %1 : tensor<4xi32>
+
+ // CHECK: "unregistered.dummy2"
+ "unregistered.dummy2"() ({
+ ^bb():
+ %cst = arith.constant dense<0> : tensor<4xi32>
+ // CHECK: ml_program.global_store @global_variable
+ ml_program.global_store @global_variable = %cst : tensor<4xi32>
+ "unregistered.terminator"() : () -> ()
+ }) : () -> ()
+
+ // CHECK: %[[LOAD:.+]] ml_program.global_load @global_variable
+ %2 = ml_program.global_load @global_variable : tensor<4xi32>
+
+ // CHECK: %[[DUMMY2:.+]] = "unregistered.dummy"(%[[LOAD]])
+ %3 = "unregistered.dummy"(%2) : (tensor<4xi32>) -> (tensor<4xi32>)
+
+ // CHECK: ml_program.global_store @global_variable %[[DUMMY2]]
+ ml_program.global_store @global_variable = %3 : tensor<4xi32>
+ func.return
+}
+
+// -----
+
+// CHECK-LABEL: @global_variable
+ml_program.global private mutable @global_variable(dense<4> : tensor<4xi32>) : tensor<4xi32>
+
+// CHECK-LABEL: @interrupt
+func.func @interrupt() {
+ %cst = arith.constant dense<0> : tensor<4xi32>
+ // CHECK: ml_program.global_store
+ ml_program.global_store @global_variable = %cst : tensor<4xi32>
+ func.return
+}
+
+// CHECK-LABEL: @call_global_store
+func.func @call_global_store() {
+ // CHECK: %[[LOAD:.+]] = ml_program.global_load @global_variable
+ %0 = ml_program.global_load @global_variable : tensor<4xi32>
+
+ // CHECK: %[[DUMMY:.+]] = "unregistered.dummy"(%[[LOAD]])
+ %1 = "unregistered.dummy"(%0) : (tensor<4xi32>) -> (tensor<4xi32>)
+
+ // CHECK: ml_program.global_store @global_variable %[[DUMMY]]
+ ml_program.global_store @global_variable = %1 : tensor<4xi32>
+ call @interrupt() : () -> ()
+
+ // CHECK: %[[LOAD:.+]] ml_program.global_load @global_variable
+ %2 = ml_program.global_load @global_variable : tensor<4xi32>
+
+ // CHECK: %[[DUMMY:.+]] = "unregistered.dummy"(%[[LOAD]])
+ %3 = "unregistered.dummy"(%2) : (tensor<4xi32>) -> (tensor<4xi32>)
+
+ // CHECK: ml_program.global_store @global_variable %[[DUMMY]]
+ ml_program.global_store @global_variable = %3 : tensor<4xi32>
+ func.return
+}
+
+
+// -----
+
+// CHECK-LABEL: @global_variable
+ml_program.global private mutable @global_variable(dense<4> : tensor<4xi32>) : tensor<4xi32>
+
+// CHECK-LABEL: @interrupt_indirect
+func.func @interrupt_indirect() {
+ %cst = arith.constant dense<0> : tensor<4xi32>
+ // CHECK: ml_program.global_store
+ ml_program.global_store @global_variable = %cst : tensor<4xi32>
+ func.return
+}
+
+// CHECK-LABEL: @interrupt
+func.func @interrupt() {
+ call @interrupt_indirect() : () -> ()
+ func.return
+}
+
+// CHECK-LABEL: @call_indirect_store
+func.func @call_indirect_store() {
+ // CHECK: %[[LOAD:.+]] = ml_program.global_load @global_variable
+ %0 = ml_program.global_load @global_variable : tensor<4xi32>
+
+ // CHECK: %[[DUMMY:.+]] = "unregistered.dummy"(%[[LOAD]])
+ %1 = "unregistered.dummy"(%0) : (tensor<4xi32>) -> (tensor<4xi32>)
+
+ // CHECK: ml_program.global_store @global_variable %[[DUMMY]]
+ ml_program.global_store @global_variable = %1 : tensor<4xi32>
+ call @interrupt() : () -> ()
+
+ // CHECK: %[[LOAD:.+]] ml_program.global_load @global_variable
+ %2 = ml_program.global_load @global_variable : tensor<4xi32>
+
+ // CHECK: %[[DUMMY:.+]] = "unregistered.dummy"(%[[LOAD]])
+ %3 = "unregistered.dummy"(%2) : (tensor<4xi32>) -> (tensor<4xi32>)
+
+ // CHECK: ml_program.global_store @global_variable %[[DUMMY]]
+ ml_program.global_store @global_variable = %3 : tensor<4xi32>
+ func.return
+}
+
+
+// -----
+
+// CHECK-LABEL: @global_variable
+ml_program.global private mutable @global_variable(dense<4> : tensor<4xi32>) : tensor<4xi32>
+
+// CHECK-LABEL: @interrupt_indirect
+func.func @interrupt_indirect() -> tensor<4xi32> {
+ // CHECK: ml_program.global_load
+ %0 = ml_program.global_load @global_variable : tensor<4xi32>
+ func.return %0 : tensor<4xi32>
+}
+
+// CHECK-LABEL: @interrupt
+func.func @interrupt() {
+ %0 = call @interrupt_indirect() : () -> (tensor<4xi32>)
+ "unregistered.dummy"(%0) : (tensor<4xi32>) -> ()
+ func.return
+}
+
+// CHECK-LABEL: @call_indirect_load
+func.func @call_indirect_load() {
+ // CHECK: %[[LOAD:.+]] = ml_program.global_load @global_variable
+ %0 = ml_program.global_load @global_variable : tensor<4xi32>
+
+ // CHECK: %[[DUMMY:.+]] = "unregistered.dummy"(%[[LOAD]])
+ %1 = "unregistered.dummy"(%0) : (tensor<4xi32>) -> (tensor<4xi32>)
+
+ // CHECK: ml_program.global_store @global_variable %[[DUMMY]]
+ ml_program.global_store @global_variable = %1 : tensor<4xi32>
+ call @interrupt() : () -> ()
+
+ // CHECK: %[[DUMMY:.+]] = "unregistered.dummy"(%[[LOAD]])
+ %2 = ml_program.global_load @global_variable : tensor<4xi32>
+ %3 = "unregistered.dummy"(%2) : (tensor<4xi32>) -> (tensor<4xi32>)
+
+ // CHECK: ml_program.global_store @global_variable %[[DUMMY]]
+ ml_program.global_store @global_variable = %3 : tensor<4xi32>
+ func.return
+}
+
+// -----
+
+// CHECK-LABEL: @global_variable
+ml_program.global private mutable @global_variable(dense<4> : tensor<4xi32>) : tensor<4xi32>
+
+// CHECK-LABEL: @call_recursive
+func.func @call_recursive() {
+ // CHECK: %[[LOAD:.+]] = ml_program.global_load @global_variable
+ %0 = ml_program.global_load @global_variable : tensor<4xi32>
+
+ // CHECK: %[[DUMMY:.+]] = "unregistered.dummy"(%[[LOAD]])
+ %1 = "unregistered.dummy"(%0) : (tensor<4xi32>) -> (tensor<4xi32>)
+
+ // CHECK: ml_program.global_store @global_variable %[[DUMMY]]
+ ml_program.global_store @global_variable = %1 : tensor<4xi32>
+ call @call_recursive() : () -> ()
+
+ // CHECK: %[[LOAD:.+]] ml_program.global_load @global_variable
+ %2 = ml_program.global_load @global_variable : tensor<4xi32>
+
+ // CHECK: %[[DUMMY:.+]] = "unregistered.dummy"(%[[LOAD]])
+ %3 = "unregistered.dummy"(%2) : (tensor<4xi32>) -> (tensor<4xi32>)
+
+ // CHECK: ml_program.global_store @global_variable %[[DUMMY]]
+ ml_program.global_store @global_variable = %3 : tensor<4xi32>
+ func.return
+}
More information about the Mlir-commits
mailing list