[Mlir-commits] [mlir] 9a7d111 - [mlir][Linalg] NFC - Modernize transformation APIs.

Nicolas Vasilache llvmlistbot at llvm.org
Wed Jan 5 08:06:37 PST 2022


Author: Nicolas Vasilache
Date: 2022-01-05T11:01:40-05:00
New Revision: 9a7d111f4fb65ad7343dcbd4f35ee608100634e8

URL: https://github.com/llvm/llvm-project/commit/9a7d111f4fb65ad7343dcbd4f35ee608100634e8
DIFF: https://github.com/llvm/llvm-project/commit/9a7d111f4fb65ad7343dcbd4f35ee608100634e8.diff

LOG: [mlir][Linalg] NFC - Modernize transformation APIs.

Differential Revision: https://reviews.llvm.org/D116665

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
    mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp
    mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp
    mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
    mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index bb396ce5a5541..c1185c0a8ff70 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -44,12 +44,12 @@ bool skipUnitDimReshape(const OpResult &producer, OpOperand &consumer);
 //===----------------------------------------------------------------------===//
 using LinalgLoops = SmallVector<Operation *, 4>;
 
-/// [DEPRECATED] Populates patterns for vectorization of all ConvN-D ops.
+/// [DEPRECATED] Populate patterns for vectorization of all ConvN-D ops.
 void populateConvVectorizationPatterns(
     MLIRContext *context, SmallVectorImpl<RewritePatternSet> &patterns,
     ArrayRef<int64_t> tileSizes);
 
-/// Populates patterns for vectorizing low-D convolution ops. This is a step in
+/// Populate patterns for vectorizing low-D convolution ops. This is a step in
 /// progressive lowering for convolution ops, it assume high-D convolution ops
 /// were decomposed previously.
 void populateConvolutionVectorizationPatterns(RewritePatternSet &patterns,
@@ -91,7 +91,7 @@ void populateFoldUnitDimsReshapeOpsByLinearizationPatterns(
 /// canonicalizations of named ops into another named op.
 void populateLinalgNamedOpConversionPatterns(RewritePatternSet &patterns);
 
-/// Populates the given list with patterns to bufferize linalg ops.
+/// Populate the given list with patterns to bufferize linalg ops.
 void populateLinalgBufferizePatterns(
     bufferization::BufferizeTypeConverter &converter,
     RewritePatternSet &patterns);
@@ -124,7 +124,7 @@ struct LinalgElementwiseFusionOptions {
     return *this;
   }
 
-  /// Function that allows the caller to control when to stop fusion. Once a
+  /// Function to allow the caller to control when to stop fusion. Once a
   /// producer is deemed fusable with the consumer (structurally), this callback
   /// can be used to abort the fusion based on non-structural constraints. This
   /// is the hook for cost models to control the amount of fusion done.
@@ -149,7 +149,7 @@ void populateElementwiseOpsFusionPatterns(
 /// more fusion opportunities.
 void populatePushReshapeOpsPatterns(RewritePatternSet &patterns);
 
-/// Performs standalone tiling of a single LinalgOp by `tileSizes`.
+/// Perform standalone tiling of a single LinalgOp by `tileSizes`.
 /// and permute the loop nest according to `interchangeVector`
 /// The permutation is expressed as a list of integers that specify
 /// the new ordering of the loop nest. The length of `interchangeVector`
@@ -157,7 +157,7 @@ void populatePushReshapeOpsPatterns(RewritePatternSet &patterns);
 /// An empty vector is interpreted as the identity permutation and the
 /// transformation returns early.
 ///
-/// Returns a struct containing the tiled loops in the specified order
+/// Return a struct containing the tiled loops in the specified order
 /// and the cloned op if successful, llvm::None otherwise.
 ///
 /// E.g. the permutation `(i,j,k) -> (j,k,i)` is expressed by
@@ -237,7 +237,7 @@ tileAndFuseLinalgOps(OpBuilder &builder, ArrayRef<LinalgOp> ops,
                      const LinalgDependenceGraph &dependenceGraph,
                      const LinalgTilingOptions &tilingOptions);
 
-/// Interchanges the `iterator_types` and `iterator_maps` dimensions and adapts
+/// Interchange the `iterator_types` and `iterator_maps` dimensions and adapts
 /// the index accesses of `op`. This is an in-place transformation controlled by
 /// `interchangeVector`. An empty vector is interpreted as the identity
 /// permutation and the transformation returns early.
@@ -246,12 +246,15 @@ tileAndFuseLinalgOps(OpBuilder &builder, ArrayRef<LinalgOp> ops,
 /// `interchangeVector = [1,2,0]`. All values in `interchangeVector` must be
 /// integers, in the range 0..`op.rank` without duplications
 /// (i.e. `[1,1,2]` is an invalid permutation).
-void interchangeGenericOp(PatternRewriter &rewriter, GenericOp genericOp,
-                          ArrayRef<unsigned> interchangeVector);
+FailureOr<GenericOp> interchangeGenericOp(RewriterBase &rewriter,
+                                          GenericOp genericOp,
+                                          ArrayRef<unsigned> interchangeVector);
 
-/// Creates a GenericOp from the given named operation `namedOp`. Assumes
-/// `namedOp` is not a GenericOp and has a region builder.
-GenericOp generalizeNamedOp(PatternRewriter &rewriter, LinalgOp namedOp);
+/// Create a GenericOp from the given named operation `namedOp` and replace
+/// namedOp.
+/// Return failure if `namedOp` is a GenericOp or misses a region builder.
+FailureOr<GenericOp> generalizeNamedOp(RewriterBase &rewriter,
+                                       LinalgOp namedOp);
 
 /// Callback function type used to perform the allocation for the promoted
 /// `subView`. In `boundingSubViewsize` a best attempt is made to find the
@@ -346,7 +349,7 @@ struct LinalgPromotionOptions {
   }
 };
 
-/// Creates a new buffer using the `allocationFn` provided. The size of this
+/// Create a new buffer using the `allocationFn` provided. The size of this
 /// buffer is the smallest constant bounding size along each dimension that can
 /// be computed for the size of the result of `subView`. Returns the allocated
 /// buffer as `fullLocalView` and the view that matches the size of the result
@@ -360,7 +363,7 @@ promoteSubviewAsNewBuffer(OpBuilder &b, Location loc, memref::SubViewOp subView,
                           const AllocBufferCallbackFn &allocationFn,
                           DataLayout &layout);
 
-/// Promotes the `subViews` into a new buffer allocated at the insertion point
+/// Promote the `subViews` into a new buffer allocated at the insertion point
 /// `b`. Promotion occurs in 3 steps:
 ///   1. Create a new buffer for a full tile (i.e. not clipped at the boundary).
 ///   2. Take a full view on the buffer.
@@ -368,24 +371,23 @@ promoteSubviewAsNewBuffer(OpBuilder &b, Location loc, memref::SubViewOp subView,
 /// Infers statically sized buffers from subViews unless `dynamicBuffers` is
 /// true.
 ///
-/// Returns the modified linalg op (the modification happens in place) as well
+/// Return the modified linalg op (the modification happens in place) as well
 /// as all the copy ops created.
 FailureOr<LinalgOp> promoteSubViews(OpBuilder &b, LinalgOp op,
                                     const LinalgPromotionOptions &options);
 
 /// Emit a suitable vector form for a Linalg op with fully static shape.
-LogicalResult vectorizeLinalgOp(OpBuilder &builder, Operation *op,
-                                SmallVectorImpl<Value> &newResults);
+LogicalResult vectorize(RewriterBase &builder, LinalgOp linalgOp);
 
-/// Emits a loop nest of `scf.for` with the proper body for `linalgOp`.
+/// Emit a loop nest of `scf.for` with the proper body for `linalgOp`.
 FailureOr<LinalgLoops> linalgOpToLoops(PatternRewriter &rewriter,
                                        LinalgOp linalgOp);
 
-/// Emits a loop nest of `scf.parallel` with the proper body for `linalgOp`.
+/// Emit a loop nest of `scf.parallel` with the proper body for `linalgOp`.
 FailureOr<LinalgLoops> linalgOpToParallelLoops(PatternRewriter &rewriter,
                                                LinalgOp linalgOp);
 
-/// Emits a loop nest of `affine.for` with the proper body for `linalgOp`.
+/// Emit a loop nest of `affine.for` with the proper body for `linalgOp`.
 FailureOr<LinalgLoops> linalgOpToAffineLoops(PatternRewriter &rewriter,
                                              LinalgOp linalgOp);
 
@@ -393,28 +395,10 @@ FailureOr<LinalgLoops> linalgOpToAffineLoops(PatternRewriter &rewriter,
 // Preconditions that ensure the corresponding transformation succeeds and can
 // be applied as a rewrite pattern.
 //===----------------------------------------------------------------------===//
-/// Emits a `generic` operation with the `indexing_maps` and `iterator_types`
-/// permutated according to `permutation`.
-LogicalResult
-interchangeGenericOpPrecondition(GenericOp genericOp,
-                                 ArrayRef<unsigned> interchangeVector);
-
-/// Generalize named operations to generic operations.
-LogicalResult generalizeNamedOpPrecondition(Operation *op);
-
-/// Promote std.subviews feeding linalg operations.
+/// Promote memref.subviews feeding linalg-on-buffers operations.
 LogicalResult promoteSubviewsPrecondition(Operation *op,
                                           LinalgPromotionOptions options);
 
-/// Return success if the operation can be vectorized.
-LogicalResult vectorizeLinalgOpPrecondition(Operation *op);
-
-/// Return success if `op` can be vectorized assuming it is static. This allows
-/// checking if an op will be vectorizable once all the dimensions are folded to
-/// static values.
-/// It is the same as `vectorizeLinalgOpPrecondition` for static shapes.
-LogicalResult vectorizeStaticLinalgOpPrecondition(LinalgOp op);
-
 //===----------------------------------------------------------------------===//
 // Transformations exposed as rewrite patterns.
 //===----------------------------------------------------------------------===//
@@ -610,7 +594,7 @@ struct LinalgTilingOptions {
 RewritePatternSet getLinalgTilingCanonicalizationPatterns(MLIRContext *ctx);
 void populateLinalgTilingCanonicalizationPatterns(RewritePatternSet &patterns);
 
-/// Base pattern that applied the tiling transformation specified by `options`.
+/// Base pattern that applies the tiling transformation specified by `options`.
 /// Abort and return failure in 2 cases:
 ///   1. if the tiling specification is invalid and tiling fails to occur.
 ///   2. if tiling occurs but `options.paddingValueComputationFunction` is set
@@ -812,9 +796,9 @@ struct LinalgTileAndFuseTensorOpsPattern : public RewritePattern {
 };
 
 ///
-/// Linalg generic interchage pattern.
+/// Linalg generic interchange pattern.
 ///
-/// Apply the `interchange` transformation as a pattern.
+/// Apply the `interchange` transformation on a RewriterBase.
 /// `filter` controls LinalgTransformMarker matching and update when specified.
 /// See `interchange` for more details.
 struct GenericOpInterchangePattern : public OpRewritePattern<GenericOp> {
@@ -909,13 +893,11 @@ struct LinalgPromotionPattern : public LinalgBasePromotionPattern {
 ///
 /// Linalg vectorization patterns.
 ///
-/// Apply the `vectorizeLinalgOp` transformation as a pattern.
-/// `filter` controls LinalgTransformMarker matching and update when specified.
-/// See `vectorizeLinalgOp` for more details.
-
 /// Empty for now, used for SFINAE purposes only.
 struct LinalgVectorizationOptions {};
 
+/// `filter` controls LinalgTransformMarker matching and update when specified.
+/// See `vectorizeLinalgOp` for more details.
 struct LinalgBaseVectorizationPattern : public RewritePattern {
   /// MatchAnyOpTag-based constructor with a mandatory `filter`.
   LinalgBaseVectorizationPattern(MLIRContext *context,

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
index a42ac8d81c4b9..721c47ca01308 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
@@ -29,7 +29,7 @@
 using namespace mlir;
 using namespace mlir::linalg;
 
-LogicalResult mlir::linalg::generalizeNamedOpPrecondition(Operation *op) {
+static LogicalResult generalizeNamedOpPrecondition(Operation *op) {
   LinalgOp namedOp = dyn_cast<LinalgOp>(op);
   // Check if the operation is a LinalgOp but not a GenericOp.
   if (!namedOp || isa<GenericOp>(op))
@@ -40,8 +40,11 @@ LogicalResult mlir::linalg::generalizeNamedOpPrecondition(Operation *op) {
   return success();
 }
 
-GenericOp mlir::linalg::generalizeNamedOp(PatternRewriter &rewriter,
-                                          LinalgOp namedOp) {
+FailureOr<GenericOp> mlir::linalg::generalizeNamedOp(RewriterBase &rewriter,
+                                                     LinalgOp namedOp) {
+  if (failed(generalizeNamedOpPrecondition(namedOp)))
+    return rewriter.notifyMatchFailure(namedOp, "preconditions not met");
+
   SmallVector<Value> inputOperands = namedOp.getInputOperands();
   SmallVector<Value> outputOperands = namedOp.getOutputOperands();
   SmallVector<AffineMap> indexingMaps = namedOp.getIndexingMaps();
@@ -58,6 +61,7 @@ GenericOp mlir::linalg::generalizeNamedOp(PatternRewriter &rewriter,
                                  outputOperands, indexingMaps, iterators);
   rewriter.inlineRegionBefore(namedOp->getRegion(0), genericOp.region(),
                               genericOp.region().begin());
+  rewriter.replaceOp(namedOp, genericOp->getResults());
   return genericOp;
 }
 

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp b/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp
index 0a1f7bc4565fc..8d0d26caa5fb0 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp
@@ -21,6 +21,7 @@
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Support/LLVM.h"
+#include "llvm/ADT/ScopeExit.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/raw_ostream.h"
 #include <type_traits>
@@ -30,8 +31,9 @@
 using namespace mlir;
 using namespace mlir::linalg;
 
-LogicalResult mlir::linalg::interchangeGenericOpPrecondition(
-    GenericOp genericOp, ArrayRef<unsigned> interchangeVector) {
+static LogicalResult
+interchangeGenericOpPrecondition(GenericOp genericOp,
+                                 ArrayRef<unsigned> interchangeVector) {
   // Interchange vector must be non-empty and match the number of loops.
   if (interchangeVector.empty() ||
       genericOp.getNumLoops() != interchangeVector.size())
@@ -43,31 +45,38 @@ LogicalResult mlir::linalg::interchangeGenericOpPrecondition(
   return success();
 }
 
-void mlir::linalg::interchangeGenericOp(PatternRewriter &rewriter,
-                                        GenericOp genericOp,
-                                        ArrayRef<unsigned> interchangeVector) {
-  // 1. Compute the inverse permutation map.
+FailureOr<GenericOp>
+mlir::linalg::interchangeGenericOp(RewriterBase &rewriter, GenericOp genericOp,
+                                   ArrayRef<unsigned> interchangeVector) {
+  if (failed(interchangeGenericOpPrecondition(genericOp, interchangeVector)))
+    return rewriter.notifyMatchFailure(genericOp, "preconditions not met");
+
+  // 1. Compute the inverse permutation map, it must be non-null since the
+  // preconditions are satisfied.
   MLIRContext *context = genericOp.getContext();
   AffineMap permutationMap = inversePermutation(
       AffineMap::getPermutationMap(interchangeVector, context));
-  assert(permutationMap && "expected permutation to be invertible");
-  assert(interchangeVector.size() == genericOp.getNumLoops() &&
-         "expected interchange vector to have entry for every loop");
+  assert(permutationMap && "unexpected null map");
+
+  // Start a guarded inplace update.
+  rewriter.startRootUpdate(genericOp);
+  auto guard =
+      llvm::make_scope_exit([&]() { rewriter.finalizeRootUpdate(genericOp); });
 
   // 2. Compute the interchanged indexing maps.
-  SmallVector<Attribute, 4> newIndexingMaps;
+  SmallVector<AffineMap> newIndexingMaps;
   for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) {
     AffineMap m = genericOp.getTiedIndexingMap(opOperand);
     if (!permutationMap.isEmpty())
       m = m.compose(permutationMap);
-    newIndexingMaps.push_back(AffineMapAttr::get(m));
+    newIndexingMaps.push_back(m);
   }
   genericOp->setAttr(getIndexingMapsAttrName(),
-                     ArrayAttr::get(context, newIndexingMaps));
+                     rewriter.getAffineMapArrayAttr(newIndexingMaps));
 
   // 3. Compute the interchanged iterator types.
   ArrayRef<Attribute> itTypes = genericOp.iterator_types().getValue();
-  SmallVector<Attribute, 4> itTypesVector;
+  SmallVector<Attribute> itTypesVector;
   llvm::append_range(itTypesVector, itTypes);
   SmallVector<int64_t> permutation(interchangeVector.begin(),
                                    interchangeVector.end());
@@ -91,4 +100,6 @@ void mlir::linalg::interchangeGenericOp(PatternRewriter &rewriter,
           indexOp, permutationMap.getSubMap(indexOp.dim()), allIndices);
     }
   }
+
+  return genericOp;
 }

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp b/mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp
index bb38607d769ac..b331b6657fcea 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp
@@ -137,7 +137,7 @@ struct SimplifyDepthwiseConvQOp
 struct LinalgNamedOpConversionPass
     : public LinalgNamedOpConversionBase<LinalgNamedOpConversionPass> {
   LinalgNamedOpConversionPass() = default;
-  LinalgNamedOpConversionPass(const LinalgNamedOpConversionPass &) {}
+  LinalgNamedOpConversionPass(const LinalgNamedOpConversionPass &) = default;
 
   void runOnOperation() override {
     Operation *op = getOperation();

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 8156c5d45744c..c2a3c2bda630a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -623,16 +623,14 @@ LogicalResult mlir::linalg::GenericOpInterchangePattern::matchAndRewrite(
     GenericOp genericOp, PatternRewriter &rewriter) const {
   if (failed(filter.checkAndNotify(rewriter, genericOp)))
     return failure();
-  if (failed(interchangeGenericOpPrecondition(genericOp, interchangeVector)))
+
+  FailureOr<GenericOp> transformedOp =
+      interchangeGenericOp(rewriter, genericOp, interchangeVector);
+  if (failed(transformedOp))
     return failure();
 
-  // TODO: figure out how this interplays with named ops. In particular this
-  // should break the named op property.
-  rewriter.updateRootInPlace(genericOp, [&]() {
-    interchangeGenericOp(rewriter, genericOp, interchangeVector);
-    // New filter if specified.
-    filter.replaceLinalgTransformationFilter(rewriter, genericOp);
-  });
+  // New filter if specified.
+  filter.replaceLinalgTransformationFilter(rewriter, genericOp);
   return success();
 }
 
@@ -652,12 +650,10 @@ LogicalResult mlir::linalg::LinalgGeneralizationPattern::matchAndRewrite(
     Operation *op, PatternRewriter &rewriter) const {
   if (failed(filter.checkAndNotify(rewriter, op)))
     return failure();
-  if (failed(generalizeNamedOpPrecondition(op)))
+  FailureOr<GenericOp> genericOp = generalizeNamedOp(rewriter, op);
+  if (failed(genericOp))
     return failure();
-
-  GenericOp genericOp = generalizeNamedOp(rewriter, op);
-  rewriter.replaceOp(op, genericOp.getResults());
-  filter.replaceLinalgTransformationFilter(rewriter, genericOp);
+  filter.replaceLinalgTransformationFilter(rewriter, *genericOp);
   return success();
 }
 
@@ -708,19 +704,13 @@ mlir::linalg::LinalgBaseVectorizationPattern::LinalgBaseVectorizationPattern(
 
 LogicalResult mlir::linalg::LinalgBaseVectorizationPattern::matchAndRewrite(
     Operation *op, PatternRewriter &rewriter) const {
+  // TODO: Interface-based rewrite.
   LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
   if (!linalgOp)
     return failure();
-  if (failed(filter.checkAndNotify(rewriter, linalgOp)))
-    return failure();
-  SmallVector<Value> newResults;
-  if (failed(vectorizeLinalgOp(rewriter, op, newResults)))
+  if (failed(filter.checkAndNotify(rewriter, op)))
     return failure();
-  if (!newResults.empty())
-    rewriter.replaceOp(op, newResults);
-  else
-    rewriter.eraseOp(op);
-  return success();
+  return vectorize(rewriter, linalgOp);
 }
 
 LogicalResult mlir::linalg::applyStagedPatterns(
@@ -758,8 +748,8 @@ static SmallVector<StringRef> getNParallelLoopsAttrs(unsigned nParallelLoops) {
   return SmallVector<StringRef>(nParallelLoops, getParallelIteratorTypeName());
 }
 
-/// Rewrite a PadTensorOp into a sequence of InitTensorOp, FillOp (to initialize
-/// with pad_val) and GenericOp (to copy contents).
+/// Rewrite a PadTensorOp into a sequence of InitTensorOp, FillOp (to
+/// initialize with pad_val) and GenericOp (to copy contents).
 LogicalResult PadTensorOpTransformationPattern::matchAndRewrite(
     linalg::PadTensorOp padOp, PatternRewriter &rewriter) const {
 

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 5fda632b2f860..4a597f64d72ff 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -597,8 +597,7 @@ static LogicalResult reductionPreconditions(LinalgOp op) {
   return success();
 }
 
-LogicalResult
-mlir::linalg::vectorizeStaticLinalgOpPrecondition(linalg::LinalgOp op) {
+static LogicalResult vectorizeStaticLinalgOpPrecondition(linalg::LinalgOp op) {
   if (isElementwise(op))
     return success();
   // TODO: isaConvolutionOpInterface that can also infer from generic features.
@@ -620,8 +619,7 @@ mlir::linalg::vectorizeStaticLinalgOpPrecondition(linalg::LinalgOp op) {
   return success();
 }
 
-LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) {
-  auto linalgOp = cast<linalg::LinalgOp>(op);
+static LogicalResult vectorizeLinalgOpPrecondition(LinalgOp linalgOp) {
   // All types must be static shape to go to vector.
   if (linalgOp.hasDynamicShape()) {
     LDBG("precondition failed: dynamic shape");
@@ -630,31 +628,32 @@ LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) {
   return vectorizeStaticLinalgOpPrecondition(linalgOp);
 }
 
-LogicalResult
-mlir::linalg::vectorizeLinalgOp(OpBuilder &b, Operation *op,
-                                SmallVectorImpl<Value> &newResults) {
-  if (failed(vectorizeLinalgOpPrecondition(op)))
+LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter,
+                                      LinalgOp linalgOp) {
+  if (failed(vectorizeLinalgOpPrecondition(linalgOp)))
     return failure();
 
-  auto linalgOp = cast<LinalgOp>(op);
-
-  // TODO: isaConvolutionOpInterface that can also infer from generic features.
-  // But we will still need stride/dilation attributes that will be annoying to
-  // reverse-engineer...
-  if (auto convOp = dyn_cast<ConvolutionOpInterface>(op)) {
-    FailureOr<Operation *> resultOrFail = vectorizeConvolution(b, convOp);
-    if (failed(resultOrFail))
+  SmallVector<Value> results;
+  // TODO: isaConvolutionOpInterface that can also infer from generic
+  // features. Will require stride/dilation attributes inference.
+  if (auto convOp = dyn_cast<ConvolutionOpInterface>(linalgOp.getOperation())) {
+    LDBG("Vectorize as a conv: " << linalgOp);
+    FailureOr<Operation *> convOr = vectorizeConvolution(rewriter, convOp);
+    if (failed(convOr))
+      return failure();
+    llvm::append_range(results, (*convOr)->getResults());
+  } else {
+    LDBG("Vectorize generic by broadcasting to a common shape: " << linalgOp);
+    if (failed(vectorizeAsLinalgGeneric(rewriter, linalgOp, results)))
       return failure();
-    Operation *newOp = *resultOrFail;
-    llvm::append_range(newResults, newOp->getResults());
-    return success();
   }
 
-  LDBG(""
-       << "Vectorize linalg op as a generic by broadcasting to "
-          "maximal common shape: "
-       << *op);
-  return vectorizeAsLinalgGeneric(b, linalgOp, newResults);
+  if (!results.empty())
+    rewriter.replaceOp(linalgOp, results);
+  else
+    rewriter.eraseOp(linalgOp);
+
+  return success();
 }
 
 //----------------------------------------------------------------------------//
@@ -666,8 +665,9 @@ static int64_t getIntFromAttr(Attribute attr) {
   return attr.cast<IntegerAttr>().getInt();
 }
 
-/// Given an ArrayRef of OpFoldResults, return a vector of Values. IntegerAttrs
-/// are converted to ConstantIndexOps. Other attribute types are not supported.
+/// Given an ArrayRef of OpFoldResults, return a vector of Values.
+/// IntegerAttrs are converted to ConstantIndexOps. Other attribute types are
+/// not supported.
 static SmallVector<Value> ofrToIndexValues(OpBuilder &builder, Location loc,
                                            ArrayRef<OpFoldResult> ofrs) {
   SmallVector<Value> result;
@@ -691,9 +691,9 @@ struct GenericPadTensorOpVectorizationPattern
   GenericPadTensorOpVectorizationPattern(MLIRContext *context,
                                          PatternBenefit benefit = 1)
       : GeneralizePadTensorOpPattern(context, tryVectorizeCopy, benefit) {}
-  /// Vectorize the copying of a PadTensorOp's source. This is possible if each
-  /// dimension size is statically know in the source type or the result type
-  /// (or both).
+  /// Vectorize the copying of a PadTensorOp's source. This is possible if
+  /// each dimension size is statically know in the source type or the result
+  /// type (or both).
   static LogicalResult tryVectorizeCopy(PatternRewriter &rewriter,
                                         PadTensorOp padOp, Value dest) {
     auto sourceType = padOp.getSourceType();
@@ -718,13 +718,14 @@ struct GenericPadTensorOpVectorizationPattern
     for (unsigned i = 0; i < sourceType.getRank(); ++i) {
       if (!sourceType.isDynamicDim(i)) {
         vecShape.push_back(sourceType.getDimSize(i));
-        // Source shape is statically known: Neither read nor write are out-of-
-        // bounds.
+        // Source shape is statically known: Neither read nor write are
+        // out-of- bounds.
         readInBounds.push_back(true);
         writeInBounds.push_back(true);
       } else if (!resultType.isDynamicDim(i)) {
-        // Source shape is not statically known, but result shape is. Vectorize
-        // with size of result shape. This may be larger than the source size.
+        // Source shape is not statically known, but result shape is.
+        // Vectorize with size of result shape. This may be larger than the
+        // source size.
         vecShape.push_back(resultType.getDimSize(i));
         // Read may be out-of-bounds because the result size could be larger
         // than the source size.
@@ -749,8 +750,8 @@ struct GenericPadTensorOpVectorizationPattern
         padOp.getLoc(), vecType, padOp.source(), readIndices, padValue,
         ArrayRef<bool>{readInBounds});
 
-    // If `dest` is a FillOp and the TransferWriteOp would overwrite the entire
-    // tensor, write directly to the FillOp's operand.
+    // If `dest` is a FillOp and the TransferWriteOp would overwrite the
+    // entire tensor, write directly to the FillOp's operand.
     if (llvm::equal(vecShape, resultType.getShape()) &&
         llvm::all_of(writeInBounds, [](bool b) { return b; }))
       if (auto fill = dest.getDefiningOp<FillOp>())
@@ -766,8 +767,8 @@ struct GenericPadTensorOpVectorizationPattern
   }
 };
 
-/// Base pattern for rewriting PadTensorOps whose result is consumed by a given
-/// operation type OpTy.
+/// Base pattern for rewriting PadTensorOps whose result is consumed by a
+/// given operation type OpTy.
 template <typename OpTy>
 struct VectorizePadTensorOpUserPattern : public OpRewritePattern<PadTensorOp> {
   using OpRewritePattern<PadTensorOp>::OpRewritePattern;
@@ -837,10 +838,10 @@ struct PadTensorOpVectorizationWithTransferReadPattern
 };
 
 /// Rewrite use of PadTensorOp result in TransferWriteOp.
-/// This pattern rewrites TransferWriteOps that write to a padded tensor value,
-/// where the same amount of padding is immediately removed again after the
-/// write. In such cases, the TransferWriteOp can write to the non-padded tensor
-/// value and apply out-of-bounds masking. E.g.:
+/// This pattern rewrites TransferWriteOps that write to a padded tensor
+/// value, where the same amount of padding is immediately removed again after
+/// the write. In such cases, the TransferWriteOp can write to the non-padded
+/// tensor value and apply out-of-bounds masking. E.g.:
 /// ```
 /// %0 = tensor.extract_slice ...[...] [%s0, %s1] [1, 1]
 ///     : tensor<...> to tensor<?x?xf32>
@@ -854,17 +855,19 @@ struct PadTensorOpVectorizationWithTransferReadPattern
 /// ```
 /// %0 = tensor.extract_slice ...[...] [%s0, %s1] [1, 1]
 ///     : tensor<...> to tensor<?x?xf32>
-/// %r = vector.transfer_write %vec, %0[...] : vector<17x5xf32>, tensor<?x?xf32>
+/// %r = vector.transfer_write %vec, %0[...] : vector<17x5xf32>,
+/// tensor<?x?xf32>
 /// ```
 /// Note: It is important that the ExtractSliceOp %r resizes the result of the
-/// TransferWriteOp to the same size as the input of the TensorPadOp (or an even
-/// smaller size). Otherwise, %r's new (dynamic) dimensions would 
diff er from
-/// %r's old dimensions.
+/// TransferWriteOp to the same size as the input of the TensorPadOp (or an
+/// even smaller size). Otherwise, %r's new (dynamic) dimensions would 
diff er
+/// from %r's old dimensions.
 ///
 /// This rewrite is possible if:
 /// - Low padding is static 0.
 /// - `xferOp` has exactly one use, which is an ExtractSliceOp. This
-///   ExtractSliceOp trims the same amount of padding that was added beforehand.
+///   ExtractSliceOp trims the same amount of padding that was added
+///   beforehand.
 /// - Single, scalar padding value.
 struct PadTensorOpVectorizationWithTransferWritePattern
     : public VectorizePadTensorOpUserPattern<vector::TransferWriteOp> {
@@ -922,8 +925,8 @@ struct PadTensorOpVectorizationWithTransferWritePattern
   /// sizes may turn out to be equal at runtime.
   bool hasSameTensorSize(Value beforePadding,
                          tensor::ExtractSliceOp afterTrimming) const {
-    // If the input to PadTensorOp is a CastOp, try with with both CastOp result
-    // and CastOp operand.
+    // If the input to PadTensorOp is a CastOp, try with with both CastOp
+    // result and CastOp operand.
     if (auto castOp = beforePadding.getDefiningOp<tensor::CastOp>())
       if (hasSameTensorSize(castOp.source(), afterTrimming))
         return true;
@@ -950,8 +953,9 @@ struct PadTensorOpVectorizationWithTransferWritePattern
     if (t1.getNumDynamicDims() == 0)
       return true;
 
-    // All dynamic sizes must be the same. The only supported case at the moment
-    // is when `beforePadding` is an ExtractSliceOp (or a cast thereof).
+    // All dynamic sizes must be the same. The only supported case at the
+    // moment is when `beforePadding` is an ExtractSliceOp (or a cast
+    // thereof).
 
     // Apart from CastOp, only ExtractSliceOp is supported.
     auto beforeSlice = beforePadding.getDefiningOp<tensor::ExtractSliceOp>();
@@ -1062,7 +1066,8 @@ struct PadTensorOpVectorizationWithInsertSlicePattern
     // InsertSliceOp.
     rewriter.setInsertionPoint(insertOp);
 
-    // Generate TransferReadOp: Read entire source tensor and add high padding.
+    // Generate TransferReadOp: Read entire source tensor and add high
+    // padding.
     SmallVector<Value> readIndices(
         vecRank, rewriter.create<arith::ConstantIndexOp>(padOp.getLoc(), 0));
     auto read = rewriter.create<vector::TransferReadOp>(
@@ -1224,9 +1229,9 @@ void mlir::linalg::populateConvVectorizationPatterns(
 // Forwarding patterns
 //----------------------------------------------------------------------------//
 
-/// Check whether there is any interleaved use of any `values` between `firstOp`
-/// and `secondOp`. Conservatively return `true` if any op or value is in a
-/// 
diff erent block.
+/// Check whether there is any interleaved use of any `values` between
+/// `firstOp` and `secondOp`. Conservatively return `true` if any op or value
+/// is in a 
diff erent block.
 static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp,
                                     ValueRange values) {
   if (firstOp->getBlock() != secondOp->getBlock() ||
@@ -1252,7 +1257,8 @@ static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp,
   return false;
 }
 
-/// Return the unique subview use of `v` if it is indeed unique, null otherwise.
+/// Return the unique subview use of `v` if it is indeed unique, null
+/// otherwise.
 static memref::SubViewOp getSubViewUseIfUnique(Value v) {
   memref::SubViewOp subViewOp;
   for (auto &u : v.getUses()) {
@@ -1307,7 +1313,8 @@ LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite(
     return failure();
   LDBG("with copy " << *copyOp);
 
-  // Find the fill into `viewOrAlloc` without interleaved uses before the copy.
+  // Find the fill into `viewOrAlloc` without interleaved uses before the
+  // copy.
   FillOp maybeFillOp;
   for (auto &u : viewOrAlloc.getUses()) {
     if (auto newFillOp = dyn_cast<FillOp>(u.getOwner())) {
@@ -1468,7 +1475,8 @@ struct Conv1DNwcGenerator : public StructuredGenerator<LinalgOp> {
   ///   Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
   /// ```
   /// kw is always unrolled.
-  /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is > 1.
+  /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
+  /// > 1.
   FailureOr<Operation *> conv() {
     if (!valid)
       return failure();
@@ -1483,7 +1491,8 @@ struct Conv1DNwcGenerator : public StructuredGenerator<LinalgOp> {
     Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
 
     // w is unrolled (i.e. wSizeStep == 1) iff strideW > 1.
-    // When strideW == 1, we can batch the contiguous loads and avoid unrolling
+    // When strideW == 1, we can batch the contiguous loads and avoid
+    // unrolling
     int64_t wSizeStep = strideW == 1 ? wSize : 1;
 
     Type lhsEltType = lhsShapedType.getElementType();
@@ -1500,7 +1509,8 @@ struct Conv1DNwcGenerator : public StructuredGenerator<LinalgOp> {
     VectorType rhsType = VectorType::get({kwSize, cSize, fSize}, rhsEltType);
     VectorType resType = VectorType::get({nSize, wSize, fSize}, resEltType);
 
-    // Read lhs slice of size {w * strideW + kw * dilationW, c, f} @ [0, 0, 0].
+    // Read lhs slice of size {w * strideW + kw * dilationW, c, f} @ [0, 0,
+    // 0].
     Value lhs = builder.create<vector::TransferReadOp>(
         loc, lhsType, lhsShaped, ValueRange{zero, zero, zero});
     // Read rhs slice of size {kw, c, f} @ [0, 0, 0].
@@ -1591,7 +1601,8 @@ struct Conv1DNwcGenerator : public StructuredGenerator<LinalgOp> {
   ///   Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
   /// ```
   /// kw is always unrolled.
-  /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is > 1.
+  /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
+  /// > 1.
   FailureOr<Operation *> dilatedConv() {
     if (!valid)
       return failure();
@@ -1605,7 +1616,8 @@ struct Conv1DNwcGenerator : public StructuredGenerator<LinalgOp> {
     Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
 
     // w is unrolled (i.e. wSizeStep == 1) iff strideW > 1.
-    // When strideW == 1, we can batch the contiguous loads and avoid unrolling
+    // When strideW == 1, we can batch the contiguous loads and avoid
+    // unrolling
     int64_t wSizeStep = strideW == 1 ? wSize : 1;
 
     Type lhsEltType = lhsShapedType.getElementType();
@@ -1621,7 +1633,8 @@ struct Conv1DNwcGenerator : public StructuredGenerator<LinalgOp> {
     VectorType rhsType = VectorType::get({kwSize, cSize}, rhsEltType);
     VectorType resType = VectorType::get({nSize, wSize, cSize}, resEltType);
 
-    // Read lhs slice of size {n, w * strideW + kw * dilationW, c} @ [0, 0, 0].
+    // Read lhs slice of size {n, w * strideW + kw * dilationW, c} @ [0, 0,
+    // 0].
     Value lhs = builder.create<vector::TransferReadOp>(
         loc, lhsType, lhsShaped, ValueRange{zero, zero, zero});
     // Read rhs slice of size {kw, c} @ [0, 0].


        


More information about the Mlir-commits mailing list