[Mlir-commits] [mlir] c8f5735 - [mlir][Linalg] NFC - Drop vectorization reliance on ConvolutionOpInterface
Nicolas Vasilache
llvmlistbot at llvm.org
Mon Jan 17 09:01:40 PST 2022
Author: Nicolas Vasilache
Date: 2022-01-17T17:01:36Z
New Revision: c8f5735301993c363c16ce5ddda6f1f6cb968090
URL: https://github.com/llvm/llvm-project/commit/c8f5735301993c363c16ce5ddda6f1f6cb968090
DIFF: https://github.com/llvm/llvm-project/commit/c8f5735301993c363c16ce5ddda6f1f6cb968090.diff
LOG: [mlir][Linalg] NFC - Drop vectorization reliance on ConvolutionOpInterface
Differential Revision: https://reviews.llvm.org/D117323
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
mlir/test/lib/Dialect/Linalg/CMakeLists.txt
mlir/tools/mlir-opt/mlir-opt.cpp
Removed:
mlir/test/Conversion/LinalgToVector/linalg-to-vector.mlir
mlir/test/lib/Dialect/Linalg/TestConvVectorization.cpp
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 814474405715d..cbf0304d8c585 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -46,11 +46,6 @@ bool skipUnitDimReshape(const OpResult &producer, OpOperand &consumer);
//===----------------------------------------------------------------------===//
using LinalgLoops = SmallVector<Operation *, 4>;
-/// [DEPRECATED] Populate patterns for vectorization of all ConvN-D ops.
-void populateConvVectorizationPatterns(
- MLIRContext *context, SmallVectorImpl<RewritePatternSet> &patterns,
- ArrayRef<int64_t> tileSizes);
-
/// Populate patterns for vectorizing low-D convolution ops. This is a step in
/// progressive lowering for convolution ops, it assume high-D convolution ops
/// were decomposed previously.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 86eaed9a136cb..3daf243ce4723 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -43,8 +43,9 @@ using namespace mlir::linalg;
#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
#define LDBG(X) LLVM_DEBUG(DBGS() << X)
-static FailureOr<Operation *>
-vectorizeConvolution(OpBuilder &b, ConvolutionOpInterface convOp);
+/// Try to vectorize `convOp` as a convolution.
+static FailureOr<Operation *> vectorizeConvolution(OpBuilder &b,
+ LinalgOp convOp);
/// Return the unique instance of OpType in `block` if it is indeed unique.
/// Return null if none or more than 1 instances exist.
@@ -636,13 +637,12 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter,
SmallVector<Value> results;
// TODO: isaConvolutionOpInterface that can also infer from generic
// features. Will require stride/dilation attributes inference.
- if (auto convOp = dyn_cast<ConvolutionOpInterface>(linalgOp.getOperation())) {
- LDBG("Vectorize as a conv: " << linalgOp);
- FailureOr<Operation *> convOr = vectorizeConvolution(rewriter, convOp);
- if (failed(convOr))
- return failure();
+ FailureOr<Operation *> convOr = vectorizeConvolution(rewriter, linalgOp);
+ if (succeeded(convOr)) {
llvm::append_range(results, (*convOr)->getResults());
} else {
+ if (failed(vectorizeLinalgOpPrecondition(linalgOp)))
+ return failure();
LDBG("Vectorize generic by broadcasting to a common shape: " << linalgOp);
if (failed(vectorizeAsLinalgGeneric(rewriter, linalgOp, results)))
return failure();
@@ -1098,134 +1098,6 @@ void mlir::linalg::populatePadTensorOpVectorizationPatterns(
patterns.getContext(), baseBenefit.getBenefit() + 1);
}
-// TODO: cleanup all the convolution vectorization patterns.
-template <class ConvOp, int N>
-LogicalResult ConvOpVectorization<ConvOp, N>::matchAndRewrite(
- ConvOp op, PatternRewriter &rewriter) const {
- Location loc = op.getLoc();
- MLIRContext *context = op.getContext();
-
- OpOperand *input = op.getInputOperand(0);
- OpOperand *kernel = op.getInputOperand(1);
- OpOperand *output = op.getOutputOperand(0);
- ArrayRef<int64_t> inShape = op.getShape(input);
- ArrayRef<int64_t> kShape = op.getShape(kernel);
-
- if (llvm::any_of(inShape, ShapedType::isDynamic) ||
- llvm::any_of(kShape, ShapedType::isDynamic))
- return failure();
-
- SmallVector<AffineExpr, 4> mapping;
- SmallVector<int64_t, 4> vectorDims;
- // Fail to apply when the size of not vectorized dimension is not 1.
- for (unsigned i = 0; i < N; i++) {
- if (!mask[i] && (inShape[i] != 1 || kShape[i] != 1))
- return failure();
-
- if (mask[i] && inShape[i] != kShape[i])
- return failure();
-
- if (mask[i]) {
- mapping.push_back(getAffineDimExpr(i, context));
- vectorDims.push_back(inShape[i]);
- }
- }
-
- int64_t rank = op.getRank(input);
- int64_t numDims = mapping.size();
- Type elemType = getElementTypeOrSelf(input->get());
-
- auto map = AffineMap::get(rank, 0, mapping, context);
- SmallVector<Value, 4> zeros(rank,
- rewriter.create<arith::ConstantIndexOp>(loc, 0));
- auto vecType = VectorType::get(vectorDims, elemType);
-
- auto inputVec = rewriter.create<vector::TransferReadOp>(
- loc, vecType, input->get(), zeros, map);
- auto kernelVec = rewriter.create<vector::TransferReadOp>(
- loc, vecType, kernel->get(), zeros, map);
-
- auto acc = rewriter.create<arith::ConstantOp>(loc, elemType,
- rewriter.getZeroAttr(elemType));
-
- std::array<AffineMap, 3> indexingMaps{
- AffineMap::getMultiDimIdentityMap(numDims, context),
- AffineMap::getMultiDimIdentityMap(numDims, context),
- AffineMap::get(numDims, 0, {}, context)};
-
- std::vector<StringRef> iteratorTypes(numDims, "reduction");
-
- auto result = rewriter.create<vector::ContractionOp>(
- loc, inputVec, kernelVec, acc,
- rewriter.getAffineMapArrayAttr(indexingMaps),
- rewriter.getStrArrayAttr(iteratorTypes));
-
- rewriter.create<memref::StoreOp>(loc, result, output->get(),
- ValueRange(zeros));
- rewriter.eraseOp(op);
- return success();
-}
-
-/// Inserts tiling, promotion and vectorization pattern for ConvOp
-/// conversion into corresponding pattern lists.
-template <typename ConvOp, unsigned N>
-static void populateVectorizationPatterns(
- RewritePatternSet &tilingPatterns, RewritePatternSet &promotionPatterns,
- RewritePatternSet &vectorizationPatterns, ArrayRef<int64_t> tileSizes) {
- auto *context = tilingPatterns.getContext();
- if (tileSizes.size() < N)
- return;
-
- constexpr static StringRef kTiledMarker = "TILED";
- constexpr static StringRef kPromotedMarker = "PROMOTED";
- tilingPatterns.add<LinalgTilingPattern>(
- ConvOp::getOperationName(), context,
- LinalgTilingOptions().setTileSizes(tileSizes),
- LinalgTransformationFilter(ArrayRef<StringAttr>{},
- StringAttr::get(context, kTiledMarker)));
-
- promotionPatterns.add<LinalgPromotionPattern<ConvOp>>(
- context, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true),
- LinalgTransformationFilter(StringAttr::get(context, kTiledMarker),
- StringAttr::get(context, kPromotedMarker)));
-
- SmallVector<bool, 4> mask(N);
- int offset = tileSizes.size() - N;
- std::transform(tileSizes.begin() + offset, tileSizes.end(), mask.begin(),
- [](int64_t i) -> bool { return i > 1; });
-
- vectorizationPatterns.add<ConvOpVectorization<ConvOp, N>>(context, mask);
-}
-
-void mlir::linalg::populateConvVectorizationPatterns(
- MLIRContext *context, SmallVectorImpl<RewritePatternSet> &patterns,
- ArrayRef<int64_t> tileSizes) {
- RewritePatternSet tiling(context);
- RewritePatternSet promotion(context);
- RewritePatternSet vectorization(context);
- populateVectorizationPatterns<Conv1DOp, 1>(tiling, promotion, vectorization,
- tileSizes);
-
- populateVectorizationPatterns<Conv2DOp, 2>(tiling, promotion, vectorization,
- tileSizes);
-
- populateVectorizationPatterns<Conv3DOp, 3>(tiling, promotion, vectorization,
- tileSizes);
-
- populateVectorizationPatterns<Conv1DNwcWcfOp, 3>(tiling, promotion,
- vectorization, tileSizes);
-
- populateVectorizationPatterns<Conv2DNhwcHwcfOp, 4>(tiling, promotion,
- vectorization, tileSizes);
-
- populateVectorizationPatterns<Conv3DNdhwcDhwcfOp, 5>(
- tiling, promotion, vectorization, tileSizes);
-
- patterns.push_back(std::move(tiling));
- patterns.push_back(std::move(promotion));
- patterns.push_back(std::move(vectorization));
-}
-
//----------------------------------------------------------------------------//
// Forwarding patterns
//----------------------------------------------------------------------------//
@@ -1754,40 +1626,39 @@ struct Conv1DNwcGenerator : public StructuredGenerator<LinalgOp> {
};
} // namespace
-/// Helper function to vectorize a `linalgOp` with convolution semantics.
+/// Helper function to vectorize a LinalgOp with convolution semantics.
// TODO: extend the generic vectorization to support windows and drop this.
-static FailureOr<Operation *>
-vectorizeConvolution(OpBuilder &b, ConvolutionOpInterface convOp) {
- // TODO: these are legitimately part of ConvolutionOpInterface.
- auto strides = convOp->getAttrOfType<DenseIntElementsAttr>("strides");
- auto dilations = convOp->getAttrOfType<DenseIntElementsAttr>("dilations");
+static FailureOr<Operation *> vectorizeConvolution(OpBuilder &b, LinalgOp op) {
+ // The ConvolutionOpInterface gives us guarantees of existence for
+ // strides/dilations. However, we do not need to rely on those, we can simply
+ // use them if present, otherwise use the default and let the generic conv.
+ // matcher in the ConvGenerator succeed or fail.
+ auto strides = op->getAttrOfType<DenseIntElementsAttr>("strides");
+ auto dilations = op->getAttrOfType<DenseIntElementsAttr>("dilations");
auto stride = strides ? *strides.getValues<uint64_t>().begin() : 1;
auto dilation = dilations ? *dilations.getValues<uint64_t>().begin() : 1;
- LinalgOp linalgOp = cast<LinalgOp>(convOp.getOperation());
- Conv1DNwcGenerator e(b, linalgOp, stride, dilation);
+ Conv1DNwcGenerator e(b, op, stride, dilation);
auto res = e.generateConv();
if (succeeded(res))
return res;
return e.generateDilatedConv();
}
-struct VectorizeConvolution
- : public OpInterfaceRewritePattern<ConvolutionOpInterface> {
+struct VectorizeConvolution : public OpInterfaceRewritePattern<LinalgOp> {
using OpInterfaceRewritePattern::OpInterfaceRewritePattern;
- LogicalResult matchAndRewrite(ConvolutionOpInterface convOp,
+ LogicalResult matchAndRewrite(LinalgOp op,
PatternRewriter &rewriter) const override {
- FailureOr<Operation *> resultOrFail =
- vectorizeConvolution(rewriter, convOp);
+ FailureOr<Operation *> resultOrFail = vectorizeConvolution(rewriter, op);
if (failed(resultOrFail))
return failure();
Operation *newOp = *resultOrFail;
if (newOp->getNumResults() == 0) {
- rewriter.eraseOp(convOp.getOperation());
+ rewriter.eraseOp(op.getOperation());
return success();
}
assert(newOp->getNumResults() == 1 && "expected single result");
- rewriter.replaceOp(convOp.getOperation(), newOp->getResult(0));
+ rewriter.replaceOp(op.getOperation(), newOp->getResult(0));
return success();
}
};
diff --git a/mlir/test/Conversion/LinalgToVector/linalg-to-vector.mlir b/mlir/test/Conversion/LinalgToVector/linalg-to-vector.mlir
deleted file mode 100644
index 6b3a7d010a9d7..0000000000000
--- a/mlir/test/Conversion/LinalgToVector/linalg-to-vector.mlir
+++ /dev/null
@@ -1,53 +0,0 @@
-// RUN: mlir-opt %s -test-conv-vectorization="tile-sizes=1,3" --cse -split-input-file
-// | FileCheck %s
-
-// CHECK-DAG: #[[$map0:.*]] = affine_map<(d0)[s0] -> (1, -d0 + s0)>
-// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0)[s0] -> (d0 + s0)>
-// CHECK-DAG: #[[$map2:.*]] = affine_map<(d0, d1) -> (d0 + d1)>
-// CHECK-DAG: #[[$map3:.*]] = affine_map<(d0, d1)[s0] -> (3, -d0 - d1 + s0)>
-// CHECK-DAG: #[[$map4:.*]] = affine_map<(d0)[s0] -> (3, -d0 + s0)>
-
-// CHECK-LABEL: @conv_1d
-// CHECK-SAME: %[[arg0:[a-zA-Z0-9]+]]: memref<?xf32>
-// CHECK-SAME: %[[arg1:[a-zA-Z0-9]+]]: memref<?xf32>
-// CHECK-SAME: %[[arg2:[a-zA-Z0-9]+]]: memref<?xf32
-func @conv_1d(%arg0: memref<?xf32>, %arg1: memref<?xf32>, %arg2: memref<?xf32>) {
-// CHECK-DAG: %[[c12:.*]] = arith.constant 12 : index
-// CHECK-DAG: %[[c4:.*]] = arith.constant 4 : index
-// CHECK-DAG: %[[cst:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK-DAG: %[[c3:.*]] = arith.constant 3 : index
-// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index
-// CHECK: %[[v0:.*]] = memref.dim %[[arg1]], %[[c0]] : memref<?xf32>
-// CHECK: %[[v1:.*]] = memref.dim %[[arg2]], %[[c0]] : memref<?xf32>
-// CHECK: %[[v2:.*]] = memref.dim %[[arg0]], %[[c0]] : memref<?xf32>
-// CHECK: %[[v3:.*]] = memref.alloc(%[[c12]]) : memref<?xi8>
-// CHECK: %[[v4:.*]] = memref.alloc(%[[c12]]) : memref<?xi8>
-// CHECK: %[[v5:.*]] = memref.alloc(%[[c4]]) : memref<?xi8>
-// CHECK: %[[v6:.*]] = memref.view %[[v3]][%[[c0]]][] : memref<?xi8> to memref<3xf32>
-// CHECK: %[[v7:.*]] = memref.view %[[v4]][%[[c0]]][] : memref<?xi8> to memref<3xf32>
-// CHECK: %[[v8:.*]] = memref.view %[[v5]][%[[c0]]][] : memref<?xi8> to memref<1xf32>
-// CHECK: scf.for %[[arg3:.*]] = %[[c0]] to %[[v1]] step %[[c1]] {
-// CHECK: %[[v9:.*]] = affine.min #[[$map0]](%[[arg3]])[%[[v1]]]
-// CHECK: %[[v10:.*]] = subview %[[arg2]][%[[arg3]]] [%[[v9]]] [1] : memref<?xf32> to memref<?xf32, #[[$map1]]>
-// CHECK: %[[v11:.*]] = subview %[[v8]][0] [%[[v9]]] [1] : memref<1xf32> to memref<?xf32>
-// CHECK: scf.for %[[arg4:.*]] = %[[c0]] to %[[v0]] step %[[c3]] {
-// CHECK: %[[v12:.*]] = affine.apply #[[$map2]](%[[arg3]], %[[arg4]])
-// CHECK: %[[v13:.*]] = affine.min #[[$map3]](%[[arg3]], %[[arg4]])[%[[v2]]]
-// CHECK: %[[v14:.*]] = subview %arg0[%12] [%13] [1] : memref<?xf32> to memref<?xf32, #[[$map1]]>
-// CHECK: %[[v15:.*]] = affine.min #[[$map4]](%arg4)[%0]
-// CHECK: %[[v16:.*]] = subview %[[arg1]][%[[arg4]]] [%[[v15]]] [1] : memref<?xf32> to memref<?xf32, #[[$map1]]>
-// CHECK: %[[v17:.*]] = subview %[[v6]][0] [%[[v13]]] [1] : memref<3xf32> to memref<?xf32>
-// CHECK: %[[v19:.*]] = vector.transfer_read %[[v6]][%[[c0]]], %[[cst]] {in_bounds = [true]} : memref<3xf32>, vector<3xf32>
-// CHECK: %[[v20:.*]] = vector.transfer_read %[[v7]][%[[c0]]], %[[cst]] {in_bounds = [true]} : memref<3xf32>, vector<3xf32>
-// CHECK: %[[v21:.*]] = arith.mulf %[[v19]], %[[v20]] : vector<3xf32>
-// CHECK: %[[v22:.*]] = vector.reduction "add", %[[v21]], %[[cst]] : vector<3xf32> into f32
-// CHECK: store %[[v22]], %[[v8]][%[[c0]]] : memref<1xf32>
-// CHECK: scf.for %[[arg5:.*]] = %[[c0]] to %[[v9]] step %[[c1]] {
-// CHECK: %[[v23:.*]] = load %[[v11]][%[[arg5]]] : memref<?xf32>
-// CHECK: store %[[v23]], %[[v10]][%[[arg5]]] : memref<?xf32, #[[$map1]]>
- linalg.conv_1d ins(%arg0, %arg1 : memref<?xf32>, memref<?xf32>)
- outs(%arg2 : memref<?xf32>)
- return
-}
-
diff --git a/mlir/test/lib/Dialect/Linalg/CMakeLists.txt b/mlir/test/lib/Dialect/Linalg/CMakeLists.txt
index fad6ec91f7c5e..7c9ad470eacf9 100644
--- a/mlir/test/lib/Dialect/Linalg/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Linalg/CMakeLists.txt
@@ -1,7 +1,6 @@
# Exclude tests from libMLIR.so
add_mlir_library(MLIRLinalgTestPasses
TestComprehensiveBufferize.cpp
- TestConvVectorization.cpp
TestLinalgCodegenStrategy.cpp
TestLinalgDistribution.cpp
TestLinalgElementwiseFusion.cpp
diff --git a/mlir/test/lib/Dialect/Linalg/TestConvVectorization.cpp b/mlir/test/lib/Dialect/Linalg/TestConvVectorization.cpp
deleted file mode 100644
index 9c8f138743dec..0000000000000
--- a/mlir/test/lib/Dialect/Linalg/TestConvVectorization.cpp
+++ /dev/null
@@ -1,143 +0,0 @@
-//===- TestConvVectorization.cpp - Vectorization of Conv ops --------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
-#include "mlir/Dialect/Linalg/Passes.h"
-#include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
-#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
-#include "mlir/Dialect/SCF/Transforms.h"
-#include "mlir/Dialect/Vector/VectorTransforms.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Pass/PassManager.h"
-#include "mlir/Transforms/DialectConversion.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#include "mlir/Transforms/LoopUtils.h"
-#include "mlir/Transforms/Passes.h"
-
-using namespace mlir;
-using namespace vector;
-
-namespace {
-/// A pass converting MLIR Linalg ops into Vector ops.
-class TestConvVectorization
- : public PassWrapper<TestConvVectorization, OperationPass<ModuleOp>> {
-public:
- StringRef getArgument() const final { return "test-conv-vectorization"; }
- StringRef getDescription() const final {
- return "Test vectorization of convolutions";
- }
- TestConvVectorization() = default;
- TestConvVectorization(const TestConvVectorization &) {}
- explicit TestConvVectorization(ArrayRef<int64_t> tileSizesParam) {
- tileSizes = tileSizesParam;
- }
-
- void runOnOperation() override;
-
- void getDependentDialects(DialectRegistry ®istry) const override {
- registry.insert<VectorDialect>();
- registry.insert<linalg::LinalgDialect>();
- registry.insert<memref::MemRefDialect>();
- registry.insert<scf::SCFDialect>();
- registry.insert<AffineDialect>();
- registry.insert<StandardOpsDialect>();
- }
-
- ListOption<int64_t> tileSizes{
- *this, "tile-sizes", llvm::cl::desc("Vectorization sizes."),
- llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
-};
-} // namespace
-
-void TestConvVectorization::runOnOperation() {
- MLIRContext *context = &getContext();
- ModuleOp module = getOperation();
-
- ConversionTarget target(*context);
- target.addLegalDialect<AffineDialect, scf::SCFDialect, StandardOpsDialect,
- VectorDialect>();
- target.addLegalOp<ModuleOp, FuncOp, ReturnOp>();
- target.addLegalOp<linalg::FillOp, linalg::YieldOp>();
-
- SmallVector<RewritePatternSet, 4> stage1Patterns;
- linalg::populateConvVectorizationPatterns(context, stage1Patterns, tileSizes);
- SmallVector<FrozenRewritePatternSet, 4> frozenStage1Patterns;
- llvm::move(stage1Patterns, std::back_inserter(frozenStage1Patterns));
-
- RewritePatternSet stage2Patterns =
- linalg::getLinalgTilingCanonicalizationPatterns(context);
- scf::populateSCFForLoopCanonicalizationPatterns(stage2Patterns);
-
- auto stage3Transforms = [](Operation *op) {
- PassManager pm(op->getContext());
- pm.addPass(createLoopInvariantCodeMotionPass());
- if (failed(pm.run(cast<ModuleOp>(op))))
- llvm_unreachable("Unexpected failure in cleanup pass pipeline.");
- op->walk([](FuncOp func) {
- promoteSingleIterationLoops(func);
- linalg::hoistRedundantVectorTransfers(func);
- });
- return success();
- };
-
- (void)linalg::applyStagedPatterns(module, frozenStage1Patterns,
- std::move(stage2Patterns),
- stage3Transforms);
-
- //===--------------------------------------------------------------------===//
- // Post staged patterns transforms
- //===--------------------------------------------------------------------===//
-
- VectorTransformsOptions vectorTransformOptions{
- VectorContractLowering::Dot, VectorMultiReductionLowering::InnerParallel,
- VectorTransposeLowering::EltWise};
-
- RewritePatternSet vectorTransferPatterns(context);
- // Pattern is not applied: rank-reducing vector transfer is not yet supported
- // (see: splitFullAndPartialTransferPrecondition in VectorTransforms.cpp).
- vectorTransferPatterns.add<VectorTransferFullPartialRewriter>(
- context, vectorTransformOptions);
- (void)applyPatternsAndFoldGreedily(module, std::move(vectorTransferPatterns));
-
- // Programmatic controlled lowering of linalg.copy and linalg.fill.
- PassManager pm(context);
- pm.addNestedPass<FuncOp>(createConvertLinalgToLoopsPass());
- if (failed(pm.run(module)))
- llvm_unreachable("Unexpected failure in linalg to loops pass.");
-
- // Programmatic controlled lowering of vector.contract only.
- RewritePatternSet vectorContractLoweringPatterns(context);
- populateVectorBroadcastLoweringPatterns(vectorContractLoweringPatterns);
- populateVectorContractLoweringPatterns(vectorContractLoweringPatterns,
- vectorTransformOptions);
- populateVectorMaskOpLoweringPatterns(vectorContractLoweringPatterns);
- populateVectorShapeCastLoweringPatterns(vectorContractLoweringPatterns);
- populateVectorTransposeLoweringPatterns(vectorContractLoweringPatterns,
- vectorTransformOptions);
- (void)applyPatternsAndFoldGreedily(module,
- std::move(vectorContractLoweringPatterns));
-
- // Programmatic controlled lowering of vector.transfer only.
- RewritePatternSet vectorToLoopsPatterns(context);
- populateVectorToSCFConversionPatterns(vectorToLoopsPatterns,
- VectorTransferToSCFOptions());
- (void)applyPatternsAndFoldGreedily(module, std::move(vectorToLoopsPatterns));
-
- // Ensure we drop the marker in the end.
- module.walk([](linalg::LinalgOp op) {
- op->removeAttr(linalg::LinalgTransforms::kLinalgTransformMarker);
- });
-}
-
-namespace mlir {
-namespace test {
-void registerTestConvVectorization() {
- PassRegistration<TestConvVectorization>();
-}
-} // namespace test
-} // namespace mlir
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 9a54cf257ce20..5b09cb8671eb1 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -66,7 +66,6 @@ void registerTestBuiltinAttributeInterfaces();
void registerTestCallGraphPass();
void registerTestComprehensiveFunctionBufferize();
void registerTestConstantFold();
-void registerTestConvVectorization();
void registerTestGpuSerializeToCubinPass();
void registerTestGpuSerializeToHsacoPass();
void registerTestDataLayoutQuery();
@@ -162,7 +161,6 @@ void registerTestPasses() {
mlir::test::registerTestGpuSerializeToHsacoPass();
#endif
mlir::test::registerTestComprehensiveFunctionBufferize();
- mlir::test::registerTestConvVectorization();
mlir::test::registerTestDecomposeCallGraphTypes();
mlir::test::registerTestDataLayoutQuery();
mlir::test::registerTestDominancePass();
More information about the Mlir-commits
mailing list