[Mlir-commits] [mlir] c0958b7 - [mlir] Add support for referencing a SymbolRefAttr in a SideEffectInstance
River Riddle
llvmlistbot at llvm.org
Wed Nov 18 18:39:35 PST 2020
Author: River Riddle
Date: 2020-11-18T18:38:43-08:00
New Revision: c0958b7b4c6a31b0b89462c3ee770e486d4eb535
URL: https://github.com/llvm/llvm-project/commit/c0958b7b4c6a31b0b89462c3ee770e486d4eb535
DIFF: https://github.com/llvm/llvm-project/commit/c0958b7b4c6a31b0b89462c3ee770e486d4eb535.diff
LOG: [mlir] Add support for referencing a SymbolRefAttr in a SideEffectInstance
This allows for operations that exclusively affect symbol operations to better describe their side effects.
Differential Revision: https://reviews.llvm.org/D91581
Added:
Modified:
mlir/include/mlir/Interfaces/SideEffectInterfaceBase.td
mlir/include/mlir/Interfaces/SideEffectInterfaces.h
mlir/include/mlir/TableGen/Attribute.h
mlir/lib/TableGen/Attribute.cpp
mlir/test/IR/test-side-effects.mlir
mlir/test/lib/Dialect/Test/TestDialect.cpp
mlir/test/lib/IR/TestSideEffects.cpp
mlir/test/mlir-tblgen/op-side-effects.td
mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Interfaces/SideEffectInterfaceBase.td b/mlir/include/mlir/Interfaces/SideEffectInterfaceBase.td
index 41b07bcedc355..89318f7796f77 100644
--- a/mlir/include/mlir/Interfaces/SideEffectInterfaceBase.td
+++ b/mlir/include/mlir/Interfaces/SideEffectInterfaceBase.td
@@ -110,9 +110,20 @@ class EffectOpInterfaceBase<string name, string baseEffect>
llvm::erase_if(effects, [&](auto &it) { return it.getValue() != value; });
}
+ /// Collect all of the effect instances that operate on the provided symbol
+ /// reference and place them in 'effects'.
+ void getEffectsOnSymbol(::mlir::SymbolRefAttr value,
+ llvm::SmallVectorImpl<::mlir::SideEffects::EffectInstance<
+ }] # baseEffect # [{>> & effects) {
+ getEffects(effects);
+ llvm::erase_if(effects, [&](auto &it) {
+ return it.getSymbolRef() != value;
+ });
+ }
+
/// Collect all of the effect instances that operate on the provided
/// resource and place them in 'effects'.
- void getEffectsOnValue(::mlir::SideEffects::Resource *resource,
+ void getEffectsOnResource(::mlir::SideEffects::Resource *resource,
llvm::SmallVectorImpl<::mlir::SideEffects::EffectInstance<
}] # baseEffect # [{>> & effects) {
getEffects(effects);
diff --git a/mlir/include/mlir/Interfaces/SideEffectInterfaces.h b/mlir/include/mlir/Interfaces/SideEffectInterfaces.h
index c19f7f4f03ee4..33a6ba69050e3 100644
--- a/mlir/include/mlir/Interfaces/SideEffectInterfaces.h
+++ b/mlir/include/mlir/Interfaces/SideEffectInterfaces.h
@@ -131,9 +131,9 @@ struct AutomaticAllocationScopeResource
/// This class represents a specific instance of an effect. It contains the
/// effect being applied, a resource that corresponds to where the effect is
-/// applied, an optional value (either operand, result, or region entry
-/// argument) that the effect is applied to, and an optional parameters
-/// attribute further specifying the details of the effect.
+/// applied, and an optional symbol reference or value(either operand, result,
+/// or region entry argument) that the effect is applied to, and an optional
+/// parameters attribute further specifying the details of the effect.
template <typename EffectT> class EffectInstance {
public:
EffectInstance(EffectT *effect, Resource *resource = DefaultResource::get())
@@ -141,6 +141,9 @@ template <typename EffectT> class EffectInstance {
EffectInstance(EffectT *effect, Value value,
Resource *resource = DefaultResource::get())
: effect(effect), resource(resource), value(value) {}
+ EffectInstance(EffectT *effect, SymbolRefAttr symbol,
+ Resource *resource = DefaultResource::get())
+ : effect(effect), resource(resource), value(symbol) {}
EffectInstance(EffectT *effect, Attribute parameters,
Resource *resource = DefaultResource::get())
: effect(effect), resource(resource), parameters(parameters) {}
@@ -148,13 +151,23 @@ template <typename EffectT> class EffectInstance {
Resource *resource = DefaultResource::get())
: effect(effect), resource(resource), value(value),
parameters(parameters) {}
+ EffectInstance(EffectT *effect, SymbolRefAttr symbol, Attribute parameters,
+ Resource *resource = DefaultResource::get())
+ : effect(effect), resource(resource), value(symbol),
+ parameters(parameters) {}
/// Return the effect being applied.
EffectT *getEffect() const { return effect; }
/// Return the value the effect is applied on, or nullptr if there isn't a
/// known value being affected.
- Value getValue() const { return value; }
+ Value getValue() const { return value ? value.dyn_cast<Value>() : Value(); }
+
+ /// Return the symbol reference the effect is applied on, or nullptr if there
+ /// isn't a known smbol being affected.
+ SymbolRefAttr getSymbolRef() const {
+ return value ? value.dyn_cast<SymbolRefAttr>() : SymbolRefAttr();
+ }
/// Return the resource that the effect applies to.
Resource *getResource() const { return resource; }
@@ -169,8 +182,8 @@ template <typename EffectT> class EffectInstance {
/// The resource that the given value resides in.
Resource *resource;
- /// The value that the effect applies to. This is optionally null.
- Value value;
+ /// The Symbol or Value that the effect applies to. This is optionally null.
+ PointerUnion<SymbolRefAttr, Value> value;
/// Additional parameters of the effect instance. An attribute is used for
/// type-safe structured storage and context-based uniquing. Concrete effects
diff --git a/mlir/include/mlir/TableGen/Attribute.h b/mlir/include/mlir/TableGen/Attribute.h
index 4571ca8ee9b38..dc6c9692581c3 100644
--- a/mlir/include/mlir/TableGen/Attribute.h
+++ b/mlir/include/mlir/TableGen/Attribute.h
@@ -94,6 +94,10 @@ class Attribute : public AttrConstraint {
// of `TypeAttrBase`).
bool isTypeAttr() const;
+ // Returns true if this attribute is a symbol reference attribute (i.e., a
+ // subclass of `SymbolRefAttr` or `FlatSymbolRefAttr`).
+ bool isSymbolRefAttr() const;
+
// Returns true if this attribute is an enum attribute (i.e., a subclass of
// `EnumAttrInfo`)
bool isEnumAttr() const;
diff --git a/mlir/lib/TableGen/Attribute.cpp b/mlir/lib/TableGen/Attribute.cpp
index f34d9c00b4388..3377ec98c2291 100644
--- a/mlir/lib/TableGen/Attribute.cpp
+++ b/mlir/lib/TableGen/Attribute.cpp
@@ -55,6 +55,13 @@ bool Attribute::isDerivedAttr() const { return isSubClassOf("DerivedAttr"); }
bool Attribute::isTypeAttr() const { return isSubClassOf("TypeAttrBase"); }
+bool Attribute::isSymbolRefAttr() const {
+ StringRef defName = def->getName();
+ if (defName == "SymbolRefAttr" || defName == "FlatSymbolRefAttr")
+ return true;
+ return isSubClassOf("SymbolRefAttr") || isSubClassOf("FlatSymbolRefAttr");
+}
+
bool Attribute::isEnumAttr() const { return isSubClassOf("EnumAttrInfo"); }
StringRef Attribute::getStorageType() const {
diff --git a/mlir/test/IR/test-side-effects.mlir b/mlir/test/IR/test-side-effects.mlir
index ca2e32c9a768e..db55414da03be 100644
--- a/mlir/test/IR/test-side-effects.mlir
+++ b/mlir/test/IR/test-side-effects.mlir
@@ -19,6 +19,11 @@
{effect="allocate", on_result, test_resource}
]} : () -> i32
+// expected-remark at +1 {{found an instance of 'read' on a symbol '@foo_ref', on resource '<Test>'}}
+"test.side_effect_op"() {effects = [
+ {effect="read", on_reference = @foo_ref, test_resource}
+]} : () -> i32
+
// No _memory_ effects, but a parametric test effect.
// expected-remark at +2 {{operation has no memory effects}}
// expected-remark at +1 {{found a parametric effect with affine_map<(d0, d1) -> (d1, d0)>}}
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index e71fceb9fa3da..e815adeece469 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -744,17 +744,18 @@ void SideEffectOp::getEffects(
.Case("read", MemoryEffects::Read::get())
.Case("write", MemoryEffects::Write::get());
- // Check for a result to affect.
- Value value;
- if (effectElement.get("on_result"))
- value = getResult();
-
// Check for a non-default resource to use.
SideEffects::Resource *resource = SideEffects::DefaultResource::get();
if (effectElement.get("test_resource"))
resource = TestResource::get();
- effects.emplace_back(effect, value, resource);
+ // Check for a result to affect.
+ if (effectElement.get("on_result"))
+ effects.emplace_back(effect, getResult(), resource);
+ else if (Attribute ref = effectElement.get("on_reference"))
+ effects.emplace_back(effect, ref.cast<SymbolRefAttr>(), resource);
+ else
+ effects.emplace_back(effect, resource);
}
}
diff --git a/mlir/test/lib/IR/TestSideEffects.cpp b/mlir/test/lib/IR/TestSideEffects.cpp
index d9d6aed592157..114c7f2430b13 100644
--- a/mlir/test/lib/IR/TestSideEffects.cpp
+++ b/mlir/test/lib/IR/TestSideEffects.cpp
@@ -43,6 +43,8 @@ struct SideEffectsPass
if (instance.getValue())
diag << " on a value,";
+ else if (SymbolRefAttr symbolRef = instance.getSymbolRef())
+ diag << " on a symbol '" << symbolRef << "',";
diag << " on resource '" << instance.getResource()->getName() << "'";
}
diff --git a/mlir/test/mlir-tblgen/op-side-effects.td b/mlir/test/mlir-tblgen/op-side-effects.td
index 6bae35aa763b0..9e97e904c7447 100644
--- a/mlir/test/mlir-tblgen/op-side-effects.td
+++ b/mlir/test/mlir-tblgen/op-side-effects.td
@@ -11,7 +11,12 @@ class TEST_Op<string mnemonic, list<OpTrait> traits = []> :
def CustomResource : Resource<"CustomResource">;
def SideEffectOpA : TEST_Op<"side_effect_op_a"> {
- let arguments = (ins Arg<Variadic<AnyMemRef>, "", [MemRead]>);
+ let arguments = (ins
+ Arg<Variadic<AnyMemRef>, "", [MemRead]>,
+ Arg<SymbolRefAttr, "", [MemRead]>:$symbol,
+ Arg<FlatSymbolRefAttr, "", [MemWrite]>:$flat_symbol,
+ Arg<OptionalAttr<SymbolRefAttr>, "", [MemRead]>:$optional_symbol
+ );
let results = (outs Res<AnyMemRef, "", [MemAlloc<CustomResource>]>);
}
@@ -21,6 +26,10 @@ def SideEffectOpB : TEST_Op<"side_effect_op_b",
// CHECK: void SideEffectOpA::getEffects
// CHECK: for (::mlir::Value value : getODSOperands(0))
// CHECK: effects.emplace_back(MemoryEffects::Read::get(), value, ::mlir::SideEffects::DefaultResource::get());
+// CHECK: effects.emplace_back(MemoryEffects::Read::get(), symbol(), ::mlir::SideEffects::DefaultResource::get());
+// CHECK: effects.emplace_back(MemoryEffects::Write::get(), flat_symbol(), ::mlir::SideEffects::DefaultResource::get());
+// CHECK: if (auto symbolRef = optional_symbolAttr())
+// CHECK: effects.emplace_back(MemoryEffects::Read::get(), symbolRef, ::mlir::SideEffects::DefaultResource::get());
// CHECK: for (::mlir::Value value : getODSResults(0))
// CHECK: effects.emplace_back(MemoryEffects::Allocate::get(), value, CustomResource::get());
diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index 737c36f07e464..65ae32f4f5a42 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -1627,12 +1627,12 @@ void OpEmitter::genOpInterfaceMethods() {
}
void OpEmitter::genSideEffectInterfaceMethods() {
- enum EffectKind { Operand, Result, Static };
+ enum EffectKind { Operand, Result, Symbol, Static };
struct EffectLocation {
/// The effect applied.
SideEffect effect;
- /// The index if the kind is either operand or result.
+ /// The index if the kind is not static.
unsigned index : 30;
/// The kind of the location.
@@ -1661,17 +1661,29 @@ void OpEmitter::genSideEffectInterfaceMethods() {
effects.push_back(EffectLocation{cast<SideEffect>(decorator),
/*index=*/0, EffectKind::Static});
}
- /// Operands.
+ /// Attributes and Operands.
for (unsigned i = 0, operandIt = 0, e = op.getNumArgs(); i != e; ++i) {
- if (op.getArg(i).is<NamedTypeConstraint *>()) {
+ Argument arg = op.getArg(i);
+ if (arg.is<NamedTypeConstraint *>()) {
resolveDecorators(op.getArgDecorators(i), operandIt, EffectKind::Operand);
++operandIt;
+ continue;
}
+ const NamedAttribute *attr = arg.get<NamedAttribute *>();
+ if (attr->attr.getBaseAttr().isSymbolRefAttr())
+ resolveDecorators(op.getArgDecorators(i), i, EffectKind::Symbol);
}
/// Results.
for (unsigned i = 0, e = op.getNumResults(); i != e; ++i)
resolveDecorators(op.getResultDecorators(i), i, EffectKind::Result);
+ // The code used to add an effect instance.
+ // {0}: The effect class.
+ // {1}: Optional value or symbol reference.
+ // {1}: The resource class.
+ const char *addEffectCode =
+ " effects.emplace_back({0}::get(), {1}{2}::get());\n";
+
for (auto &it : interfaceEffects) {
// Generate the 'getEffects' method.
std::string type = llvm::formatv("::mlir::SmallVectorImpl<::mlir::"
@@ -1684,19 +1696,30 @@ void OpEmitter::genSideEffectInterfaceMethods() {
// Add effect instances for each of the locations marked on the operation.
for (auto &location : it.second) {
- if (location.kind != EffectKind::Static) {
+ StringRef effect = location.effect.getName();
+ StringRef resource = location.effect.getResource();
+ if (location.kind == EffectKind::Static) {
+ // A static instance has no attached value.
+ body << llvm::formatv(addEffectCode, effect, "", resource).str();
+ } else if (location.kind == EffectKind::Symbol) {
+ // A symbol reference requires adding the proper attribute.
+ const auto *attr = op.getArg(location.index).get<NamedAttribute *>();
+ if (attr->attr.isOptional()) {
+ body << " if (auto symbolRef = " << attr->name << "Attr())\n "
+ << llvm::formatv(addEffectCode, effect, "symbolRef, ", resource)
+ .str();
+ } else {
+ body << llvm::formatv(addEffectCode, effect, attr->name + "(), ",
+ resource)
+ .str();
+ }
+ } else {
+ // Otherwise this is an operand/result, so we need to attach the Value.
body << " for (::mlir::Value value : getODS"
<< (location.kind == EffectKind::Operand ? "Operands" : "Results")
- << "(" << location.index << "))\n ";
+ << "(" << location.index << "))\n "
+ << llvm::formatv(addEffectCode, effect, "value, ", resource).str();
}
-
- body << " effects.emplace_back(" << location.effect.getName()
- << "::get()";
-
- // If the effect isn't static, it has a specific value attached to it.
- if (location.kind != EffectKind::Static)
- body << ", value";
- body << ", " << location.effect.getResource() << "::get());\n";
}
}
}
More information about the Mlir-commits
mailing list