[Mlir-commits] [mlir] [mlir][func] Replace `ValueDecomposer` with target materialization (PR #114192)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Oct 30 00:56:16 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-func

Author: Matthias Springer (matthias-springer)

<details>
<summary>Changes</summary>

The `ValueDecomposer` in `DecomposeCallGraphTypes` was a workaround around missing 1:N support in the dialect conversion. Since #<!-- -->113032, the dialect conversion infrastructure supports 1:N type conversions and 1:N target materializations. The `ValueDecomposer` class is no longer needed. (However, target materializations must still be inserted manually, until we fully merge the 1:1 and 1:N drivers.)

---
Full diff: https://github.com/llvm/llvm-project/pull/114192.diff


3 Files Affected:

- (modified) mlir/include/mlir/Dialect/Func/Transforms/DecomposeCallGraphTypes.h (+1-61) 
- (modified) mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp (+53-56) 
- (modified) mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp (+37-23) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Func/Transforms/DecomposeCallGraphTypes.h b/mlir/include/mlir/Dialect/Func/Transforms/DecomposeCallGraphTypes.h
index 1d311b37b37a4f..1be406bf3adf92 100644
--- a/mlir/include/mlir/Dialect/Func/Transforms/DecomposeCallGraphTypes.h
+++ b/mlir/include/mlir/Dialect/Func/Transforms/DecomposeCallGraphTypes.h
@@ -23,70 +23,10 @@
 
 namespace mlir {
 
-/// This class provides a hook that expands one Value into multiple Value's,
-/// with a TypeConverter-inspired callback registration mechanism.
-///
-/// For folks that are familiar with the dialect conversion framework /
-/// TypeConverter, this is effectively the inverse of a source/argument
-/// materialization. A target materialization is not what we want here because
-/// it always produces a single Value, but in this case the whole point is to
-/// decompose a Value into multiple Value's.
-///
-/// The reason we need this inverse is easily understood by looking at what we
-/// need to do for decomposing types for a return op. When converting a return
-/// op, the dialect conversion framework will give the list of converted
-/// operands, and will ensure that each converted operand, even if it expanded
-/// into multiple types, is materialized as a single result. We then need to
-/// undo that materialization to a single result, which we do with the
-/// decomposeValue hooks registered on this object.
-///
-/// TODO: Eventually, the type conversion infra should have this hook built-in.
-/// See
-/// https://llvm.discourse.group/t/extending-type-conversion-infrastructure/779/2
-class ValueDecomposer {
-public:
-  /// This method tries to decompose a value of a certain type using provided
-  /// decompose callback functions. If it is unable to do so, the original value
-  /// is returned.
-  void decomposeValue(OpBuilder &, Location, Type, Value,
-                      SmallVectorImpl<Value> &);
-
-  /// This method registers a callback function that will be called to decompose
-  /// a value of a certain type into 0, 1, or multiple values.
-  template <typename FnT, typename T = typename llvm::function_traits<
-                              std::decay_t<FnT>>::template arg_t<2>>
-  void addDecomposeValueConversion(FnT &&callback) {
-    decomposeValueConversions.emplace_back(
-        wrapDecomposeValueConversionCallback<T>(std::forward<FnT>(callback)));
-  }
-
-private:
-  using DecomposeValueConversionCallFn =
-      std::function<std::optional<LogicalResult>(
-          OpBuilder &, Location, Type, Value, SmallVectorImpl<Value> &)>;
-
-  /// Generate a wrapper for the given decompose value conversion callback.
-  template <typename T, typename FnT>
-  DecomposeValueConversionCallFn
-  wrapDecomposeValueConversionCallback(FnT &&callback) {
-    return
-        [callback = std::forward<FnT>(callback)](
-            OpBuilder &builder, Location loc, Type type, Value value,
-            SmallVectorImpl<Value> &newValues) -> std::optional<LogicalResult> {
-          if (T derivedType = dyn_cast<T>(type))
-            return callback(builder, loc, derivedType, value, newValues);
-          return std::nullopt;
-        };
-  }
-
-  SmallVector<DecomposeValueConversionCallFn, 2> decomposeValueConversions;
-};
-
 /// Populates the patterns needed to drive the conversion process for
-/// decomposing call graph types with the given `ValueDecomposer`.
+/// decomposing call graph types with the given `TypeConverter`.
 void populateDecomposeCallGraphTypesPatterns(MLIRContext *context,
                                              const TypeConverter &typeConverter,
-                                             ValueDecomposer &decomposer,
                                              RewritePatternSet &patterns);
 
 } // namespace mlir
diff --git a/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp b/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp
index 357f993710a26a..8800ffd0be96dc 100644
--- a/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp
+++ b/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp
@@ -14,52 +14,46 @@ using namespace mlir;
 using namespace mlir::func;
 
 //===----------------------------------------------------------------------===//
-// ValueDecomposer
+// Helper functions
 //===----------------------------------------------------------------------===//
 
-void ValueDecomposer::decomposeValue(OpBuilder &builder, Location loc,
-                                     Type type, Value value,
-                                     SmallVectorImpl<Value> &results) {
-  for (auto &conversion : decomposeValueConversions)
-    if (conversion(builder, loc, type, value, results))
-      return;
-  results.push_back(value);
+/// If the given value can be decomposed with the type converter, decompose it.
+/// Otherwise, return the given value.
+static SmallVector<Value> decomposeValue(OpBuilder &builder, Location loc,
+                                         Value value,
+                                         const TypeConverter *converter) {
+  // Try to convert the given value's type. If that fails, just return the
+  // given value.
+  SmallVector<Type> convertedTypes;
+  if (failed(converter->convertType(value.getType(), convertedTypes)))
+    return {value};
+  if (convertedTypes.empty())
+    return {};
+
+  // If the given value's type is already legal, just return the given value.
+  TypeRange convertedTypeRange(convertedTypes);
+  if (convertedTypeRange == TypeRange(value.getType()))
+    return {value};
+
+  // Try to materialize a target conversion. If the materialization did not
+  // produce values of the requested type, the materialization failed. Just
+  // return the given value in that case.
+  SmallVector<Value> result = converter->materializeTargetConversion(
+      builder, loc, convertedTypeRange, value);
+  if (result.empty())
+    return {value};
+  return result;
 }
 
-//===----------------------------------------------------------------------===//
-// DecomposeCallGraphTypesOpConversionPattern
-//===----------------------------------------------------------------------===//
-
-namespace {
-/// Base OpConversionPattern class to make a ValueDecomposer available to
-/// inherited patterns.
-template <typename SourceOp>
-class DecomposeCallGraphTypesOpConversionPattern
-    : public OpConversionPattern<SourceOp> {
-public:
-  DecomposeCallGraphTypesOpConversionPattern(const TypeConverter &typeConverter,
-                                             MLIRContext *context,
-                                             ValueDecomposer &decomposer,
-                                             PatternBenefit benefit = 1)
-      : OpConversionPattern<SourceOp>(typeConverter, context, benefit),
-        decomposer(decomposer) {}
-
-protected:
-  ValueDecomposer &decomposer;
-};
-} // namespace
-
 //===----------------------------------------------------------------------===//
 // DecomposeCallGraphTypesForFuncArgs
 //===----------------------------------------------------------------------===//
 
 namespace {
-/// Expand function arguments according to the provided TypeConverter and
-/// ValueDecomposer.
+/// Expand function arguments according to the provided TypeConverter.
 struct DecomposeCallGraphTypesForFuncArgs
-    : public DecomposeCallGraphTypesOpConversionPattern<func::FuncOp> {
-  using DecomposeCallGraphTypesOpConversionPattern::
-      DecomposeCallGraphTypesOpConversionPattern;
+    : public OpConversionPattern<func::FuncOp> {
+  using OpConversionPattern::OpConversionPattern;
 
   LogicalResult
   matchAndRewrite(func::FuncOp op, OpAdaptor adaptor,
@@ -100,19 +94,22 @@ struct DecomposeCallGraphTypesForFuncArgs
 //===----------------------------------------------------------------------===//
 
 namespace {
-/// Expand return operands according to the provided TypeConverter and
-/// ValueDecomposer.
+/// Expand return operands according to the provided TypeConverter.
 struct DecomposeCallGraphTypesForReturnOp
-    : public DecomposeCallGraphTypesOpConversionPattern<ReturnOp> {
-  using DecomposeCallGraphTypesOpConversionPattern::
-      DecomposeCallGraphTypesOpConversionPattern;
+    : public OpConversionPattern<ReturnOp> {
+  using OpConversionPattern::OpConversionPattern;
+
   LogicalResult
   matchAndRewrite(ReturnOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const final {
     SmallVector<Value, 2> newOperands;
-    for (Value operand : adaptor.getOperands())
-      decomposer.decomposeValue(rewriter, op.getLoc(), operand.getType(),
-                                operand, newOperands);
+    for (Value operand : adaptor.getOperands()) {
+      // TODO: We can directly take the values from the adaptor once this is a
+      // 1:N conversion pattern.
+      llvm::append_range(newOperands,
+                         decomposeValue(rewriter, operand.getLoc(), operand,
+                                        getTypeConverter()));
+    }
     rewriter.replaceOpWithNewOp<ReturnOp>(op, newOperands);
     return success();
   }
@@ -124,12 +121,9 @@ struct DecomposeCallGraphTypesForReturnOp
 //===----------------------------------------------------------------------===//
 
 namespace {
-/// Expand call op operands and results according to the provided TypeConverter
-/// and ValueDecomposer.
-struct DecomposeCallGraphTypesForCallOp
-    : public DecomposeCallGraphTypesOpConversionPattern<CallOp> {
-  using DecomposeCallGraphTypesOpConversionPattern::
-      DecomposeCallGraphTypesOpConversionPattern;
+/// Expand call op operands and results according to the provided TypeConverter.
+struct DecomposeCallGraphTypesForCallOp : public OpConversionPattern<CallOp> {
+  using OpConversionPattern::OpConversionPattern;
 
   LogicalResult
   matchAndRewrite(CallOp op, OpAdaptor adaptor,
@@ -137,9 +131,13 @@ struct DecomposeCallGraphTypesForCallOp
 
     // Create the operands list of the new `CallOp`.
     SmallVector<Value, 2> newOperands;
-    for (Value operand : adaptor.getOperands())
-      decomposer.decomposeValue(rewriter, op.getLoc(), operand.getType(),
-                                operand, newOperands);
+    for (Value operand : adaptor.getOperands()) {
+      // TODO: We can directly take the values from the adaptor once this is a
+      // 1:N conversion pattern.
+      llvm::append_range(newOperands,
+                         decomposeValue(rewriter, operand.getLoc(), operand,
+                                        getTypeConverter()));
+    }
 
     // Create the new result types for the new `CallOp` and track the indices in
     // the new call op's results that correspond to the old call op's results.
@@ -189,9 +187,8 @@ struct DecomposeCallGraphTypesForCallOp
 
 void mlir::populateDecomposeCallGraphTypesPatterns(
     MLIRContext *context, const TypeConverter &typeConverter,
-    ValueDecomposer &decomposer, RewritePatternSet &patterns) {
+    RewritePatternSet &patterns) {
   patterns
       .add<DecomposeCallGraphTypesForCallOp, DecomposeCallGraphTypesForFuncArgs,
-           DecomposeCallGraphTypesForReturnOp>(typeConverter, context,
-                                               decomposer);
+           DecomposeCallGraphTypesForReturnOp>(typeConverter, context);
 }
diff --git a/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp b/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp
index 92216da9f201e6..de511c58ae6ee0 100644
--- a/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp
+++ b/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp
@@ -21,23 +21,40 @@ namespace {
 /// given tuple value. If some tuple elements are, in turn, tuples, the elements
 /// of those are extracted recursively such that the returned values have the
 /// same types as `resultTypes.getFlattenedTypes()`.
-static LogicalResult buildDecomposeTuple(OpBuilder &builder, Location loc,
-                                         TupleType resultType, Value value,
-                                         SmallVectorImpl<Value> &values) {
-  for (unsigned i = 0, e = resultType.size(); i < e; ++i) {
-    Type elementType = resultType.getType(i);
-    Value element = builder.create<test::GetTupleElementOp>(
-        loc, elementType, value, builder.getI32IntegerAttr(i));
-    if (auto nestedTupleType = dyn_cast<TupleType>(elementType)) {
-      // Recurse if the current element is also a tuple.
-      if (failed(buildDecomposeTuple(builder, loc, nestedTupleType, element,
-                                     values)))
-        return failure();
-    } else {
-      values.push_back(element);
+static SmallVector<Value> buildDecomposeTuple(OpBuilder &builder,
+                                              TypeRange resultTypes,
+                                              ValueRange inputs, Location loc) {
+  // Skip materialization if the single input value is not a tuple.
+  if (inputs.size() != 1)
+    return {};
+  Value tuple = inputs.front();
+  auto tupleType = dyn_cast<TupleType>(tuple.getType());
+  if (!tupleType)
+    return {};
+  // Skip materialization if the flattened types do not match the requested
+  // result types.
+  SmallVector<Type> flattenedTypes;
+  tupleType.getFlattenedTypes(flattenedTypes);
+  if (TypeRange(resultTypes) != TypeRange(flattenedTypes))
+    return {};
+  // Recursively decompose the tuple.
+  SmallVector<Value> result;
+  std::function<void(Value)> decompose = [&](Value tuple) {
+    auto tupleType = dyn_cast<TupleType>(tuple.getType());
+    if (!tupleType) {
+      // This is not a tuple.
+      result.push_back(tuple);
+      return;
     }
-  }
-  return success();
+    for (unsigned i = 0, e = tupleType.size(); i < e; ++i) {
+      Type elementType = tupleType.getType(i);
+      Value element = builder.create<test::GetTupleElementOp>(
+          loc, elementType, tuple, builder.getI32IntegerAttr(i));
+      decompose(element);
+    }
+  };
+  decompose(tuple);
+  return result;
 }
 
 /// Creates a `test.make_tuple` op out of the given inputs building a tuple of
@@ -82,8 +99,8 @@ static Value buildMakeTupleOp(OpBuilder &builder, TupleType resultType,
 
 /// A pass for testing call graph type decomposition.
 ///
-/// This instantiates the patterns with a TypeConverter and ValueDecomposer
-/// that splits tuple types into their respective element types.
+/// This instantiates the patterns with a TypeConverter that splits tuple types
+/// into their respective element types.
 /// For example, `tuple<T1, T2, T3> --> T1, T2, T3`.
 struct TestDecomposeCallGraphTypes
     : public PassWrapper<TestDecomposeCallGraphTypes, OperationPass<ModuleOp>> {
@@ -123,12 +140,9 @@ struct TestDecomposeCallGraphTypes
           return success();
         });
     typeConverter.addArgumentMaterialization(buildMakeTupleOp);
+    typeConverter.addTargetMaterialization(buildDecomposeTuple);
 
-    ValueDecomposer decomposer;
-    decomposer.addDecomposeValueConversion(buildDecomposeTuple);
-
-    populateDecomposeCallGraphTypesPatterns(context, typeConverter, decomposer,
-                                            patterns);
+    populateDecomposeCallGraphTypesPatterns(context, typeConverter, patterns);
 
     if (failed(applyPartialConversion(module, target, std::move(patterns))))
       return signalPassFailure();

``````````

</details>


https://github.com/llvm/llvm-project/pull/114192


More information about the Mlir-commits mailing list