[Mlir-commits] [mlir] [MLIR][DLTI] Make queries visit all ancestors/respect nested scopes (PR #115043)
Mehdi Amini
llvmlistbot at llvm.org
Tue Nov 5 15:48:10 PST 2024
================
@@ -489,77 +485,77 @@ void TargetSystemSpecAttr::print(AsmPrinter &printer) const {
// DLTIDialect
//===----------------------------------------------------------------------===//
-/// Retrieve the first `DLTIQueryInterface`-implementing attribute that is
-/// attached to `op` or such an attr on as close as possible an ancestor. The
-/// op the attribute is attached to is returned as well.
-static std::pair<DLTIQueryInterface, Operation *>
-getClosestQueryable(Operation *op) {
- DLTIQueryInterface queryable = {};
-
- // Search op and its ancestors for the first attached DLTIQueryInterface attr.
- do {
- for (NamedAttribute attr : op->getAttrs())
- if ((queryable = dyn_cast<DLTIQueryInterface>(attr.getValue())))
- break;
- } while (!queryable && (op = op->getParentOp()));
-
- return std::pair(queryable, op);
-}
-
FailureOr<Attribute>
dlti::query(Operation *op, ArrayRef<DataLayoutEntryKey> keys, bool emitError) {
+ InFlightDiagnostic diag = op->emitError() << "target op of failed DLTI query";
+
if (keys.empty()) {
- if (emitError) {
- auto diag = op->emitError() << "target op of failed DLTI query";
+ if (emitError)
diag.attachNote(op->getLoc()) << "no keys provided to attempt query with";
- }
+ else
+ diag.abandon();
return failure();
}
- auto [queryable, queryOp] = getClosestQueryable(op);
- Operation *reportOp = (queryOp ? queryOp : op);
-
- if (!queryable) {
- if (emitError) {
- auto diag = op->emitError() << "target op of failed DLTI query";
- diag.attachNote(reportOp->getLoc())
- << "no DLTI-queryable attrs on target op or any of its ancestors";
- }
- return failure();
- }
-
- Attribute currentAttr = queryable;
- for (auto &&[idx, key] : llvm::enumerate(keys)) {
- if (auto map = dyn_cast<DLTIQueryInterface>(currentAttr)) {
- auto maybeAttr = map.query(key);
- if (failed(maybeAttr)) {
- if (emitError) {
- auto diag = op->emitError() << "target op of failed DLTI query";
- diag.attachNote(reportOp->getLoc())
- << "key " << keyToStr(key)
- << " has no DLTI-mapping per attr: " << map;
+ auto interleaveComma = [](ArrayRef<DataLayoutEntryKey> keys) {
+ std::string buf;
+ llvm::interleave(
+ keys, [&](auto key) { buf += keyToStr(key); }, [&]() { buf += ","; });
+ return buf;
+ };
+
+ // Recursively replace `currentAttr` by the attribute obtained by querying a
+ // new key on each new `currentAttr` until all `keys` have been exhausted -
+ // `atOp` is only used for error reporting.
+ auto queryKeysOnAttribute = [&](Attribute currentAttr,
+ Operation *atOp) -> FailureOr<Attribute> {
+ for (auto &&[idx, key] : llvm::enumerate(keys)) {
+ if (auto map = dyn_cast<DLTIQueryInterface>(currentAttr)) {
+ auto maybeAttr = map.query(key);
+ if (failed(maybeAttr)) {
+ if (emitError)
+ diag.attachNote(atOp->getLoc())
+ << "key not present - failed at keys: ["
+ << interleaveComma(keys.take_front(idx + 1)) << "]";
+ return failure();
}
+ currentAttr = *maybeAttr;
+ } else {
+ // The previous key, if any, is responsible for the current currentAttr.
+ if (idx > 0 && emitError)
+ diag.attachNote(atOp->getLoc())
+ << "attribute at keys [" << interleaveComma(keys.take_front(idx))
+ << "] is not queryable";
return failure();
}
- currentAttr = *maybeAttr;
- } else {
- if (emitError) {
- std::string commaSeparatedKeys;
- llvm::interleave(
- keys.take_front(idx), // All prior keys.
- [&](auto key) { commaSeparatedKeys += keyToStr(key); },
- [&]() { commaSeparatedKeys += ","; });
-
- auto diag = op->emitError() << "target op of failed DLTI query";
- diag.attachNote(reportOp->getLoc())
- << "got non-DLTI-queryable attribute upon looking up keys ["
- << commaSeparatedKeys << "] at op";
- }
- return failure();
}
+ return currentAttr;
+ };
+
+ // Run over all ancestors of `op`, starting the recursive attribute query for
+ // each ancestor which has an attribute on which we can perform a query.
+ for (Operation *ancestor = op; ancestor; ancestor = ancestor->getParentOp()) {
+ DLTIQueryInterface queryableAttr;
+ // NB: only the op's first DLTI attr will be inspected
+ for (NamedAttribute attr : ancestor->getAttrs())
+ if (auto queryableAttr = dyn_cast<DLTIQueryInterface>(attr.getValue())) {
+ auto maybeAttr = queryKeysOnAttribute(queryableAttr, ancestor);
+ if (succeeded(maybeAttr)) {
+ diag.abandon();
+ return maybeAttr;
+ }
+ }
+ }
+
+ if (emitError) {
+ if (diag.getUnderlyingDiagnostic()->getNotes().empty())
+ diag.attachNote(op->getLoc())
+ << "no DLTI-queryable attrs on target op or any of its ancestors";
+ } else {
+ diag.abandon();
}
- return currentAttr;
+ return failure();
}
----------------
joker-eph wrote:
The overall idea seems appropriate to me.
In terms of implementation, this kind of traversal can become costly (it is a source of quadratic behavior in the compiler). Can we look into caching here by implementing a similar side-datastructure? (like the SymbolTable for example)
https://github.com/llvm/llvm-project/pull/115043
More information about the Mlir-commits
mailing list