[Mlir-commits] [mlir] efdd4c1 - [mlir][Linalg] NFC - Drop vectorization reliance on ConvolutionOpInterface
Nicolas Vasilache
llvmlistbot at llvm.org
Tue Jan 18 01:30:51 PST 2022
Author: Nicolas Vasilache
Date: 2022-01-18T09:30:46Z
New Revision: efdd4c169d307b73db2d64552ea7698d5ee6feff
URL: https://github.com/llvm/llvm-project/commit/efdd4c169d307b73db2d64552ea7698d5ee6feff
DIFF: https://github.com/llvm/llvm-project/commit/efdd4c169d307b73db2d64552ea7698d5ee6feff.diff
LOG: [mlir][Linalg] NFC - Drop vectorization reliance on ConvolutionOpInterface
Differential revision: https://reviews.llvm.org/D117323
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 6f7e4a50a3b93..db3e7197a0aac 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -43,8 +43,9 @@ using namespace mlir::linalg;
#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
#define LDBG(X) LLVM_DEBUG(DBGS() << X)
-static FailureOr<Operation *>
-vectorizeConvolution(OpBuilder &b, ConvolutionOpInterface convOp);
+/// Try to vectorize `convOp` as a convolution.
+static FailureOr<Operation *> vectorizeConvolution(OpBuilder &b,
+ LinalgOp convOp);
/// Return the unique instance of OpType in `block` if it is indeed unique.
/// Return null if none or more than 1 instances exist.
@@ -636,13 +637,12 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter,
SmallVector<Value> results;
// TODO: isaConvolutionOpInterface that can also infer from generic
// features. Will require stride/dilation attributes inference.
- if (auto convOp = dyn_cast<ConvolutionOpInterface>(linalgOp.getOperation())) {
- LDBG("Vectorize as a conv: " << linalgOp);
- FailureOr<Operation *> convOr = vectorizeConvolution(rewriter, convOp);
- if (failed(convOr))
- return failure();
+ FailureOr<Operation *> convOr = vectorizeConvolution(rewriter, linalgOp);
+ if (succeeded(convOr)) {
llvm::append_range(results, (*convOr)->getResults());
} else {
+ if (failed(vectorizeLinalgOpPrecondition(linalgOp)))
+ return failure();
LDBG("Vectorize generic by broadcasting to a common shape: " << linalgOp);
if (failed(vectorizeAsLinalgGeneric(rewriter, linalgOp, results)))
return failure();
@@ -1640,40 +1640,39 @@ struct Conv1DNwcGenerator : public StructuredGenerator<LinalgOp> {
};
} // namespace
-/// Helper function to vectorize a `linalgOp` with convolution semantics.
+/// Helper function to vectorize a LinalgOp with convolution semantics.
// TODO: extend the generic vectorization to support windows and drop this.
-static FailureOr<Operation *>
-vectorizeConvolution(OpBuilder &b, ConvolutionOpInterface convOp) {
- // TODO: these are legitimately part of ConvolutionOpInterface.
- auto strides = convOp->getAttrOfType<DenseIntElementsAttr>("strides");
- auto dilations = convOp->getAttrOfType<DenseIntElementsAttr>("dilations");
+static FailureOr<Operation *> vectorizeConvolution(OpBuilder &b, LinalgOp op) {
+ // The ConvolutionOpInterface gives us guarantees of existence for
+ // strides/dilations. However, we do not need to rely on those, we can simply
+ // use them if present, otherwise use the default and let the generic conv.
+ // matcher in the ConvGenerator succeed or fail.
+ auto strides = op->getAttrOfType<DenseIntElementsAttr>("strides");
+ auto dilations = op->getAttrOfType<DenseIntElementsAttr>("dilations");
auto stride = strides ? *strides.getValues<uint64_t>().begin() : 1;
auto dilation = dilations ? *dilations.getValues<uint64_t>().begin() : 1;
- LinalgOp linalgOp = cast<LinalgOp>(convOp.getOperation());
- Conv1DNwcGenerator e(b, linalgOp, stride, dilation);
+ Conv1DNwcGenerator e(b, op, stride, dilation);
auto res = e.generateConv();
if (succeeded(res))
return res;
return e.generateDilatedConv();
}
-struct VectorizeConvolution
- : public OpInterfaceRewritePattern<ConvolutionOpInterface> {
+struct VectorizeConvolution : public OpInterfaceRewritePattern<LinalgOp> {
using OpInterfaceRewritePattern::OpInterfaceRewritePattern;
- LogicalResult matchAndRewrite(ConvolutionOpInterface convOp,
+ LogicalResult matchAndRewrite(LinalgOp op,
PatternRewriter &rewriter) const override {
- FailureOr<Operation *> resultOrFail =
- vectorizeConvolution(rewriter, convOp);
+ FailureOr<Operation *> resultOrFail = vectorizeConvolution(rewriter, op);
if (failed(resultOrFail))
return failure();
Operation *newOp = *resultOrFail;
if (newOp->getNumResults() == 0) {
- rewriter.eraseOp(convOp.getOperation());
+ rewriter.eraseOp(op.getOperation());
return success();
}
assert(newOp->getNumResults() == 1 && "expected single result");
- rewriter.replaceOp(convOp.getOperation(), newOp->getResult(0));
+ rewriter.replaceOp(op.getOperation(), newOp->getResult(0));
return success();
}
};
More information about the Mlir-commits
mailing list