[Mlir-commits] [mlir] a782922 - [mlir][SubElementInterfaces] Prefer calling the derived get if possible

River Riddle llvmlistbot at llvm.org
Sat Nov 5 17:04:22 PDT 2022


Author: River Riddle
Date: 2022-11-05T16:35:25-07:00
New Revision: a782922708af4e80bc9eaba977704420b6c765d9

URL: https://github.com/llvm/llvm-project/commit/a782922708af4e80bc9eaba977704420b6c765d9
DIFF: https://github.com/llvm/llvm-project/commit/a782922708af4e80bc9eaba977704420b6c765d9.diff

LOG: [mlir][SubElementInterfaces] Prefer calling the derived get if possible

This allows for better supporting attributes/types that override the
default builders.

Added: 
    

Modified: 
    mlir/include/mlir/IR/SubElementInterfaces.h

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/SubElementInterfaces.h b/mlir/include/mlir/IR/SubElementInterfaces.h
index ed387eb9a122..07d246aafbfa 100644
--- a/mlir/include/mlir/IR/SubElementInterfaces.h
+++ b/mlir/include/mlir/IR/SubElementInterfaces.h
@@ -220,6 +220,8 @@ template <typename T>
 struct is_tuple : public std::false_type {};
 template <typename... Ts>
 struct is_tuple<std::tuple<Ts...>> : public std::true_type {};
+template <typename T, typename... Ts>
+using has_get_method = decltype(T::get(std::declval<Ts>()...));
 
 /// This function provides the underlying implementation for the
 /// SubElementInterface walk method, using the key type of the derived
@@ -239,6 +241,23 @@ void walkImmediateSubElementsImpl(T derived,
   }
 }
 
+/// This function invokes the proper `get` method for  a type `T` with the given
+/// values.
+template <typename T, typename... Ts>
+T constructSubElementReplacement(MLIRContext *ctx, Ts &&...params) {
+  // Prefer a direct `get` method if one exists.
+  if constexpr (llvm::is_detected<has_get_method, T, Ts...>::value) {
+    (void)ctx;
+    return T::get(std::forward<Ts>(params)...);
+  } else if constexpr (llvm::is_detected<has_get_method, T, MLIRContext *,
+                                         Ts...>::value) {
+    return T::get(ctx, std::forward<Ts>(params)...);
+  } else {
+    // Otherwise, pass to the base get.
+    return T::Base::get(ctx, std::forward<Ts>(params)...);
+  }
+}
+
 /// This function provides the underlying implementation for the
 /// SubElementInterface replace method, using the key type of the derived
 /// attribute/type to interact with the individual parameters.
@@ -260,12 +279,13 @@ T replaceImmediateSubElementsImpl(T derived, ArrayRef<Attribute> &replAttrs,
     if constexpr (is_tuple<decltype(key)>::value) {
       return std::apply(
           [&](auto &&...params) {
-            return T::Base::get(derived.getContext(),
-                                std::forward<decltype(params)>(params)...);
+            return constructSubElementReplacement<T>(
+                derived.getContext(),
+                std::forward<decltype(params)>(params)...);
           },
           newKey);
     } else {
-      return T::Base::get(derived.getContext(), newKey);
+      return constructSubElementReplacement<T>(derived.getContext(), newKey);
     }
   }
 }


        


More information about the Mlir-commits mailing list