[Mlir-commits] [mlir] d30c022 - [mlir] Split MLProgram global load and store to Graph variants
Jacques Pienaar
llvmlistbot at llvm.org
Thu Jun 16 20:02:02 PDT 2022
Author: Jacques Pienaar
Date: 2022-06-16T20:01:54-07:00
New Revision: d30c0221cf5aa36c079b7cc0d36fb89f7b32149b
URL: https://github.com/llvm/llvm-project/commit/d30c0221cf5aa36c079b7cc0d36fb89f7b32149b
DIFF: https://github.com/llvm/llvm-project/commit/d30c0221cf5aa36c079b7cc0d36fb89f7b32149b.diff
LOG: [mlir] Split MLProgram global load and store to Graph variants
* Split ops into X_graph variants as discussed;
* Remove tokens from non-Graph region variants and rely on side-effect
modelling there while removing side-effect modelling from Graph
variants and relying on explicit ordering there;
* Make tokens required to be produced by Graph variants - but kept
explicit token type specification given previous discussion on this
potentially being configurable in future;
This results in duplicating some code. I considered adding helper
functions but decided against adding an abstraction there early given
size of duplication and creating accidental coupling.
Differential Revision: https://reviews.llvm.org/D127813
Added:
Modified:
mlir/include/mlir/Dialect/MLProgram/IR/MLProgramOps.td
mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp
mlir/test/Dialect/MLProgram/invalid.mlir
mlir/test/Dialect/MLProgram/ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramOps.td b/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramOps.td
index d820573b200ab..69b1eab379b3c 100644
--- a/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramOps.td
+++ b/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramOps.td
@@ -171,7 +171,8 @@ def MLProgram_GlobalLoadOp : MLProgram_Op<"global_load", [
advanced cases.
This op is side effecting and may not be valid to use in graph regions
- without additional consideration to evaluation order constraints.
+ without additional consideration to evaluation order constraints. See
+ `global_load_graph` for op which allows for explicit ordering constraints.
Example:
@@ -181,16 +182,14 @@ def MLProgram_GlobalLoadOp : MLProgram_Op<"global_load", [
}];
let arguments = (ins
- Arg<SymbolRefAttr, "", [MemRead]>:$global,
- Variadic<MLProgram_TokenType>:$consumeTokens
+ Arg<SymbolRefAttr, "", [MemRead]>:$global
);
let results = (outs
- AnyType:$result,
- Optional<MLProgram_TokenType>:$produceToken
+ AnyType:$result
);
let assemblyFormat = [{
- $global `` custom<TokenOrdering>($consumeTokens, type($produceToken)) `:` type($result) attr-dict
+ $global `:` type($result) attr-dict
}];
let extraClassDeclaration = [{
@@ -238,6 +237,52 @@ def MLProgram_GlobalLoadConstOp : MLProgram_Op<"global_load_const", [
}];
}
+//===----------------------------------------------------------------------===//
+// GlobalLoadGraphOp
+//===----------------------------------------------------------------------===//
+
+def MLProgram_GlobalLoadGraphOp : MLProgram_Op<"global_load_graph", [
+ DeclareOpInterfaceMethods<SymbolUserOpInterface>
+ ]> {
+ let summary = "Direct load of a mutable value from a global in Graph region";
+ let description = [{
+ Performs a non-atomic, non-volatile, non-synchronized load from a global
+ that may be mutable.
+
+ It is fully expected that these constraints are not suitable for all
+ situations, and alternative ops should be defined and used for more advanced
+ cases.
+
+ This op is side effecting and may not be valid to use in graph regions
+ without additional consideration to evaluation order constraints.
+
+ Example:
+
+ ```mlir
+ %0, %cstr = ml_program.global_load_graph @foobar
+ ordering (%token -> !ml_program.token) : tensor<?xi32>
+ ```
+ }];
+
+ let arguments = (ins
+ Arg<SymbolRefAttr, "", [MemRead]>:$global,
+ Variadic<MLProgram_TokenType>:$consumeTokens
+ );
+ let results = (outs
+ AnyType:$result,
+ MLProgram_TokenType:$produceToken
+ );
+
+ let assemblyFormat = [{
+ $global `` custom<TokenOrdering>($consumeTokens, type($produceToken)) `:` type($result) attr-dict
+ }];
+
+ let extraClassDeclaration = [{
+ /// Gets the corresponding GlobalOp (or nullptr).
+ GlobalOp getGlobalOp(SymbolTableCollection &symbolTable);
+ }];
+}
+
//===----------------------------------------------------------------------===//
// GlobalStoreOp
//===----------------------------------------------------------------------===//
@@ -255,7 +300,8 @@ def MLProgram_GlobalStoreOp : MLProgram_Op<"global_store", [
advanced cases.
This op is side effecting and may not be valid to use in graph regions
- without additional consideration to evaluation order constraints.
+ without additional consideration to evaluation order constraints. See
+ `global_store_graph` for op which allows for explicit ordering constraints.
Example:
@@ -266,11 +312,53 @@ def MLProgram_GlobalStoreOp : MLProgram_Op<"global_store", [
let arguments = (ins
Arg<SymbolRefAttr, "", [MemWrite]>:$global,
+ AnyType:$value
+ );
+
+ let assemblyFormat = [{
+ $global `=` $value `:` type($value) attr-dict
+ }];
+
+ let extraClassDeclaration = [{
+ /// Gets the corresponding GlobalOp (or nullptr).
+ GlobalOp getGlobalOp(SymbolTableCollection &symbolTable);
+ }];
+}
+
+//===----------------------------------------------------------------------===//
+// GlobalStoreGraphOp
+//===----------------------------------------------------------------------===//
+
+def MLProgram_GlobalStoreGraphOp : MLProgram_Op<"global_store_graph", [
+ DeclareOpInterfaceMethods<SymbolUserOpInterface>
+ ]> {
+ let summary = "Direct store of a value into a mutable global";
+ let description = [{
+ Performs a non-atomic, non-volatile, non-synchronized store to a mutable
+ global.
+
+ It is fully expected that these constraints are not suitable for
+ all situations, and alternative ops should be defined and used for more
+ advanced cases.
+
+ This op is side effecting and may not be valid to use in graph regions
+ without additional consideration to evaluation order constraints.
+
+ Example:
+
+ ```mlir
+ %token = ml_program.global_store @foobar = %0 : tensor<?xi32>
+ ordering (%in_token -> !ml_program.token) : tensor<?xi32>
+ ```
+ }];
+
+ let arguments = (ins
+ Arg<SymbolRefAttr, "", [MemRead]>:$global,
AnyType:$value,
Variadic<MLProgram_TokenType>:$consumeTokens
);
let results = (outs
- Optional<MLProgram_TokenType>:$produceToken
+ MLProgram_TokenType:$produceToken
);
let assemblyFormat = [{
diff --git a/mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp b/mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp
index 9411ea17b8165..2f1e4b93a6ac3 100644
--- a/mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp
+++ b/mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp
@@ -18,12 +18,11 @@ using namespace mlir::ml_program;
//===----------------------------------------------------------------------===//
/// Parse and print an ordering clause for a variadic of consuming tokens
-/// and an optional producing token.
+/// and an producing token.
///
/// Syntax:
/// ordering(%0, %1 -> !ml_program.token)
/// ordering(() -> !ml_program.token)
-/// ordering(%0, %1)
///
/// If both the consuming and producing token are not present on the op, then
/// the clause prints nothing.
@@ -46,10 +45,11 @@ static ParseResult parseTokenOrdering(
return failure();
}
- // Parse optional producer token.
- if (succeeded(parser.parseOptionalArrow()))
- if (failed(parser.parseType(produceTokenType)))
- return failure();
+ // Parse producer token.
+ if (failed(parser.parseArrow()))
+ return failure();
+ if (failed(parser.parseType(produceTokenType)))
+ return failure();
if (failed(parser.parseRParen()))
return failure();
@@ -220,6 +220,30 @@ GlobalLoadConstOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
return success();
}
+//===----------------------------------------------------------------------===//
+// GlobalLoadGraphOp
+//===----------------------------------------------------------------------===//
+
+GlobalOp GlobalLoadGraphOp::getGlobalOp(SymbolTableCollection &symbolTable) {
+ return symbolTable.lookupNearestSymbolFrom<GlobalOp>(
+ getOperation()->getParentOp(), getGlobalAttr());
+}
+
+LogicalResult
+GlobalLoadGraphOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+ GlobalOp referrent = getGlobalOp(symbolTable);
+ if (!referrent)
+ return emitOpError() << "undefined global: " << getGlobal();
+
+ if (referrent.getType() != getResult().getType()) {
+ return emitOpError() << "cannot load from global typed "
+ << referrent.getType() << " as "
+ << getResult().getType();
+ }
+
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// GlobalStoreOp
//===----------------------------------------------------------------------===//
@@ -249,6 +273,35 @@ GlobalStoreOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
return success();
}
+//===----------------------------------------------------------------------===//
+// GlobalStoreGraphOp
+//===----------------------------------------------------------------------===//
+
+GlobalOp GlobalStoreGraphOp::getGlobalOp(SymbolTableCollection &symbolTable) {
+ return symbolTable.lookupNearestSymbolFrom<GlobalOp>(
+ getOperation()->getParentOp(), getGlobalAttr());
+}
+
+LogicalResult
+GlobalStoreGraphOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+ GlobalOp referrent = getGlobalOp(symbolTable);
+ if (!referrent)
+ return emitOpError() << "undefined global: " << getGlobal();
+
+ if (!referrent.getIsMutable()) {
+ return emitOpError() << "cannot store to an immutable global "
+ << getGlobal();
+ }
+
+ if (referrent.getType() != getValue().getType()) {
+ return emitOpError() << "cannot store to a global typed "
+ << referrent.getType() << " from "
+ << getValue().getType();
+ }
+
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// SubgraphOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/MLProgram/invalid.mlir b/mlir/test/Dialect/MLProgram/invalid.mlir
index e193c6d58f11a..79725a9bfbe18 100644
--- a/mlir/test/Dialect/MLProgram/invalid.mlir
+++ b/mlir/test/Dialect/MLProgram/invalid.mlir
@@ -96,3 +96,17 @@ ml_program.func @store_immutable(%arg0: i64) {
ml_program.global_store @var = %arg0 : i64
ml_program.return
}
+
+// -----
+
+ml_program.global private mutable @global_mutable_undef : tensor<?xi32>
+ml_program.subgraph @global_load_store_tokens() -> (tensor<?xi32>, !ml_program.token) {
+ %token1 = ml_program.token
+ %0, %token2 = ml_program.global_load_graph @global_mutable_undef
+ ordering(() -> !ml_program.token) : tensor<?xi32>
+ %token3 = ml_program.global_store_graph @global_mutable_undef = %0
+ // expected-error @+1 {{expected '->'}}
+ ordering(%token1, %token2) : tensor<?xi32>
+
+ ml_program.output %0, %token3 : tensor<?xi32>, !ml_program.token
+}
diff --git a/mlir/test/Dialect/MLProgram/ops.mlir b/mlir/test/Dialect/MLProgram/ops.mlir
index ca2d72afb6c66..9a48497a3efc8 100644
--- a/mlir/test/Dialect/MLProgram/ops.mlir
+++ b/mlir/test/Dialect/MLProgram/ops.mlir
@@ -45,12 +45,12 @@ ml_program.func @global_load_store() {
// CHECK-LABEL: @global_load_store_tokens
ml_program.subgraph @global_load_store_tokens() -> (tensor<?xi32>, !ml_program.token) {
%token1 = ml_program.token
- %0, %token2 = ml_program.global_load @global_mutable_undef
+ %0, %token2 = ml_program.global_load_graph @global_mutable_undef
ordering(() -> !ml_program.token) : tensor<?xi32>
- %token3 = ml_program.global_store @global_mutable_undef = %0
+ %token3 = ml_program.global_store_graph @global_mutable_undef = %0
ordering(%token1, %token2 -> !ml_program.token) : tensor<?xi32>
- ml_program.global_store @global_mutable_undef = %0
- ordering(%token3) : tensor<?xi32>
+ %token4 = ml_program.global_store_graph @global_mutable_undef = %0
+ ordering(%token3 -> !ml_program.token) : tensor<?xi32>
ml_program.output %0, %token3 : tensor<?xi32>, !ml_program.token
}
More information about the Mlir-commits
mailing list