[Mlir-commits] [mlir] a782922 - [mlir][SubElementInterfaces] Prefer calling the derived get if possible
River Riddle
llvmlistbot at llvm.org
Sat Nov 5 17:04:22 PDT 2022
Author: River Riddle
Date: 2022-11-05T16:35:25-07:00
New Revision: a782922708af4e80bc9eaba977704420b6c765d9
URL: https://github.com/llvm/llvm-project/commit/a782922708af4e80bc9eaba977704420b6c765d9
DIFF: https://github.com/llvm/llvm-project/commit/a782922708af4e80bc9eaba977704420b6c765d9.diff
LOG: [mlir][SubElementInterfaces] Prefer calling the derived get if possible
This allows for better supporting attributes/types that override the
default builders.
Added:
Modified:
mlir/include/mlir/IR/SubElementInterfaces.h
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/SubElementInterfaces.h b/mlir/include/mlir/IR/SubElementInterfaces.h
index ed387eb9a122..07d246aafbfa 100644
--- a/mlir/include/mlir/IR/SubElementInterfaces.h
+++ b/mlir/include/mlir/IR/SubElementInterfaces.h
@@ -220,6 +220,8 @@ template <typename T>
struct is_tuple : public std::false_type {};
template <typename... Ts>
struct is_tuple<std::tuple<Ts...>> : public std::true_type {};
+template <typename T, typename... Ts>
+using has_get_method = decltype(T::get(std::declval<Ts>()...));
/// This function provides the underlying implementation for the
/// SubElementInterface walk method, using the key type of the derived
@@ -239,6 +241,23 @@ void walkImmediateSubElementsImpl(T derived,
}
}
+/// This function invokes the proper `get` method for a type `T` with the given
+/// values.
+template <typename T, typename... Ts>
+T constructSubElementReplacement(MLIRContext *ctx, Ts &&...params) {
+ // Prefer a direct `get` method if one exists.
+ if constexpr (llvm::is_detected<has_get_method, T, Ts...>::value) {
+ (void)ctx;
+ return T::get(std::forward<Ts>(params)...);
+ } else if constexpr (llvm::is_detected<has_get_method, T, MLIRContext *,
+ Ts...>::value) {
+ return T::get(ctx, std::forward<Ts>(params)...);
+ } else {
+ // Otherwise, pass to the base get.
+ return T::Base::get(ctx, std::forward<Ts>(params)...);
+ }
+}
+
/// This function provides the underlying implementation for the
/// SubElementInterface replace method, using the key type of the derived
/// attribute/type to interact with the individual parameters.
@@ -260,12 +279,13 @@ T replaceImmediateSubElementsImpl(T derived, ArrayRef<Attribute> &replAttrs,
if constexpr (is_tuple<decltype(key)>::value) {
return std::apply(
[&](auto &&...params) {
- return T::Base::get(derived.getContext(),
- std::forward<decltype(params)>(params)...);
+ return constructSubElementReplacement<T>(
+ derived.getContext(),
+ std::forward<decltype(params)>(params)...);
},
newKey);
} else {
- return T::Base::get(derived.getContext(), newKey);
+ return constructSubElementReplacement<T>(derived.getContext(), newKey);
}
}
}
More information about the Mlir-commits
mailing list