[Mlir-commits] [mlir] 00a52c7 - [mlir:SubElementsInterface] Add support for "skipping" when replacing attributes/types

River Riddle llvmlistbot at llvm.org
Thu Jul 28 10:52:39 PDT 2022


Author: River Riddle
Date: 2022-07-28T10:52:12-07:00
New Revision: 00a52c75655bb352f875729a93c3f2ae990e5b78

URL: https://github.com/llvm/llvm-project/commit/00a52c75655bb352f875729a93c3f2ae990e5b78
DIFF: https://github.com/llvm/llvm-project/commit/00a52c75655bb352f875729a93c3f2ae990e5b78.diff

LOG: [mlir:SubElementsInterface] Add support for "skipping" when replacing attributes/types

This is used to fix a bug in SymbolTable::replaceAllSymbolUses where we replace symbols that
we shouldn't.

Differential Revision: https://reviews.llvm.org/D130693

Added: 
    

Modified: 
    mlir/include/mlir/IR/SubElementInterfaces.h
    mlir/include/mlir/IR/SubElementInterfaces.td
    mlir/lib/IR/SubElementInterfaces.cpp
    mlir/lib/IR/SymbolTable.cpp
    mlir/test/IR/test-symbol-rauw.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/SubElementInterfaces.h b/mlir/include/mlir/IR/SubElementInterfaces.h
index d0fc0ba921708..2c40e4edfa0fa 100644
--- a/mlir/include/mlir/IR/SubElementInterfaces.h
+++ b/mlir/include/mlir/IR/SubElementInterfaces.h
@@ -16,6 +16,14 @@
 
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/Types.h"
+#include "mlir/IR/Visitors.h"
+
+namespace mlir {
+template <typename T>
+using SubElementReplFn = function_ref<T(T)>;
+template <typename T>
+using SubElementResultReplFn = function_ref<std::pair<T, WalkResult>(T)>;
+} // namespace mlir
 
 /// Include the definitions of the sub elemnt interfaces.
 #include "mlir/IR/SubElementAttrInterfaces.h.inc"

diff  --git a/mlir/include/mlir/IR/SubElementInterfaces.td b/mlir/include/mlir/IR/SubElementInterfaces.td
index f1aede639cf09..e857beb379de3 100644
--- a/mlir/include/mlir/IR/SubElementInterfaces.td
+++ b/mlir/include/mlir/IR/SubElementInterfaces.td
@@ -56,8 +56,22 @@ class SubElementInterfaceBase<string interfaceName, string attrOrType,
     /// Recursively replace all of the nested sub-attributes and sub-types using the
     /// provided map functions. Returns nullptr in the case of failure.
     }] # attrOrType # [{ replaceSubElements(
-      llvm::function_ref<mlir::Attribute(mlir::Attribute)> replaceAttrFn,
-      llvm::function_ref<mlir::Type(mlir::Type)> replaceTypeFn
+      mlir::SubElementReplFn<mlir::Attribute> replaceAttrFn,
+      mlir::SubElementReplFn<mlir::Type> replaceTypeFn
+    ) {
+      return replaceSubElements(
+        [&](Attribute attr) { return std::make_pair(replaceAttrFn(attr), WalkResult::advance()); },
+        [&](Type type) { return std::make_pair(replaceTypeFn(type), WalkResult::advance()); }
+      );
+    }
+    /// Recursively replace all of the nested sub-attributes and sub-types using the
+    /// provided map functions. This variant allows for the map function to return an
+    /// additional walk result. Returns nullptr in the case of failure.
+    }] # attrOrType # [{ replaceSubElements(
+      llvm::function_ref<
+        std::pair<mlir::Attribute, mlir::WalkResult>(mlir::Attribute)> replaceAttrFn,
+      llvm::function_ref<
+        std::pair<mlir::Type, mlir::WalkResult>(mlir::Type)> replaceTypeFn
     );
   }];
   code extraTraitClassDeclaration = [{
@@ -71,18 +85,16 @@ class SubElementInterfaceBase<string interfaceName, string attrOrType,
     /// Recursively replace all of the nested sub-attributes and sub-types using the
     /// provided map functions. Returns nullptr in the case of failure.
     }] # attrOrType # [{ replaceSubElements(
-      llvm::function_ref<mlir::Attribute(mlir::Attribute)> replaceAttrFn,
-      llvm::function_ref<mlir::Type(mlir::Type)> replaceTypeFn) {
+      mlir::SubElementReplFn<mlir::Attribute> replaceAttrFn,
+      mlir::SubElementReplFn<mlir::Type> replaceTypeFn) {
       }] # interfaceName # " interface(" # derivedValue # [{);
       return interface.replaceSubElements(replaceAttrFn, replaceTypeFn);
     }
-
-    /// Recursively replace all of the nested sub-attributes and sub-types using the
-    /// provided map functions. Returns nullptr in the case of failure.
-    }] # attrOrType # [{ replaceImmediateSubElements(
-      llvm::ArrayRef<mlir::Attribute> replAttrs,
-      llvm::function_ref<mlir::Type(mlir::Type)> replTypes) {
-      return nullptr;
+    }] # attrOrType # [{ replaceSubElements(
+      mlir::SubElementResultReplFn<mlir::Attribute> replaceAttrFn,
+      mlir::SubElementResultReplFn<mlir::Type> replaceTypeFn) {
+      }] # interfaceName # " interface(" # derivedValue # [{);
+      return interface.replaceSubElements(replaceAttrFn, replaceTypeFn);
     }
   }];
   code extraSharedClassDeclaration = [{
@@ -98,17 +110,31 @@ class SubElementInterfaceBase<string interfaceName, string attrOrType,
     /// Recursively replace all of the nested sub-attributes using the provided
     /// map function. Returns nullptr in the case of failure.
     }] # attrOrType # [{ replaceSubElements(
-      llvm::function_ref<mlir::Attribute(mlir::Attribute)> replaceAttrFn) {
+      mlir::SubElementReplFn<mlir::Attribute> replaceAttrFn) {
       return replaceSubElements(
         replaceAttrFn, [](mlir::Type type) { return type; });
     }
+    }] # attrOrType # [{ replaceSubElements(
+      mlir::SubElementResultReplFn<mlir::Attribute> replaceAttrFn) {
+      return replaceSubElements(
+        replaceAttrFn,
+        [](mlir::Type type) { return std::make_pair(type, WalkResult::advance()); }
+      );
+    }
     /// Recursively replace all of the nested sub-types using the provided map
     /// function. Returns nullptr in the case of failure.
     }] # attrOrType # [{ replaceSubElements(
-      llvm::function_ref<mlir::Type(mlir::Type)> replaceTypeFn) {
+      mlir::SubElementReplFn<mlir::Type> replaceTypeFn) {
       return replaceSubElements(
         [](mlir::Attribute attr) { return attr; }, replaceTypeFn);
     }
+    }] # attrOrType # [{ replaceSubElements(
+      mlir::SubElementResultReplFn<mlir::Type> replaceTypeFn) {
+      return replaceSubElements(
+        [](mlir::Attribute attr) { return std::make_pair(attr, WalkResult::advance()); },
+        replaceTypeFn
+      );
+    }
   }];
 }
 

diff  --git a/mlir/lib/IR/SubElementInterfaces.cpp b/mlir/lib/IR/SubElementInterfaces.cpp
index f8d47083f11c4..f8526dc6d3869 100644
--- a/mlir/lib/IR/SubElementInterfaces.cpp
+++ b/mlir/lib/IR/SubElementInterfaces.cpp
@@ -102,11 +102,10 @@ static bool isMutable(Type type) {
 }
 
 template <typename InterfaceT, typename T, typename ReplaceSubElementFnT>
-static void updateSubElementImpl(T element, function_ref<T(T)> walkFn,
-                                 DenseMap<T, T> &visited,
-                                 SmallVectorImpl<T> &newElements,
-                                 FailureOr<bool> &changed,
-                                 ReplaceSubElementFnT &&replaceSubElementFn) {
+static void updateSubElementImpl(
+    T element, function_ref<std::pair<T, WalkResult>(T)> walkFn,
+    DenseMap<T, T> &visited, SmallVectorImpl<T> &newElements,
+    FailureOr<bool> &changed, ReplaceSubElementFnT &&replaceSubElementFn) {
   // Bail early if we failed at any point.
   if (failed(changed))
     return;
@@ -120,17 +119,22 @@ static void updateSubElementImpl(T element, function_ref<T(T)> walkFn,
   // yet.
   T &mappedElement = visited[element];
   if (!mappedElement) {
+    WalkResult result = WalkResult::advance();
+    std::tie(mappedElement, result) = walkFn(element);
+
     // Try walking this element.
-    if (!(mappedElement = walkFn(element))) {
+    if (result.wasInterrupted() || !mappedElement) {
       changed = failure();
       return;
     }
 
     // Handle replacing sub-elements if this element is also a container.
-    if (auto interface = mappedElement.template dyn_cast<InterfaceT>()) {
-      if (!(mappedElement = replaceSubElementFn(interface))) {
-        changed = failure();
-        return;
+    if (!result.wasSkipped()) {
+      if (auto interface = mappedElement.template dyn_cast<InterfaceT>()) {
+        if (!(mappedElement = replaceSubElementFn(interface))) {
+          changed = failure();
+          return;
+        }
       }
     }
   }
@@ -145,8 +149,8 @@ static void updateSubElementImpl(T element, function_ref<T(T)> walkFn,
 template <typename InterfaceT>
 static typename InterfaceT::ValueType
 replaceSubElementsImpl(InterfaceT interface,
-                       function_ref<Attribute(Attribute)> walkAttrsFn,
-                       function_ref<Type(Type)> walkTypesFn,
+                       SubElementResultReplFn<Attribute> walkAttrsFn,
+                       SubElementResultReplFn<Type> walkTypesFn,
                        DenseMap<Attribute, Attribute> &visitedAttrs,
                        DenseMap<Type, Type> &visitedTypes) {
   // Walk the current sub-elements, replacing them as necessary.
@@ -186,8 +190,8 @@ replaceSubElementsImpl(InterfaceT interface,
 }
 
 Attribute SubElementAttrInterface::replaceSubElements(
-    function_ref<Attribute(Attribute)> replaceAttrFn,
-    function_ref<Type(Type)> replaceTypeFn) {
+    SubElementResultReplFn<Attribute> replaceAttrFn,
+    SubElementResultReplFn<Type> replaceTypeFn) {
   assert(replaceAttrFn && replaceTypeFn && "expected valid replace functions");
   DenseMap<Attribute, Attribute> visitedAttrs;
   DenseMap<Type, Type> visitedTypes;
@@ -196,8 +200,8 @@ Attribute SubElementAttrInterface::replaceSubElements(
 }
 
 Type SubElementTypeInterface::replaceSubElements(
-    function_ref<Attribute(Attribute)> replaceAttrFn,
-    function_ref<Type(Type)> replaceTypeFn) {
+    SubElementResultReplFn<Attribute> replaceAttrFn,
+    SubElementResultReplFn<Type> replaceTypeFn) {
   assert(replaceAttrFn && replaceTypeFn && "expected valid replace functions");
   DenseMap<Attribute, Attribute> visitedAttrs;
   DenseMap<Type, Type> visitedTypes;

diff  --git a/mlir/lib/IR/SymbolTable.cpp b/mlir/lib/IR/SymbolTable.cpp
index fb56d91f68a6c..792bf42488372 100644
--- a/mlir/lib/IR/SymbolTable.cpp
+++ b/mlir/lib/IR/SymbolTable.cpp
@@ -853,23 +853,30 @@ replaceAllSymbolUsesImpl(SymbolT symbol, StringAttr newSymbol, IRUnitT *limit) {
     SymbolRefAttr newAttr = generateNewRefAttr(scope.symbol, newLeafAttr);
 
     auto walkFn = [&](Operation *op) -> Optional<WalkResult> {
-      auto remapAttrFn = [&](Attribute attr) -> Attribute {
+      auto remapAttrFn =
+          [&](Attribute attr) -> std::pair<Attribute, WalkResult> {
+        // Regardless of the match, don't walk nested SymbolRefAttrs, we don't
+        // want to accidentally replace an inner reference.
         if (attr == oldAttr)
-          return newAttr;
+          return {newAttr, WalkResult::skip()};
         // Handle prefix matches.
         if (SymbolRefAttr symRef = attr.dyn_cast<SymbolRefAttr>()) {
           if (isReferencePrefixOf(oldAttr, symRef)) {
             auto oldNestedRefs = oldAttr.getNestedReferences();
             auto nestedRefs = symRef.getNestedReferences();
             if (oldNestedRefs.empty())
-              return SymbolRefAttr::get(newSymbol, nestedRefs);
+              return {SymbolRefAttr::get(newSymbol, nestedRefs),
+                      WalkResult::skip()};
 
             auto newNestedRefs = llvm::to_vector<4>(nestedRefs);
             newNestedRefs[oldNestedRefs.size() - 1] = newLeafAttr;
-            return SymbolRefAttr::get(symRef.getRootReference(), newNestedRefs);
+            return {
+                SymbolRefAttr::get(symRef.getRootReference(), newNestedRefs),
+                WalkResult::skip()};
           }
+          return {attr, WalkResult::skip()};
         }
-        return attr;
+        return {attr, WalkResult::advance()};
       };
       // Generate a new attribute dictionary by replacing references to the old
       // symbol.

diff  --git a/mlir/test/IR/test-symbol-rauw.mlir b/mlir/test/IR/test-symbol-rauw.mlir
index b33651f1129d6..c7d48b6c4eb1d 100644
--- a/mlir/test/IR/test-symbol-rauw.mlir
+++ b/mlir/test/IR/test-symbol-rauw.mlir
@@ -94,3 +94,19 @@ module {
       } : () -> ()
     }
 }
+
+// -----
+
+module {
+  // CHECK: module @replaced_foo
+  module @foo attributes {sym.new_name = "replaced_foo" } {
+    // CHECK: func.func private @foo
+    func.func private @foo()
+  }
+
+  // CHECK: foo.op
+  // CHECK-SAME: use = @replaced_foo::@foo
+  "foo.op"() {
+    use = @foo::@foo
+  } : () -> ()
+}


        


More information about the Mlir-commits mailing list