[Mlir-commits] [mlir] 2bb2528 - [mlir] Add GlobalOp, GlobalLoadConstOp to ml_program.
Stella Laurenzo
llvmlistbot at llvm.org
Wed May 18 23:09:27 PDT 2022
Author: Stella Laurenzo
Date: 2022-05-18T23:08:28-07:00
New Revision: 2bb252852c72a4563fd7cd36604a1698c34d22a8
URL: https://github.com/llvm/llvm-project/commit/2bb252852c72a4563fd7cd36604a1698c34d22a8
DIFF: https://github.com/llvm/llvm-project/commit/2bb252852c72a4563fd7cd36604a1698c34d22a8.diff
LOG: [mlir] Add GlobalOp, GlobalLoadConstOp to ml_program.
The approach I took was to define a dialect 'extern' attribute that a GlobalOp can take as a value to signify external linkage. I think this approach should compose well and should also work with wherever the OpaqueElements work goes in the future (since that is just another kind of attribute). I special cased the GlobalOp parser/printer for this case because it is significantly easier on the eyes.
In the discussion, Jeff Niu had proposed an alternative syntax for GlobalOp that I ended up not taking. I did try to implement it but a) I don't think it made anything easier to read in the common case, and b) it made the parsing/printing logic a lot more complicated (I think I would need a completely custom parser/printer to do it well). Please have a look at the common cases where the global type and initial value type match: I don't think how I have it is too bad. The less common cases seem ok to me.
I chose to only implement the direct, constant load op since that is non side effecting and there was still discussion pending on that.
Differential Revision: https://reviews.llvm.org/D124318
Added:
mlir/include/mlir/Dialect/MLProgram/IR/MLProgramAttributes.h
mlir/include/mlir/Dialect/MLProgram/IR/MLProgramAttributes.td
mlir/test/Dialect/MLProgram/attrs.mlir
Modified:
mlir/include/mlir/Dialect/MLProgram/IR/CMakeLists.txt
mlir/include/mlir/Dialect/MLProgram/IR/MLProgram.h
mlir/include/mlir/Dialect/MLProgram/IR/MLProgramBase.td
mlir/include/mlir/Dialect/MLProgram/IR/MLProgramOps.td
mlir/lib/Dialect/MLProgram/IR/CMakeLists.txt
mlir/lib/Dialect/MLProgram/IR/MLProgramDialect.cpp
mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp
mlir/test/Dialect/MLProgram/invalid.mlir
mlir/test/Dialect/MLProgram/ops.mlir
utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/MLProgram/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/MLProgram/IR/CMakeLists.txt
index fce18e65e952e..80e2f24e465d7 100644
--- a/mlir/include/mlir/Dialect/MLProgram/IR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/MLProgram/IR/CMakeLists.txt
@@ -1,3 +1,10 @@
set(LLVM_TARGET_DEFINITIONS MLProgramOps.td)
add_mlir_dialect(MLProgramOps ml_program)
add_mlir_doc(MLProgramOps MLProgramOps Dialects/ -gen-dialect-doc)
+
+set(LLVM_TARGET_DEFINITIONS MLProgramAttributes.td)
+mlir_tablegen(MLProgramAttributes.h.inc -gen-attrdef-decls)
+mlir_tablegen(MLProgramAttributes.cpp.inc -gen-attrdef-defs)
+add_public_tablegen_target(MLIRMLProgramAttributesIncGen)
+add_dependencies(mlir-headers MLIRMLProgramAttributesIncGen)
+add_mlir_doc(MLProgramAttributes MLProgramAttributes Dialects/ -gen-attrdef-doc)
diff --git a/mlir/include/mlir/Dialect/MLProgram/IR/MLProgram.h b/mlir/include/mlir/Dialect/MLProgram/IR/MLProgram.h
index fad8cbcf1c669..8dbbf8825b84a 100644
--- a/mlir/include/mlir/Dialect/MLProgram/IR/MLProgram.h
+++ b/mlir/include/mlir/Dialect/MLProgram/IR/MLProgram.h
@@ -8,6 +8,7 @@
#ifndef MLIR_DIALECT_MLPROGRAM_IR_MLPROGRAM_H_
#define MLIR_DIALECT_MLPROGRAM_IR_MLPROGRAM_H_
+#include "mlir/Dialect/MLProgram/IR/MLProgramAttributes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/FunctionInterfaces.h"
#include "mlir/IR/OpDefinition.h"
diff --git a/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramAttributes.h b/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramAttributes.h
new file mode 100644
index 0000000000000..253daedcc605a
--- /dev/null
+++ b/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramAttributes.h
@@ -0,0 +1,21 @@
+//===- MLProgramAttributes.h - Attribute Classes ----------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_MLPROGRAM_IR_MLPROGRAMATTRIBUTES_H_
+#define MLIR_DIALECT_MLPROGRAM_IR_MLPROGRAMATTRIBUTES_H_
+
+#include "mlir/IR/Attributes.h"
+
+//===----------------------------------------------------------------------===//
+// Tablegen Attribute Declarations
+//===----------------------------------------------------------------------===//
+
+#define GET_ATTRDEF_CLASSES
+#include "mlir/Dialect/MLProgram/IR/MLProgramAttributes.h.inc"
+
+#endif // MLIR_DIALECT_MLPROGRAM_IR_MLPROGRAMATTRIBUTES_H_
diff --git a/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramAttributes.td b/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramAttributes.td
new file mode 100644
index 0000000000000..1323afec2e9a3
--- /dev/null
+++ b/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramAttributes.td
@@ -0,0 +1,44 @@
+//===- MLProgramAttributed.td - Attr definitions -----------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLPROGRAM_ATTRIBUTES
+#define MLPROGRAM_ATTRIBUTES
+
+include "mlir/IR/AttrTypeBase.td"
+include "mlir/Dialect/MLProgram/IR/MLProgramBase.td"
+
+// Base class for MLProgram dialect attributes.
+class MLProgram_Attr<string name, list<Trait> traits = []>
+ : AttrDef<MLProgram_Dialect, name, traits> {
+ let mnemonic = ?;
+}
+
+//===----------------------------------------------------------------------===//
+// ExternAttr
+//===----------------------------------------------------------------------===//
+
+def MLProgram_ExternAttr : MLProgram_Attr<"Extern"> {
+ let summary = "Value used for a global signalling external resolution";
+ let description = [{
+ When used as the value for a GlobalOp, this indicates that the actual
+ value should be resolved externally in an implementation defined manner.
+ The `sym_name` of the global is the key for locating the value.
+
+ Examples:
+
+ ```mlir
+ extern : tensor<4xi32>
+ ```
+ }];
+
+ let parameters = (ins AttributeSelfTypeParameter<"">:$type);
+ let mnemonic = "extern";
+ let assemblyFormat = "";
+}
+
+#endif // MLPROGRAM_ATTRIBUTES
diff --git a/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramBase.td b/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramBase.td
index b670bc89204c2..ba3781f30ec56 100644
--- a/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramBase.td
+++ b/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramBase.td
@@ -27,6 +27,7 @@ def MLProgram_Dialect : Dialect {
it is recommended to inquire further prior to using this dialect.
}];
+ let useDefaultAttributePrinterParser = 1;
let emitAccessorPrefix = kEmitAccessorPrefix_Prefixed;
}
diff --git a/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramOps.td b/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramOps.td
index b096c3dd53e8c..08e0974ab9b3e 100644
--- a/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramOps.td
+++ b/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramOps.td
@@ -96,6 +96,101 @@ def MLProgram_FuncOp : MLProgram_Op<"func", [
let hasCustomAssemblyFormat = 1;
}
+//===----------------------------------------------------------------------===//
+// GlobalOp
+//===----------------------------------------------------------------------===//
+
+def MLProgram_GlobalOp : MLProgram_Op<"global", [
+ Symbol
+ ]> {
+ let summary = "Module level declaration of a global variable";
+ let description = [{
+ Declares a named global variable (or constant).
+
+ A global contains a value of a specified type which can be accessed at
+ runtime via appropriate load/store operations. It can be mutable or
+ constant, optionally taking an initial value or declared as
+ extern (in which case, the initial value is found in external storage
+ by symbol name).
+
+ Generally, the type of the global and the type of the initial value
+ will be the same. However, for type hierarchies which can have a more
+ generalized bounding type that can be assigned from a narrow type, this
+ is allowed (but not verified).
+
+ Examples:
+
+ ```mlir
+ // Constant global.
+ ml_program.global @foobar(dense<4> : tensor<4xi32>) : tensor<?xi32>
+
+ // Constant with external linkage.
+ ml_program.global mutable @foobar(#ml_program.extern<tensor<4xi32>>)
+ : tensor<?xi32>
+
+ // Mutable global with an undefined initial value.
+ ml_program.global mutable @foobar : tensor<?xi32>
+ ```
+ }];
+
+ let arguments = (ins
+ SymbolNameAttr:$sym_name,
+ TypeAttr:$type,
+ UnitAttr:$is_mutable,
+ OptionalAttr<AnyAttr>:$value,
+ OptionalAttr<StrAttr>:$sym_visibility
+ );
+
+ let assemblyFormat = [{
+ custom<SymbolVisibility>($sym_visibility)
+ (`mutable` $is_mutable^)?
+ $sym_name ``
+ custom<TypedInitialValue>($type, $value)
+ attr-dict
+ }];
+
+ let hasVerifier = 1;
+}
+
+//===----------------------------------------------------------------------===//
+// GlobalLoadConstOp
+//===----------------------------------------------------------------------===//
+
+def MLProgram_GlobalLoadConstOp : MLProgram_Op<"global_load_const", [
+ NoSideEffect,
+ DeclareOpInterfaceMethods<SymbolUserOpInterface>
+ ]> {
+ let summary = "Direct load a constant value from a global";
+ let description = [{
+ Loads a constant (immutable) value from a global directly by symbol.
+
+ This op is only legal for globals that are not mutable and exists because
+ such a load can be considered to have no side effects.
+
+ Example:
+
+ ```mlir
+ %0 = ml_program.global_load_const @foobar : tensor<?xi32>
+ ```
+ }];
+
+ let arguments = (ins
+ FlatSymbolRefAttr:$global
+ );
+ let results = (outs
+ AnyType:$result
+ );
+
+ let assemblyFormat = [{
+ $global attr-dict `:` type($result)
+ }];
+
+ let extraClassDeclaration = [{
+ /// Gets the corresponding GlobalOp (or nullptr).
+ GlobalOp getGlobalOp(SymbolTableCollection &symbolTable);
+ }];
+}
+
//===----------------------------------------------------------------------===//
// SubgraphOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/MLProgram/IR/CMakeLists.txt b/mlir/lib/Dialect/MLProgram/IR/CMakeLists.txt
index 61dba7539908b..a49627d683533 100644
--- a/mlir/lib/Dialect/MLProgram/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/MLProgram/IR/CMakeLists.txt
@@ -7,6 +7,7 @@ add_mlir_dialect_library(MLIRMLProgram
DEPENDS
MLIRMLProgramOpsIncGen
+ MLIRMLProgramAttributesIncGen
LINK_LIBS PUBLIC
MLIRDialect
diff --git a/mlir/lib/Dialect/MLProgram/IR/MLProgramDialect.cpp b/mlir/lib/Dialect/MLProgram/IR/MLProgramDialect.cpp
index eb012bac3984a..0462347609ccc 100644
--- a/mlir/lib/Dialect/MLProgram/IR/MLProgramDialect.cpp
+++ b/mlir/lib/Dialect/MLProgram/IR/MLProgramDialect.cpp
@@ -7,15 +7,42 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
+#include "mlir/IR/DialectImplementation.h"
+#include "llvm/ADT/TypeSwitch.h"
using namespace mlir;
using namespace mlir::ml_program;
+//===----------------------------------------------------------------------===//
+/// Tablegen Definitions
+//===----------------------------------------------------------------------===//
+
#include "mlir/Dialect/MLProgram/IR/MLProgramOpsDialect.cpp.inc"
+#define GET_ATTRDEF_CLASSES
+#include "mlir/Dialect/MLProgram/IR/MLProgramAttributes.cpp.inc"
+
+namespace {
+struct MLProgramOpAsmDialectInterface : public OpAsmDialectInterface {
+ using OpAsmDialectInterface::OpAsmDialectInterface;
+
+ AliasResult getAlias(Attribute attr, raw_ostream &os) const override {
+ if (attr.isa<ExternAttr>()) {
+ os << "extern";
+ return AliasResult::OverridableAlias;
+ }
+ return AliasResult::NoAlias;
+ }
+};
+} // namespace
void ml_program::MLProgramDialect::initialize() {
+#define GET_ATTRDEF_LIST
+ addAttributes<
+#include "mlir/Dialect/MLProgram/IR/MLProgramAttributes.cpp.inc"
+ >();
addOperations<
#define GET_OP_LIST
#include "mlir/Dialect/MLProgram/IR/MLProgramOps.cpp.inc"
>();
+ addInterfaces<MLProgramOpAsmDialectInterface>();
}
diff --git a/mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp b/mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp
index 4d8038c21f17f..8c8a4591addc0 100644
--- a/mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp
+++ b/mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp
@@ -13,6 +13,69 @@
using namespace mlir;
using namespace mlir::ml_program;
+//===----------------------------------------------------------------------===//
+// Custom asm helpers
+//===----------------------------------------------------------------------===//
+
+/// some.op custom<TypeOrAttr>($type, $attr)
+///
+/// Uninitialized:
+/// some.op : tensor<3xi32>
+/// Initialized to narrower type than op:
+/// some.op (dense<0> : tensor<3xi32>) : tensor<?xi32>
+static ParseResult parseTypedInitialValue(OpAsmParser &parser,
+ TypeAttr &typeAttr, Attribute &attr) {
+ if (succeeded(parser.parseOptionalLParen())) {
+ if (failed(parser.parseAttribute(attr)))
+ return failure();
+ if (failed(parser.parseRParen()))
+ return failure();
+ }
+
+ Type type;
+ if (failed(parser.parseColonType(type)))
+ return failure();
+ typeAttr = TypeAttr::get(type);
+ return success();
+}
+
+static void printTypedInitialValue(OpAsmPrinter &p, Operation *op,
+ TypeAttr type, Attribute attr) {
+ if (attr) {
+ p << "(";
+ p.printAttribute(attr);
+ p << ")";
+ }
+
+ p << " : ";
+ p.printAttribute(type);
+}
+
+/// some.op custom<SymbolVisibility>($sym_visibility) $sym_name
+/// ->
+/// some.op public @foo
+/// some.op private @foo
+static ParseResult parseSymbolVisibility(OpAsmParser &parser,
+ StringAttr &symVisibilityAttr) {
+ StringRef symVisibility;
+ (void)parser.parseOptionalKeyword(&symVisibility,
+ {"public", "private", "nested"});
+ if (symVisibility.empty())
+ return parser.emitError(parser.getCurrentLocation())
+ << "expected 'public', 'private', or 'nested'";
+ if (!symVisibility.empty())
+ symVisibilityAttr = parser.getBuilder().getStringAttr(symVisibility);
+ return success();
+}
+
+static void printSymbolVisibility(OpAsmPrinter &p, Operation *op,
+ StringAttr symVisibilityAttr) {
+ if (!symVisibilityAttr)
+ p << "public";
+ else
+ p << symVisibilityAttr.getValue();
+}
+
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
@@ -38,6 +101,43 @@ void FuncOp::print(OpAsmPrinter &p) {
function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false);
}
+//===----------------------------------------------------------------------===//
+// GlobalOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult GlobalOp::verify() {
+ if (!getIsMutable() && !getValue())
+ return emitOpError() << "immutable global must have an initial value";
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// GlobalLoadConstOp
+//===----------------------------------------------------------------------===//
+
+GlobalOp GlobalLoadConstOp::getGlobalOp(SymbolTableCollection &symbolTable) {
+ return symbolTable.lookupNearestSymbolFrom<GlobalOp>(
+ getOperation()->getParentOp(), getGlobalAttr());
+}
+
+LogicalResult
+GlobalLoadConstOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+ GlobalOp referrent = getGlobalOp(symbolTable);
+ if (!referrent)
+ return emitOpError() << "undefined global: " << getGlobal();
+
+ if (referrent.getIsMutable())
+ return emitOpError() << "cannot load as const from mutable global "
+ << getGlobal();
+
+ if (referrent.getType() != getResult().getType())
+ return emitOpError() << "cannot load from global typed "
+ << referrent.getType() << " as "
+ << getResult().getType();
+
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// SubgraphOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/MLProgram/attrs.mlir b/mlir/test/Dialect/MLProgram/attrs.mlir
new file mode 100644
index 0000000000000..8e3c06ae10f7b
--- /dev/null
+++ b/mlir/test/Dialect/MLProgram/attrs.mlir
@@ -0,0 +1,7 @@
+// RUN: mlir-opt %s --allow-unregistered-dialect | mlir-opt --allow-unregistered-dialect | FileCheck %s
+
+// CHECK: #ml_program.extern : i32
+"unregistered.attributes"() {
+ value = #ml_program.extern : i32
+} : () -> ()
+
diff --git a/mlir/test/Dialect/MLProgram/invalid.mlir b/mlir/test/Dialect/MLProgram/invalid.mlir
index 851998a8326a2..969d88766e822 100644
--- a/mlir/test/Dialect/MLProgram/invalid.mlir
+++ b/mlir/test/Dialect/MLProgram/invalid.mlir
@@ -31,3 +31,30 @@ ml_program.subgraph @output_type_match(%arg0 : i64) -> i32 {
// expected-error @+1 {{doesn't match function result}}
ml_program.output %arg0 : i64
}
+
+// -----
+// expected-error @+1 {{immutable global must have an initial value}}
+ml_program.global private @const : i32
+
+// -----
+ml_program.func @undef_global() -> i32 {
+ // expected-error @+1 {{undefined global: nothere}}
+ %0 = ml_program.global_load_const @nothere : i32
+ ml_program.return %0 : i32
+}
+
+// -----
+ml_program.global private mutable @var : i32
+ml_program.func @mutable_const_load() -> i32 {
+ // expected-error @+1 {{op cannot load as const from mutable global var}}
+ %0 = ml_program.global_load_const @var : i32
+ ml_program.return %0 : i32
+}
+
+// -----
+ml_program.global private @var(42 : i64) : i64
+ml_program.func @const_load_type_mismatch() -> i32 {
+ // expected-error @+1 {{cannot load from global typed 'i64' as 'i32'}}
+ %0 = ml_program.global_load_const @var : i32
+ ml_program.return %0 : i32
+}
diff --git a/mlir/test/Dialect/MLProgram/ops.mlir b/mlir/test/Dialect/MLProgram/ops.mlir
index 24f5f8af1be7c..90c65a7714bd2 100644
--- a/mlir/test/Dialect/MLProgram/ops.mlir
+++ b/mlir/test/Dialect/MLProgram/ops.mlir
@@ -18,3 +18,12 @@ ml_program.subgraph @compute_subgraph(%arg0 : i32) -> i32 {
%0 = "unregistered.dummy"(%arg0) : (i32) -> i32
ml_program.output %0 : i32
}
+
+// CHECK: ml_program.global private @global_same_type(dense<4> : tensor<4xi32>) : tensor<4xi32>
+ml_program.global private @global_same_type(dense<4> : tensor<4xi32>) : tensor<4xi32>
+
+// CHECK: ml_program.global private mutable @global_mutable_undef : tensor<?xi32>
+ml_program.global private mutable @global_mutable_undef : tensor<?xi32>
+
+// CHECK: ml_program.global private mutable @global_extern(#extern) : tensor<?xi32>
+ml_program.global private mutable @global_extern(#ml_program.extern : tensor<4xi32>) : tensor<?xi32>
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index c9f659bff04c9..53d8467739ef7 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -8560,6 +8560,7 @@ td_library(
name = "MLProgramOpsTdFiles",
srcs = [
"include/mlir/Dialect/MLProgram/IR/MLProgramBase.td",
+ "include/mlir/Dialect/MLProgram/IR/MLProgramAttributes.td",
"include/mlir/Dialect/MLProgram/IR/MLProgramOps.td",
],
includes = ["include"],
@@ -8599,6 +8600,24 @@ gentbl_cc_library(
deps = [":MLProgramOpsTdFiles"],
)
+gentbl_cc_library(
+ name = "MLProgramAttributesIncGen",
+ strip_include_prefix = "include",
+ tbl_outs = [
+ (
+ ["-gen-attrdef-decls"],
+ "include/mlir/Dialect/MLProgram/IR/MLProgramAttributes.h.inc",
+ ),
+ (
+ ["-gen-attrdef-defs"],
+ "include/mlir/Dialect/MLProgram/IR/MLProgramAttributes.cpp.inc",
+ ),
+ ],
+ tblgen = ":mlir-tblgen",
+ td_file = "include/mlir/Dialect/MLProgram/IR/MLProgramAttributes.td",
+ deps = [":MLProgramOpsTdFiles"],
+)
+
cc_library(
name = "MLProgramDialect",
srcs = glob([
@@ -8612,6 +8631,7 @@ cc_library(
deps = [
":ControlFlowInterfaces",
":IR",
+ ":MLProgramAttributesIncGen",
":MLProgramOpsIncGen",
":Pass",
":Support",
More information about the Mlir-commits
mailing list