[Mlir-commits] [mlir] f4cd367 - [mlir][scf] refactor scf structuralOpConversion to better support 1:N type conversion

Peiming Liu llvmlistbot at llvm.org
Wed Nov 2 09:45:47 PDT 2022


Author: Peiming Liu
Date: 2022-11-02T16:45:39Z
New Revision: f4cd3674ea2e055a770fa993cbfd2356c38fc545

URL: https://github.com/llvm/llvm-project/commit/f4cd3674ea2e055a770fa993cbfd2356c38fc545
DIFF: https://github.com/llvm/llvm-project/commit/f4cd3674ea2e055a770fa993cbfd2356c38fc545.diff

LOG: [mlir][scf] refactor scf structuralOpConversion to better support 1:N type conversion

This patch moves the 1:N type mapping into its own classes to allow better code reuse in D137100.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D137099

Added: 
    

Modified: 
    mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
index c4c219617b782..a441b6c80b75b 100644
--- a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
@@ -32,22 +32,82 @@ static void unpackUnrealizedConversionCast(Value v,
   unpacked.push_back(v);
 }
 
-class ConvertForOpTypes : public OpConversionPattern<ForOp> {
+// CRTP
+// A base class that takes care of 1:N type conversion, which maps the converted
+// op results (computed by the derived class) and materializes 1:N conversion.
+template <typename SourceOp, typename ConcretePattern>
+class Structural1ToNConversionPattern : public OpConversionPattern<SourceOp> {
 public:
-  using OpConversionPattern::OpConversionPattern;
+  using OpConversionPattern<SourceOp>::typeConverter;
+  using OpConversionPattern<SourceOp>::OpConversionPattern;
+  using OpAdaptor = typename OpConversionPattern<SourceOp>::OpAdaptor;
+
+  //
+  // Derived classes should provide the following method which performs the
+  // actual conversion. It should return llvm::None upon conversion failure and
+  // return the converted operation upon success.
+  //
+  // Optional<SourceOp> convertSourceOp(SourceOp op, OpAdaptor adaptor,
+  //                                    ConversionPatternRewriter &rewriter,
+  //                                    TypeRange dstTypes) const;
+
   LogicalResult
-  matchAndRewrite(ForOp op, OpAdaptor adaptor,
+  matchAndRewrite(SourceOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    SmallVector<Type> newResultTypes;
+    SmallVector<Type> dstTypes;
     SmallVector<unsigned> offsets;
     offsets.push_back(0);
     // Do the type conversion and record the offsets.
     for (Type type : op.getResultTypes()) {
-      if (failed(typeConverter->convertTypes(type, newResultTypes)))
-        return rewriter.notifyMatchFailure(op, "could not convert result");
-      offsets.push_back(newResultTypes.size());
+      if (failed(typeConverter->convertTypes(type, dstTypes)))
+        return rewriter.notifyMatchFailure(op, "could not convert result type");
+      offsets.push_back(dstTypes.size());
     }
 
+    // Calls the actual converter implementation to convert the operation.
+    Optional<SourceOp> newOp =
+        static_cast<const ConcretePattern *>(this)->convertSourceOp(
+            op, adaptor, rewriter, dstTypes);
+
+    if (!newOp)
+      return rewriter.notifyMatchFailure(op, "could not convert operation");
+
+    // Packs the return value.
+    SmallVector<Value> packedRets;
+    for (unsigned i = 1, e = offsets.size(); i < e; i++) {
+      unsigned start = offsets[i - 1], end = offsets[i];
+      unsigned len = end - start;
+      ValueRange mappedValue = newOp->getResults().slice(start, len);
+      if (len != 1) {
+        // 1 : N type conversion.
+        Type origType = op.getResultTypes()[i - 1];
+        Value mat = typeConverter->materializeSourceConversion(
+            rewriter, op.getLoc(), origType, mappedValue);
+        if (!mat) {
+          return rewriter.notifyMatchFailure(
+              op, "Failed to materialize 1:N type conversion");
+        }
+        packedRets.push_back(mat);
+      } else {
+        // 1 : 1 type conversion.
+        packedRets.push_back(mappedValue.front());
+      }
+    }
+
+    rewriter.replaceOp(op, packedRets);
+    return success();
+  }
+};
+
+class ConvertForOpTypes
+    : public Structural1ToNConversionPattern<ForOp, ConvertForOpTypes> {
+public:
+  using Structural1ToNConversionPattern::Structural1ToNConversionPattern;
+
+  // The callback required by CRTP.
+  Optional<ForOp> convertSourceOp(ForOp op, OpAdaptor adaptor,
+                                  ConversionPatternRewriter &rewriter,
+                                  TypeRange dstTypes) const {
     // Create a empty new op and inline the regions from the old op.
     //
     // This is a little bit tricky. We have two concerns here:
@@ -67,15 +127,15 @@ class ConvertForOpTypes : public OpConversionPattern<ForOp> {
 
     // convertRegionTypes already takes care of 1:N conversion.
     if (failed(rewriter.convertRegionTypes(&op.getLoopBody(), *typeConverter)))
-      return failure();
+      return llvm::None;
 
     // Unpacked the iteration arguments.
     SmallVector<Value> flatArgs;
     for (Value arg : adaptor.getInitArgs())
       unpackUnrealizedConversionCast(arg, flatArgs);
 
-    // We can not do clone as the number of result types after conversion might
-    // be 
diff erent.
+    // We can not do clone as the number of result types after conversion
+    // might be 
diff erent.
     ForOp newOp = rewriter.create<ForOp>(op.getLoc(), adaptor.getLowerBound(),
                                          adaptor.getUpperBound(),
                                          adaptor.getStep(), flatArgs);
@@ -89,29 +149,7 @@ class ConvertForOpTypes : public OpConversionPattern<ForOp> {
     rewriter.inlineRegionBefore(op.getLoopBody(), newOp.getLoopBody(),
                                 newOp.getLoopBody().end());
 
-    // Pack the return value.
-    SmallVector<Value, 6> packedRets;
-    for (unsigned i = 1, e = offsets.size(); i < e; i++) {
-      unsigned start = offsets[i - 1], end = offsets[i];
-      unsigned len = end - start;
-      ValueRange mappedValue = newOp.getResults().slice(start, len);
-      if (len != 1) {
-        // 1 : N type conversion.
-        Type origType = op.getResultTypes()[i - 1];
-        Value mat = typeConverter->materializeSourceConversion(
-            rewriter, op.getLoc(), origType, mappedValue);
-        if (!mat)
-          return rewriter.notifyMatchFailure(
-              op, "Failed to materialize 1:N type conversion");
-        packedRets.push_back(mat);
-      } else {
-        // 1 : 1 type conversion.
-        packedRets.push_back(mappedValue.front());
-      }
-    }
-
-    rewriter.replaceOp(op, packedRets);
-    return success();
+    return newOp;
   }
 };
 } // namespace


        


More information about the Mlir-commits mailing list