[Mlir-commits] [mlir] 20dca52 - [mlir][SideEffects] Enable specifying side effects directly on the arguments/results of an operation.
River Riddle
llvmlistbot at llvm.org
Fri Mar 6 14:05:10 PST 2020
Author: River Riddle
Date: 2020-03-06T14:04:36-08:00
New Revision: 20dca52288adfb64cdeb25fd25fbe0bb8628e7c3
URL: https://github.com/llvm/llvm-project/commit/20dca52288adfb64cdeb25fd25fbe0bb8628e7c3
DIFF: https://github.com/llvm/llvm-project/commit/20dca52288adfb64cdeb25fd25fbe0bb8628e7c3.diff
LOG: [mlir][SideEffects] Enable specifying side effects directly on the arguments/results of an operation.
Summary:
New classes are added to ODS to enable specifying additional information on the arguments and results of an operation. These classes, `Arg` and `Res` allow for adding a description and a set of 'decorators' along with the constraint. This enables specifying the side effects of an operation directly on the arguments and results themselves.
Example:
```
def LoadOp : Std_Op<"load"> {
let arguments = (ins Arg<AnyMemRef, "the MemRef to load from",
[MemRead]>:$memref,
Variadic<Index>:$indices);
}
```
Differential Revision: https://reviews.llvm.org/D74440
Added:
mlir/include/mlir/TableGen/SideEffects.h
mlir/lib/TableGen/SideEffects.cpp
mlir/test/mlir-tblgen/op-side-effects.td
Modified:
mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
mlir/include/mlir/IR/OpBase.td
mlir/include/mlir/IR/SideEffects.td
mlir/include/mlir/TableGen/Operator.h
mlir/lib/TableGen/CMakeLists.txt
mlir/lib/TableGen/Operator.cpp
mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index e6e28e8d00a6..50735d54f19a 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -16,6 +16,7 @@
include "mlir/Analysis/CallInterfaces.td"
include "mlir/Analysis/ControlFlowInterfaces.td"
include "mlir/IR/OpAsmInterface.td"
+include "mlir/IR/SideEffects.td"
def Std_Dialect : Dialect {
let name = "std";
@@ -1052,7 +1053,9 @@ def LoadOp : Std_Op<"load",
%3 = load %0[%1, %1] : memref<4x4xi32>
}];
- let arguments = (ins AnyMemRef:$memref, Variadic<Index>:$indices);
+ let arguments = (ins Arg<AnyMemRef, "the reference to load from",
+ [MemRead]>:$memref,
+ Variadic<Index>:$indices);
let results = (outs AnyType:$result);
let builders = [OpBuilder<
@@ -1563,8 +1566,10 @@ def StoreOp : Std_Op<"store",
store %v, %A[%i, %j] : memref<4x128xf32, (d0, d1) -> (d0, d1), 0>
}];
- let arguments = (ins AnyType:$value, AnyMemRef:$memref,
- Variadic<Index>:$indices);
+ let arguments = (ins AnyType:$value,
+ Arg<AnyMemRef, "the reference to store to",
+ [MemWrite]>:$memref,
+ Variadic<Index>:$indices);
let builders = [OpBuilder<
"Builder *, OperationState &result, Value valueToStore, Value memref", [{
@@ -1846,7 +1851,8 @@ def TensorLoadOp : Std_Op<"tensor_load",
%12 = tensor_load %10 : memref<4x?xf32, #layout, memspace0>
}];
- let arguments = (ins AnyMemRef:$memref);
+ let arguments = (ins Arg<AnyMemRef, "the reference to load from",
+ [MemRead]>:$memref);
let results = (outs AnyTensor:$result);
// TensorLoadOp is fully verified by traits.
let verifier = ?;
@@ -1890,7 +1896,9 @@ def TensorStoreOp : Std_Op<"tensor_store",
tensor_store %8, %10 : memref<4x?xf32, #layout, memspace0>
}];
- let arguments = (ins AnyTensor:$tensor, AnyMemRef:$memref);
+ let arguments = (ins AnyTensor:$tensor,
+ Arg<AnyMemRef, "the reference to store to",
+ [MemWrite]>:$memref);
// TensorStoreOp is fully verified by traits.
let verifier = ?;
diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index 1c51fc86b444..419ca0d1f63f 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -1712,6 +1712,29 @@ class OpBuilder<string p, code b = ""> {
code body = b;
}
+// A base decorator class that may optionally be added to OpVariables.
+class OpVariableDecorator;
+
+// Class for providing additional information on the variables, i.e. arguments
+// and results, of an operation.
+class OpVariable<Constraint varConstraint, string desc = "",
+ list<OpVariableDecorator> varDecorators = []> {
+ // The constraint, either attribute or type, of the argument.
+ Constraint constraint = varConstraint;
+
+ // A description for the argument.
+ string description = desc;
+
+ // The list of decorators for this variable, e.g. side effects.
+ list<OpVariableDecorator> decorators = varDecorators;
+}
+class Arg<Constraint constraint, string desc = "",
+ list<OpVariableDecorator> decorators = []>
+ : OpVariable<constraint, desc, decorators>;
+class Res<Constraint constraint, string desc = "",
+ list<OpVariableDecorator> decorators = []>
+ : OpVariable<constraint, desc, decorators>;
+
// Base class for all ops.
class Op<Dialect dialect, string mnemonic, list<OpTrait> props = []> {
// The dialect of the op.
diff --git a/mlir/include/mlir/IR/SideEffects.td b/mlir/include/mlir/IR/SideEffects.td
index 9d06348c98b7..04d2bfe5127c 100644
--- a/mlir/include/mlir/IR/SideEffects.td
+++ b/mlir/include/mlir/IR/SideEffects.td
@@ -107,7 +107,7 @@ class EffectOpInterfaceBase<string name, string baseEffect>
// This class is the general base side effect class. This is used by derived
// effect interfaces to define their effects.
class SideEffect<EffectOpInterfaceBase interface, string effectName,
- string resourceName> {
+ string resourceName> : OpVariableDecorator {
/// The parent interface that the effect belongs to.
string interfaceTrait = interface.trait;
diff --git a/mlir/include/mlir/TableGen/Operator.h b/mlir/include/mlir/TableGen/Operator.h
index e83b25231a87..08af550499a7 100644
--- a/mlir/include/mlir/TableGen/Operator.h
+++ b/mlir/include/mlir/TableGen/Operator.h
@@ -57,6 +57,34 @@ class Operator {
// Returns this op's C++ class name prefixed with namespaces.
std::string getQualCppClassName() const;
+ /// A class used to represent the decorators of an operator variable, i.e.
+ /// argument or result.
+ struct VariableDecorator {
+ public:
+ explicit VariableDecorator(const llvm::Record *def) : def(def) {}
+ const llvm::Record &getDef() const { return *def; }
+
+ protected:
+ // The TableGen definition of this decorator.
+ const llvm::Record *def;
+ };
+
+ // A utility iterator over a list of variable decorators.
+ struct VariableDecoratorIterator
+ : public llvm::mapped_iterator<llvm::Init *const *,
+ VariableDecorator (*)(llvm::Init *)> {
+ using reference = VariableDecorator;
+
+ /// Initializes the iterator to the specified iterator.
+ VariableDecoratorIterator(llvm::Init *const *it)
+ : llvm::mapped_iterator<llvm::Init *const *,
+ VariableDecorator (*)(llvm::Init *)>(it,
+ &unwrap) {}
+ static VariableDecorator unwrap(llvm::Init *init);
+ };
+ using var_decorator_iterator = VariableDecoratorIterator;
+ using var_decorator_range = llvm::iterator_range<VariableDecoratorIterator>;
+
using value_iterator = NamedTypeConstraint *;
using value_range = llvm::iterator_range<value_iterator>;
@@ -84,6 +112,8 @@ class Operator {
TypeConstraint getResultTypeConstraint(int index) const;
// Returns the `index`-th result's name.
StringRef getResultName(int index) const;
+ // Returns the `index`-th result's decorators.
+ var_decorator_range getResultDecorators(int index) const;
// Returns the number of variadic results in this operation.
unsigned getNumVariadicResults() const;
@@ -128,6 +158,7 @@ class Operator {
// Op argument (attribute or operand) accessors.
Argument getArg(int index) const;
StringRef getArgName(int index) const;
+ var_decorator_range getArgDecorators(int index) const;
// Returns the trait wrapper for the given MLIR C++ `trait`.
// TODO: We should add a C++ wrapper class for TableGen OpTrait instead of
diff --git a/mlir/include/mlir/TableGen/SideEffects.h b/mlir/include/mlir/TableGen/SideEffects.h
new file mode 100644
index 000000000000..c93502cc7a7a
--- /dev/null
+++ b/mlir/include/mlir/TableGen/SideEffects.h
@@ -0,0 +1,55 @@
+//===- SideEffects.h - Side Effects classes ---------------------*- C++ -*-===//
+//
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Wrapper around side effect related classes defined in TableGen.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TABLEGEN_SIDEEFFECTS_H_
+#define MLIR_TABLEGEN_SIDEEFFECTS_H_
+
+#include "mlir/Support/LLVM.h"
+#include "mlir/TableGen/Operator.h"
+
+namespace mlir {
+namespace tblgen {
+
+// This class represents a specific instance of an effect that is being
+// exhibited.
+class SideEffect : public Operator::VariableDecorator {
+public:
+ // Return the name of the C++ effect.
+ StringRef getName() const;
+
+ // Return the name of the base C++ effect.
+ StringRef getBaseName() const;
+
+ // Return the name of the parent interface trait.
+ StringRef getInterfaceTrait() const;
+
+ // Return the name of the resource class.
+ StringRef getResource() const;
+
+ static bool classof(const Operator::VariableDecorator *var);
+};
+
+// This class represents an instance of a side effect interface applied to an
+// operation. This is a wrapper around an OpInterfaceTrait that also includes
+// the effects that are applied.
+class SideEffectTrait : public InterfaceOpTrait {
+public:
+ // Return the effects that are attached to the side effect interface.
+ Operator::var_decorator_range getEffects() const;
+
+ static bool classof(const OpTrait *t);
+};
+
+} // end namespace tblgen
+} // end namespace mlir
+
+#endif // MLIR_TABLEGEN_SIDEEFFECTS_H_
diff --git a/mlir/lib/TableGen/CMakeLists.txt b/mlir/lib/TableGen/CMakeLists.txt
index 6e3bf27720d2..4c6ac720f0ea 100644
--- a/mlir/lib/TableGen/CMakeLists.txt
+++ b/mlir/lib/TableGen/CMakeLists.txt
@@ -10,6 +10,7 @@ add_llvm_library(LLVMMLIRTableGen
OpTrait.cpp
Pattern.cpp
Predicate.cpp
+ SideEffects.cpp
Successor.cpp
Type.cpp
diff --git a/mlir/lib/TableGen/Operator.cpp b/mlir/lib/TableGen/Operator.cpp
index 007d8ab7bd69..6492b772e4b7 100644
--- a/mlir/lib/TableGen/Operator.cpp
+++ b/mlir/lib/TableGen/Operator.cpp
@@ -109,6 +109,15 @@ StringRef tblgen::Operator::getResultName(int index) const {
return results->getArgNameStr(index);
}
+auto tblgen::Operator::getResultDecorators(int index) const
+ -> var_decorator_range {
+ Record *result =
+ cast<DefInit>(def.getValueAsDag("results")->getArg(index))->getDef();
+ if (!result->isSubClassOf("OpVariable"))
+ return var_decorator_range(nullptr, nullptr);
+ return *result->getValueAsListInit("decorators");
+}
+
unsigned tblgen::Operator::getNumVariadicResults() const {
return std::count_if(
results.begin(), results.end(),
@@ -138,6 +147,15 @@ StringRef tblgen::Operator::getArgName(int index) const {
return argumentValues->getArgName(index)->getValue();
}
+auto tblgen::Operator::getArgDecorators(int index) const
+ -> var_decorator_range {
+ Record *arg =
+ cast<DefInit>(def.getValueAsDag("arguments")->getArg(index))->getDef();
+ if (!arg->isSubClassOf("OpVariable"))
+ return var_decorator_range(nullptr, nullptr);
+ return *arg->getValueAsListInit("decorators");
+}
+
const tblgen::OpTrait *tblgen::Operator::getTrait(StringRef trait) const {
for (const auto &t : traits) {
if (auto opTrait = dyn_cast<tblgen::NativeOpTrait>(&t)) {
@@ -226,6 +244,7 @@ void tblgen::Operator::populateOpStructure() {
auto typeConstraintClass = recordKeeper.getClass("TypeConstraint");
auto attrClass = recordKeeper.getClass("Attr");
auto derivedAttrClass = recordKeeper.getClass("DerivedAttr");
+ auto opVarClass = recordKeeper.getClass("OpVariable");
numNativeAttributes = 0;
DagInit *argumentValues = def.getValueAsDag("arguments");
@@ -240,10 +259,12 @@ void tblgen::Operator::populateOpStructure() {
PrintFatalError(def.getLoc(),
Twine("undefined type for argument #") + Twine(i));
Record *argDef = argDefInit->getDef();
+ if (argDef->isSubClassOf(opVarClass))
+ argDef = argDef->getValueAsDef("constraint");
if (argDef->isSubClassOf(typeConstraintClass)) {
operands.push_back(
- NamedTypeConstraint{givenName, TypeConstraint(argDefInit)});
+ NamedTypeConstraint{givenName, TypeConstraint(argDef)});
} else if (argDef->isSubClassOf(attrClass)) {
if (givenName.empty())
PrintFatalError(argDef->getLoc(), "attributes must be named");
@@ -285,6 +306,8 @@ void tblgen::Operator::populateOpStructure() {
int operandIndex = 0, attrIndex = 0;
for (unsigned i = 0; i != numArgs; ++i) {
Record *argDef = dyn_cast<DefInit>(argumentValues->getArg(i))->getDef();
+ if (argDef->isSubClassOf(opVarClass))
+ argDef = argDef->getValueAsDef("constraint");
if (argDef->isSubClassOf(typeConstraintClass)) {
arguments.emplace_back(&operands[operandIndex++]);
@@ -303,11 +326,14 @@ void tblgen::Operator::populateOpStructure() {
// Handle results.
for (unsigned i = 0, e = resultsDag->getNumArgs(); i < e; ++i) {
auto name = resultsDag->getArgNameStr(i);
- auto *resultDef = dyn_cast<DefInit>(resultsDag->getArg(i));
- if (!resultDef) {
+ auto *resultInit = dyn_cast<DefInit>(resultsDag->getArg(i));
+ if (!resultInit) {
PrintFatalError(def.getLoc(),
Twine("undefined type for result #") + Twine(i));
}
+ auto *resultDef = resultInit->getDef();
+ if (resultDef->isSubClassOf(opVarClass))
+ resultDef = resultDef->getValueAsDef("constraint");
results.push_back({name, TypeConstraint(resultDef)});
}
@@ -394,3 +420,8 @@ void tblgen::Operator::print(llvm::raw_ostream &os) const {
os << "[operand] " << arg.get<NamedTypeConstraint *>()->name << '\n';
}
}
+
+auto tblgen::Operator::VariableDecoratorIterator::unwrap(llvm::Init *init)
+ -> VariableDecorator {
+ return VariableDecorator(cast<llvm::DefInit>(init)->getDef());
+}
diff --git a/mlir/lib/TableGen/SideEffects.cpp b/mlir/lib/TableGen/SideEffects.cpp
new file mode 100644
index 000000000000..0b334b8297a0
--- /dev/null
+++ b/mlir/lib/TableGen/SideEffects.cpp
@@ -0,0 +1,51 @@
+//===- SideEffects.cpp - SideEffect classes -------------------------------===//
+//
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/TableGen/SideEffects.h"
+#include "llvm/TableGen/Record.h"
+
+using namespace mlir;
+using namespace mlir::tblgen;
+
+//===----------------------------------------------------------------------===//
+// SideEffect
+//===----------------------------------------------------------------------===//
+
+StringRef SideEffect::getName() const {
+ return def->getValueAsString("effect");
+}
+
+StringRef SideEffect::getBaseName() const {
+ return def->getValueAsString("baseEffect");
+}
+
+StringRef SideEffect::getInterfaceTrait() const {
+ return def->getValueAsString("interfaceTrait");
+}
+
+StringRef SideEffect::getResource() const {
+ auto value = def->getValueAsString("resource");
+ return value.empty() ? "::mlir::SideEffects::DefaultResource" : value;
+}
+
+bool SideEffect::classof(const Operator::VariableDecorator *var) {
+ return var->getDef().isSubClassOf("SideEffect");
+}
+
+//===----------------------------------------------------------------------===//
+// SideEffectsTrait
+//===----------------------------------------------------------------------===//
+
+Operator::var_decorator_range SideEffectTrait::getEffects() const {
+ auto *listInit = dyn_cast<llvm::ListInit>(def->getValueInit("effects"));
+ return {listInit->begin(), listInit->end()};
+}
+
+bool SideEffectTrait::classof(const OpTrait *t) {
+ return t->getDef().isSubClassOf("SideEffectsTraitBase");
+}
diff --git a/mlir/test/mlir-tblgen/op-side-effects.td b/mlir/test/mlir-tblgen/op-side-effects.td
new file mode 100644
index 000000000000..67679dd3b017
--- /dev/null
+++ b/mlir/test/mlir-tblgen/op-side-effects.td
@@ -0,0 +1,26 @@
+// RUN: mlir-tblgen -gen-op-defs -I %S/../../include %s | FileCheck %s
+
+include "mlir/IR/SideEffects.td"
+
+def TEST_Dialect : Dialect {
+ let name = "test";
+}
+class TEST_Op<string mnemonic, list<OpTrait> traits = []> :
+ Op<TEST_Dialect, mnemonic, traits>;
+
+def SideEffectOpA : TEST_Op<"side_effect_op_a"> {
+ let arguments = (ins Arg<Variadic<AnyMemRef>, "", [MemRead]>);
+ let results = (outs Res<AnyMemRef, "", [MemAlloc<"CustomResource">]>);
+}
+
+def SideEffectOpB : TEST_Op<"side_effect_op_b",
+ [MemoryEffects<[MemWrite<"CustomResource">]>]>;
+
+// CHECK: void SideEffectOpA::getEffects
+// CHECK: for (Value value : getODSOperands(0))
+// CHECK: effects.emplace_back(MemoryEffects::Read::get(), value, ::mlir::SideEffects::DefaultResource::get());
+// CHECK: for (Value value : getODSResults(0))
+// CHECK: effects.emplace_back(MemoryEffects::Allocate::get(), value, CustomResource::get());
+
+// CHECK: void SideEffectOpB::getEffects
+// CHECK: effects.emplace_back(MemoryEffects::Write::get(), CustomResource::get());
diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index 2b751e00c0d0..e5c6c560f1d1 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -20,6 +20,7 @@
#include "mlir/TableGen/OpInterfaces.h"
#include "mlir/TableGen/OpTrait.h"
#include "mlir/TableGen/Operator.h"
+#include "mlir/TableGen/SideEffects.h"
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/Signals.h"
@@ -280,6 +281,9 @@ class OpEmitter {
// Generate the OpInterface methods.
void genOpInterfaceMethods();
+ // Generate the side effect interface methods.
+ void genSideEffectInterfaceMethods();
+
private:
// The TableGen record for this op.
// TODO(antiagainst,zinenko): OpEmitter should not have a Record directly,
@@ -321,6 +325,7 @@ OpEmitter::OpEmitter(const Operator &op)
genFolderDecls();
genOpInterfaceMethods();
generateOpFormat(op, opClass);
+ genSideEffectInterfaceMethods();
}
void OpEmitter::emitDecl(const Operator &op, raw_ostream &os) {
@@ -1161,6 +1166,75 @@ void OpEmitter::genOpInterfaceMethods() {
}
}
+void OpEmitter::genSideEffectInterfaceMethods() {
+ enum EffectKind { Operand, Result, Static };
+ struct EffectLocation {
+ /// The effect applied.
+ SideEffect effect;
+
+ /// The index if the kind is either operand or result.
+ unsigned index : 30;
+
+ /// The kind of the location.
+ EffectKind kind : 2;
+ };
+
+ StringMap<SmallVector<EffectLocation, 1>> interfaceEffects;
+ auto resolveDecorators = [&](Operator::var_decorator_range decorators,
+ unsigned index, EffectKind kind) {
+ for (auto decorator : decorators)
+ if (SideEffect *effect = dyn_cast<SideEffect>(&decorator))
+ interfaceEffects[effect->getInterfaceTrait()].push_back(
+ EffectLocation{*effect, index, kind});
+ };
+
+ // Collect effects that were specified via:
+ /// Traits.
+ for (const auto &trait : op.getTraits())
+ if (const auto *opTrait = dyn_cast<tblgen::SideEffectTrait>(&trait))
+ resolveDecorators(opTrait->getEffects(), /*index=*/0, EffectKind::Static);
+ /// Operands.
+ for (unsigned i = 0, operandIt = 0, e = op.getNumArgs(); i != e; ++i) {
+ if (op.getArg(i).is<NamedTypeConstraint *>()) {
+ resolveDecorators(op.getArgDecorators(i), operandIt, EffectKind::Operand);
+ ++operandIt;
+ }
+ }
+ /// Results.
+ for (unsigned i = 0, e = op.getNumResults(); i != e; ++i)
+ resolveDecorators(op.getResultDecorators(i), i, EffectKind::Result);
+
+ for (auto &it : interfaceEffects) {
+ StringRef baseEffect = it.second.front().effect.getBaseName();
+ auto effectsParam =
+ llvm::formatv(
+ "SmallVectorImpl<SideEffects::EffectInstance<{0}>> &effects",
+ baseEffect)
+ .str();
+
+ // Generate the 'getEffects' method.
+ auto &getEffects = opClass.newMethod("void", "getEffects", effectsParam);
+ auto &body = getEffects.body();
+
+ // Add effect instances for each of the locations marked on the operation.
+ for (auto &location : it.second) {
+ if (location.kind != EffectKind::Static) {
+ body << " for (Value value : getODS"
+ << (location.kind == EffectKind::Operand ? "Operands" : "Results")
+ << "(" << location.index << "))\n ";
+ }
+
+ body << " effects.emplace_back(" << location.effect.getName()
+ << "::get()";
+
+ // If the effect isn't static, it has a specific value attached to it.
+ if (location.kind != EffectKind::Static)
+ body << ", value";
+ body << ", " << location.effect.getResource() << "::get());\n";
+ }
+ }
+}
+
void OpEmitter::genParser() {
if (!hasStringAttribute(def, "parser") ||
hasStringAttribute(def, "assemblyFormat"))
More information about the Mlir-commits
mailing list