[llvm] 706c9c5 - [mlir] Add support for walking locations similarly to Operations

River Riddle via llvm-commits llvm-commits at lists.llvm.org
Thu Apr 15 16:16:21 PDT 2021


Author: River Riddle
Date: 2021-04-15T16:09:34-07:00
New Revision: 706c9c5ce0382644d4e693741f5d885be7c20e46

URL: https://github.com/llvm/llvm-project/commit/706c9c5ce0382644d4e693741f5d885be7c20e46
DIFF: https://github.com/llvm/llvm-project/commit/706c9c5ce0382644d4e693741f5d885be7c20e46.diff

LOG: [mlir] Add support for walking locations similarly to Operations

This allows for walking all nested locations of a given location, and is generally useful when processing locations.

Differential Revision: https://reviews.llvm.org/D100437

Added: 
    

Modified: 
    llvm/include/llvm/ADT/TypeSwitch.h
    llvm/unittests/ADT/TypeSwitchTest.cpp
    mlir/include/mlir/IR/Location.h
    mlir/lib/IR/Diagnostics.cpp
    mlir/lib/IR/Location.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/ADT/TypeSwitch.h b/llvm/include/llvm/ADT/TypeSwitch.h
index bfcb2064301d..815b9a40afaf 100644
--- a/llvm/include/llvm/ADT/TypeSwitch.h
+++ b/llvm/include/llvm/ADT/TypeSwitch.h
@@ -124,6 +124,12 @@ class TypeSwitch : public detail::TypeSwitchBase<TypeSwitch<T, ResultT>, T> {
       return std::move(*result);
     return defaultFn(this->value);
   }
+  /// As a default, return the given value.
+  LLVM_NODISCARD ResultT Default(ResultT defaultResult) {
+    if (result)
+      return std::move(*result);
+    return defaultResult;
+  }
 
   LLVM_NODISCARD
   operator ResultT() {

diff  --git a/llvm/unittests/ADT/TypeSwitchTest.cpp b/llvm/unittests/ADT/TypeSwitchTest.cpp
index fde423d011be..442ac1910a05 100644
--- a/llvm/unittests/ADT/TypeSwitchTest.cpp
+++ b/llvm/unittests/ADT/TypeSwitchTest.cpp
@@ -47,7 +47,7 @@ TEST(TypeSwitchTest, CasesResult) {
     return TypeSwitch<Base *, int>(&value)
         .Case<DerivedA, DerivedB, DerivedD>([](auto *) { return 0; })
         .Case([](DerivedC *) { return 1; })
-        .Default([](Base *) { return -1; });
+        .Default(-1);
   };
   EXPECT_EQ(0, translate(DerivedA()));
   EXPECT_EQ(0, translate(DerivedB()));

diff  --git a/mlir/include/mlir/IR/Location.h b/mlir/include/mlir/IR/Location.h
index f7c1da622215..b625a4f0ed0e 100644
--- a/mlir/include/mlir/IR/Location.h
+++ b/mlir/include/mlir/IR/Location.h
@@ -20,6 +20,8 @@
 namespace mlir {
 
 class Identifier;
+class Location;
+class WalkResult;
 
 //===----------------------------------------------------------------------===//
 // LocationAttr
@@ -31,6 +33,9 @@ class LocationAttr : public Attribute {
 public:
   using Attribute::Attribute;
 
+  /// Walk all of the locations nested under, and including, the current.
+  WalkResult walk(function_ref<WalkResult(Location)> walkFn);
+
   /// Methods for support type inquiry through isa, cast, and dyn_cast.
   static bool classof(Attribute attr);
 };

diff  --git a/mlir/lib/IR/Diagnostics.cpp b/mlir/lib/IR/Diagnostics.cpp
index 15dbac1f6cb8..8b3c485573b7 100644
--- a/mlir/lib/IR/Diagnostics.cpp
+++ b/mlir/lib/IR/Diagnostics.cpp
@@ -366,21 +366,15 @@ struct SourceMgrDiagnosticHandlerImpl {
 
 /// Return a processable FileLineColLoc from the given location.
 static Optional<FileLineColLoc> getFileLineColLoc(Location loc) {
-  if (auto nameLoc = loc.dyn_cast<NameLoc>())
-    return getFileLineColLoc(loc.cast<NameLoc>().getChildLoc());
-  if (auto fileLoc = loc.dyn_cast<FileLineColLoc>())
-    return fileLoc;
-  if (auto callLoc = loc.dyn_cast<CallSiteLoc>())
-    return getFileLineColLoc(loc.cast<CallSiteLoc>().getCallee());
-  if (auto fusedLoc = loc.dyn_cast<FusedLoc>()) {
-    for (auto subLoc : loc.cast<FusedLoc>().getLocations()) {
-      if (auto callLoc = getFileLineColLoc(subLoc)) {
-        return callLoc;
-      }
+  Optional<FileLineColLoc> firstFileLoc;
+  loc->walk([&](Location loc) {
+    if (FileLineColLoc fileLoc = loc.dyn_cast<FileLineColLoc>()) {
+      firstFileLoc = fileLoc;
+      return WalkResult::interrupt();
     }
-    return llvm::None;
-  }
-  return llvm::None;
+    return WalkResult::advance();
+  });
+  return firstFileLoc;
 }
 
 /// Return a processable CallSiteLoc from the given location.

diff  --git a/mlir/lib/IR/Location.cpp b/mlir/lib/IR/Location.cpp
index cf730199e693..6a39efa0cc39 100644
--- a/mlir/lib/IR/Location.cpp
+++ b/mlir/lib/IR/Location.cpp
@@ -9,7 +9,9 @@
 #include "mlir/IR/Location.h"
 #include "mlir/IR/BuiltinDialect.h"
 #include "mlir/IR/Identifier.h"
+#include "mlir/IR/Visitors.h"
 #include "llvm/ADT/SetVector.h"
+#include "llvm/ADT/TypeSwitch.h"
 
 using namespace mlir;
 using namespace mlir::detail;
@@ -36,6 +38,31 @@ void BuiltinDialect::registerLocationAttributes() {
 // LocationAttr
 //===----------------------------------------------------------------------===//
 
+WalkResult LocationAttr::walk(function_ref<WalkResult(Location)> walkFn) {
+  if (walkFn(*this).wasInterrupted())
+    return WalkResult::interrupt();
+
+  return TypeSwitch<LocationAttr, WalkResult>(*this)
+      .Case([&](CallSiteLoc callLoc) -> WalkResult {
+        if (callLoc.getCallee()->walk(walkFn).wasInterrupted())
+          return WalkResult::interrupt();
+        return callLoc.getCaller()->walk(walkFn);
+      })
+      .Case([&](FusedLoc fusedLoc) -> WalkResult {
+        for (Location subLoc : fusedLoc.getLocations())
+          if (subLoc->walk(walkFn).wasInterrupted())
+            return WalkResult::interrupt();
+        return WalkResult::advance();
+      })
+      .Case([&](NameLoc nameLoc) -> WalkResult {
+        return nameLoc.getChildLoc()->walk(walkFn);
+      })
+      .Case([&](OpaqueLoc opaqueLoc) -> WalkResult {
+        return opaqueLoc.getFallbackLocation()->walk(walkFn);
+      })
+      .Default(WalkResult::advance());
+}
+
 /// Methods for support type inquiry through isa, cast, and dyn_cast.
 bool LocationAttr::classof(Attribute attr) {
   return attr.isa<CallSiteLoc, FileLineColLoc, FusedLoc, NameLoc, OpaqueLoc,


        


More information about the llvm-commits mailing list