[Mlir-commits] [mlir] 62f5c46 - [mlir][Linalg] NFC - Expose more options to the CodegenStrategy

Nicolas Vasilache llvmlistbot at llvm.org
Fri Feb 19 06:10:44 PST 2021


Author: Nicolas Vasilache
Date: 2021-02-19T14:01:44Z
New Revision: 62f5c46eecf8d356b76e840fb6cab09360f25f76

URL: https://github.com/llvm/llvm-project/commit/62f5c46eecf8d356b76e840fb6cab09360f25f76
DIFF: https://github.com/llvm/llvm-project/commit/62f5c46eecf8d356b76e840fb6cab09360f25f76.diff

LOG: [mlir][Linalg] NFC - Expose more options to the CodegenStrategy

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h
    mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h b/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h
index fe481bb1f011..872e76387d84 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h
@@ -66,8 +66,7 @@ void enqueue(OwningRewritePatternList &patternList, OptionsType options,
 
 /// Promotion transformation enqueues a particular stage-1 pattern for
 /// `Tile<LinalgOpType>`with the appropriate `options`.
-template <typename LinalgOpType>
-struct Tile : public Transformation {
+template <typename LinalgOpType> struct Tile : public Transformation {
   explicit Tile(linalg::LinalgTilingOptions options,
                 linalg::LinalgTransformationFilter::FilterFunction f = nullptr)
       : Transformation(f), opName(LinalgOpType::getOperationName()),
@@ -93,8 +92,7 @@ struct Tile : public Transformation {
 
 /// Promotion transformation enqueues a particular stage-1 pattern for
 /// `Promote<LinalgOpType>`with the appropriate `options`.
-template <typename LinalgOpType>
-struct Promote : public Transformation {
+template <typename LinalgOpType> struct Promote : public Transformation {
   explicit Promote(
       linalg::LinalgPromotionOptions options,
       linalg::LinalgTransformationFilter::FilterFunction f = nullptr)
@@ -150,6 +148,16 @@ struct Vectorize : public Transformation {
   linalg::LinalgVectorizationOptions options;
 };
 
+/// Options to control the application of late transformations.
+struct LateCodegenStrategyOptions {
+  bool enableLICM = true;
+  bool enableHoistRedundantVectorTransfers = true;
+  bool enableHoistRedundantVectorTransfersOnTensor = true;
+  bool enableVectorTransferPartialRewrite = true;
+  bool enableVectorContractLowering = true;
+  bool enableVectorToSCFConversion = true;
+};
+
 /// Codegen strategy controls how a Linalg op is progressively lowered.
 /// The application uses a 3-level staged patterns strategy which allows
 /// ordering transformations by using the Linalg `applyStagedPatterns`
@@ -283,10 +291,32 @@ struct CodegenStrategy {
     vectorToSCFOptions = options;
     return *this;
   }
-  /// Configure the post staged-patterns late vector.transfer to scf
-  /// conversion.
-  CodegenStrategy &setHoistInvariantCode(bool enableLICM) {
-    this->enableLICM = enableLICM;
+  ///
+  /// Configure the application of late transformations.
+  ///
+  CodegenStrategy &setEnableLICM(bool val) {
+    this->lateCodegenStrategyOptions.enableLICM = val;
+    return *this;
+  }
+  CodegenStrategy &setEnableHoistRedundantVectorTransfers(bool val) {
+    this->lateCodegenStrategyOptions.enableHoistRedundantVectorTransfers = val;
+    return *this;
+  }
+  CodegenStrategy &setEnableHoistRedundantVectorTransfersOnTensor(bool val) {
+    this->lateCodegenStrategyOptions
+        .enableHoistRedundantVectorTransfersOnTensor = val;
+    return *this;
+  }
+  CodegenStrategy &setEnableVectorTransferPartialRewrite(bool val) {
+    this->lateCodegenStrategyOptions.enableVectorTransferPartialRewrite = val;
+    return *this;
+  }
+  CodegenStrategy &setEnableVectorContractLowering(bool val) {
+    this->lateCodegenStrategyOptions.enableVectorContractLowering = val;
+    return *this;
+  }
+  CodegenStrategy &setEnableVectorToSCFConversion(bool val) {
+    this->lateCodegenStrategyOptions.enableVectorToSCFConversion = val;
     return *this;
   }
 
@@ -300,7 +330,7 @@ struct CodegenStrategy {
   vector::VectorTransformsOptions vectorTransformsOptions;
   VectorTransferToSCFOptions vectorToSCFOptions;
   SmallVector<std::unique_ptr<Transformation>, 4> transformationSequence;
-  bool enableLICM = true;
+  LateCodegenStrategyOptions lateCodegenStrategyOptions;
 };
 
 } // namespace linalg

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp b/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp
index e18a0b7ea985..6a6517bf40c7 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp
@@ -53,7 +53,7 @@ void mlir::linalg::CodegenStrategy::transform(FuncOp func) const {
     // Some of these may be too aggressive as a stage 3 that is applied on each
     // stage 1 application and may have to be split out to post staged patterns
     // application (in which case they could just be passes, TBD).
-    if (enableLICM) {
+    if (lateCodegenStrategyOptions.enableLICM) {
       op->walk([&](LoopLikeOpInterface loopLike) {
         LLVM_DEBUG(loopLike.print(llvm::dbgs() << "\nOriginal loop:\n"));
         if (failed(moveLoopInvariantCode(loopLike)))
@@ -62,8 +62,10 @@ void mlir::linalg::CodegenStrategy::transform(FuncOp func) const {
     }
     promoteSingleIterationLoops(cast<FuncOp>(op));
     hoistViewAllocOps(cast<FuncOp>(op));
-    hoistRedundantVectorTransfers(cast<FuncOp>(op));
-    hoistRedundantVectorTransfersOnTensor(cast<FuncOp>(op));
+    if (lateCodegenStrategyOptions.enableHoistRedundantVectorTransfers)
+      hoistRedundantVectorTransfers(cast<FuncOp>(op));
+    if (lateCodegenStrategyOptions.enableHoistRedundantVectorTransfersOnTensor)
+      hoistRedundantVectorTransfersOnTensor(cast<FuncOp>(op));
     return success();
   };
   (void)linalg::applyStagedPatterns(
@@ -74,25 +76,31 @@ void mlir::linalg::CodegenStrategy::transform(FuncOp func) const {
   //===--------------------------------------------------------------------===//
 
   // Programmatic splitting of slow/fast path vector transfers.
-  OwningRewritePatternList patterns;
-  patterns.insert<vector::VectorTransferFullPartialRewriter>(
-      context, vectorTransformsOptions);
-  (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
+  if (lateCodegenStrategyOptions.enableVectorTransferPartialRewrite) {
+    OwningRewritePatternList patterns;
+    patterns.insert<vector::VectorTransferFullPartialRewriter>(
+        context, vectorTransformsOptions);
+    (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
+  }
 
   // Programmatic controlled lowering of vector.contract only.
-  OwningRewritePatternList vectorContractLoweringPatterns;
-  vectorContractLoweringPatterns
-      .insert<ContractionOpToOuterProductOpLowering,
-              ContractionOpToMatmulOpLowering, ContractionOpLowering>(
-          vectorTransformsOptions, context);
-  (void)applyPatternsAndFoldGreedily(func,
-                                     std::move(vectorContractLoweringPatterns));
+  if (lateCodegenStrategyOptions.enableVectorContractLowering) {
+    OwningRewritePatternList vectorContractLoweringPatterns;
+    vectorContractLoweringPatterns
+        .insert<ContractionOpToOuterProductOpLowering,
+                ContractionOpToMatmulOpLowering, ContractionOpLowering>(
+            vectorTransformsOptions, context);
+    (void)applyPatternsAndFoldGreedily(
+        func, std::move(vectorContractLoweringPatterns));
+  }
 
   // Programmatic controlled lowering of vector.transfer only.
-  OwningRewritePatternList vectorToLoopsPatterns;
-  populateVectorToSCFConversionPatterns(vectorToLoopsPatterns, context,
-                                        vectorToSCFOptions);
-  (void)applyPatternsAndFoldGreedily(func, std::move(vectorToLoopsPatterns));
+  if (lateCodegenStrategyOptions.enableVectorToSCFConversion) {
+    OwningRewritePatternList vectorToLoopsPatterns;
+    populateVectorToSCFConversionPatterns(vectorToLoopsPatterns, context,
+                                          vectorToSCFOptions);
+    (void)applyPatternsAndFoldGreedily(func, std::move(vectorToLoopsPatterns));
+  }
 
   // Ensure we drop the marker in the end.
   func.walk([](LinalgOp op) {


        


More information about the Mlir-commits mailing list