[Mlir-commits] [mlir] d30c022 - [mlir] Split MLProgram global load and store to Graph variants

Jacques Pienaar llvmlistbot at llvm.org
Thu Jun 16 20:02:02 PDT 2022


Author: Jacques Pienaar
Date: 2022-06-16T20:01:54-07:00
New Revision: d30c0221cf5aa36c079b7cc0d36fb89f7b32149b

URL: https://github.com/llvm/llvm-project/commit/d30c0221cf5aa36c079b7cc0d36fb89f7b32149b
DIFF: https://github.com/llvm/llvm-project/commit/d30c0221cf5aa36c079b7cc0d36fb89f7b32149b.diff

LOG: [mlir] Split MLProgram global load and store to Graph variants

* Split ops into X_graph variants as discussed;
* Remove tokens from non-Graph region variants and rely on side-effect
  modelling there while removing side-effect modelling from Graph
  variants and relying on explicit ordering there;
* Make tokens required to be produced by Graph variants - but kept
  explicit token type specification given previous discussion on this
  potentially being configurable in future;

This results in duplicating some code. I considered adding helper
functions but decided against adding an abstraction there early given
size of duplication and creating accidental coupling.

Differential Revision: https://reviews.llvm.org/D127813

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/MLProgram/IR/MLProgramOps.td
    mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp
    mlir/test/Dialect/MLProgram/invalid.mlir
    mlir/test/Dialect/MLProgram/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramOps.td b/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramOps.td
index d820573b200ab..69b1eab379b3c 100644
--- a/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramOps.td
+++ b/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramOps.td
@@ -171,7 +171,8 @@ def MLProgram_GlobalLoadOp : MLProgram_Op<"global_load", [
     advanced cases.
 
     This op is side effecting and may not be valid to use in graph regions
-    without additional consideration to evaluation order constraints.
+    without additional consideration to evaluation order constraints. See
+    `global_load_graph` for op which allows for explicit ordering constraints.
 
     Example:
 
@@ -181,16 +182,14 @@ def MLProgram_GlobalLoadOp : MLProgram_Op<"global_load", [
   }];
 
   let arguments = (ins
-    Arg<SymbolRefAttr, "", [MemRead]>:$global,
-    Variadic<MLProgram_TokenType>:$consumeTokens
+    Arg<SymbolRefAttr, "", [MemRead]>:$global
   );
   let results = (outs
-    AnyType:$result,
-    Optional<MLProgram_TokenType>:$produceToken
+    AnyType:$result
   );
 
   let assemblyFormat = [{
-    $global `` custom<TokenOrdering>($consumeTokens, type($produceToken)) `:` type($result) attr-dict
+    $global `:` type($result) attr-dict
   }];
 
   let extraClassDeclaration = [{
@@ -238,6 +237,52 @@ def MLProgram_GlobalLoadConstOp : MLProgram_Op<"global_load_const", [
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// GlobalLoadGraphOp
+//===----------------------------------------------------------------------===//
+
+def MLProgram_GlobalLoadGraphOp : MLProgram_Op<"global_load_graph", [
+    DeclareOpInterfaceMethods<SymbolUserOpInterface>
+  ]> {
+  let summary = "Direct load of a mutable value from a global in Graph region";
+  let description = [{
+    Performs a non-atomic, non-volatile, non-synchronized load from a global
+    that may be mutable.
+
+    It is fully expected that these constraints are not suitable for all
+    situations, and alternative ops should be defined and used for more advanced
+    cases.
+
+    This op is side effecting and may not be valid to use in graph regions
+    without additional consideration to evaluation order constraints.
+
+    Example:
+
+    ```mlir
+    %0, %cstr = ml_program.global_load_graph @foobar
+      ordering (%token -> !ml_program.token) : tensor<?xi32>
+    ```
+  }];
+
+  let arguments = (ins
+    Arg<SymbolRefAttr, "", [MemRead]>:$global,
+    Variadic<MLProgram_TokenType>:$consumeTokens
+  );
+  let results = (outs
+    AnyType:$result,
+    MLProgram_TokenType:$produceToken
+  );
+
+  let assemblyFormat = [{
+    $global `` custom<TokenOrdering>($consumeTokens, type($produceToken)) `:` type($result) attr-dict
+  }];
+
+  let extraClassDeclaration = [{
+    /// Gets the corresponding GlobalOp (or nullptr).
+    GlobalOp getGlobalOp(SymbolTableCollection &symbolTable);
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // GlobalStoreOp
 //===----------------------------------------------------------------------===//
@@ -255,7 +300,8 @@ def MLProgram_GlobalStoreOp : MLProgram_Op<"global_store", [
     advanced cases.
 
     This op is side effecting and may not be valid to use in graph regions
-    without additional consideration to evaluation order constraints.
+    without additional consideration to evaluation order constraints. See
+    `global_store_graph` for op which allows for explicit ordering constraints.
 
     Example:
 
@@ -266,11 +312,53 @@ def MLProgram_GlobalStoreOp : MLProgram_Op<"global_store", [
 
   let arguments = (ins
     Arg<SymbolRefAttr, "", [MemWrite]>:$global,
+    AnyType:$value
+  );
+
+  let assemblyFormat = [{
+    $global `=` $value `:` type($value) attr-dict
+  }];
+
+  let extraClassDeclaration = [{
+    /// Gets the corresponding GlobalOp (or nullptr).
+    GlobalOp getGlobalOp(SymbolTableCollection &symbolTable);
+  }];
+}
+
+//===----------------------------------------------------------------------===//
+// GlobalStoreGraphOp
+//===----------------------------------------------------------------------===//
+
+def MLProgram_GlobalStoreGraphOp : MLProgram_Op<"global_store_graph", [
+    DeclareOpInterfaceMethods<SymbolUserOpInterface>
+  ]> {
+  let summary = "Direct store of a value into a mutable global";
+  let description = [{
+    Performs a non-atomic, non-volatile, non-synchronized store to a mutable
+    global.
+
+    It is fully expected that these constraints are not suitable for
+    all situations, and alternative ops should be defined and used for more
+    advanced cases.
+
+    This op is side effecting and may not be valid to use in graph regions
+    without additional consideration to evaluation order constraints.
+
+    Example:
+
+    ```mlir
+    %token = ml_program.global_store @foobar = %0 : tensor<?xi32>
+      ordering (%in_token -> !ml_program.token) : tensor<?xi32>
+    ```
+  }];
+
+  let arguments = (ins
+    Arg<SymbolRefAttr, "", [MemRead]>:$global,
     AnyType:$value,
     Variadic<MLProgram_TokenType>:$consumeTokens
   );
   let results = (outs
-    Optional<MLProgram_TokenType>:$produceToken
+    MLProgram_TokenType:$produceToken
   );
 
   let assemblyFormat = [{

diff  --git a/mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp b/mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp
index 9411ea17b8165..2f1e4b93a6ac3 100644
--- a/mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp
+++ b/mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp
@@ -18,12 +18,11 @@ using namespace mlir::ml_program;
 //===----------------------------------------------------------------------===//
 
 /// Parse and print an ordering clause for a variadic of consuming tokens
-/// and an optional producing token.
+/// and an producing token.
 ///
 /// Syntax:
 ///   ordering(%0, %1 -> !ml_program.token)
 ///   ordering(() -> !ml_program.token)
-///   ordering(%0, %1)
 ///
 /// If both the consuming and producing token are not present on the op, then
 /// the clause prints nothing.
@@ -46,10 +45,11 @@ static ParseResult parseTokenOrdering(
       return failure();
   }
 
-  // Parse optional producer token.
-  if (succeeded(parser.parseOptionalArrow()))
-    if (failed(parser.parseType(produceTokenType)))
-      return failure();
+  // Parse producer token.
+  if (failed(parser.parseArrow()))
+    return failure();
+  if (failed(parser.parseType(produceTokenType)))
+    return failure();
 
   if (failed(parser.parseRParen()))
     return failure();
@@ -220,6 +220,30 @@ GlobalLoadConstOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// GlobalLoadGraphOp
+//===----------------------------------------------------------------------===//
+
+GlobalOp GlobalLoadGraphOp::getGlobalOp(SymbolTableCollection &symbolTable) {
+  return symbolTable.lookupNearestSymbolFrom<GlobalOp>(
+      getOperation()->getParentOp(), getGlobalAttr());
+}
+
+LogicalResult
+GlobalLoadGraphOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+  GlobalOp referrent = getGlobalOp(symbolTable);
+  if (!referrent)
+    return emitOpError() << "undefined global: " << getGlobal();
+
+  if (referrent.getType() != getResult().getType()) {
+    return emitOpError() << "cannot load from global typed "
+                         << referrent.getType() << " as "
+                         << getResult().getType();
+  }
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // GlobalStoreOp
 //===----------------------------------------------------------------------===//
@@ -249,6 +273,35 @@ GlobalStoreOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// GlobalStoreGraphOp
+//===----------------------------------------------------------------------===//
+
+GlobalOp GlobalStoreGraphOp::getGlobalOp(SymbolTableCollection &symbolTable) {
+  return symbolTable.lookupNearestSymbolFrom<GlobalOp>(
+      getOperation()->getParentOp(), getGlobalAttr());
+}
+
+LogicalResult
+GlobalStoreGraphOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+  GlobalOp referrent = getGlobalOp(symbolTable);
+  if (!referrent)
+    return emitOpError() << "undefined global: " << getGlobal();
+
+  if (!referrent.getIsMutable()) {
+    return emitOpError() << "cannot store to an immutable global "
+                         << getGlobal();
+  }
+
+  if (referrent.getType() != getValue().getType()) {
+    return emitOpError() << "cannot store to a global typed "
+                         << referrent.getType() << " from "
+                         << getValue().getType();
+  }
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // SubgraphOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/MLProgram/invalid.mlir b/mlir/test/Dialect/MLProgram/invalid.mlir
index e193c6d58f11a..79725a9bfbe18 100644
--- a/mlir/test/Dialect/MLProgram/invalid.mlir
+++ b/mlir/test/Dialect/MLProgram/invalid.mlir
@@ -96,3 +96,17 @@ ml_program.func @store_immutable(%arg0: i64) {
   ml_program.global_store @var = %arg0 : i64
   ml_program.return
 }
+
+// -----
+
+ml_program.global private mutable @global_mutable_undef : tensor<?xi32>
+ml_program.subgraph @global_load_store_tokens() -> (tensor<?xi32>, !ml_program.token) {
+  %token1 = ml_program.token
+  %0, %token2 = ml_program.global_load_graph @global_mutable_undef
+      ordering(() -> !ml_program.token) : tensor<?xi32>
+  %token3 = ml_program.global_store_graph @global_mutable_undef = %0
+  // expected-error @+1 {{expected '->'}}
+      ordering(%token1, %token2) : tensor<?xi32>
+
+  ml_program.output %0, %token3 : tensor<?xi32>, !ml_program.token
+}

diff  --git a/mlir/test/Dialect/MLProgram/ops.mlir b/mlir/test/Dialect/MLProgram/ops.mlir
index ca2d72afb6c66..9a48497a3efc8 100644
--- a/mlir/test/Dialect/MLProgram/ops.mlir
+++ b/mlir/test/Dialect/MLProgram/ops.mlir
@@ -45,12 +45,12 @@ ml_program.func @global_load_store() {
 // CHECK-LABEL: @global_load_store_tokens
 ml_program.subgraph @global_load_store_tokens() -> (tensor<?xi32>, !ml_program.token) {
   %token1 = ml_program.token
-  %0, %token2 = ml_program.global_load @global_mutable_undef
+  %0, %token2 = ml_program.global_load_graph @global_mutable_undef
       ordering(() -> !ml_program.token) : tensor<?xi32>
-  %token3 = ml_program.global_store @global_mutable_undef = %0
+  %token3 = ml_program.global_store_graph @global_mutable_undef = %0
       ordering(%token1, %token2 -> !ml_program.token) : tensor<?xi32>
-  ml_program.global_store @global_mutable_undef = %0
-      ordering(%token3) : tensor<?xi32>
+  %token4 = ml_program.global_store_graph @global_mutable_undef = %0
+      ordering(%token3 -> !ml_program.token) : tensor<?xi32>
 
   ml_program.output %0, %token3 : tensor<?xi32>, !ml_program.token
 }


        


More information about the Mlir-commits mailing list