[Mlir-commits] [mlir] [mlir][Transforms] Merge 1:1 and 1:N type converters (PR #113032)

Markus Böck llvmlistbot at llvm.org
Wed Oct 23 10:16:02 PDT 2024


================
@@ -409,32 +419,51 @@ class TypeConverter {
   /// callback.
   ///
   /// With callback of form:
-  /// `Value(OpBuilder &, T, ValueRange, Location, Type)`
+  /// - Value(OpBuilder &, T, ValueRange, Location, Type)
+  /// - SmallVector<Value>(OpBuilder &, TypeRange, ValueRange, Location, Type)
   template <typename T, typename FnT>
   std::enable_if_t<
       std::is_invocable_v<FnT, OpBuilder &, T, ValueRange, Location, Type>,
       TargetMaterializationCallbackFn>
   wrapTargetMaterialization(FnT &&callback) const {
     return [callback = std::forward<FnT>(callback)](
-               OpBuilder &builder, Type resultType, ValueRange inputs,
-               Location loc, Type originalType) -> Value {
-      if (T derivedType = dyn_cast<T>(resultType))
-        return callback(builder, derivedType, inputs, loc, originalType);
-      return Value();
+               OpBuilder &builder, TypeRange resultTypes, ValueRange inputs,
+               Location loc, Type originalType) -> SmallVector<Value> {
+      SmallVector<Value> result;
+      if constexpr (std::is_same<T, TypeRange>::value) {
+        // This is a 1:N target materialization. Return the produces values
+        // directly.
+        result = callback(builder, resultTypes, inputs, loc, originalType);
+      } else if constexpr (std::is_assignable<Type, T>::value) {
+        // This is a 1:1 target materialization. Invoke it only if the result
+        // type class of the callback matches the requested result type.
+        if (T derivedType = dyn_cast<T>(resultTypes.front())) {
+          // 1:1 materializations produce single values, but we store 1:N
+          // target materialization functions in the type converter. Wrap the
+          // result value in a SmallVector<Value>.
+          Value val = callback(builder, derivedType, inputs, loc, originalType);
+          if (val)
+            result.push_back(val);
+        }
+      } else {
+        static_assert(false, "T must be a Type or a TypeRange");
----------------
zero9178 wrote:

```suggestion
        static_assert(sizeof(T) == 0, "T must be a Type or a TypeRange");
```
This needs to be a type dependent expression to not immediately fail 

https://github.com/llvm/llvm-project/pull/113032


More information about the Mlir-commits mailing list