[llvm-branch-commits] [mlir] 3e3e276 - [mlir][vector][NFC] Change UnrollVectorPattern to not be statically dependent on an op type

Thomas Raoux via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Fri Dec 4 09:58:12 PST 2020


Author: Thomas Raoux
Date: 2020-12-04T09:53:01-08:00
New Revision: 3e3e276d22ca6917a721c4173b00b37850d8020c

URL: https://github.com/llvm/llvm-project/commit/3e3e276d22ca6917a721c4173b00b37850d8020c
DIFF: https://github.com/llvm/llvm-project/commit/3e3e276d22ca6917a721c4173b00b37850d8020c.diff

LOG: [mlir][vector][NFC] Change UnrollVectorPattern to not be statically dependent on an op type

Make UnrollVectorPattern inherit from RewritePattern instead of
OpRewritePattern so that we don't need to create many patterns when applying to
many different type of ops. Since we may want to apply the pattern to all
arithmetic op, it is more convenient to filter dynamically.

Differential Revision: https://reviews.llvm.org/D92635

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/VectorTransforms.h
    mlir/test/lib/Transforms/TestVectorTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h
index cc0b2841d3f1..c88aa7f5bc65 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h
+++ b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h
@@ -91,7 +91,7 @@ struct UnrollVectorOptions {
   /// Callback function that indicates whether vector unrolling should be
   /// attempted on the operation.
   FilterConstraintFnType filterConstraint = nullptr;
-  UnrollVectorOptions &setFilterContraint(FilterConstraintFnType constraint) {
+  UnrollVectorOptions &setFilterConstraint(FilterConstraintFnType constraint) {
     filterConstraint = constraint;
     return *this;
   }
@@ -117,21 +117,19 @@ struct UnrollVectorOptions {
 };
 /// Pattern to apply `unrollSingleResultVectorOp` to a `targetShape`
 /// declaratively.
-template <typename OpTy>
-struct UnrollVectorPattern : public OpRewritePattern<OpTy> {
-  using FilterConstraintType = std::function<LogicalResult(OpTy op)>;
+struct UnrollVectorPattern : public RewritePattern {
+  using FilterConstraintType = std::function<LogicalResult(Operation *op)>;
   UnrollVectorPattern(MLIRContext *context, UnrollVectorOptions options)
-      : OpRewritePattern<OpTy>(context), options(options) {}
-  LogicalResult matchAndRewrite(OpTy op,
+      : RewritePattern(/*benefit=*/1, MatchAnyOpTypeTag()), options(options) {}
+  LogicalResult matchAndRewrite(Operation *op,
                                 PatternRewriter &rewriter) const override {
     if (options.filterConstraint && failed(options.filterConstraint(op)))
       return failure();
     if (!options.nativeShape) {
-      return op.emitError("vector unrolling expects the native shape or native"
-                          "shape call back function to be set");
+      return op->emitError("vector unrolling expects the native shape or native"
+                           "shape call back function to be set");
     }
-    auto unrollableVectorOp =
-        dyn_cast<VectorUnrollOpInterface>(op.getOperation());
+    auto unrollableVectorOp = dyn_cast<VectorUnrollOpInterface>(op);
     if (!unrollableVectorOp)
       return failure();
     auto maybeUnrollShape = unrollableVectorOp.getShapeForUnroll();
@@ -139,12 +137,12 @@ struct UnrollVectorPattern : public OpRewritePattern<OpTy> {
       return failure();
     Optional<SmallVector<int64_t, 4>> targetShape = options.nativeShape(op);
     if (!targetShape)
-      return op.emitError("failed to get target shape for vector unroll");
+      return op->emitError("failed to get target shape for vector unroll");
     auto maybeShapeRatio = shapeRatio(*maybeUnrollShape, *targetShape);
     if (!maybeShapeRatio ||
         llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; }))
       return failure();
-    if (std::is_same<OpTy, TransferWriteOp>::value) {
+    if (isa<TransferWriteOp>(op)) {
       if (failed(unrollTransferWriteOp(rewriter, op, *targetShape)))
         return failure();
       rewriter.eraseOp(op);

diff  --git a/mlir/test/lib/Transforms/TestVectorTransforms.cpp b/mlir/test/lib/Transforms/TestVectorTransforms.cpp
index 602bf8148cd8..99c336ef0565 100644
--- a/mlir/test/lib/Transforms/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Transforms/TestVectorTransforms.cpp
@@ -27,14 +27,22 @@ struct TestVectorToVectorConversion
   void runOnFunction() override {
     OwningRewritePatternList patterns;
     auto *ctx = &getContext();
-    patterns.insert<UnrollVectorPattern<AddFOp>>(
-        ctx, UnrollVectorOptions().setNativeShape(ArrayRef<int64_t>{2, 2}));
-    patterns.insert<UnrollVectorPattern<vector::ContractionOp>>(
-        ctx, UnrollVectorOptions().setNativeShape(ArrayRef<int64_t>{2, 2, 2}));
+    patterns.insert<UnrollVectorPattern>(
+        ctx, UnrollVectorOptions().setNativeShapeFn(getShape));
     populateVectorToVectorCanonicalizationPatterns(patterns, ctx);
     populateVectorToVectorTransformationPatterns(patterns, ctx);
     applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
   }
+
+private:
+  // Return the target shape based on op type.
+  static Optional<SmallVector<int64_t, 4>> getShape(Operation *op) {
+    if (isa<AddFOp>(op))
+      return SmallVector<int64_t, 4>(2, 2);
+    if (isa<vector::ContractionOp>(op))
+      return SmallVector<int64_t, 4>(3, 2);
+    return llvm::None;
+  }
 };
 
 struct TestVectorSlicesConversion
@@ -120,8 +128,11 @@ struct TestVectorUnrollingPatterns
   void runOnFunction() override {
     MLIRContext *ctx = &getContext();
     OwningRewritePatternList patterns;
-    patterns.insert<UnrollVectorPattern<AddFOp>>(
-        ctx, UnrollVectorOptions().setNativeShape(ArrayRef<int64_t>{2, 2}));
+    patterns.insert<UnrollVectorPattern>(
+        ctx, UnrollVectorOptions()
+                 .setNativeShape(ArrayRef<int64_t>{2, 2})
+                 .setFilterConstraint(
+                     [](Operation *op) { return success(isa<AddFOp>(op)); }));
 
     if (unrollBasedOnType) {
       UnrollVectorOptions::NativeShapeFnType nativeShapeFn =
@@ -137,12 +148,19 @@ struct TestVectorUnrollingPatterns
         }
         return nativeShape;
       };
-      patterns.insert<UnrollVectorPattern<vector::ContractionOp>>(
-          ctx, UnrollVectorOptions().setNativeShapeFn(nativeShapeFn));
+      patterns.insert<UnrollVectorPattern>(
+          ctx, UnrollVectorOptions()
+                   .setNativeShapeFn(nativeShapeFn)
+                   .setFilterConstraint([](Operation *op) {
+                     return success(isa<ContractionOp>(op));
+                   }));
     } else {
-      patterns.insert<UnrollVectorPattern<vector::ContractionOp>>(
-          ctx,
-          UnrollVectorOptions().setNativeShape(ArrayRef<int64_t>{2, 2, 2}));
+      patterns.insert<UnrollVectorPattern>(
+          ctx, UnrollVectorOptions()
+                   .setNativeShape(ArrayRef<int64_t>{2, 2, 2})
+                   .setFilterConstraint([](Operation *op) {
+                     return success(isa<ContractionOp>(op));
+                   }));
     }
     populateVectorToVectorCanonicalizationPatterns(patterns, ctx);
     populateVectorToVectorTransformationPatterns(patterns, ctx);
@@ -273,10 +291,14 @@ struct TestVectorTransferUnrollingPatterns
   void runOnFunction() override {
     MLIRContext *ctx = &getContext();
     OwningRewritePatternList patterns;
-    patterns.insert<UnrollVectorPattern<vector::TransferReadOp>>(
-        ctx, UnrollVectorOptions().setNativeShape(ArrayRef<int64_t>{2, 2}));
-    patterns.insert<UnrollVectorPattern<vector::TransferWriteOp>>(
-        ctx, UnrollVectorOptions().setNativeShape(ArrayRef<int64_t>{2, 2}));
+    patterns.insert<UnrollVectorPattern>(
+        ctx,
+        UnrollVectorOptions()
+            .setNativeShape(ArrayRef<int64_t>{2, 2})
+            .setFilterConstraint([](Operation *op) {
+              return success(
+                  isa<vector::TransferReadOp, vector::TransferWriteOp>(op));
+            }));
     populateVectorToVectorCanonicalizationPatterns(patterns, ctx);
     populateVectorToVectorTransformationPatterns(patterns, ctx);
     applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));


        


More information about the llvm-branch-commits mailing list