[Mlir-commits] [mlir] efdd4c1 - [mlir][Linalg] NFC - Drop vectorization reliance on ConvolutionOpInterface

Nicolas Vasilache llvmlistbot at llvm.org
Tue Jan 18 01:30:51 PST 2022


Author: Nicolas Vasilache
Date: 2022-01-18T09:30:46Z
New Revision: efdd4c169d307b73db2d64552ea7698d5ee6feff

URL: https://github.com/llvm/llvm-project/commit/efdd4c169d307b73db2d64552ea7698d5ee6feff
DIFF: https://github.com/llvm/llvm-project/commit/efdd4c169d307b73db2d64552ea7698d5ee6feff.diff

LOG: [mlir][Linalg] NFC - Drop vectorization reliance on ConvolutionOpInterface

Differential revision: https://reviews.llvm.org/D117323

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 6f7e4a50a3b93..db3e7197a0aac 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -43,8 +43,9 @@ using namespace mlir::linalg;
 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
 #define LDBG(X) LLVM_DEBUG(DBGS() << X)
 
-static FailureOr<Operation *>
-vectorizeConvolution(OpBuilder &b, ConvolutionOpInterface convOp);
+/// Try to vectorize `convOp` as a convolution.
+static FailureOr<Operation *> vectorizeConvolution(OpBuilder &b,
+                                                   LinalgOp convOp);
 
 /// Return the unique instance of OpType in `block` if it is indeed unique.
 /// Return null if none or more than 1 instances exist.
@@ -636,13 +637,12 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter,
   SmallVector<Value> results;
   // TODO: isaConvolutionOpInterface that can also infer from generic
   // features. Will require stride/dilation attributes inference.
-  if (auto convOp = dyn_cast<ConvolutionOpInterface>(linalgOp.getOperation())) {
-    LDBG("Vectorize as a conv: " << linalgOp);
-    FailureOr<Operation *> convOr = vectorizeConvolution(rewriter, convOp);
-    if (failed(convOr))
-      return failure();
+  FailureOr<Operation *> convOr = vectorizeConvolution(rewriter, linalgOp);
+  if (succeeded(convOr)) {
     llvm::append_range(results, (*convOr)->getResults());
   } else {
+    if (failed(vectorizeLinalgOpPrecondition(linalgOp)))
+      return failure();
     LDBG("Vectorize generic by broadcasting to a common shape: " << linalgOp);
     if (failed(vectorizeAsLinalgGeneric(rewriter, linalgOp, results)))
       return failure();
@@ -1640,40 +1640,39 @@ struct Conv1DNwcGenerator : public StructuredGenerator<LinalgOp> {
 };
 } // namespace
 
-/// Helper function to vectorize a `linalgOp` with convolution semantics.
+/// Helper function to vectorize a LinalgOp with convolution semantics.
 // TODO: extend the generic vectorization to support windows and drop this.
-static FailureOr<Operation *>
-vectorizeConvolution(OpBuilder &b, ConvolutionOpInterface convOp) {
-  // TODO: these are legitimately part of ConvolutionOpInterface.
-  auto strides = convOp->getAttrOfType<DenseIntElementsAttr>("strides");
-  auto dilations = convOp->getAttrOfType<DenseIntElementsAttr>("dilations");
+static FailureOr<Operation *> vectorizeConvolution(OpBuilder &b, LinalgOp op) {
+  // The ConvolutionOpInterface gives us guarantees of existence for
+  // strides/dilations. However, we do not need to rely on those, we can simply
+  // use them if present, otherwise use the default and let the generic conv.
+  // matcher in the ConvGenerator succeed or fail.
+  auto strides = op->getAttrOfType<DenseIntElementsAttr>("strides");
+  auto dilations = op->getAttrOfType<DenseIntElementsAttr>("dilations");
   auto stride = strides ? *strides.getValues<uint64_t>().begin() : 1;
   auto dilation = dilations ? *dilations.getValues<uint64_t>().begin() : 1;
-  LinalgOp linalgOp = cast<LinalgOp>(convOp.getOperation());
-  Conv1DNwcGenerator e(b, linalgOp, stride, dilation);
+  Conv1DNwcGenerator e(b, op, stride, dilation);
   auto res = e.generateConv();
   if (succeeded(res))
     return res;
   return e.generateDilatedConv();
 }
 
-struct VectorizeConvolution
-    : public OpInterfaceRewritePattern<ConvolutionOpInterface> {
+struct VectorizeConvolution : public OpInterfaceRewritePattern<LinalgOp> {
   using OpInterfaceRewritePattern::OpInterfaceRewritePattern;
 
-  LogicalResult matchAndRewrite(ConvolutionOpInterface convOp,
+  LogicalResult matchAndRewrite(LinalgOp op,
                                 PatternRewriter &rewriter) const override {
-    FailureOr<Operation *> resultOrFail =
-        vectorizeConvolution(rewriter, convOp);
+    FailureOr<Operation *> resultOrFail = vectorizeConvolution(rewriter, op);
     if (failed(resultOrFail))
       return failure();
     Operation *newOp = *resultOrFail;
     if (newOp->getNumResults() == 0) {
-      rewriter.eraseOp(convOp.getOperation());
+      rewriter.eraseOp(op.getOperation());
       return success();
     }
     assert(newOp->getNumResults() == 1 && "expected single result");
-    rewriter.replaceOp(convOp.getOperation(), newOp->getResult(0));
+    rewriter.replaceOp(op.getOperation(), newOp->getResult(0));
     return success();
   }
 };


        


More information about the Mlir-commits mailing list