[Mlir-commits] [mlir] [MLIR][DLTI] Make queries visit all ancestors/respect nested scopes (PR #115043)

Mehdi Amini llvmlistbot at llvm.org
Tue Nov 5 15:48:10 PST 2024


================
@@ -489,77 +485,77 @@ void TargetSystemSpecAttr::print(AsmPrinter &printer) const {
 // DLTIDialect
 //===----------------------------------------------------------------------===//
 
-/// Retrieve the first `DLTIQueryInterface`-implementing attribute that is
-/// attached to `op` or such an attr on as close as possible an ancestor. The
-/// op the attribute is attached to is returned as well.
-static std::pair<DLTIQueryInterface, Operation *>
-getClosestQueryable(Operation *op) {
-  DLTIQueryInterface queryable = {};
-
-  // Search op and its ancestors for the first attached DLTIQueryInterface attr.
-  do {
-    for (NamedAttribute attr : op->getAttrs())
-      if ((queryable = dyn_cast<DLTIQueryInterface>(attr.getValue())))
-        break;
-  } while (!queryable && (op = op->getParentOp()));
-
-  return std::pair(queryable, op);
-}
-
 FailureOr<Attribute>
 dlti::query(Operation *op, ArrayRef<DataLayoutEntryKey> keys, bool emitError) {
+  InFlightDiagnostic diag = op->emitError() << "target op of failed DLTI query";
+
   if (keys.empty()) {
-    if (emitError) {
-      auto diag = op->emitError() << "target op of failed DLTI query";
+    if (emitError)
       diag.attachNote(op->getLoc()) << "no keys provided to attempt query with";
-    }
+    else
+      diag.abandon();
     return failure();
   }
 
-  auto [queryable, queryOp] = getClosestQueryable(op);
-  Operation *reportOp = (queryOp ? queryOp : op);
-
-  if (!queryable) {
-    if (emitError) {
-      auto diag = op->emitError() << "target op of failed DLTI query";
-      diag.attachNote(reportOp->getLoc())
-          << "no DLTI-queryable attrs on target op or any of its ancestors";
-    }
-    return failure();
-  }
-
-  Attribute currentAttr = queryable;
-  for (auto &&[idx, key] : llvm::enumerate(keys)) {
-    if (auto map = dyn_cast<DLTIQueryInterface>(currentAttr)) {
-      auto maybeAttr = map.query(key);
-      if (failed(maybeAttr)) {
-        if (emitError) {
-          auto diag = op->emitError() << "target op of failed DLTI query";
-          diag.attachNote(reportOp->getLoc())
-              << "key " << keyToStr(key)
-              << " has no DLTI-mapping per attr: " << map;
+  auto interleaveComma = [](ArrayRef<DataLayoutEntryKey> keys) {
+    std::string buf;
+    llvm::interleave(
+        keys, [&](auto key) { buf += keyToStr(key); }, [&]() { buf += ","; });
+    return buf;
+  };
+
+  // Recursively replace `currentAttr` by the attribute obtained by querying a
+  // new key on each new `currentAttr` until all `keys` have been exhausted -
+  // `atOp` is only used for error reporting.
+  auto queryKeysOnAttribute = [&](Attribute currentAttr,
+                                  Operation *atOp) -> FailureOr<Attribute> {
+    for (auto &&[idx, key] : llvm::enumerate(keys)) {
+      if (auto map = dyn_cast<DLTIQueryInterface>(currentAttr)) {
+        auto maybeAttr = map.query(key);
+        if (failed(maybeAttr)) {
+          if (emitError)
+            diag.attachNote(atOp->getLoc())
+                << "key not present - failed at keys: ["
+                << interleaveComma(keys.take_front(idx + 1)) << "]";
+          return failure();
         }
+        currentAttr = *maybeAttr;
+      } else {
+        // The previous key, if any, is responsible for the current currentAttr.
+        if (idx > 0 && emitError)
+          diag.attachNote(atOp->getLoc())
+              << "attribute at keys [" << interleaveComma(keys.take_front(idx))
+              << "] is not queryable";
         return failure();
       }
-      currentAttr = *maybeAttr;
-    } else {
-      if (emitError) {
-        std::string commaSeparatedKeys;
-        llvm::interleave(
-            keys.take_front(idx), // All prior keys.
-            [&](auto key) { commaSeparatedKeys += keyToStr(key); },
-            [&]() { commaSeparatedKeys += ","; });
-
-        auto diag = op->emitError() << "target op of failed DLTI query";
-        diag.attachNote(reportOp->getLoc())
-            << "got non-DLTI-queryable attribute upon looking up keys ["
-            << commaSeparatedKeys << "] at op";
-      }
-      return failure();
     }
+    return currentAttr;
+  };
+
+  // Run over all ancestors of `op`, starting the recursive attribute query for
+  // each ancestor which has an attribute on which we can perform a query.
+  for (Operation *ancestor = op; ancestor; ancestor = ancestor->getParentOp()) {
+    DLTIQueryInterface queryableAttr;
+    // NB: only the op's first DLTI attr will be inspected
+    for (NamedAttribute attr : ancestor->getAttrs())
+      if (auto queryableAttr = dyn_cast<DLTIQueryInterface>(attr.getValue())) {
+        auto maybeAttr = queryKeysOnAttribute(queryableAttr, ancestor);
+        if (succeeded(maybeAttr)) {
+          diag.abandon();
+          return maybeAttr;
+        }
+      }
+  }
+
+  if (emitError) {
+    if (diag.getUnderlyingDiagnostic()->getNotes().empty())
+      diag.attachNote(op->getLoc())
+          << "no DLTI-queryable attrs on target op or any of its ancestors";
+  } else {
+    diag.abandon();
   }
 
-  return currentAttr;
+  return failure();
 }
----------------
joker-eph wrote:

The overall idea seems appropriate to me.

In terms of implementation, this kind of traversal can become costly (it is a source of quadratic behavior in the compiler). Can we look into caching here by implementing a similar side-datastructure? (like the SymbolTable for example)

https://github.com/llvm/llvm-project/pull/115043


More information about the Mlir-commits mailing list