[Mlir-commits] [mlir] 4e56361 - [mlir] Use GenericAdaptor to simplify 1:N type conversion API.

Ingo Müller llvmlistbot at llvm.org
Fri Mar 31 03:43:14 PDT 2023


Author: Ingo Müller
Date: 2023-03-31T10:43:09Z
New Revision: 4e563616a5fffa1204286a9aa03604a68a7db835

URL: https://github.com/llvm/llvm-project/commit/4e563616a5fffa1204286a9aa03604a68a7db835
DIFF: https://github.com/llvm/llvm-project/commit/4e563616a5fffa1204286a9aa03604a68a7db835.diff

LOG: [mlir] Use GenericAdaptor to simplify 1:N type conversion API.

For 1:N type conversion, there is a 1:N relationship between the
original operands and the converted operands. The same is true for the
results. The previous design passed an instance of a "mapping" class
into each pattern that helped with handling this 1:N correspondance.
However, this was still rather manual and, in particular, it required
the use of magic constants for the indices of the different operands.

This commits uses the generated GenericAdaptor class that is generated
for each op class in order to simplify this relationship further. The
GenericAdaptor allows to wrap around a list of arbitrary types for each
operand (via templating); for 1:N type conversion, this allows the
operand accessors of the adaptor class to return a ValueRange that
corresponds to the N values in the converted types. Patterns can thus
use the named accessors instead of magic constants, which eliminates a
common class of errors.

This commit further simplifies the API that patterns need to implement
by making the operand and result type mappings part of the adaptor.
Since many patterns only need one of the two (or even neither), this
reduces the number of unnecessary arguments in many cases.

Reviewed By: springerm

Differential Revision: https://reviews.llvm.org/D147225

Added: 
    

Modified: 
    mlir/include/mlir/Transforms/OneToNTypeConversion.h
    mlir/lib/Dialect/Func/Transforms/OneToNFuncConversions.cpp
    mlir/lib/Dialect/SCF/Transforms/OneToNTypeConversion.cpp
    mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Transforms/OneToNTypeConversion.h b/mlir/include/mlir/Transforms/OneToNTypeConversion.h
index 5992fc4204935..80e61e7a817ca 100644
--- a/mlir/include/mlir/Transforms/OneToNTypeConversion.h
+++ b/mlir/include/mlir/Transforms/OneToNTypeConversion.h
@@ -215,22 +215,66 @@ class OneToNOpConversionPattern : public OneToNConversionPattern {
                             ArrayRef<StringRef> generatedNames = {})
       : OneToNConversionPattern(typeConverter, SourceOp::getOperationName(),
                                 benefit, context, generatedNames) {}
+  /// Generic adaptor around the root op of this pattern using the converted
+  /// operands. Importantly, each operand is represented as a *range* of values,
+  /// namely the N values each original operand gets converted to. Concretely,
+  /// this makes the result type of the accessor functions of the adaptor class
+  /// be a `ValueRange`.
+  class OpAdaptor
+      : public SourceOp::template GenericAdaptor<ArrayRef<ValueRange>> {
+  public:
+    using RangeT = ArrayRef<ValueRange>;
+    using BaseT = typename SourceOp::template GenericAdaptor<RangeT>;
+
+    OpAdaptor(const OneToNTypeMapping *operandMapping,
+              const OneToNTypeMapping *resultMapping,
+              const ValueRange *convertedOperands, RangeT values,
+              DictionaryAttr attrs = nullptr, RegionRange regions = {})
+        : BaseT(values, attrs, regions), operandMapping(operandMapping),
+          resultMapping(resultMapping), convertedOperands(convertedOperands) {}
+
+    /// Get the type mapping of the original operands to the converted operands.
+    const OneToNTypeMapping &getOperandMapping() const {
+      return *operandMapping;
+    }
+
+    /// Get the type mapping of the original results to the converted results.
+    const OneToNTypeMapping &getResultMapping() const { return *resultMapping; }
+
+    /// Get a flat range of all converted operands. Unlike `getOperands`, which
+    /// returns an `ArrayRef` with one `ValueRange` for each original operand,
+    /// this function returns a `ValueRange` that contains all converted
+    /// operands irrespectively of which operand they originated from.
+    ValueRange getFlatOperands() const { return *convertedOperands; }
+
+  private:
+    const OneToNTypeMapping *operandMapping;
+    const OneToNTypeMapping *resultMapping;
+    const ValueRange *convertedOperands;
+  };
 
   using OneToNConversionPattern::matchAndRewrite;
 
   /// Overload that derived classes have to override for their op type.
-  virtual LogicalResult matchAndRewrite(SourceOp op,
-                                        OneToNPatternRewriter &rewriter,
-                                        const OneToNTypeMapping &operandMapping,
-                                        const OneToNTypeMapping &resultMapping,
-                                        ValueRange convertedOperands) const = 0;
+  virtual LogicalResult
+  matchAndRewrite(SourceOp op, OpAdaptor adaptor,
+                  OneToNPatternRewriter &rewriter) const = 0;
 
   LogicalResult matchAndRewrite(Operation *op, OneToNPatternRewriter &rewriter,
                                 const OneToNTypeMapping &operandMapping,
                                 const OneToNTypeMapping &resultMapping,
                                 ValueRange convertedOperands) const final {
-    return matchAndRewrite(cast<SourceOp>(op), rewriter, operandMapping,
-                           resultMapping, convertedOperands);
+    // Wrap converted operands and type mappings into an adaptor.
+    SmallVector<ValueRange> valueRanges;
+    for (int64_t i = 0; i < op->getNumOperands(); i++) {
+      auto values = operandMapping.getConvertedValues(convertedOperands, i);
+      valueRanges.push_back(values);
+    }
+    OpAdaptor adaptor(&operandMapping, &resultMapping, &convertedOperands,
+                      valueRanges, op->getAttrDictionary(), op->getRegions());
+
+    // Call overload implemented by the derived class.
+    return matchAndRewrite(cast<SourceOp>(op), adaptor, rewriter);
   }
 };
 

diff  --git a/mlir/lib/Dialect/Func/Transforms/OneToNFuncConversions.cpp b/mlir/lib/Dialect/Func/Transforms/OneToNFuncConversions.cpp
index 5e8125ca94283..7005693241121 100644
--- a/mlir/lib/Dialect/Func/Transforms/OneToNFuncConversions.cpp
+++ b/mlir/lib/Dialect/Func/Transforms/OneToNFuncConversions.cpp
@@ -27,21 +27,21 @@ class ConvertTypesInFuncCallOp : public OneToNOpConversionPattern<CallOp> {
 public:
   using OneToNOpConversionPattern<CallOp>::OneToNOpConversionPattern;
 
-  LogicalResult matchAndRewrite(CallOp op, OneToNPatternRewriter &rewriter,
-                                const OneToNTypeMapping &operandMapping,
-                                const OneToNTypeMapping &resultMapping,
-                                ValueRange convertedOperands) const override {
+  LogicalResult
+  matchAndRewrite(CallOp op, OpAdaptor adaptor,
+                  OneToNPatternRewriter &rewriter) const override {
     Location loc = op->getLoc();
+    const OneToNTypeMapping &resultMapping = adaptor.getResultMapping();
 
     // Nothing to do if the op doesn't have any non-identity conversions for its
     // operands or results.
-    if (!operandMapping.hasNonIdentityConversion() &&
+    if (!adaptor.getOperandMapping().hasNonIdentityConversion() &&
         !resultMapping.hasNonIdentityConversion())
       return failure();
 
     // Create new CallOp.
     auto newOp = rewriter.create<CallOp>(loc, resultMapping.getConvertedTypes(),
-                                         convertedOperands);
+                                         adaptor.getFlatOperands());
     newOp->setAttrs(op->getAttrs());
 
     rewriter.replaceOp(op, newOp->getResults(), resultMapping);
@@ -54,10 +54,8 @@ class ConvertTypesInFuncFuncOp : public OneToNOpConversionPattern<FuncOp> {
   using OneToNOpConversionPattern<FuncOp>::OneToNOpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(FuncOp op, OneToNPatternRewriter &rewriter,
-                  const OneToNTypeMapping & /*operandMapping*/,
-                  const OneToNTypeMapping & /*resultMapping*/,
-                  ValueRange /*convertedOperands*/) const override {
+  matchAndRewrite(FuncOp op, OpAdaptor adaptor,
+                  OneToNPatternRewriter &rewriter) const override {
     auto *typeConverter = getTypeConverter<OneToNTypeConverter>();
 
     // Construct mapping for function arguments.
@@ -99,16 +97,16 @@ class ConvertTypesInFuncReturnOp : public OneToNOpConversionPattern<ReturnOp> {
 public:
   using OneToNOpConversionPattern<ReturnOp>::OneToNOpConversionPattern;
 
-  LogicalResult matchAndRewrite(ReturnOp op, OneToNPatternRewriter &rewriter,
-                                const OneToNTypeMapping &operandMapping,
-                                const OneToNTypeMapping & /*resultMapping*/,
-                                ValueRange convertedOperands) const override {
+  LogicalResult
+  matchAndRewrite(ReturnOp op, OpAdaptor adaptor,
+                  OneToNPatternRewriter &rewriter) const override {
     // Nothing to do if there is no non-identity conversion.
-    if (!operandMapping.hasNonIdentityConversion())
+    if (!adaptor.getOperandMapping().hasNonIdentityConversion())
       return failure();
 
     // Convert operands.
-    rewriter.updateRootInPlace(op, [&] { op->setOperands(convertedOperands); });
+    rewriter.updateRootInPlace(
+        op, [&] { op->setOperands(adaptor.getFlatOperands()); });
 
     return success();
   }

diff  --git a/mlir/lib/Dialect/SCF/Transforms/OneToNTypeConversion.cpp b/mlir/lib/Dialect/SCF/Transforms/OneToNTypeConversion.cpp
index 74207e6fbb647..9fd266bc44c68 100644
--- a/mlir/lib/Dialect/SCF/Transforms/OneToNTypeConversion.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/OneToNTypeConversion.cpp
@@ -25,11 +25,10 @@ class ConvertTypesInSCFIfOp : public OneToNOpConversionPattern<IfOp> {
   using OneToNOpConversionPattern<IfOp>::OneToNOpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(IfOp op, OneToNPatternRewriter &rewriter,
-                  const OneToNTypeMapping & /*operandMapping*/,
-                  const OneToNTypeMapping &resultMapping,
-                  const ValueRange /*convertedOperands*/) const override {
+  matchAndRewrite(IfOp op, OpAdaptor adaptor,
+                  OneToNPatternRewriter &rewriter) const override {
     Location loc = op->getLoc();
+    const OneToNTypeMapping &resultMapping = adaptor.getResultMapping();
 
     // Nothing to do if there is no non-identity conversion.
     if (!resultMapping.hasNonIdentityConversion())
@@ -62,12 +61,13 @@ class ConvertTypesInSCFWhileOp : public OneToNOpConversionPattern<WhileOp> {
   using OneToNOpConversionPattern<WhileOp>::OneToNOpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(WhileOp op, OneToNPatternRewriter &rewriter,
-                  const OneToNTypeMapping &operandMapping,
-                  const OneToNTypeMapping &resultMapping,
-                  const ValueRange convertedOperands) const override {
+  matchAndRewrite(WhileOp op, OpAdaptor adaptor,
+                  OneToNPatternRewriter &rewriter) const override {
     Location loc = op->getLoc();
 
+    const OneToNTypeMapping &operandMapping = adaptor.getOperandMapping();
+    const OneToNTypeMapping &resultMapping = adaptor.getResultMapping();
+
     // Nothing to do if the op doesn't have any non-identity conversions for its
     // operands or results.
     if (!operandMapping.hasNonIdentityConversion() &&
@@ -77,8 +77,8 @@ class ConvertTypesInSCFWhileOp : public OneToNOpConversionPattern<WhileOp> {
     // Create new WhileOp.
     TypeRange convertedResultTypes = resultMapping.getConvertedTypes();
 
-    auto newOp =
-        rewriter.create<WhileOp>(loc, convertedResultTypes, convertedOperands);
+    auto newOp = rewriter.create<WhileOp>(loc, convertedResultTypes,
+                                          adaptor.getFlatOperands());
     newOp->setAttrs(op->getAttrs());
 
     // Update block signatures.
@@ -106,16 +106,15 @@ class ConvertTypesInSCFYieldOp : public OneToNOpConversionPattern<YieldOp> {
   using OneToNOpConversionPattern<YieldOp>::OneToNOpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(YieldOp op, OneToNPatternRewriter &rewriter,
-                  const OneToNTypeMapping &operandMapping,
-                  const OneToNTypeMapping & /*resultMapping*/,
-                  const ValueRange convertedOperands) const override {
+  matchAndRewrite(YieldOp op, OpAdaptor adaptor,
+                  OneToNPatternRewriter &rewriter) const override {
     // Nothing to do if there is no non-identity conversion.
-    if (!operandMapping.hasNonIdentityConversion())
+    if (!adaptor.getOperandMapping().hasNonIdentityConversion())
       return failure();
 
     // Convert operands.
-    rewriter.updateRootInPlace(op, [&] { op->setOperands(convertedOperands); });
+    rewriter.updateRootInPlace(
+        op, [&] { op->setOperands(adaptor.getFlatOperands()); });
 
     return success();
   }
@@ -127,16 +126,15 @@ class ConvertTypesInSCFConditionOp
   using OneToNOpConversionPattern<ConditionOp>::OneToNOpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(ConditionOp op, OneToNPatternRewriter &rewriter,
-                  const OneToNTypeMapping &operandMapping,
-                  const OneToNTypeMapping & /*resultMapping*/,
-                  const ValueRange convertedOperands) const override {
+  matchAndRewrite(ConditionOp op, OpAdaptor adaptor,
+                  OneToNPatternRewriter &rewriter) const override {
     // Nothing to do if there is no non-identity conversion.
-    if (!operandMapping.hasNonIdentityConversion())
+    if (!adaptor.getOperandMapping().hasNonIdentityConversion())
       return failure();
 
     // Convert operands.
-    rewriter.updateRootInPlace(op, [&] { op->setOperands(convertedOperands); });
+    rewriter.updateRootInPlace(
+        op, [&] { op->setOperands(adaptor.getFlatOperands()); });
 
     return success();
   }

diff  --git a/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp b/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp
index c60c323a58d4f..c3f20989dbd6b 100644
--- a/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp
+++ b/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp
@@ -77,13 +77,12 @@ class ConvertMakeTupleOp
   using OneToNOpConversionPattern<
       ::test::MakeTupleOp>::OneToNOpConversionPattern;
 
-  LogicalResult matchAndRewrite(::test::MakeTupleOp op,
-                                OneToNPatternRewriter &rewriter,
-                                const OneToNTypeMapping &operandMapping,
-                                const OneToNTypeMapping &resultMapping,
-                                ValueRange convertedOperands) const override {
+  LogicalResult
+  matchAndRewrite(::test::MakeTupleOp op, OpAdaptor adaptor,
+                  OneToNPatternRewriter &rewriter) const override {
     // Simply replace the current op with the converted operands.
-    rewriter.replaceOp(op, convertedOperands, resultMapping);
+    rewriter.replaceOp(op, adaptor.getFlatOperands(),
+                       adaptor.getResultMapping());
     return success();
   }
 };
@@ -99,11 +98,9 @@ class ConvertGetTupleElementOp
   using OneToNOpConversionPattern<
       ::test::GetTupleElementOp>::OneToNOpConversionPattern;
 
-  LogicalResult matchAndRewrite(::test::GetTupleElementOp op,
-                                OneToNPatternRewriter &rewriter,
-                                const OneToNTypeMapping &operandMapping,
-                                const OneToNTypeMapping &resultMapping,
-                                ValueRange convertedOperands) const override {
+  LogicalResult
+  matchAndRewrite(::test::GetTupleElementOp op, OpAdaptor adaptor,
+                  OneToNPatternRewriter &rewriter) const override {
     // Construct mapping for tuple element types.
     auto stateType = op->getOperand(0).getType().cast<TupleType>();
     TypeRange originalElementTypes = stateType.getTypes();
@@ -113,16 +110,17 @@ class ConvertGetTupleElementOp
       return failure();
 
     // Compute converted operands corresponding to original input tuple.
-    ValueRange convertedTuple =
-        operandMapping.getConvertedValues(convertedOperands, 0);
+    assert(adaptor.getOperands().size() == 1 &&
+           "expected 'get_tuple_element' to have one operand");
+    ValueRange convertedTuple = adaptor.getOperands()[0];
 
-    // Got those converted operands that correspond to the index-th element of
+    // Got those converted operands that correspond to the index-th element ofq
     // the original input tuple.
     size_t index = op.getIndex();
     ValueRange extractedElement =
         elementMapping.getConvertedValues(convertedTuple, index);
 
-    rewriter.replaceOp(op, extractedElement, resultMapping);
+    rewriter.replaceOp(op, extractedElement, adaptor.getResultMapping());
 
     return success();
   }


        


More information about the Mlir-commits mailing list