[Mlir-commits] [mlir] f40a579 - Revert "[mlir][Linalg] NFC - Drop vectorization reliance on ConvolutionOpInterface"

Mehdi Amini llvmlistbot at llvm.org
Mon Jan 17 11:38:14 PST 2022


Author: Nicolas Vasilache
Date: 2022-01-17T19:38:07Z
New Revision: f40a579bea9c079fc6d4b048fa3ddb8ac8c97337

URL: https://github.com/llvm/llvm-project/commit/f40a579bea9c079fc6d4b048fa3ddb8ac8c97337
DIFF: https://github.com/llvm/llvm-project/commit/f40a579bea9c079fc6d4b048fa3ddb8ac8c97337.diff

LOG: Revert "[mlir][Linalg] NFC - Drop vectorization reliance on ConvolutionOpInterface"

This reverts commit c8f5735301993c363c16ce5ddda6f1f6cb968090.

The integration tests are broken.

Added: 
    mlir/test/Conversion/LinalgToVector/linalg-to-vector.mlir
    mlir/test/lib/Dialect/Linalg/TestConvVectorization.cpp

Modified: 
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
    mlir/test/lib/Dialect/Linalg/CMakeLists.txt
    mlir/tools/mlir-opt/mlir-opt.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 8e7ea21fb8f81..bee8608051f34 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -49,6 +49,11 @@ using LinalgLoops = SmallVector<Operation *, 4>;
 void populatePadTensorTilingPatterns(RewritePatternSet &patterns,
                                      const LinalgTilingOptions &options);
 
+/// [DEPRECATED] Populate patterns for vectorization of all ConvN-D ops.
+void populateConvVectorizationPatterns(
+    MLIRContext *context, SmallVectorImpl<RewritePatternSet> &patterns,
+    ArrayRef<int64_t> tileSizes);
+
 /// Populate patterns for vectorizing low-D convolution ops. This is a step in
 /// progressive lowering for convolution ops, it assume high-D convolution ops
 /// were decomposed previously.

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index db3e7197a0aac..c06082cae9e75 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -43,9 +43,8 @@ using namespace mlir::linalg;
 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
 #define LDBG(X) LLVM_DEBUG(DBGS() << X)
 
-/// Try to vectorize `convOp` as a convolution.
-static FailureOr<Operation *> vectorizeConvolution(OpBuilder &b,
-                                                   LinalgOp convOp);
+static FailureOr<Operation *>
+vectorizeConvolution(OpBuilder &b, ConvolutionOpInterface convOp);
 
 /// Return the unique instance of OpType in `block` if it is indeed unique.
 /// Return null if none or more than 1 instances exist.
@@ -637,12 +636,13 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter,
   SmallVector<Value> results;
   // TODO: isaConvolutionOpInterface that can also infer from generic
   // features. Will require stride/dilation attributes inference.
-  FailureOr<Operation *> convOr = vectorizeConvolution(rewriter, linalgOp);
-  if (succeeded(convOr)) {
+  if (auto convOp = dyn_cast<ConvolutionOpInterface>(linalgOp.getOperation())) {
+    LDBG("Vectorize as a conv: " << linalgOp);
+    FailureOr<Operation *> convOr = vectorizeConvolution(rewriter, convOp);
+    if (failed(convOr))
+      return failure();
     llvm::append_range(results, (*convOr)->getResults());
   } else {
-    if (failed(vectorizeLinalgOpPrecondition(linalgOp)))
-      return failure();
     LDBG("Vectorize generic by broadcasting to a common shape: " << linalgOp);
     if (failed(vectorizeAsLinalgGeneric(rewriter, linalgOp, results)))
       return failure();
@@ -1098,6 +1098,134 @@ void mlir::linalg::populatePadTensorOpVectorizationPatterns(
       patterns.getContext(), baseBenefit.getBenefit() + 1);
 }
 
+// TODO: cleanup all the convolution vectorization patterns.
+template <class ConvOp, int N>
+LogicalResult ConvOpVectorization<ConvOp, N>::matchAndRewrite(
+    ConvOp op, PatternRewriter &rewriter) const {
+  Location loc = op.getLoc();
+  MLIRContext *context = op.getContext();
+
+  OpOperand *input = op.getInputOperand(0);
+  OpOperand *kernel = op.getInputOperand(1);
+  OpOperand *output = op.getOutputOperand(0);
+  ArrayRef<int64_t> inShape = op.getShape(input);
+  ArrayRef<int64_t> kShape = op.getShape(kernel);
+
+  if (llvm::any_of(inShape, ShapedType::isDynamic) ||
+      llvm::any_of(kShape, ShapedType::isDynamic))
+    return failure();
+
+  SmallVector<AffineExpr, 4> mapping;
+  SmallVector<int64_t, 4> vectorDims;
+  // Fail to apply when the size of not vectorized dimension is not 1.
+  for (unsigned i = 0; i < N; i++) {
+    if (!mask[i] && (inShape[i] != 1 || kShape[i] != 1))
+      return failure();
+
+    if (mask[i] && inShape[i] != kShape[i])
+      return failure();
+
+    if (mask[i]) {
+      mapping.push_back(getAffineDimExpr(i, context));
+      vectorDims.push_back(inShape[i]);
+    }
+  }
+
+  int64_t rank = op.getRank(input);
+  int64_t numDims = mapping.size();
+  Type elemType = getElementTypeOrSelf(input->get());
+
+  auto map = AffineMap::get(rank, 0, mapping, context);
+  SmallVector<Value, 4> zeros(rank,
+                              rewriter.create<arith::ConstantIndexOp>(loc, 0));
+  auto vecType = VectorType::get(vectorDims, elemType);
+
+  auto inputVec = rewriter.create<vector::TransferReadOp>(
+      loc, vecType, input->get(), zeros, map);
+  auto kernelVec = rewriter.create<vector::TransferReadOp>(
+      loc, vecType, kernel->get(), zeros, map);
+
+  auto acc = rewriter.create<arith::ConstantOp>(loc, elemType,
+                                                rewriter.getZeroAttr(elemType));
+
+  std::array<AffineMap, 3> indexingMaps{
+      AffineMap::getMultiDimIdentityMap(numDims, context),
+      AffineMap::getMultiDimIdentityMap(numDims, context),
+      AffineMap::get(numDims, 0, {}, context)};
+
+  std::vector<StringRef> iteratorTypes(numDims, "reduction");
+
+  auto result = rewriter.create<vector::ContractionOp>(
+      loc, inputVec, kernelVec, acc,
+      rewriter.getAffineMapArrayAttr(indexingMaps),
+      rewriter.getStrArrayAttr(iteratorTypes));
+
+  rewriter.create<memref::StoreOp>(loc, result, output->get(),
+                                   ValueRange(zeros));
+  rewriter.eraseOp(op);
+  return success();
+}
+
+/// Inserts tiling, promotion and vectorization pattern for ConvOp
+/// conversion into corresponding pattern lists.
+template <typename ConvOp, unsigned N>
+static void populateVectorizationPatterns(
+    RewritePatternSet &tilingPatterns, RewritePatternSet &promotionPatterns,
+    RewritePatternSet &vectorizationPatterns, ArrayRef<int64_t> tileSizes) {
+  auto *context = tilingPatterns.getContext();
+  if (tileSizes.size() < N)
+    return;
+
+  constexpr static StringRef kTiledMarker = "TILED";
+  constexpr static StringRef kPromotedMarker = "PROMOTED";
+  tilingPatterns.add<LinalgTilingPattern>(
+      ConvOp::getOperationName(), context,
+      LinalgTilingOptions().setTileSizes(tileSizes),
+      LinalgTransformationFilter(ArrayRef<StringAttr>{},
+                                 StringAttr::get(context, kTiledMarker)));
+
+  promotionPatterns.add<LinalgPromotionPattern<ConvOp>>(
+      context, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true),
+      LinalgTransformationFilter(StringAttr::get(context, kTiledMarker),
+                                 StringAttr::get(context, kPromotedMarker)));
+
+  SmallVector<bool, 4> mask(N);
+  int offset = tileSizes.size() - N;
+  std::transform(tileSizes.begin() + offset, tileSizes.end(), mask.begin(),
+                 [](int64_t i) -> bool { return i > 1; });
+
+  vectorizationPatterns.add<ConvOpVectorization<ConvOp, N>>(context, mask);
+}
+
+void mlir::linalg::populateConvVectorizationPatterns(
+    MLIRContext *context, SmallVectorImpl<RewritePatternSet> &patterns,
+    ArrayRef<int64_t> tileSizes) {
+  RewritePatternSet tiling(context);
+  RewritePatternSet promotion(context);
+  RewritePatternSet vectorization(context);
+  populateVectorizationPatterns<Conv1DOp, 1>(tiling, promotion, vectorization,
+                                             tileSizes);
+
+  populateVectorizationPatterns<Conv2DOp, 2>(tiling, promotion, vectorization,
+                                             tileSizes);
+
+  populateVectorizationPatterns<Conv3DOp, 3>(tiling, promotion, vectorization,
+                                             tileSizes);
+
+  populateVectorizationPatterns<Conv1DNwcWcfOp, 3>(tiling, promotion,
+                                                   vectorization, tileSizes);
+
+  populateVectorizationPatterns<Conv2DNhwcHwcfOp, 4>(tiling, promotion,
+                                                     vectorization, tileSizes);
+
+  populateVectorizationPatterns<Conv3DNdhwcDhwcfOp, 5>(
+      tiling, promotion, vectorization, tileSizes);
+
+  patterns.push_back(std::move(tiling));
+  patterns.push_back(std::move(promotion));
+  patterns.push_back(std::move(vectorization));
+}
+
 //----------------------------------------------------------------------------//
 // Forwarding patterns
 //----------------------------------------------------------------------------//
@@ -1640,39 +1768,40 @@ struct Conv1DNwcGenerator : public StructuredGenerator<LinalgOp> {
 };
 } // namespace
 
-/// Helper function to vectorize a LinalgOp with convolution semantics.
+/// Helper function to vectorize a `linalgOp` with convolution semantics.
 // TODO: extend the generic vectorization to support windows and drop this.
-static FailureOr<Operation *> vectorizeConvolution(OpBuilder &b, LinalgOp op) {
-  // The ConvolutionOpInterface gives us guarantees of existence for
-  // strides/dilations. However, we do not need to rely on those, we can simply
-  // use them if present, otherwise use the default and let the generic conv.
-  // matcher in the ConvGenerator succeed or fail.
-  auto strides = op->getAttrOfType<DenseIntElementsAttr>("strides");
-  auto dilations = op->getAttrOfType<DenseIntElementsAttr>("dilations");
+static FailureOr<Operation *>
+vectorizeConvolution(OpBuilder &b, ConvolutionOpInterface convOp) {
+  // TODO: these are legitimately part of ConvolutionOpInterface.
+  auto strides = convOp->getAttrOfType<DenseIntElementsAttr>("strides");
+  auto dilations = convOp->getAttrOfType<DenseIntElementsAttr>("dilations");
   auto stride = strides ? *strides.getValues<uint64_t>().begin() : 1;
   auto dilation = dilations ? *dilations.getValues<uint64_t>().begin() : 1;
-  Conv1DNwcGenerator e(b, op, stride, dilation);
+  LinalgOp linalgOp = cast<LinalgOp>(convOp.getOperation());
+  Conv1DNwcGenerator e(b, linalgOp, stride, dilation);
   auto res = e.generateConv();
   if (succeeded(res))
     return res;
   return e.generateDilatedConv();
 }
 
-struct VectorizeConvolution : public OpInterfaceRewritePattern<LinalgOp> {
+struct VectorizeConvolution
+    : public OpInterfaceRewritePattern<ConvolutionOpInterface> {
   using OpInterfaceRewritePattern::OpInterfaceRewritePattern;
 
-  LogicalResult matchAndRewrite(LinalgOp op,
+  LogicalResult matchAndRewrite(ConvolutionOpInterface convOp,
                                 PatternRewriter &rewriter) const override {
-    FailureOr<Operation *> resultOrFail = vectorizeConvolution(rewriter, op);
+    FailureOr<Operation *> resultOrFail =
+        vectorizeConvolution(rewriter, convOp);
     if (failed(resultOrFail))
       return failure();
     Operation *newOp = *resultOrFail;
     if (newOp->getNumResults() == 0) {
-      rewriter.eraseOp(op.getOperation());
+      rewriter.eraseOp(convOp.getOperation());
       return success();
     }
     assert(newOp->getNumResults() == 1 && "expected single result");
-    rewriter.replaceOp(op.getOperation(), newOp->getResult(0));
+    rewriter.replaceOp(convOp.getOperation(), newOp->getResult(0));
     return success();
   }
 };

diff  --git a/mlir/test/Conversion/LinalgToVector/linalg-to-vector.mlir b/mlir/test/Conversion/LinalgToVector/linalg-to-vector.mlir
new file mode 100644
index 0000000000000..6b3a7d010a9d7
--- /dev/null
+++ b/mlir/test/Conversion/LinalgToVector/linalg-to-vector.mlir
@@ -0,0 +1,53 @@
+// RUN: mlir-opt %s -test-conv-vectorization="tile-sizes=1,3" --cse -split-input-file
+// | FileCheck %s
+
+// CHECK-DAG:  #[[$map0:.*]] = affine_map<(d0)[s0] -> (1, -d0 + s0)>
+// CHECK-DAG:  #[[$map1:.*]] = affine_map<(d0)[s0] -> (d0 + s0)>
+// CHECK-DAG:  #[[$map2:.*]] = affine_map<(d0, d1) -> (d0 + d1)>
+// CHECK-DAG:  #[[$map3:.*]] = affine_map<(d0, d1)[s0] -> (3, -d0 - d1 + s0)>
+// CHECK-DAG:  #[[$map4:.*]] = affine_map<(d0)[s0] -> (3, -d0 + s0)>
+
+// CHECK-LABEL: @conv_1d
+//  CHECK-SAME: %[[arg0:[a-zA-Z0-9]+]]: memref<?xf32>
+//  CHECK-SAME: %[[arg1:[a-zA-Z0-9]+]]: memref<?xf32>
+//  CHECK-SAME: %[[arg2:[a-zA-Z0-9]+]]: memref<?xf32
+func @conv_1d(%arg0: memref<?xf32>, %arg1: memref<?xf32>, %arg2: memref<?xf32>) {
+//   CHECK-DAG:   %[[c12:.*]] = arith.constant 12 : index
+//   CHECK-DAG:   %[[c4:.*]] = arith.constant 4 : index
+//   CHECK-DAG:   %[[cst:.*]] = arith.constant 0.000000e+00 : f32
+//   CHECK-DAG:   %[[c3:.*]] = arith.constant 3 : index
+//   CHECK-DAG:   %[[c0:.*]] = arith.constant 0 : index
+//   CHECK-DAG:   %[[c1:.*]] = arith.constant 1 : index
+//       CHECK:   %[[v0:.*]] = memref.dim %[[arg1]], %[[c0]] : memref<?xf32>
+//       CHECK:   %[[v1:.*]] = memref.dim %[[arg2]], %[[c0]] : memref<?xf32>
+//       CHECK:   %[[v2:.*]] = memref.dim %[[arg0]], %[[c0]] : memref<?xf32>
+//       CHECK:   %[[v3:.*]] = memref.alloc(%[[c12]]) : memref<?xi8>
+//       CHECK:   %[[v4:.*]] = memref.alloc(%[[c12]]) : memref<?xi8>
+//       CHECK:   %[[v5:.*]] = memref.alloc(%[[c4]]) : memref<?xi8>
+//       CHECK:   %[[v6:.*]] = memref.view %[[v3]][%[[c0]]][] : memref<?xi8> to memref<3xf32>
+//       CHECK:   %[[v7:.*]] = memref.view %[[v4]][%[[c0]]][] : memref<?xi8> to memref<3xf32>
+//       CHECK:   %[[v8:.*]] = memref.view %[[v5]][%[[c0]]][] : memref<?xi8> to memref<1xf32>
+//       CHECK:   scf.for %[[arg3:.*]] = %[[c0]] to %[[v1]] step %[[c1]] {
+//       CHECK:     %[[v9:.*]] = affine.min #[[$map0]](%[[arg3]])[%[[v1]]]
+//       CHECK:     %[[v10:.*]] = subview %[[arg2]][%[[arg3]]] [%[[v9]]] [1]  : memref<?xf32> to memref<?xf32, #[[$map1]]>
+//       CHECK:     %[[v11:.*]] = subview %[[v8]][0] [%[[v9]]] [1]  : memref<1xf32> to memref<?xf32>
+//       CHECK:     scf.for %[[arg4:.*]] = %[[c0]] to %[[v0]] step %[[c3]] {
+//       CHECK:       %[[v12:.*]] = affine.apply #[[$map2]](%[[arg3]], %[[arg4]])
+//       CHECK:       %[[v13:.*]] = affine.min #[[$map3]](%[[arg3]], %[[arg4]])[%[[v2]]]
+//       CHECK:       %[[v14:.*]] = subview %arg0[%12] [%13] [1]  : memref<?xf32> to memref<?xf32, #[[$map1]]>
+//       CHECK:       %[[v15:.*]] = affine.min #[[$map4]](%arg4)[%0]
+//       CHECK:       %[[v16:.*]] = subview %[[arg1]][%[[arg4]]] [%[[v15]]] [1]  : memref<?xf32> to memref<?xf32, #[[$map1]]>
+//       CHECK:       %[[v17:.*]] = subview %[[v6]][0] [%[[v13]]] [1]  : memref<3xf32> to memref<?xf32>
+//       CHECK:       %[[v19:.*]] = vector.transfer_read %[[v6]][%[[c0]]], %[[cst]] {in_bounds = [true]} : memref<3xf32>, vector<3xf32>
+//       CHECK:       %[[v20:.*]] = vector.transfer_read %[[v7]][%[[c0]]], %[[cst]] {in_bounds = [true]} : memref<3xf32>, vector<3xf32>
+//       CHECK:       %[[v21:.*]] = arith.mulf %[[v19]], %[[v20]] : vector<3xf32>
+//       CHECK:       %[[v22:.*]] = vector.reduction "add", %[[v21]], %[[cst]] : vector<3xf32> into f32
+//       CHECK:       store %[[v22]], %[[v8]][%[[c0]]] : memref<1xf32>
+//       CHECK:       scf.for %[[arg5:.*]] = %[[c0]] to %[[v9]] step %[[c1]] {
+//       CHECK:         %[[v23:.*]] = load %[[v11]][%[[arg5]]] : memref<?xf32>
+//       CHECK:         store %[[v23]], %[[v10]][%[[arg5]]] : memref<?xf32, #[[$map1]]>
+  linalg.conv_1d ins(%arg0, %arg1 : memref<?xf32>, memref<?xf32>)
+                outs(%arg2 : memref<?xf32>)
+  return
+}
+

diff  --git a/mlir/test/lib/Dialect/Linalg/CMakeLists.txt b/mlir/test/lib/Dialect/Linalg/CMakeLists.txt
index 7c9ad470eacf9..fad6ec91f7c5e 100644
--- a/mlir/test/lib/Dialect/Linalg/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Linalg/CMakeLists.txt
@@ -1,6 +1,7 @@
 # Exclude tests from libMLIR.so
 add_mlir_library(MLIRLinalgTestPasses
   TestComprehensiveBufferize.cpp
+  TestConvVectorization.cpp
   TestLinalgCodegenStrategy.cpp
   TestLinalgDistribution.cpp
   TestLinalgElementwiseFusion.cpp

diff  --git a/mlir/test/lib/Dialect/Linalg/TestConvVectorization.cpp b/mlir/test/lib/Dialect/Linalg/TestConvVectorization.cpp
new file mode 100644
index 0000000000000..9c8f138743dec
--- /dev/null
+++ b/mlir/test/lib/Dialect/Linalg/TestConvVectorization.cpp
@@ -0,0 +1,143 @@
+//===- TestConvVectorization.cpp - Vectorization of Conv ops --------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
+#include "mlir/Dialect/Linalg/Passes.h"
+#include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Dialect/SCF/Transforms.h"
+#include "mlir/Dialect/Vector/VectorTransforms.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "mlir/Transforms/LoopUtils.h"
+#include "mlir/Transforms/Passes.h"
+
+using namespace mlir;
+using namespace vector;
+
+namespace {
+/// A pass converting MLIR Linalg ops into Vector ops.
+class TestConvVectorization
+    : public PassWrapper<TestConvVectorization, OperationPass<ModuleOp>> {
+public:
+  StringRef getArgument() const final { return "test-conv-vectorization"; }
+  StringRef getDescription() const final {
+    return "Test vectorization of convolutions";
+  }
+  TestConvVectorization() = default;
+  TestConvVectorization(const TestConvVectorization &) {}
+  explicit TestConvVectorization(ArrayRef<int64_t> tileSizesParam) {
+    tileSizes = tileSizesParam;
+  }
+
+  void runOnOperation() override;
+
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<VectorDialect>();
+    registry.insert<linalg::LinalgDialect>();
+    registry.insert<memref::MemRefDialect>();
+    registry.insert<scf::SCFDialect>();
+    registry.insert<AffineDialect>();
+    registry.insert<StandardOpsDialect>();
+  }
+
+  ListOption<int64_t> tileSizes{
+      *this, "tile-sizes", llvm::cl::desc("Vectorization sizes."),
+      llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
+};
+} // namespace
+
+void TestConvVectorization::runOnOperation() {
+  MLIRContext *context = &getContext();
+  ModuleOp module = getOperation();
+
+  ConversionTarget target(*context);
+  target.addLegalDialect<AffineDialect, scf::SCFDialect, StandardOpsDialect,
+                         VectorDialect>();
+  target.addLegalOp<ModuleOp, FuncOp, ReturnOp>();
+  target.addLegalOp<linalg::FillOp, linalg::YieldOp>();
+
+  SmallVector<RewritePatternSet, 4> stage1Patterns;
+  linalg::populateConvVectorizationPatterns(context, stage1Patterns, tileSizes);
+  SmallVector<FrozenRewritePatternSet, 4> frozenStage1Patterns;
+  llvm::move(stage1Patterns, std::back_inserter(frozenStage1Patterns));
+
+  RewritePatternSet stage2Patterns =
+      linalg::getLinalgTilingCanonicalizationPatterns(context);
+  scf::populateSCFForLoopCanonicalizationPatterns(stage2Patterns);
+
+  auto stage3Transforms = [](Operation *op) {
+    PassManager pm(op->getContext());
+    pm.addPass(createLoopInvariantCodeMotionPass());
+    if (failed(pm.run(cast<ModuleOp>(op))))
+      llvm_unreachable("Unexpected failure in cleanup pass pipeline.");
+    op->walk([](FuncOp func) {
+      promoteSingleIterationLoops(func);
+      linalg::hoistRedundantVectorTransfers(func);
+    });
+    return success();
+  };
+
+  (void)linalg::applyStagedPatterns(module, frozenStage1Patterns,
+                                    std::move(stage2Patterns),
+                                    stage3Transforms);
+
+  //===--------------------------------------------------------------------===//
+  // Post staged patterns transforms
+  //===--------------------------------------------------------------------===//
+
+  VectorTransformsOptions vectorTransformOptions{
+      VectorContractLowering::Dot, VectorMultiReductionLowering::InnerParallel,
+      VectorTransposeLowering::EltWise};
+
+  RewritePatternSet vectorTransferPatterns(context);
+  // Pattern is not applied: rank-reducing vector transfer is not yet supported
+  // (see: splitFullAndPartialTransferPrecondition in VectorTransforms.cpp).
+  vectorTransferPatterns.add<VectorTransferFullPartialRewriter>(
+      context, vectorTransformOptions);
+  (void)applyPatternsAndFoldGreedily(module, std::move(vectorTransferPatterns));
+
+  // Programmatic controlled lowering of linalg.copy and linalg.fill.
+  PassManager pm(context);
+  pm.addNestedPass<FuncOp>(createConvertLinalgToLoopsPass());
+  if (failed(pm.run(module)))
+    llvm_unreachable("Unexpected failure in linalg to loops pass.");
+
+  // Programmatic controlled lowering of vector.contract only.
+  RewritePatternSet vectorContractLoweringPatterns(context);
+  populateVectorBroadcastLoweringPatterns(vectorContractLoweringPatterns);
+  populateVectorContractLoweringPatterns(vectorContractLoweringPatterns,
+                                         vectorTransformOptions);
+  populateVectorMaskOpLoweringPatterns(vectorContractLoweringPatterns);
+  populateVectorShapeCastLoweringPatterns(vectorContractLoweringPatterns);
+  populateVectorTransposeLoweringPatterns(vectorContractLoweringPatterns,
+                                          vectorTransformOptions);
+  (void)applyPatternsAndFoldGreedily(module,
+                                     std::move(vectorContractLoweringPatterns));
+
+  // Programmatic controlled lowering of vector.transfer only.
+  RewritePatternSet vectorToLoopsPatterns(context);
+  populateVectorToSCFConversionPatterns(vectorToLoopsPatterns,
+                                        VectorTransferToSCFOptions());
+  (void)applyPatternsAndFoldGreedily(module, std::move(vectorToLoopsPatterns));
+
+  // Ensure we drop the marker in the end.
+  module.walk([](linalg::LinalgOp op) {
+    op->removeAttr(linalg::LinalgTransforms::kLinalgTransformMarker);
+  });
+}
+
+namespace mlir {
+namespace test {
+void registerTestConvVectorization() {
+  PassRegistration<TestConvVectorization>();
+}
+} // namespace test
+} // namespace mlir

diff  --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 5b09cb8671eb1..9a54cf257ce20 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -66,6 +66,7 @@ void registerTestBuiltinAttributeInterfaces();
 void registerTestCallGraphPass();
 void registerTestComprehensiveFunctionBufferize();
 void registerTestConstantFold();
+void registerTestConvVectorization();
 void registerTestGpuSerializeToCubinPass();
 void registerTestGpuSerializeToHsacoPass();
 void registerTestDataLayoutQuery();
@@ -161,6 +162,7 @@ void registerTestPasses() {
   mlir::test::registerTestGpuSerializeToHsacoPass();
 #endif
   mlir::test::registerTestComprehensiveFunctionBufferize();
+  mlir::test::registerTestConvVectorization();
   mlir::test::registerTestDecomposeCallGraphTypes();
   mlir::test::registerTestDataLayoutQuery();
   mlir::test::registerTestDominancePass();


        


More information about the Mlir-commits mailing list