[Mlir-commits] [mlir] 8ed2e8e - [mlir][linalg] Retire Linalg ConvOp.

Tobias Gysi llvmlistbot at llvm.org
Thu Oct 7 23:57:30 PDT 2021


Author: Tobias Gysi
Date: 2021-10-08T06:56:37Z
New Revision: 8ed2e8e04ff42eb4d8009999ae1fd341a30bf6c0

URL: https://github.com/llvm/llvm-project/commit/8ed2e8e04ff42eb4d8009999ae1fd341a30bf6c0
DIFF: https://github.com/llvm/llvm-project/commit/8ed2e8e04ff42eb4d8009999ae1fd341a30bf6c0.diff

LOG: [mlir][linalg] Retire Linalg ConvOp.

The convolution op is one of the remaining hard coded Linalg operations that have no region attached. It got obsolete due to the OpDSL convolution operations. Removing it allows us to delete specialized code and tests that are not needed for the OpDSL counterparts that rely on the standard code paths.

Test needed due to specialized implementations are removed. Tiling and fusion tests are replaced by variants using linalg.conv_2d.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D111233

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h
    mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
    mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
    mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
    mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
    mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
    mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
    mlir/test/Dialect/Linalg/affine.mlir
    mlir/test/Dialect/Linalg/fusion-pattern.mlir
    mlir/test/Dialect/Linalg/fusion.mlir
    mlir/test/Dialect/Linalg/generalize-named-ops.mlir
    mlir/test/Dialect/Linalg/loops.mlir
    mlir/test/Dialect/Linalg/roundtrip.mlir
    mlir/test/Dialect/Linalg/tile-conv.mlir
    mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp

Removed: 
    mlir/test/Dialect/Linalg/tile-conv-padding.mlir
    mlir/test/Dialect/Linalg/tile-simple-conv.mlir


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h
index 2d65106eef4e8..c8eb1db039839 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h
@@ -34,7 +34,6 @@
 namespace mlir {
 namespace linalg {
 
-class ConvOp;
 class LinalgOp;
 
 // TOFO: allow an extra ValueRange to specify an indexing and allow
@@ -81,14 +80,6 @@ std::string generateLibraryCallName(Operation *op);
 SmallVector<AffineExpr, 4> makeAffineDimExprs(unsigned num, unsigned &startIdx,
                                               MLIRContext *context);
 
-/// Builds the indexing expressions for a ConvOp/PoolingOp `op`. Returns the
-/// vector of AffineMaps representing:
-///   `stride[i] * outputDims[i] + dilation[i] * windowDims[i] - pad_low[i]`
-template <typename PoolingOp>
-extern SmallVector<AffineExpr, 4>
-weightedPoolingInputIndex(PoolingOp op, ArrayRef<AffineExpr> outputDims,
-                          ArrayRef<AffineExpr> windowDims);
-
 /// Returns `maybeMap.get()` if `maybeMap` is set, otherwise returns the
 /// symbol-less identity map of `rank`.
 AffineMap extractOrIdentityMap(Optional<AffineMap> maybeMap, unsigned rank,

diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index bf8d90020889d..a30684c117f4c 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -222,201 +222,6 @@ def FillOp : LinalgStructured_Op<"fill", []> {
   let hasFolder = 1;
 }
 
-/// A base class for pooling operation such as conv. The arguments must contain
-/// optional arguments `strides`, `dilations` and `padding` with following type:
-///   OptionalAttr<I64ArrayAttr>:$strides
-///   OptionalAttr<I64ArrayAttr>:$dilations
-///   OptionalAttr<I64ElementsAttr>:$padding
-/// `strides` denotes the step of each window along the dimension.
-class PoolingBase_Op<string mnemonic, list<OpTrait> props>
-  : LinalgStructured_Op<mnemonic, props> {
-  let description = [{
-    Performs an N-D pooling operation similarly to the description in the TF
-    documentation:
-    https://www.tensorflow.org/api_docs/python/tf/nn/pool
-
-    Different from the description, this operation doesn't perform on batch and
-    channel. It only takes tensors of rank `N`.
-
-    ```
-      output[x[0], ..., x[N-1]] =
-        REDUCE_{z[0], ..., z[N-1]}
-          input[
-                x[0] * strides[0] - pad_before[0] + dilation_rate[0]*z[0],
-                ...
-                x[N-1]*strides[N-1] - pad_before[N-1] + dilation_rate[N-1]*z[N-1]
-                ],
-    ```
-
-    The required optional arguments are:
-      - strides: an i64 array specifying the stride (i.e. step) for window
-        loops.
-      - dilations: an i64 array specifying the filter upsampling/input
-        downsampling rate
-      - padding: an i64 array of pairs (low, high) specifying the number of
-        elements to pad along a dimension.
-
-    If strides or dilations attributes are missing then the default value is
-    one for each of the input dimensions. Similarly, padding values are zero
-    for both low and high in each of the dimensions, if not specified.
-  }];
-
-  code commonUtils = structuredOpsDecls # [{
-    int64_t getStride(unsigned i) {
-      assert(i < getNumWindowLoops());
-      if (!strides().hasValue()) return 1;
-      return strides()->getValue()[i]
-        .cast<IntegerAttr>().getValue().getSExtValue();
-    }
-
-    int64_t getDilation(unsigned i) {
-      assert(i < getNumWindowLoops());
-      if (!dilations().hasValue()) return 1;
-      return dilations()->getValue()[i]
-        .cast<IntegerAttr>().getValue().getSExtValue();
-    }
-
-    int64_t getLowPad(unsigned i) {
-      assert(i < getNumWindowLoops());
-      if (!padding().hasValue()) return 0;
-      return padding().getValue().getValue<int64_t>({i, 0});
-    }
-
-    int64_t getHighPad(unsigned i) {
-      assert(i < getNumWindowLoops());
-      if (!padding().hasValue()) return 0;
-      return padding().getValue().getValue<int64_t>({i, 1});
-    }
-
-    static std::function<void(ImplicitLocOpBuilder &b, Block &block)>
-    getRegionBuilder() {
-      return nullptr;
-    }
-  }];
-}
-
-// Only support buffer semantics.
-def ConvOp : PoolingBase_Op<"conv", []> {
-  let description = [{
-    Generic n-D convolution as described in the TF documentation:
-    https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/nn/convolution
-
-    ```
-      output[b, x[0], ..., x[N-1], k] =
-      sum_{z[0], ..., z[N-1], q}
-          filter[z[0], ..., z[N-1], q, k] *
-          padded_input[b,
-                       x[0] * strides[0] + dilation_rate[0] * z[0],
-                       ...,
-                       x[N-1] * strides[N-1] + dilation_rate[N-1] * z[N-1],
-                       q]
-    ```
-  }];
-
-  // Following the TF source of truth above, strides, dilations and padding are
-  // integer attributes of the same rank as the number of window dimensions.
-  // The padding attribute specifies the amount of zero padding to be applied to
-  // the base area, which is a n-d array of (low, high) padding. Each pair has
-  // the low padding as the first element and the high padding as the second
-  // element. Using padding is equivalent to inserting those same zero values
-  // into the input before doing the convolution.
-  let arguments = (ins AnyStridedMemRef:$filter, AnyStridedMemRef:$input,
-                   AnyStridedMemRef:$output,
-                   OptionalAttr<I64ArrayAttr>:$strides,
-                   OptionalAttr<I64ArrayAttr>:$dilations,
-                   OptionalAttr<I64ElementsAttr>:$padding);
-
-  let extraClassDeclaration = commonUtils # [{
-    ValueRange inputs() { return getOperands().slice(0, 2); }
-    ValueRange outputs() { return getOperands().take_back(); }
-
-    // TODO: extend to support more than 1 dimensions and potentially grouping
-    // too.
-    unsigned getNumBatchDimensions() { return 1; }
-
-    unsigned getNumInputFeatureDimensions() { return 1; }
-
-    unsigned getNumOutputFeatureDimensions() { return 1; }
-
-    unsigned getNumSpatialDimensions() {
-      return getRank(getOutputOperand(0)) - getNumBatchDimensions() -
-             getNumOutputFeatureDimensions();
-    }
-
-    ArrayAttr iterator_types() {
-      // Outer parallel loops are always the number of output dimensions; i.e.
-      // [b, xs, q] in the TF notation above.
-      int64_t nPar = getRank(getOutputOperand(0));
-      unsigned nRed = getNumInputFeatureDimensions();
-      // Window loops are a special kind of reduction that is never tiled or
-      // parallelized across; i.e. [zs] in the TF notation above whose number
-      // match `xs` (i.e. 1 window loop per "image" dimension).
-      // This may evolve in the future.
-      // Conditionally check nPar is large enough for cases of ill-formed op:
-      // this avoids overflows before hitting the verifier.
-      assert(nPar > getNumBatchDimensions() + getNumInputFeatureDimensions() &&
-             "expected at least one window dimension (i.e. memref ranks greater "
-             "than 2). See 'func @conv_rank_limit' in "
-             "mlir/test/Dialect/Linalg/invalid.mlir");
-      unsigned nWin =
-        nPar - getNumBatchDimensions() - getNumInputFeatureDimensions();
-      SmallVector<StringRef, 8> iters(nPar, getParallelIteratorTypeName());
-      iters.reserve(nPar + nRed + nWin);
-      iters.append(nRed, getReductionIteratorTypeName());
-      iters.append(nWin, getWindowIteratorTypeName());
-      return Builder(getContext()).getStrArrayAttr(iters);
-    }
-
-    //   F(z0, ..., zN-1, q, k) *
-    //     I(b, x0 + z0 - pad_low_0, ..., xN-1 + zN-1 - pad_low_N-1, q)
-    //   ->  O(b, x0, ..., xN-1, k)
-    // for N equal to `nWindow`. If there is no padding attribute, it will be
-    // ignored.
-    ArrayAttr indexing_maps() {
-      MLIRContext *context = getContext();
-      auto nWin = getNumWindowLoops();
-      assert(nWin > 0 && "expected at least one window dimension (i.e. memref "
-                         "ranks greater than 2)");
-      unsigned idx = 0;
-      // In the following, AffineDimExprs are indexed in loop order:
-      //   [ b, xs, k,           q,                     zs]
-      //    parallels     non-window reductions     windows
-      //
-      // Parallel dims are exactly the dimensions indexing `output`:
-      //     output[b, x[0], ..., x[N-1], k]; i.e.
-      //  * batch dimensions (bs with #bs = 1 for now)
-      //  * "image" dimensions (xs with #xs = #zs = output_rank - #bs - #ks)
-      //  * output filter dimensions (ks with #ks = 1 for now)
-      auto bs = makeAffineDimExprs(getNumBatchDimensions(), idx, context);
-      auto xs = makeAffineDimExprs(nWin, idx, context);
-      auto ks = makeAffineDimExprs(
-        getNumOutputFeatureDimensions(), idx, context);
-      // Non-window reduction dim: sum_{z[0], ..., z[N-1], q}
-      auto qs = makeAffineDimExprs(
-        getNumInputFeatureDimensions(), idx, context);
-      // Window reduction dims: sum_{z[0], ..., z[N-1], q}
-      auto zs = makeAffineDimExprs(nWin, idx, context);
-      // Construct the weighedSum expression.
-      auto ws = weightedPoolingInputIndex(*this, xs, zs);
-      return Builder(getContext()).getAffineMapArrayAttr({
-        // filter[z[0], ..., z[N-1], q, k]
-        AffineMap::get(idx, 0, concat(concat(zs, qs), ks), context),
-        // input[b,
-        //       x[0]*s[0] + d[0]*z[0] - pad_low[0],
-        //       ...
-        //       x[N-1]*s[N-1] + d[N-1]*z[N-1] - pad_low[N-1],
-        //       q]
-        AffineMap::get(idx, 0, concat(concat(bs, ws), qs), context),
-        // output[b, x[0], ..., x[N-1], k]
-        AffineMap::get(idx, 0, concat(concat(bs, xs), ks), context)});
-    }
-  }];
-
-  let verifier = [{ return ::verify(*this); }];
-
-  let hasFolder = 1;
-}
-
 //===----------------------------------------------------------------------===//
 // Generic Linalg ops.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index b2bd87086ac56..60b6e1a2a13a8 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -970,12 +970,6 @@ void populateLinalgNamedOpsGeneralizationPatterns(
     RewritePatternSet &patterns,
     LinalgTransformationFilter filter = LinalgTransformationFilter());
 
-/// Populates `patterns` with patterns to convert linalg.conv ops to
-/// linalg.generic ops.
-void populateLinalgConvGeneralizationPatterns(
-    RewritePatternSet &patterns,
-    LinalgTransformationFilter filter = LinalgTransformationFilter());
-
 /// Linalg distribution patterns
 //
 /// Populates `patterns` with patterns to distribute linalg.tiled_loop.

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 2cba281acf8f0..44d470cac5117 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2610,52 +2610,6 @@ static LogicalResult verify(IndexOp op) {
 
 /////// Operations corresponding to library calls defined with Tablegen ////////
 
-template <typename LinalgPoolingOp>
-static LogicalResult verifyStrideOrDilation(LinalgPoolingOp op,
-                                            ArrayRef<Attribute> attrs,
-                                            bool isStride) {
-  auto strideOrDilation = isStride ? "stride" : "dilation";
-  if (attrs.size() != op.getNumWindowLoops())
-    return op.emitOpError("expects num ")
-           << strideOrDilation
-           << "s equal to number of window dimensions: " << attrs.size()
-           << " vs " << op.getNumWindowLoops();
-  return success();
-}
-
-void ConvOp::getEffects(
-    SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
-        &effects) {
-  effects.emplace_back(MemoryEffects::Read::get(), input(),
-                       SideEffects::DefaultResource::get());
-  effects.emplace_back(MemoryEffects::Read::get(), filter(),
-                       SideEffects::DefaultResource::get());
-  effects.emplace_back(MemoryEffects::Write::get(), output(),
-                       SideEffects::DefaultResource::get());
-}
-
-static LogicalResult verify(ConvOp op) {
-  auto oType = op.output().getType().cast<MemRefType>();
-  auto fType = op.filter().getType().cast<MemRefType>();
-  auto iType = op.input().getType().cast<MemRefType>();
-  if (oType.getElementType() != iType.getElementType() ||
-      oType.getElementType() != fType.getElementType())
-    return op.emitOpError("expects memref elemental types to match");
-  if (oType.getRank() != iType.getRank() || oType.getRank() != fType.getRank())
-    return op.emitOpError("expects memref ranks to match");
-  if (auto strides = op.strides()) {
-    if (failed(verifyStrideOrDilation(op, strides->getValue(),
-                                      /*isStride=*/true)))
-      return failure();
-  }
-  if (auto dilations = op.dilations()) {
-    if (failed(verifyStrideOrDilation(op, dilations->getValue(),
-                                      /*isStride=*/false)))
-      return failure();
-  }
-  return success();
-}
-
 #include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yamlgen.cpp.inc"
 
 #define GET_OP_CLASSES
@@ -2701,31 +2655,6 @@ mlir::linalg::makeAffineDimExprs(unsigned num, unsigned &startIdx,
   return res;
 }
 
-template <typename PoolingOp>
-SmallVector<AffineExpr, 4>
-mlir::linalg::weightedPoolingInputIndex(PoolingOp op,
-                                        ArrayRef<AffineExpr> outputDims,
-                                        ArrayRef<AffineExpr> windowDims) {
-  assert(outputDims.size() == windowDims.size());
-  SmallVector<AffineExpr, 4> res;
-  res.reserve(outputDims.size());
-  for (unsigned i = 0, e = outputDims.size(); i < e; ++i) {
-    // TODO: add a level of indirection to linalg.generic.
-    auto expr = op.getStride(i) * outputDims[i] +
-                op.getDilation(i) * windowDims[i] - op.getLowPad(i);
-    res.push_back(expr);
-  }
-  return res;
-}
-
-#define INSTANTIATE_WEIGHTED_POOLING_INPUT_INDEX(OP_TYPE)                      \
-  template SmallVector<AffineExpr, 4>                                          \
-  mlir::linalg::weightedPoolingInputIndex<OP_TYPE>(                            \
-      OP_TYPE op, ArrayRef<AffineExpr> outputDims,                             \
-      ArrayRef<AffineExpr> windowDims);
-
-INSTANTIATE_WEIGHTED_POOLING_INPUT_INDEX(ConvOp)
-
 SmallVector<AffineExpr, 4> mlir::linalg::concat(ArrayRef<AffineExpr> a,
                                                 ArrayRef<AffineExpr> b) {
   auto rangeA = llvm::make_range(a.begin(), a.end());
@@ -3180,7 +3109,6 @@ struct SimplifyDepthwiseConvQOp
     return foldMemRefCast(*this);                                              \
   }
 
-LINALGOP_FOLDERS(ConvOp)
 LINALGOP_FOLDERS(CopyOp)
 LINALGOP_FOLDERS(FillOp)
 LINALGOP_FOLDERS(GenericOp)

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
index adce2c74b78c2..db7d11ae58c75 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
@@ -324,16 +324,6 @@ bool mlir::linalg::isFusableInto(const LinalgDependenceGraph &graph,
                << *producer.getOperation());
     return false;
   }
-  if (auto convOp = dyn_cast<linalg::ConvOp>(producer.getOperation())) {
-    // TODO: add a level of indirection to linalg.generic.
-    if (convOp.padding())
-      return false;
-  }
-  if (auto convOp = dyn_cast<linalg::ConvOp>(consumer.getOperation())) {
-    // TODO: add a level of indirection to linalg.generic.
-    if (convOp.padding())
-      return false;
-  }
   return true;
 }
 

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
index afa78b8bc3845..7740baafe21a5 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
@@ -75,49 +75,6 @@ GenericOp mlir::linalg::generalizeNamedOp(PatternRewriter &rewriter,
 
 namespace {
 
-/// Base class for all linalg generalization patterns. A subclass must provide
-/// the following method:
-///   GenericOp createGenericOp(RootOp, PatternRewriter &)
-/// for creating the generic op.
-// TODO: remove this pattern after migrating all manually-written named ops
-// into auto-generated ones.
-template <typename ConcretePattern, typename RootOp>
-struct LinalgGeneralizationPattern : OpRewritePattern<RootOp> {
-  LinalgGeneralizationPattern(MLIRContext *context,
-                              LinalgTransformationFilter marker,
-                              PatternBenefit benefit = 1)
-      : OpRewritePattern<RootOp>(context, benefit), marker(std::move(marker)) {}
-
-  LogicalResult matchAndRewrite(RootOp rootOp,
-                                PatternRewriter &rewriter) const override {
-    auto linalgOp = dyn_cast<LinalgOp>(rootOp.getOperation());
-    if (!linalgOp)
-      return failure();
-    if (failed(marker.checkAndNotify(rewriter, linalgOp)))
-      return failure();
-
-    auto *pattern = static_cast<const ConcretePattern *>(this);
-    GenericOp genericOp = pattern->createGenericOp(rootOp, rewriter);
-    if (!genericOp)
-      return failure();
-
-    rewriter.replaceOp(rootOp, genericOp.getResults());
-    marker.replaceLinalgTransformationFilter(rewriter,
-                                             genericOp.getOperation());
-    return success();
-  }
-
-private:
-  LinalgTransformationFilter marker;
-};
-
-struct GeneralizeConvOp
-    : public LinalgGeneralizationPattern<GeneralizeConvOp, ConvOp> {
-  using LinalgGeneralizationPattern::LinalgGeneralizationPattern;
-
-  GenericOp createGenericOp(ConvOp convOp, OpBuilder &builder) const;
-};
-
 struct LinalgGeneralizationPass
     : public LinalgGeneralizationBase<LinalgGeneralizationPass> {
   void runOnFunction() override;
@@ -128,34 +85,10 @@ struct LinalgGeneralizationPass
 void LinalgGeneralizationPass::runOnFunction() {
   FuncOp func = getFunction();
   RewritePatternSet patterns(&getContext());
-  populateLinalgConvGeneralizationPatterns(patterns);
   populateLinalgNamedOpsGeneralizationPatterns(patterns);
   (void)applyPatternsAndFoldGreedily(func.getBody(), std::move(patterns));
 }
 
-GenericOp GeneralizeConvOp::createGenericOp(ConvOp convOp,
-                                            OpBuilder &builder) const {
-  SmallVector<AffineMap> indexingMaps = convOp.getIndexingMaps();
-  auto iterators =
-      llvm::to_vector<4>(convOp.iterator_types().getAsValueRange<StringAttr>());
-  SmallVector<Value> inputBuffers = convOp.getInputBufferOperands();
-  SmallVector<Value> outputBuffers = convOp.getOutputBufferOperands();
-  return builder.create<GenericOp>(
-      convOp.getLoc(), /*resultTensorTypes=*/ArrayRef<Type>(), inputBuffers,
-      outputBuffers, indexingMaps, iterators,
-      [](OpBuilder &bodyBuilder, Location bodyLoc, ValueRange bodyArgs) {
-        Value mul =
-            bodyBuilder.create<MulFOp>(bodyLoc, bodyArgs[0], bodyArgs[1]);
-        Value add = bodyBuilder.create<AddFOp>(bodyLoc, mul, bodyArgs[2]);
-        bodyBuilder.create<YieldOp>(bodyLoc, add);
-      });
-}
-
-void mlir::linalg::populateLinalgConvGeneralizationPatterns(
-    RewritePatternSet &patterns, LinalgTransformationFilter marker) {
-  patterns.add<GeneralizeConvOp>(patterns.getContext(), marker);
-}
-
 void mlir::linalg::populateLinalgNamedOpsGeneralizationPatterns(
     RewritePatternSet &patterns, LinalgTransformationFilter marker) {
   patterns.add<LinalgGeneralizationPattern>(patterns.getContext(), marker);

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
index 854166d0ef679..f027d0d5f388c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
@@ -160,116 +160,6 @@ static void emitScalarImplementation(OpBuilder &b, Location loc,
                                                 indexing, outputBuffers);
 }
 
-// Create a padded view into the given `input` tensor using the 'indices'
-// to access the tensor. `skipPadding` lists the dimensions for which no padding
-// is needed e.g. the non-spatial dimensions for convolutions.
-Value getPaddedInput(OpBuilder &b, Location loc, Value input,
-                     ArrayRef<Value> indices, ArrayRef<int> skipPadding,
-                     Value padValue) {
-  Value zeroIndex = b.create<ConstantIndexOp>(loc, 0);
-  SmallVector<Value> conds;
-  SmallVector<Value> clampedImIdx;
-  for (auto iter : llvm::enumerate(indices)) {
-    int idx = iter.index();
-    auto dim = iter.value();
-    if (is_contained(skipPadding, idx)) {
-      clampedImIdx.push_back(dim);
-      continue;
-    }
-
-    Value leftOutOfBound =
-        b.create<CmpIOp>(loc, CmpIPredicate::slt, dim, zeroIndex);
-    if (conds.empty())
-      conds.push_back(leftOutOfBound);
-    else
-      conds.push_back(b.create<OrOp>(loc, conds.back(), leftOutOfBound));
-    Value rightBound = createOrFoldDimOp(b, loc, input, idx);
-    Value rightOutOfBound =
-        b.create<CmpIOp>(loc, CmpIPredicate::sge, dim, rightBound);
-    conds.push_back(b.create<OrOp>(loc, conds.back(), rightOutOfBound));
-
-    // When padding is involved, the indices will only be shifted to negative,
-    // so having a max op is enough.
-    MLIRContext *ctx = input.getContext();
-    AffineExpr m = getAffineDimExpr(/*position=*/0, ctx),
-               zero = getAffineConstantExpr(0, ctx);
-    AffineMap maxMap =
-        AffineMap::inferFromExprList(ArrayRef<ArrayRef<AffineExpr>>{{m, zero}})
-            .front();
-    clampedImIdx.push_back(b.create<AffineMaxOp>(loc, maxMap, ValueRange{dim}));
-  }
-
-  Value readInput = b.create<memref::LoadOp>(loc, input, clampedImIdx);
-  if (conds.empty())
-    return readInput;
-
-  return b.create<SelectOp>(loc, conds.back(), padValue, readInput);
-}
-
-namespace {
-
-/// The padding value for a given Op depends on the semantics of the Op.
-/// The identity value for ConvOp is 0.
-template <typename OpType> Attribute getPadValueAttr(Type type) {
-  llvm_unreachable("Unexpected op type for getPadValueAttr");
-  return {};
-}
-
-template <> Attribute getPadValueAttr<ConvOp>(Type type) {
-  return OpBuilder(type.getContext()).getZeroAttr(type);
-}
-
-} // namespace
-
-/// Returns true is `convOp` has a non-zero padding.
-static bool hasPadding(ConvOp convOp) {
-  for (unsigned i = 0, e = convOp.getNumSpatialDimensions(); i < e; ++i) {
-    if (convOp.getLowPad(i) > 0 || convOp.getHighPad(i) > 0)
-      return true;
-  }
-  return false;
-}
-
-template <typename LoadOpTy, typename StoreOpTy>
-static void emitScalarImplementation(OpBuilder &b, Location loc,
-                                     ArrayRef<Value> allIvs, ConvOp convOp) {
-  assert(convOp.hasBufferSemantics() &&
-         "expected linalg op with buffer semantics");
-  auto mapsRange = convOp.indexing_maps().getAsRange<AffineMapAttr>();
-  auto maps = llvm::to_vector<8>(
-      llvm::map_range(mapsRange, [](AffineMapAttr a) { return a.getValue(); }));
-  SmallVector<Value> fIdx(makeCanonicalAffineApplies(b, loc, maps[0], allIvs));
-  SmallVector<Value> imIdx(makeCanonicalAffineApplies(b, loc, maps[1], allIvs));
-  SmallVector<Value> oIdx(makeCanonicalAffineApplies(b, loc, maps[2], allIvs));
-
-  Value filter = convOp.filter(), output = convOp.output();
-
-  // Emit scalar form. Padded conv involves an affine.max in the memory access
-  // which is not allowed by affine.load. Override to use an MemRefIndexedValue
-  // when there is non-zero padding.
-  if (hasPadding(convOp)) {
-    Type type = convOp.input().getType().cast<MemRefType>().getElementType();
-    Value padValue =
-        b.create<ConstantOp>(loc, type, getPadValueAttr<ConvOp>(type));
-    Value paddedInput =
-        getPaddedInput(b, loc, convOp.input(), imIdx,
-                       /* Only need to pad the window dimensions */
-                       {0, static_cast<int>(imIdx.size()) - 1}, padValue);
-    Value filterVal = b.create<LoadOpTy>(loc, filter, fIdx);
-    Value mulVal = ArithBuilder(b, loc).mul(filterVal, paddedInput);
-    Value outputVal = b.create<LoadOpTy>(loc, output, oIdx);
-    Value addVal = ArithBuilder(b, loc).add(mulVal, outputVal);
-    b.create<StoreOpTy>(loc, addVal, output, oIdx);
-  } else {
-    Value inputVal = b.create<LoadOpTy>(loc, convOp.input(), imIdx);
-    Value filterVal = b.create<LoadOpTy>(loc, filter, fIdx);
-    Value mulVal = ArithBuilder(b, loc).mul(filterVal, inputVal);
-    Value outputVal = b.create<LoadOpTy>(loc, output, oIdx);
-    Value addVal = ArithBuilder(b, loc).add(mulVal, outputVal);
-    b.create<StoreOpTy>(loc, addVal, output, oIdx);
-  }
-}
-
 /// Replace the index operations in the body of the loop nest by the matching
 /// induction variables.
 static void replaceIndexOpsByInductionVariables(LinalgOp linalgOp,
@@ -328,11 +218,7 @@ static Optional<LinalgLoops> linalgOpToLoopsImpl(PatternRewriter &rewriter,
         assert(operandValuesToUse == linalgOp->getOperands() &&
                "expect operands are captured and not passed by loop argument");
         allIvs.append(ivs.begin(), ivs.end());
-        llvm::TypeSwitch<Operation *>(linalgOp)
-            .Case<ConvOp, LinalgOp>([&](auto op) {
-              emitScalarImplementation<LoadOpTy, StoreOpTy>(b, loc, allIvs, op);
-            })
-            .Default([&](Operation *op) { assert(false && "unexpected op"); });
+        emitScalarImplementation<LoadOpTy, StoreOpTy>(b, loc, allIvs, linalgOp);
         return scf::ValueVector{};
       });
   // Number of loop ops might be 
diff erent from the number of ivs since some

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
index 2ec183fec8e29..1f96c59e18e6d 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
@@ -304,12 +304,6 @@ promoteSubViews(ImplicitLocOpBuilder &b, LinalgOp op,
                 LinalgOpInstancePromotionOptions options, DataLayout &layout) {
   assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics");
 
-  if (auto convOp = dyn_cast<linalg::ConvOp>(op.getOperation())) {
-    // TODO: add a level of indirection to linalg.generic.
-    if (convOp.padding())
-      return {};
-  }
-
   // 1. Promote the specified views and use them in the new op.
   auto promotedBuffersAndViews = promoteSubViews(b, options, layout);
   if (!promotedBuffersAndViews ||

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index cb5c3b0ce47de..59874a99cf1e2 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -162,13 +162,6 @@ tileLinalgOpImpl(OpBuilder &b, LinalgOp op, ValueRange tileSizes,
   if (llvm::all_of(tileSizes, isZero))
     return llvm::None;
 
-  if (auto convOp = dyn_cast<linalg::ConvOp>(op.getOperation())) {
-    // For conv op only support tiling along batch dimension (which is the first
-    // loop).
-    if (convOp.padding() && !llvm::all_of(tileSizes.drop_front(), isZero))
-      return llvm::None;
-  }
-
   // 1. Build the tiled loop ranges.
   auto allShapeSizes = op.createFlatListOfOperandDims(b, op.getLoc());
   AffineMap shapeSizesToLoopsMap = op.getShapesToLoopsMap();

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 76279ef2d3ca8..e6df2dbf9b1ac 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1114,8 +1114,6 @@ LogicalResult ConvOpVectorization<ConvOp, N>::matchAndRewrite(
   return success();
 }
 
-using ConvOpConst = ConvOpVectorization<Conv1DOp, 1>;
-
 /// Inserts tiling, promotion and vectorization pattern for ConvOp
 /// conversion into corresponding pattern lists.
 template <typename ConvOp, unsigned N>

diff  --git a/mlir/test/Dialect/Linalg/affine.mlir b/mlir/test/Dialect/Linalg/affine.mlir
index 74dce8a006d13..aefa4a3a7da5a 100644
--- a/mlir/test/Dialect/Linalg/affine.mlir
+++ b/mlir/test/Dialect/Linalg/affine.mlir
@@ -3,12 +3,6 @@
 // Test that we can lower all the way to LLVM without crashing, don't check results here.
 // RUN: mlir-opt %s -convert-linalg-to-affine-loops -convert-linalg-to-llvm -o=/dev/null 2>&1
 
-// CHECK-DAG: #[[$strided3D:.*]] = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2 + d2)>
-
-// CHECK-DAG: #[[$stride2Dilation1:.*]] = affine_map<(d0, d1) -> (d0 * 2 + d1)>
-
-// CHECK-DAG: #[[$clampMinMap:.*]] = affine_map<(d0) -> (d0, 0)>
-
 func @matmul(%arg0: memref<?xi8>, %M: index, %N: index, %K: index) {
   %c0 = constant 0 : index
   %c1 = constant 1 : index
@@ -20,85 +14,6 @@ func @matmul(%arg0: memref<?xi8>, %M: index, %N: index, %K: index) {
   return
 }
 
-// CHECK-LABEL: func @matmul(%{{.*}}: memref<?xi8>,
-// CHECK-SAME: [[M:arg[0-9]+]]: index
-// CHECK-SAME: [[N:arg[0-9]+]]: index
-// CHECK-SAME: [[K:arg[0-9]+]]: index
-//       CHECK: %[[A:.*]] = memref.view %{{.*}} : memref<?xi8> to memref<?x?xf32>
-//       CHECK: %[[B:.*]] = memref.view %{{.*}} : memref<?xi8> to memref<?x?xf32>
-//       CHECK: %[[C:.*]] = memref.view %{{.*}} : memref<?xi8> to memref<?x?xf32>
-//       CHECK: affine.for
-//       CHECK:   affine.for
-//       CHECK:     affine.for
-//   CHECK-DAG:       %[[a:.*]] = affine.load %[[A]]{{.*}} : memref<?x?xf32>
-//   CHECK-DAG:       %[[b:.*]] = affine.load %[[B]]{{.*}} : memref<?x?xf32>
-//   CHECK-DAG:       %[[inc:.*]] = mulf %[[a]], %[[b]] : f32
-//   CHECK-DAG:       %[[c:.*]] = affine.load %[[C]]{{.*}} : memref<?x?xf32>
-//   CHECK-DAG:       %[[res:.*]] = addf %[[c]], %[[inc]] : f32
-//       CHECK:       affine.store %[[res]], %[[C]]{{.*}} : memref<?x?xf32>
-
-func @conv_view3(%arg0: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>, %arg1: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>, %arg2: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>) {
-  linalg.conv(%arg0, %arg1, %arg2) {strides = [2]}: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>, memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>, memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>
-  return
-}
-
-// CHECK-LABEL: func @conv_view3(
-//  CHECK: %{{.*}}: memref<?x?x?xf32, #[[$strided3D]]>, %{{.*}}: memref<?x?x?xf32, #[[$strided3D]]>, %{{.*}}: memref<?x?x?xf32, #[[$strided3D]]>) {
-//       CHECK:   %[[Z0:.*]] = memref.dim %arg0, %c0 : memref<?x?x?xf32, #[[$strided3D]]>
-//       CHECK:   %[[Q:.*]] = memref.dim %arg0, %c1 : memref<?x?x?xf32, #[[$strided3D]]>
-//       CHECK:   %[[K:.*]] = memref.dim %arg0, %c2 : memref<?x?x?xf32, #[[$strided3D]]>
-//       CHECK:   %[[B:.*]] = memref.dim %arg1, %c0 : memref<?x?x?xf32, #[[$strided3D]]>
-//       CHECK:   %[[X0:.*]] = memref.dim %arg2, %c1 : memref<?x?x?xf32, #[[$strided3D]]>
-//       CHECK:   affine.for {{.*}}0 to %[[B]] {
-//       CHECK:     affine.for {{.*}}0 to %[[X0]] {
-//       CHECK:       affine.for {{.*}}0 to %[[K]] {
-//       CHECK:         affine.for {{.*}}0 to %[[Q]] {
-//       CHECK:           affine.for {{.*}}0 to %[[Z0]] {
-//       CHECK:            %[[SUM:.*]] = affine.apply #[[$stride2Dilation1]]{{.*}}
-//       No padding needed here; only affine loads.
-//       CHECK-NEXT:       affine.load
-//       CHECK-NEXT:       affine.load
-
-func @conv_padding(%arg0: memref<?x?x?x?xf32>,
-                   %arg1: memref<?x?x?x?xf32>,
-                   %arg2: memref<?x?x?x?xf32>) {
-  linalg.conv(%arg0, %arg1, %arg2) {dilations = [1, 1],
-                                    padding = dense<[[0, 1], [1, 1]]> : tensor<2x2xi64>,
-                                    strides = [1, 1]} :
-    memref<?x?x?x?xf32>, memref<?x?x?x?xf32>, memref<?x?x?x?xf32>
-  return
-}
-// CHECK-LABEL: func @conv_padding
-//       CHECK: %{{.*}}: memref<?x?x?x?xf32>, %{{.*}}: memref<?x?x?x?xf32>, %{{.*}}: memref<?x?x?x?xf32>) {
-//       CHECK:   %[[ZERO:.*]] = constant 0.000000e+00 : f32
-//       CHECK:   %[[Z0:.*]] = memref.dim %arg0, %c0 : memref<?x?x?x?xf32>
-//       CHECK:   %[[Z1:.*]] = memref.dim %arg0, %c1 : memref<?x?x?x?xf32>
-//       CHECK:   %[[Q:.*]] =  memref.dim %arg0, %c2 : memref<?x?x?x?xf32>
-//       CHECK:   %[[K:.*]] =  memref.dim %arg0, %c3 : memref<?x?x?x?xf32>
-//       CHECK:   %[[B:.*]] =  memref.dim %arg1, %c0 : memref<?x?x?x?xf32>
-//       CHECK:   %[[X0:.*]] = memref.dim %arg2, %c1 : memref<?x?x?x?xf32>
-//       CHECK:   %[[X1:.*]] = memref.dim %arg2, %c2 : memref<?x?x?x?xf32>
-//       CHECK:   affine.for {{.*}}0 to %[[B]] {
-//       CHECK:     affine.for {{.*}}0 to %[[X0]] {
-//       CHECK:       affine.for {{.*}}0 to %[[X1]] {
-//       CHECK:         affine.for {{.*}}0 to %[[K]] {
-//       CHECK:           affine.for {{.*}}0 to %[[Q]] {
-//       CHECK:             affine.for {{.*}}0 to %[[Z0]] {
-//       CHECK:               affine.for {{.*}}0 to %[[Z1]] {
-//       CHECK:                 %[[SUM0:.*]] = affine.apply #{{.*}}
-//       CHECK:                 %[[SUM1:.*]] = affine.apply #{{.*}}
-//       CHECK:                 %[[IDX:.*]] = affine.max #[[$clampMinMap]](%[[SUM0]])
-//       CHECK:                 %[[IDY:.*]] = affine.max #[[$clampMinMap]](%[[SUM1]])
-// Padded conv involves an affine.max in the memory access and this is not
-// allowed by affine.load. Use memref.load in such cases.
-//       CHECK:                 memref.load %{{.*}}[%{{.*}}, %[[IDX]], %[[IDY]], %{{.*}}] : memref<?x?x?x?xf32>
-//       CHECK:                 select {{.*}} : f32
-//       CHECK:                 affine.load
-//       CHECK:                 mulf {{.*}} : f32
-//       CHECK:                 affine.load
-//       CHECK:                 addf {{.*}} : f32
-//       CHECK:                 affine.store
-
 //----------------------------------------------------------------------------//
 // Named ops to loops.
 //----------------------------------------------------------------------------//

diff  --git a/mlir/test/Dialect/Linalg/fusion-pattern.mlir b/mlir/test/Dialect/Linalg/fusion-pattern.mlir
index 84f228b712d1d..c7396ae32e0d3 100644
--- a/mlir/test/Dialect/Linalg/fusion-pattern.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-pattern.mlir
@@ -441,14 +441,12 @@ module {
 // -----
 
 module {
-  func @basic_conv_fusion(%arg0: memref<?x?x?x?xf32>, %arg1: memref<?x?x?x?xf32>,
-                          %arg2: memref<?x?x?x?xf32>) {
+  func @basic_conv_fusion(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>,
+                          %arg2: memref<?x?xf32>) {
     %cst = constant 0.000000e+00 : f32
-    linalg.fill(%cst, %arg2) : f32, memref<?x?x?x?xf32>
-    linalg.conv(%arg0, %arg1, %arg2) {
-      dilations = [1, 1], strides = [1, 1],
-      __internal_linalg_transform__ = "basic_fusion"} :
-      memref<?x?x?x?xf32>, memref<?x?x?x?xf32>, memref<?x?x?x?xf32>
+    linalg.fill(%cst, %arg2) : f32, memref<?x?xf32>
+    linalg.conv_2d {__internal_linalg_transform__ = "basic_fusion"}
+      ins(%arg1, %arg0 : memref<?x?xf32>, memref<?x?xf32>) outs(%arg2 : memref<?x?xf32>)
     return
   }
 }
@@ -459,8 +457,8 @@ module {
 // CHECK-SAME:  {
 //      CHECK:    linalg.fill
 // CHECK-SAME:      __internal_linalg_transform__ = "after_basic_fusion_producer"
-//      CHECK:    linalg.conv
+//      CHECK:    linalg.conv_2d
 // CHECK-SAME:      __internal_linalg_transform__ = "after_basic_fusion"
 //      CHECK:  }
-//      CHECK:  linalg.conv
+//      CHECK:  linalg.conv_2d
 // CHECK-SAME:    __internal_linalg_transform__ = "after_basic_fusion_original"

diff  --git a/mlir/test/Dialect/Linalg/fusion.mlir b/mlir/test/Dialect/Linalg/fusion.mlir
index a69db827889d3..dbc941cc0bfa2 100644
--- a/mlir/test/Dialect/Linalg/fusion.mlir
+++ b/mlir/test/Dialect/Linalg/fusion.mlir
@@ -672,44 +672,33 @@ func @fusion_of_three(%arg0: memref<100x10xf32>,
 
 // -----
 
-#map0 = affine_map<(d0, d1, d2) -> (d0, d1 - d2)>
-#map1 = affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3, s4] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3 + d3 * s4)>
-#map2 = affine_map<()[s0] -> (s0 + 3)>
 
-func @fill_and_conv(%arg0: memref<?x?x?x?xf32>, %arg1: memref<2x3x1x1xf32>, %arg2: memref<?x?x?x?xf32>) {
-  %cst = constant 0.000000e+00 : f32
-  linalg.fill(%cst, %arg2) : f32, memref<?x?x?x?xf32>
+#map0 = affine_map<(d0)[s0] -> (2, -d0 + s0)>
+#map1 = affine_map<(d0)[s0] -> (3, -d0 + s0)>
+#map2 = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
+#map3 = affine_map<(d0)[s0, s1] -> (s0 + 1, -d0 + s0 + s1)>
+#map4 = affine_map<(d0)[s0, s1] -> (s0 + 2, -d0 + s0 + s1)>
 
-  %c4 = constant 4 : index
-  %c1 = constant 1 : index
-  %c0 = constant 0 : index
+func @fill_and_conv(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?xf32>) {
+  %cst = constant 0.000000e+00 : f32
   %c2 = constant 2 : index
   %c3 = constant 3 : index
-  %4 = memref.dim %arg1, %c0 : memref<2x3x1x1xf32>
-  %5 = memref.dim %arg1, %c1 : memref<2x3x1x1xf32>
-  %6 = memref.dim %arg0, %c0 : memref<?x?x?x?xf32>
-  %7 = memref.dim %arg0, %c1 : memref<?x?x?x?xf32>
-  %8 = memref.dim %arg0, %c3 : memref<?x?x?x?xf32>
-  %9 = memref.dim %arg2, %c0 : memref<?x?x?x?xf32>
-  %10 = memref.dim %arg2, %c1 : memref<?x?x?x?xf32>
-  %11 = memref.dim %arg2, %c2 : memref<?x?x?x?xf32>
-  %12 = memref.dim %arg2, %c3 : memref<?x?x?x?xf32>
-  %13 = linalg.range %c0 : %6 : %c2 : !linalg.range
-  %14 = linalg.range %c0 : %10 : %c3 : !linalg.range
-  scf.for %arg3 = %c0 to %6 step %c2 {
-    scf.for %arg4 = %c0 to %10 step %c3 {
-      %15 = affine.min #map0(%c2, %c1, %arg3)
-      %16 = affine.apply #map2()[%7]
-      %17 = affine.min #map0(%16, %c4, %arg4)
-      %18 = memref.dim %arg0, %c2 : memref<?x?x?x?xf32>
-      %19 = memref.dim %arg0, %c3 : memref<?x?x?x?xf32>
-      %20 = memref.subview %arg0[%arg3, %arg4, %c0, %c0] [%15, %17, %18, %19] [%c1, %c1, %c1, %c1] : memref<?x?x?x?xf32> to memref<?x?x?x?xf32, #map1>
-      %21 = affine.min #map0(%c2, %c1, %arg3)
-      %22 = affine.min #map0(%c3, %c4, %arg4)
-      %23 = memref.dim %arg2, %c2 : memref<?x?x?x?xf32>
-      %24 = memref.dim %arg2, %c3 : memref<?x?x?x?xf32>
-      %25 = memref.subview %arg2[%arg3, %arg4, %c0, %c0] [%21, %22, %23, %24] [%c1, %c1, %c1, %c1] : memref<?x?x?x?xf32> to memref<?x?x?x?xf32, #map1>
-      linalg.conv(%arg1, %20, %25) {dilations = [1, 1], strides = [1, 1]} : memref<2x3x1x1xf32>, memref<?x?x?x?xf32, #map1>, memref<?x?x?x?xf32, #map1>
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  linalg.fill(%cst, %arg0) : f32, memref<?x?xf32>
+  %2 = memref.dim %arg1, %c0 : memref<?x?xf32>
+  %3 = memref.dim %arg1, %c1 : memref<?x?xf32>
+  %4 = memref.dim %arg2, %c0 : memref<?x?xf32>
+  %5 = memref.dim %arg2, %c1 : memref<?x?xf32>
+  scf.for %arg3 = %c0 to %4 step %c2 {
+    scf.for %arg4 = %c0 to %5 step %c3 {
+      %6 = affine.min #map3(%arg3)[%2, %4]
+      %7 = affine.min #map4(%arg4)[%3, %5]
+      %8 = memref.subview %arg0[%arg3, %arg4] [%6, %7] [1, 1] : memref<?x?xf32> to memref<?x?xf32, #map2>
+      %9 = affine.min #map0(%arg3)[%4]
+      %10 = affine.min #map1(%arg4)[%5]
+      %11 = memref.subview %arg2[%arg3, %arg4] [%9, %10] [1, 1] : memref<?x?xf32> to memref<?x?xf32, #map2>
+      linalg.conv_2d ins(%8, %arg1 : memref<?x?xf32, #map2>, memref<?x?xf32>) outs(%11 : memref<?x?xf32, #map2>)
     }
   }
   return
@@ -718,7 +707,7 @@ func @fill_and_conv(%arg0: memref<?x?x?x?xf32>, %arg1: memref<2x3x1x1xf32>, %arg
 // CHECK: scf.for
 // CHECK:   scf.for
 // CHECK:     linalg.fill
-// CHECK:     linalg.conv
+// CHECK:     linalg.conv_2d
 
 // -----
 

diff  --git a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
index abfeda6b738db..10926e7ac65be 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
@@ -1,32 +1,5 @@
 // RUN: mlir-opt %s -split-input-file -linalg-generalize-named-ops | FileCheck %s
 
-func @generalize_conv(%input : memref<1x449x562x3xf32>, %filter: memref<3x3x3x32xf32>, %output: memref<1x112x112x32xf32>) {
-  linalg.conv(%filter, %input, %output) {dilations = [2, 3], strides = [4, 5]} : memref<3x3x3x32xf32>, memref<1x449x562x3xf32>, memref<1x112x112x32xf32>
-  return
-}
-
-// CHECK: #[[FILTER_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d5, d6, d4, d3)>
-// CHECK:  #[[INPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 * 4 + d5 * 2, d2 * 5 + d6 * 3, d4)>
-// CHECK: #[[OUTPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
-
-// CHECK: func @generalize_conv
-// CHECK-SAME:  %[[INPUT:.+]]: memref<1x449x562x3xf32>
-// CHECK-SAME: %[[FILTER:.+]]: memref<3x3x3x32xf32>
-// CHECK-SAME: %[[OUTPUT:.+]]: memref<1x112x112x32xf32>
-
-// CHECK: linalg.generic
-// CHECK-SAME: indexing_maps = [#[[FILTER_MAP]], #[[INPUT_MAP]], #[[OUTPUT_MAP]]]
-// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "window", "window"]
-// CHECK-SAME:  ins(%[[FILTER]], %[[INPUT]]
-// CHECK-SAME: outs(%[[OUTPUT]]
-
-// CHECK: ^{{.*}}(%[[FILTER_ARG:.+]]: f32, %[[INPUT_ARG:.+]]: f32, %[[OUTPUT_ARG:.+]]: f32)
-// CHECK:   %[[MUL:.+]] = mulf %[[FILTER_ARG]], %[[INPUT_ARG]]
-// CHECK:   %[[ADD:.+]] = addf %[[MUL]], %[[OUTPUT_ARG]]
-// CHECK:   linalg.yield %[[ADD]]
-
-// -----
-
 func @generalize_matmul_buffer(%A : memref<16x8xf32>, %B: memref<8x32xf32>, %C: memref<16x32xf32>) {
   linalg.matmul ins(%A, %B: memref<16x8xf32>, memref<8x32xf32>)
                outs(%C: memref<16x32xf32>)

diff  --git a/mlir/test/Dialect/Linalg/loops.mlir b/mlir/test/Dialect/Linalg/loops.mlir
index 4ba52527b27a3..fb74790eff05a 100644
--- a/mlir/test/Dialect/Linalg/loops.mlir
+++ b/mlir/test/Dialect/Linalg/loops.mlir
@@ -7,24 +7,12 @@
 // CHECK-DAG: #[[$strided1D:.*]] = affine_map<(d0)[s0] -> (d0 + s0)>
 // CHECK-DAG: #[[$strided2D:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
 // CHECK-DAG: #[[$strided3D:.*]] = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2 + d2)>
-// CHECK-DAG: #[[$strided4D:.*]] = affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3 + d3)>
-// CHECK-DAG: #[[$clampMinMap:.*]] = affine_map<(d0) -> (d0, 0)>
-
 // CHECK-DAG: #[[$stride1Dilation1:.*]] = affine_map<(d0, d1) -> (d0  + d1)>
-// CHECK-DAG: #[[$stride2Dilation1:.*]] = affine_map<(d0, d1) -> (d0 * 2 + d1)>
-// CHECK-DAG: #[[$stride2Dilation4:.*]] = affine_map<(d0, d1) -> (d0 * 2 + d1 * 4)>
-// CHECK-DAG: #[[$stride3Dilation5:.*]] = affine_map<(d0, d1) -> (d0 * 3 + d1 * 5)>
 
 // CHECKPARALLEL-DAG: #[[$strided1D:.*]] = affine_map<(d0)[s0] -> (d0 + s0)>
 // CHECKPARALLEL-DAG: #[[$strided2D:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
 // CHECKPARALLEL-DAG: #[[$strided3D:.*]] = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2 + d2)>
-// CHECKPARALLEL-DAG: #[[$strided4D:.*]] = affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3 + d3)>
-// CHECKPARALLEL-DAG: #[[$clampMinMap:.*]] = affine_map<(d0) -> (d0, 0)>
-
 // CHECKPARALLEL-DAG: #[[$stride1Dilation1:.*]] = affine_map<(d0, d1) -> (d0  + d1)>
-// CHECKPARALLEL-DAG: #[[$stride2Dilation1:.*]] = affine_map<(d0, d1) -> (d0 * 2 + d1)>
-// CHECKPARALLEL-DAG: #[[$stride2Dilation4:.*]] = affine_map<(d0, d1) -> (d0 * 2 + d1 * 4)>
-// CHECKPARALLEL-DAG: #[[$stride3Dilation5:.*]] = affine_map<(d0, d1) -> (d0 * 3 + d1 * 5)>
 
 func @matmul(%arg0: memref<?xi8>, %M: index, %N: index, %K: index) {
   %c0 = constant 0 : index
@@ -265,163 +253,6 @@ func @copy_view3(%arg0: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>, %arg1:
 //       CHECKPARALLEL:     %[[L:.*]] = memref.load {{.*}} : memref<?x?x?xf32, #[[$strided3D]]>
 //       CHECKPARALLEL:     store %[[L]], {{.*}} : memref<?x?x?xf32, #[[$strided3D]]>
 
-func @conv_view3(%arg0: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>, %arg1: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>, %arg2: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>) {
-  linalg.conv(%arg0, %arg1, %arg2) {strides = [2]}: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>, memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>, memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>
-  return
-}
-// CHECK-LABEL: func @conv_view3(
-//       CHECK: %{{.*}}: memref<?x?x?xf32, #[[$strided3D]]>, %{{.*}}: memref<?x?x?xf32, #[[$strided3D]]>, %{{.*}}: memref<?x?x?xf32, #[[$strided3D]]>) {
-//       CHECK:   %[[Z0:.*]] = memref.dim %arg0, %c0 : memref<?x?x?xf32, #[[$strided3D]]>
-//       CHECK:   %[[Q:.*]] = memref.dim %arg0, %c1 : memref<?x?x?xf32, #[[$strided3D]]>
-//       CHECK:   %[[K:.*]] = memref.dim %arg0, %c2 : memref<?x?x?xf32, #[[$strided3D]]>
-//       CHECK:   %[[B:.*]] = memref.dim %arg1, %c0 : memref<?x?x?xf32, #[[$strided3D]]>
-//       CHECK:   %[[X0:.*]] = memref.dim %arg2, %c1 : memref<?x?x?xf32, #[[$strided3D]]>
-//       CHECK:   scf.for {{.*}} to %[[B]]
-//       CHECK:     scf.for {{.*}} to %[[X0]]
-//       CHECK:       scf.for {{.*}} to %[[K]]
-//       CHECK:         scf.for {{.*}} to %[[Q]]
-//       CHECK:           scf.for {{.*}} to %[[Z0]]
-//       CHECK:             %[[SUM:.*]] = affine.apply #[[$stride2Dilation1]]
-//       CHECK:             memref.load %{{.*}}[%{{.*}}, %[[SUM]], %{{.*}}] : memref<?x?x?xf32, #[[$strided3D]]>
-//       CHECK:             memref.load {{.*}} : memref<?x?x?xf32, #[[$strided3D]]>
-//       CHECK:             mulf
-//       CHECK:             memref.load {{.*}} : memref<?x?x?xf32, #[[$strided3D]]>
-//       CHECK:             addf
-//       CHECK:             store %{{.*}}, {{.*}} : memref<?x?x?xf32, #[[$strided3D]]>
-
-// CHECKPARALLEL-LABEL: func @conv_view3(
-//       CHECKPARALLEL: %{{.*}}: memref<?x?x?xf32, #[[$strided3D]]>, %{{.*}}: memref<?x?x?xf32, #[[$strided3D]]>, %{{.*}}: memref<?x?x?xf32, #[[$strided3D]]>) {
-//       CHECKPARALLEL:   %[[Z0:.*]] = memref.dim %arg0, %c0 : memref<?x?x?xf32, #[[$strided3D]]>
-//       CHECKPARALLEL:   %[[Q:.*]] = memref.dim %arg0, %c1 : memref<?x?x?xf32, #[[$strided3D]]>
-//       CHECKPARALLEL:   %[[K:.*]] = memref.dim %arg0, %c2 : memref<?x?x?xf32, #[[$strided3D]]>
-//       CHECKPARALLEL:   %[[B:.*]] = memref.dim %arg1, %c0 : memref<?x?x?xf32, #[[$strided3D]]>
-//       CHECKPARALLEL:   %[[X0:.*]] = memref.dim %arg2, %c1 : memref<?x?x?xf32, #[[$strided3D]]>
-//       CHECKPARALLEL:   scf.parallel (%{{.*}}, %{{.*}}, %{{.*}}) = (%{{.*}}, %{{.*}}, %{{.*}}) to (%[[B]], %[[X0]], %[[K]]) step (%{{.*}}, %{{.*}}, %{{.*}}) {
-//       CHECKPARALLEL:     scf.for {{.*}} to %[[Q]]
-//       CHECKPARALLEL:       scf.for {{.*}} to %[[Z0]]
-//       CHECKPARALLEL:         %[[SUM:.*]] = affine.apply #[[$stride2Dilation1]]
-//       CHECKPARALLEL:         memref.load %{{.*}}[%{{.*}}, %[[SUM]], %{{.*}}] : memref<?x?x?xf32, #[[$strided3D]]>
-//       CHECKPARALLEL:         memref.load {{.*}} : memref<?x?x?xf32, #[[$strided3D]]>
-//       CHECKPARALLEL:         mulf
-//       CHECKPARALLEL:         memref.load {{.*}} : memref<?x?x?xf32, #[[$strided3D]]>
-//       CHECKPARALLEL:         addf
-//       CHECKPARALLEL:         store %{{.*}}, {{.*}} : memref<?x?x?xf32, #[[$strided3D]]>
-
-func @conv_view4(%arg0: memref<?x?x?x?xf32, offset: ?, strides: [?, ?, ?, 1]>, %arg1: memref<?x?x?x?xf32, offset: ?, strides: [?, ?, ?, 1]>, %arg2: memref<?x?x?x?xf32, offset: ?, strides: [?, ?, ?, 1]>) {
-  linalg.conv(%arg0, %arg1, %arg2) {dilations = [4, 5], strides = [2, 3]} : memref<?x?x?x?xf32, offset: ?, strides: [?, ?, ?, 1]>, memref<?x?x?x?xf32, offset: ?, strides: [?, ?, ?, 1]>, memref<?x?x?x?xf32, offset: ?, strides: [?, ?, ?, 1]>
-  return
-}
-// CHECK-LABEL: func @conv_view4(
-//       CHECK: %{{.*}}: memref<?x?x?x?xf32, #[[$strided4D]]>, %{{.*}}: memref<?x?x?x?xf32, #[[$strided4D]]>, %{{.*}}: memref<?x?x?x?xf32, #[[$strided4D]]>) {
-//       CHECK:   %[[Z0:.*]] = memref.dim %arg0, %c0 : memref<?x?x?x?xf32, #[[$strided4D]]>
-//       CHECK:   %[[Z1:.*]] = memref.dim %arg0, %c1 : memref<?x?x?x?xf32, #[[$strided4D]]>
-//       CHECK:   %[[Q:.*]] = memref.dim %arg0, %c2 : memref<?x?x?x?xf32, #[[$strided4D]]>
-//       CHECK:   %[[K:.*]] = memref.dim %arg0, %c3 : memref<?x?x?x?xf32, #[[$strided4D]]>
-//       CHECK:   %[[B:.*]] = memref.dim %arg1, %c0 : memref<?x?x?x?xf32, #[[$strided4D]]>
-//       CHECK:   %[[X0:.*]] = memref.dim %arg2, %c1 : memref<?x?x?x?xf32, #[[$strided4D]]>
-//       CHECK:   %[[X1:.*]] = memref.dim %arg2, %c2 : memref<?x?x?x?xf32, #[[$strided4D]]>
-//       CHECK:   scf.for {{.*}} to %[[B]]
-//       CHECK:     scf.for {{.*}} to %[[X0]]
-//       CHECK:       scf.for {{.*}} to %[[X1]]
-//       CHECK:         scf.for {{.*}} to %[[K]]
-//       CHECK:           scf.for {{.*}} to %[[Q]]
-//       CHECK:             scf.for {{.*}} to %[[Z0]]
-//       CHECK:               scf.for {{.*}} to %[[Z1]]
-//       CHECK:                 %[[SUM0:.*]] = affine.apply #[[$stride2Dilation4]]
-//       CHECK:                 %[[SUM1:.*]] = affine.apply #[[$stride3Dilation5]]
-//       CHECK:                 memref.load %{{.*}}[%{{.*}}, %[[SUM0]], %[[SUM1]], %{{.*}}] : memref<?x?x?x?xf32, #[[$strided4D]]>
-//       CHECK:                 memref.load {{.*}} : memref<?x?x?x?xf32, #[[$strided4D]]>
-//       CHECK:                 mulf
-//       CHECK:                 memref.load {{.*}} : memref<?x?x?x?xf32, #[[$strided4D]]>
-//       CHECK:                 addf
-//       CHECK:                 store %{{.*}}, {{.*}} : memref<?x?x?x?xf32, #[[$strided4D]]>
-
-// CHECKPARALLEL-LABEL: func @conv_view4(
-//       CHECKPARALLEL: %{{.*}}: memref<?x?x?x?xf32, #[[$strided4D]]>, %{{.*}}: memref<?x?x?x?xf32, #[[$strided4D]]>, %{{.*}}: memref<?x?x?x?xf32, #[[$strided4D]]>) {
-//       CHECKPARALLEL:   %[[Z0:.*]] = memref.dim %arg0, %c0 : memref<?x?x?x?xf32, #[[$strided4D]]>
-//       CHECKPARALLEL:   %[[Z1:.*]] = memref.dim %arg0, %c1 : memref<?x?x?x?xf32, #[[$strided4D]]>
-//       CHECKPARALLEL:   %[[Q:.*]] = memref.dim %arg0, %c2 : memref<?x?x?x?xf32, #[[$strided4D]]>
-//       CHECKPARALLEL:   %[[K:.*]] = memref.dim %arg0, %c3 : memref<?x?x?x?xf32, #[[$strided4D]]>
-//       CHECKPARALLEL:   %[[B:.*]] = memref.dim %arg1, %c0 : memref<?x?x?x?xf32, #[[$strided4D]]>
-//       CHECKPARALLEL:   %[[X0:.*]] = memref.dim %arg2, %c1 : memref<?x?x?x?xf32, #[[$strided4D]]>
-//       CHECKPARALLEL:   %[[X1:.*]] = memref.dim %arg2, %c2 : memref<?x?x?x?xf32, #[[$strided4D]]>
-//       CHECKPARALLEL:   scf.parallel (%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) = (%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) to (%[[B]], %[[X0]], %[[X1]], %[[K]]) step (%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) {
-//       CHECKPARALLEL:     scf.for {{.*}} to %[[Q]]
-//       CHECKPARALLEL:       scf.for {{.*}} to %[[Z0]]
-//       CHECKPARALLEL:         scf.for {{.*}} to %[[Z1]]
-//       CHECKPARALLEL:           %[[SUM0:.*]] = affine.apply #[[$stride2Dilation4]]
-//       CHECKPARALLEL:           %[[SUM1:.*]] = affine.apply #[[$stride3Dilation5]]
-//       CHECKPARALLEL:           memref.load %{{.*}}[%{{.*}}, %[[SUM0]], %[[SUM1]], %{{.*}}] : memref<?x?x?x?xf32, #[[$strided4D]]>
-//       CHECKPARALLEL:           memref.load {{.*}} : memref<?x?x?x?xf32, #[[$strided4D]]>
-//       CHECKPARALLEL:           mulf
-//       CHECKPARALLEL:           memref.load {{.*}} : memref<?x?x?x?xf32, #[[$strided4D]]>
-//       CHECKPARALLEL:           addf
-//       CHECKPARALLEL:           store %{{.*}}, {{.*}} : memref<?x?x?x?xf32, #[[$strided4D]]>
-
-func @conv_padding(%arg0: memref<?x?x?x?xf32>,
-                   %arg1: memref<?x?x?x?xf32>,
-                   %arg2: memref<?x?x?x?xf32>) {
-  linalg.conv(%arg0, %arg1, %arg2) {dilations = [1, 1],
-                                    padding = dense<[[0, 1], [1, 1]]> : tensor<2x2xi64>,
-                                    strides = [1, 1]} :
-    memref<?x?x?x?xf32>, memref<?x?x?x?xf32>, memref<?x?x?x?xf32>
-  return
-}
-// CHECK-LABEL: func @conv_padding
-//       CHECK: %{{.*}}: memref<?x?x?x?xf32>, %{{.*}}: memref<?x?x?x?xf32>, %{{.*}}: memref<?x?x?x?xf32>) {
-//       CHECK:   %[[ZERO:.*]] = constant 0.000000e+00 : f32
-//       CHECK:   %[[Z0:.*]] = memref.dim %arg0, %c0 : memref<?x?x?x?xf32>
-//       CHECK:   %[[Z1:.*]] = memref.dim %arg0, %c1 : memref<?x?x?x?xf32>
-//       CHECK:   %[[Q:.*]] =  memref.dim %arg0, %c2 : memref<?x?x?x?xf32>
-//       CHECK:   %[[K:.*]] =  memref.dim %arg0, %c3 : memref<?x?x?x?xf32>
-//       CHECK:   %[[B:.*]] =  memref.dim %arg1, %c0 : memref<?x?x?x?xf32>
-//       CHECK:   %[[X0:.*]] = memref.dim %arg2, %c1 : memref<?x?x?x?xf32>
-//       CHECK:   %[[X1:.*]] = memref.dim %arg2, %c2 : memref<?x?x?x?xf32>
-//       CHECK:   scf.for {{.*}} to %[[B]]
-//       CHECK:     scf.for {{.*}} to %[[X0]]
-//       CHECK:       scf.for {{.*}} to %[[X1]]
-//       CHECK:         scf.for {{.*}} to %[[K]]
-//       CHECK:           scf.for {{.*}} to %[[Q]]
-//       CHECK:             scf.for {{.*}} to %[[Z0]]
-//       CHECK:               scf.for {{.*}} to %[[Z1]]
-//       CHECK:                 %[[SUM0:.*]] = affine.apply #{{.*}}
-//       CHECK:                 %[[SUM1:.*]] = affine.apply #{{.*}}
-//       CHECK:                 %[[IDX:.*]] = affine.max #[[$clampMinMap]](%[[SUM0]])
-//       CHECK:                 %[[IDY:.*]] = affine.max #[[$clampMinMap]](%[[SUM1]])
-//       CHECK:                 memref.load %{{.*}}[%{{.*}}, %[[IDX]], %[[IDY]], %{{.*}}] : memref<?x?x?x?xf32>
-//       CHECK:                 select %{{.*}},
-//       CHECK:                 memref.load {{.*}} : memref<?x?x?x?xf32>
-//       CHECK:                 mulf
-//       CHECK:                 memref.load {{.*}} : memref<?x?x?x?xf32>
-//       CHECK:                 addf
-//       CHECK:                 store %{{.*}}, {{.*}} : memref<?x?x?x?xf32>
-
-// CHECKPARALLEL-LABEL: func @conv_padding
-//       CHECKPARALLEL: %{{.*}}: memref<?x?x?x?xf32>, %{{.*}}: memref<?x?x?x?xf32>, %{{.*}}: memref<?x?x?x?xf32>) {
-//       CHECKPARALLEL:   %[[ZERO:.*]] = constant 0.000000e+00 : f32
-//       CHECKPARALLEL:   %[[Z0:.*]] = memref.dim %arg0, %c0 : memref<?x?x?x?xf32>
-//       CHECKPARALLEL:   %[[Z1:.*]] = memref.dim %arg0, %c1 : memref<?x?x?x?xf32>
-//       CHECKPARALLEL:   %[[Q:.*]] =  memref.dim %arg0, %c2 : memref<?x?x?x?xf32>
-//       CHECKPARALLEL:   %[[K:.*]] =  memref.dim %arg0, %c3 : memref<?x?x?x?xf32>
-//       CHECKPARALLEL:   %[[B:.*]] =  memref.dim %arg1, %c0 : memref<?x?x?x?xf32>
-//       CHECKPARALLEL:   %[[X0:.*]] = memref.dim %arg2, %c1 : memref<?x?x?x?xf32>
-//       CHECKPARALLEL:   %[[X1:.*]] = memref.dim %arg2, %c2 : memref<?x?x?x?xf32>
-//       CHECKPARALLEL:   scf.parallel (%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) = (%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) to (%[[B]], %[[X0]], %[[X1]], %[[K]]) step (%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) {
-//       CHECKPARALLEL:     scf.for {{.*}} to %[[Q]]
-//       CHECKPARALLEL:       scf.for {{.*}} to %[[Z0]]
-//       CHECKPARALLEL:         scf.for {{.*}} to %[[Z1]]
-//       CHECKPARALLEL:           %[[SUM0:.*]] = affine.apply #{{.*}}
-//       CHECKPARALLEL:           %[[SUM1:.*]] = affine.apply #{{.*}}
-//       CHECKPARALLEL:           %[[IDX:.*]] = affine.max #[[$clampMinMap]](%[[SUM0]])
-//       CHECKPARALLEL:           %[[IDY:.*]] = affine.max #[[$clampMinMap]](%[[SUM1]])
-//       CHECKPARALLEL:           memref.load %{{.*}}[%{{.*}}, %[[IDX]], %[[IDY]], %{{.*}}] : memref<?x?x?x?xf32>
-//       CHECKPARALLEL:           select %{{.*}},
-//       CHECKPARALLEL:           memref.load {{.*}} : memref<?x?x?x?xf32>
-//       CHECKPARALLEL:           mulf
-//       CHECKPARALLEL:           memref.load {{.*}} : memref<?x?x?x?xf32>
-//       CHECKPARALLEL:           addf
-//       CHECKPARALLEL:           store %{{.*}}, {{.*}} : memref<?x?x?x?xf32>
-
 #accesses = [
   affine_map<(i, j, k) -> (i, j)>,
   affine_map<(i, j, k) -> (i, j, k)>,

diff  --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index ad6935dc7b8ba..3a7be8b5a4c54 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -14,7 +14,6 @@
 // CHECK-DAG: #[[$strided2D:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
 // CHECK-DAG: #[[$strided3D:.*]] = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2 + d2)>
 // CHECK-DAG: #[[$strided3DT:.*]] = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d2 * s1 + s0 + d1 * s2 + d0)>
-// CHECK-DAG: #[[$strided6D:.*]] = affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3 + d3 * s4 + d4 * s5 + d5)>
 
 func @pad_dynamic(%arg0: tensor<1x2x2x?xf32>, %low: index, %high: index,
                   %pad_value: f32) -> tensor<6x?x?x?xf32> {
@@ -211,64 +210,6 @@ func @copy_view3(%arg0: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>,
 
 // -----
 
-
-func @conv_view3(%arg0: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>,
-                 %arg1: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>,
-                 %arg2: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>) {
-  linalg.conv(%arg0, %arg1, %arg2) : memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>,
-                                     memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>,
-                                     memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>
-  return
-}
-// CHECK-LABEL: func @conv_view3(
-//       CHECK:   linalg.conv(%{{.*}}, %{{.*}}, %{{.*}}) :
-//  CHECK-SAME:     memref<?x?x?xf32, #[[$strided3D]]>,
-//  CHECK-SAME:     memref<?x?x?xf32, #[[$strided3D]]>,
-//  CHECK-SAME:     memref<?x?x?xf32, #[[$strided3D]]>
-
-// -----
-
-
-func @conv_view6(%arg0: memref<?x?x?x?x?x?xf32, offset: ?, strides: [?, ?, ?, ?, ?, 1]>,
-                 %arg1: memref<?x?x?x?x?x?xf32, offset: ?, strides: [?, ?, ?, ?, ?, 1]>,
-                 %arg2: memref<?x?x?x?x?x?xf32, offset: ?, strides: [?, ?, ?, ?, ?, 1]>) {
-  linalg.conv(%arg0, %arg1, %arg2) {dilations = [4, 4, 5, 5], strides = [2, 2, 3, 3]} :
-    memref<?x?x?x?x?x?xf32, offset: ?, strides: [?, ?, ?, ?, ?, 1]>,
-    memref<?x?x?x?x?x?xf32, offset: ?, strides: [?, ?, ?, ?, ?, 1]>,
-    memref<?x?x?x?x?x?xf32, offset: ?, strides: [?, ?, ?, ?, ?, 1]>
-  return
-}
-// CHECK-LABEL: func @conv_view6(
-//       CHECK:   linalg.conv(%{{.*}}, %{{.*}}, %{{.*}}) {
-//  CHECK-SAME:     dilations = [4, 4, 5, 5], strides = [2, 2, 3, 3]} :
-//  CHECK-SAME:     memref<?x?x?x?x?x?xf32, #[[$strided6D]]>,
-//  CHECK-SAME:     memref<?x?x?x?x?x?xf32, #[[$strided6D]]>,
-//  CHECK-SAME:     memref<?x?x?x?x?x?xf32, #[[$strided6D]]>
-
-// -----
-
-func @conv_padding(%arg0: memref<?x?x?x?xf32>,
-                   %arg1: memref<?x?x?x?xf32>,
-                   %arg2: memref<?x?x?x?xf32>) {
-  linalg.conv(%arg0, %arg1, %arg2) {dilations = [1, 1],
-                                    padding = dense<[[0, 1], [1, 1]]> : tensor<2x2xi64>,
-                                    strides = [1, 1]} :
-    memref<?x?x?x?xf32>, memref<?x?x?x?xf32>, memref<?x?x?x?xf32>
-  return
-}
-
-// CHECK-LABEL: func @conv_padding(
-//       CHECK:   linalg.conv(%{{.*}}, %{{.*}}, %{{.*}}) {
-//  CHECK-SAME:     dilations = [1, 1],
-//  CHECK-SAME:     padding = dense<[
-//  CHECK-SAME:                      [0, 1], [1, 1]]> : tensor<2x2xi64>,
-//  CHECK-SAME:     strides = [1, 1]} :
-//  CHECK-SAME:     memref<?x?x?x?xf32>,
-//  CHECK-SAME:     memref<?x?x?x?xf32>,
-//  CHECK-SAME:     memref<?x?x?x?xf32>
-
-// -----
-
 #accesses_0 = [
   affine_map<(i, j, k) -> (j, i)>,
   affine_map<(i, j, k) -> ()>,

diff  --git a/mlir/test/Dialect/Linalg/tile-conv-padding.mlir b/mlir/test/Dialect/Linalg/tile-conv-padding.mlir
deleted file mode 100644
index d2bbc2d20e7e1..0000000000000
--- a/mlir/test/Dialect/Linalg/tile-conv-padding.mlir
+++ /dev/null
@@ -1,36 +0,0 @@
-// RUN: mlir-opt %s -linalg-tile="tile-sizes=2,3,0,0,4" | FileCheck %s -check-prefix=TILE-23004
-// RUN: mlir-opt %s -linalg-tile="tile-sizes=2" | FileCheck %s -check-prefix=TILE-20000
-
-// TILE-23004-DAG: #[[$strided4D:.*]] = affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3 + d3)>
-// TILE-20000-DAG: #[[$strided4D:.*]] = affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3 + d3)>
-// TILE-20000-DAG: #[[$minmap:.*]] = affine_map<(d0)[s0] -> (2, -d0 + s0)>
-
-func @conv_padding(%arg0: memref<?x?x?x?xf32, offset: ?, strides: [?, ?, ?, 1]>, %arg1: memref<?x?x?x?xf32, offset: ?, strides: [?, ?, ?, 1]>, %arg2: memref<?x?x?x?xf32, offset: ?, strides: [?, ?, ?, 1]>) {
-  linalg.conv(%arg0, %arg1, %arg2) {dilations = [10, 20], padding = dense<[[1, 1], [0, 1]]> : tensor<2x2xi64>, strides = [30, 40]} : memref<?x?x?x?xf32, offset: ?, strides: [?, ?, ?, 1]>, memref<?x?x?x?xf32, offset: ?, strides: [?, ?, ?, 1]>, memref<?x?x?x?xf32, offset: ?, strides: [?, ?, ?, 1]>
-  return
-}
-// TILE-23004-LABEL: func @conv_padding(
-//  TILE-23004-SAME:   %[[ARG0:[a-zA-Z0-9_]*]]: memref<?x?x?x?xf32, #[[$strided4D]]>
-//  TILE-23004-SAME:   %[[ARG1:[a-zA-Z0-9_]*]]: memref<?x?x?x?xf32, #[[$strided4D]]>
-//  TILE-23004-SAME:   %[[ARG2:[a-zA-Z0-9_]*]]: memref<?x?x?x?xf32, #[[$strided4D]]>)
-//       TILE-23004:         linalg.conv(%[[ARG0]], %[[ARG1]], %[[ARG2]])
-
-// TILE-20000-LABEL: func @conv_padding(
-//  TILE-20000-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref<?x?x?x?xf32, #[[$strided4D]]>
-//  TILE-20000-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref<?x?x?x?xf32, #[[$strided4D]]>
-//  TILE-20000-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: memref<?x?x?x?xf32, #[[$strided4D]]>)
-//   TILE-20000-DAG:   %[[C0:.*]] = constant 0 : index
-//   TILE-20000-DAG:   %[[C2:.*]] = constant 2 : index
-//       TILE-20000:   %[[B:.*]] = memref.dim %[[ARG1]], %c0
-//       TILE-20000:   scf.for %[[ivI:.*]] = %[[C0]] to %[[B]] step %[[C2]] {
-//       TILE-20000:     %[[EXTENT:.*]] = affine.min #[[$minmap]](%[[ivI]])[%[[B]]]
-//       TILE-20000:     %[[DIM11:.*]] = memref.dim %[[ARG1]], %c1
-//       TILE-20000:     %[[DIM12:.*]] = memref.dim %[[ARG1]], %c2
-//       TILE-20000:     %[[DIM13:.*]] = memref.dim %[[ARG1]], %c3
-//       TILE-20000:     %[[SUBVIEW1:.*]] = memref.subview %[[ARG1]][%[[ivI]], 0, 0, 0] [%[[EXTENT]], %[[DIM11]], %[[DIM12]], %[[DIM13]]]
-//       TILE-20000:     %[[EXTENT:.*]] = affine.min #[[$minmap]](%[[ivI]])[%[[B]]]
-//       TILE-20000:     %[[DIM21:.*]] = memref.dim %[[ARG2]], %c1
-//       TILE-20000:     %[[DIM22:.*]] = memref.dim %[[ARG2]], %c2
-//       TILE-20000:     %[[DIM23:.*]] = memref.dim %[[ARG2]], %c3
-//       TILE-20000:     %[[SUBVIEW2:.*]] = memref.subview %[[ARG2]][%[[ivI]], 0, 0, 0] [%[[EXTENT]], %[[DIM21]], %[[DIM22]], %[[DIM23]]]
-//       TILE-20000:     linalg.conv(%[[ARG0]], %[[SUBVIEW1]], %[[SUBVIEW2]])

diff  --git a/mlir/test/Dialect/Linalg/tile-conv.mlir b/mlir/test/Dialect/Linalg/tile-conv.mlir
index b4ee26a15ba5c..21fef9648bdc3 100644
--- a/mlir/test/Dialect/Linalg/tile-conv.mlir
+++ b/mlir/test/Dialect/Linalg/tile-conv.mlir
@@ -1,45 +1,35 @@
-// RUN: mlir-opt %s -linalg-tile="tile-sizes=2,3,0,0,4" | FileCheck %s -check-prefix=TILE-23004
+// RUN: mlir-opt %s -linalg-tile="tile-sizes=2,3" | FileCheck %s
 
-// TILE-23004-DAG: #[[$D0x30pS0x10:.*]] = affine_map<(d0) -> (d0 * 30)>
-// TILE-23004-DAG: #[[$S0x10p90D0x30pS1:.*]] = affine_map<(d0)[s0, s1] -> (s0 * 10 + 51, d0 * -30 + s0 * 10 + s1 * 30 - 39)>
-// TILE-23004-DAG: #[[$strided4D:.*]] = affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3 + d3)>
-// TILE-23004-DAG: #[[$bound_map_2:.*]] = affine_map<(d0)[s0] -> (2, -d0 + s0)>
-// TILE-23004-DAG: #[[$bound_map_3:.*]] = affine_map<(d0)[s0] -> (3, -d0 + s0)>
-// TILE-23004-DAG: #[[$bound_map_4:.*]] = affine_map<(d0)[s0] -> (4, -d0 + s0)>
+//  CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0)[s0, s1] -> (s0 + 1, -d0 + s0 + s1 - 1)>
+//  CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0)[s0, s1] -> (s0 + 2, -d0 + s0 + s1 - 1)>
+//  CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0)[s0] -> (2, -d0 + s0)>
+//  CHECK-DAG: #[[MAP3:.*]] = affine_map<(d0)[s0] -> (3, -d0 + s0)>
 
-func @conv(%arg0: memref<?x?x?x?xf32, offset: ?, strides: [?, ?, ?, 1]>, %arg1: memref<?x?x?x?xf32, offset: ?, strides: [?, ?, ?, 1]>, %arg2: memref<?x?x?x?xf32, offset: ?, strides: [?, ?, ?, 1]>) {
-  linalg.conv(%arg0, %arg1, %arg2) {dilations = [10, 20], strides = [30, 40]} : memref<?x?x?x?xf32, offset: ?, strides: [?, ?, ?, 1]>, memref<?x?x?x?xf32, offset: ?, strides: [?, ?, ?, 1]>, memref<?x?x?x?xf32, offset: ?, strides: [?, ?, ?, 1]>
+func @conv(%arg0 : memref<?x?xf32>, %arg1 : memref<?x?xf32>, %arg2 : memref<?x?xf32>) {
+  linalg.conv_2d ins(%arg0, %arg1 : memref<?x?xf32>, memref<?x?xf32>) outs(%arg2 : memref<?x?xf32>)
   return
 }
-//       TILE-23004: func @conv(
-//  TILE-23004-SAME:   %[[ARG0:[a-zA-Z0-9_]*]]: memref<?x?x?x?xf32, #[[$strided4D]]>
-//  TILE-23004-SAME:   %[[ARG1:[a-zA-Z0-9_]*]]: memref<?x?x?x?xf32, #[[$strided4D]]>
-//  TILE-23004-SAME:   %[[ARG2:[a-zA-Z0-9_]*]]: memref<?x?x?x?xf32, #[[$strided4D]]>)
-//   TILE-23004-DAG:   %[[C0:.*]] = constant 0 : index
-//   TILE-23004-DAG:   %[[C2:.*]] = constant 2 : index
-//   TILE-23004-DAG:   %[[C3:.*]] = constant 3 : index
-//   TILE-23004-DAG:   %[[C4:.*]] = constant 4 : index
-//       TILE-23004:   %[[Z0:.*]] = memref.dim %[[ARG0]], %c0 : memref<?x?x?x?xf32, #[[$strided4D]]>
-//       TILE-23004:   %[[Q:.*]] = memref.dim %[[ARG0]], %c2 : memref<?x?x?x?xf32, #[[$strided4D]]>
-//       TILE-23004:   %[[B:.*]] = memref.dim %[[ARG1]], %c0 : memref<?x?x?x?xf32, #[[$strided4D]]>
-//       TILE-23004:   %[[X0:.*]] = memref.dim %[[ARG2]], %c1 : memref<?x?x?x?xf32, #[[$strided4D]]>
-//       TILE-23004:   scf.for %[[ivI:.*]] = %{{.*}} to %[[B]] step %{{.*}} {
-//       TILE-23004:     scf.for %[[ivJ:.*]] = %{{.*}} to %[[X0]] step %{{.*}} {
-//       TILE-23004:       scf.for %[[ivK:.*]] = %{{.*}} to %[[Q]] step %{{.*}} {
-//       TILE-23004:         %[[Z0_1:.*]] = memref.dim %[[ARG0]], %c0 : memref<?x?x?x?xf32, #[[$strided4D]]>
-//       TILE-23004:         %[[Z1:.*]] = memref.dim %[[ARG0]], %c1 : memref<?x?x?x?xf32, #[[$strided4D]]>
-//       TILE-23004:         %[[szK:.*]] = affine.min #[[$bound_map_4]](%[[ivK]])[%[[Q]]]
-//       TILE-23004:         %[[K:.*]] = memref.dim %[[ARG0]], %c3 : memref<?x?x?x?xf32, #[[$strided4D]]>
-//       TILE-23004:         %[[FilterView:.*]] = memref.subview %{{.*}}[0, 0, %[[ivK]], 0] [%[[Z0_1]], %[[Z1]], %[[szK]], %[[K]]] [1, 1, 1, 1] : memref<?x?x?x?xf32, #[[$strided4D]]> to memref<?x?x?x?xf32, #[[$strided4D]]>
-//
-//       TILE-23004:         %[[J1:.*]] = affine.apply #[[$D0x30pS0x10]](%[[ivJ]])
-//       TILE-23004:         %[[I1pStep:.*]] = affine.min #[[$S0x10p90D0x30pS1]](%[[ivJ]])[%[[Z0]], %[[X0]]]
-//       TILE-23004:         %[[SZ2:.*]] = memref.dim %[[ARG1]], %c2 : memref<?x?x?x?xf32, #[[$strided4D]]>
-//       TILE-23004:         %[[sz3:.*]] = affine.min #[[$bound_map_4]](%[[ivK]])[%[[Q]]]
-//       TILE-23004:         %[[InputView:.*]] = memref.subview %{{.*}}[%[[ivI]], %[[J1]], 0, %[[ivK]]] [%{{.*}}, %{{.*}}, %[[SZ2]], %[[sz3]]] [1, 1, 1, 1] : memref<?x?x?x?xf32, #[[$strided4D]]> to memref<?x?x?x?xf32, #[[$strided4D]]>
-//
-//       TILE-23004:         %[[X0:.*]] = memref.dim %[[ARG2]], %c2 : memref<?x?x?x?xf32, #[[$strided4D]]>
-//       TILE-23004:         %[[X1:.*]] = memref.dim %[[ARG2]], %c3 : memref<?x?x?x?xf32, #[[$strided4D]]>
-//       TILE-23004:         %[[OutputView:.*]] = memref.subview %{{.*}}[%[[ivI]], %[[ivJ]], 0, 0] [%{{.*}}, %{{.*}}, %[[X0]], %[[X1]]] [1, 1, 1, 1] : memref<?x?x?x?xf32, #[[$strided4D]]> to memref<?x?x?x?xf32, #[[$strided4D]]>
-//
-//       TILE-23004:         linalg.conv(%[[FilterView]], %[[InputView]], %[[OutputView]]) {dilations = [10, 20], strides = [30, 40]} : memref<?x?x?x?xf32, #[[$strided4D]]>, memref<?x?x?x?xf32, #[[$strided4D]]>, memref<?x?x?x?xf32, #[[$strided4D]]>
+
+//       CHECK: func @conv
+//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]*]]: memref<?x?xf32>
+//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]*]]: memref<?x?xf32>
+//  CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]*]]: memref<?x?xf32>
+//   CHECK-DAG:   %[[C0:.*]] = constant 0 : index
+//   CHECK-DAG:   %[[C1:.*]] = constant 1 : index
+//   CHECK-DAG:   %[[C2:.*]] = constant 2 : index
+//   CHECK-DAG:   %[[C3:.*]] = constant 3 : index
+//   CHECK-DAG:   %[[T0:.*]] = memref.dim %[[ARG1]], %[[C0]]
+//   CHECK-DAG:   %[[T1:.*]] = memref.dim %[[ARG1]], %[[C1]]
+//   CHECK-DAG:   %[[T2:.*]] = memref.dim %[[ARG2]], %[[C0]]
+//   CHECK-DAG:   %[[T3:.*]] = memref.dim %[[ARG2]], %[[C1]]
+//       CHECK:   scf.for %[[ARG3:.*]] = %[[C0]] to %[[T2]] step %[[C2]]
+//       CHECK:     scf.for %[[ARG4:.*]] = %[[C0]] to %[[T3]] step %[[C3]]
+//       CHECK:       %[[T4:.*]] = affine.min #[[MAP0]](%[[ARG3]])[%[[T0]], %[[T2]]]
+//       CHECK:       %[[T5:.*]] = affine.min #[[MAP1]](%[[ARG4]])[%[[T1]], %[[T3]]]
+//       CHECK:       %[[SV1:.*]] = memref.subview %[[ARG0]][%[[ARG3]], %[[ARG4]]] [%[[T4]], %[[T5]]]
+//       CHECK:       %[[T6:.*]] = affine.min #[[MAP2]](%[[ARG3]])[%[[T2]]
+//       CHECK:       %[[T7:.*]] = affine.min #[[MAP3]](%[[ARG4]])[%[[T3]]]
+//       CHECK:       %[[SV2:.*]] = memref.subview %[[ARG2]][%[[ARG3]], %[[ARG4]]] [%[[T6]], %[[T7]]]
+//       CHECK:       linalg.conv_2d
+//  CHECK-SAME:         ins(%[[SV1]], %[[ARG1]]
+//  CHECK-SAME:         outs(%[[SV2]]

diff  --git a/mlir/test/Dialect/Linalg/tile-simple-conv.mlir b/mlir/test/Dialect/Linalg/tile-simple-conv.mlir
deleted file mode 100644
index 9d3e0a5cd745c..0000000000000
--- a/mlir/test/Dialect/Linalg/tile-simple-conv.mlir
+++ /dev/null
@@ -1,43 +0,0 @@
-// RUN: mlir-opt %s -linalg-tile="tile-sizes=2,3,4" | FileCheck %s
-
-//  CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0)[s0] -> (2, -d0 + s0)>
-//  CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0)[s0, s1] -> (s0 + 2, -d0 + s0 + s1 - 1)>
-//  CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0)[s0, s1] -> (s0 + 3, -d0 + s0 + s1 - 1)>
-//  CHECK-DAG: #[[MAP4:.*]] = affine_map<(d0)[s0] -> (3, -d0 + s0)>
-//  CHECK-DAG: #[[MAP5:.*]] = affine_map<(d0)[s0] -> (4, -d0 + s0)>
-
-func @conv(%arg0 : memref<?x?x?x?xf32>, %arg1 : memref<?x?x?x?xf32>, %arg2 : memref<?x?x?x?xf32>) {
-  linalg.conv(%arg0, %arg1, %arg2) : memref<?x?x?x?xf32>, memref<?x?x?x?xf32>, memref<?x?x?x?xf32>
-  return
-}
-
-//       CHECK: func @conv
-//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]*]]: memref<?x?x?x?xf32>
-//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]*]]: memref<?x?x?x?xf32>
-//  CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]*]]: memref<?x?x?x?xf32>
-//   CHECK-DAG:   %[[C0:.*]] = constant 0 : index
-//   CHECK-DAG:   %[[C1:.*]] = constant 1 : index
-//   CHECK-DAG:   %[[C2:.*]] = constant 2 : index
-//   CHECK-DAG:   %[[C3:.*]] = constant 3 : index
-//   CHECK-DAG:   %[[C4:.*]] = constant 4 : index
-//       CHECK:   %[[T0:.*]] = memref.dim %[[ARG0]], %[[C0]]
-//       CHECK:   %[[T1:.*]] = memref.dim %[[ARG0]], %[[C1]]
-//       CHECK:   %[[T2:.*]] = memref.dim %[[ARG1]], %[[C0]]
-//       CHECK:   %[[T3:.*]] = memref.dim %[[ARG2]], %[[C1]]
-//       CHECK:   %[[T4:.*]] = memref.dim %[[ARG2]], %[[C2]]
-//       CHECK:   scf.for %[[ARG3:.*]] = %[[C0]] to %[[T2]] step %[[C2]]
-//       CHECK:     scf.for %[[ARG4:.*]] = %[[C0]] to %[[T3]] step %[[C3]]
-//       CHECK:       scf.for %[[ARG5:.*]] = %[[C0]] to %[[T4]] step %[[C4]]
-//       CHECK:         %[[T6:.*]] = affine.min #[[MAP0]](%[[ARG3]])[%[[T2]]]
-//       CHECK:         %[[T8:.*]] = affine.min #[[MAP1]](%[[ARG4]])[%[[T0]], %[[T3]]]
-//       CHECK:         %[[T10:.*]] = affine.min #[[MAP2]](%[[ARG5]])[%[[T1]], %[[T4]]]
-//       CHECK:         %[[T11:.*]] = memref.dim %[[ARG1]], %[[C3]]
-//       CHECK:         %[[SV1:.*]] = memref.subview %[[ARG1]][%[[ARG3]], %[[ARG4]], %[[ARG5]], 0]
-//  CHECK-SAME:                                        [%[[T6]], %[[T8]], %[[T10]], %[[T11]]]
-//       CHECK:         %[[T14:.*]] = affine.min #[[MAP0]](%[[ARG3]])[%[[T2]]
-//       CHECK:         %[[T16:.*]] = affine.min #[[MAP4]](%[[ARG4]])[%[[T3]]]
-//       CHECK:         %[[T18:.*]] = affine.min #[[MAP5]](%[[ARG5]])[%[[T4]]
-//       CHECK:         %[[T19:.*]] = memref.dim %[[ARG2]], %[[C3]]
-//       CHECK:         %[[SV2:.*]] = memref.subview %[[ARG2]][%[[ARG3]], %[[ARG4]], %[[ARG5]], 0]
-//  CHECK-SAME:                                        [%[[T14]], %[[T16]], %[[T18]], %[[T19]]]
-//       CHECK:         linalg.conv(%[[ARG0]], %[[SV1]], %[[SV2]])

diff  --git a/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp
index 7b1ef87ab702d..069a17384f1f7 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp
@@ -26,7 +26,7 @@ static void fillFusionPatterns(MLIRContext *context,
                                const LinalgDependenceGraph &dependenceGraph,
                                RewritePatternSet &patterns) {
   patterns.add<LinalgTileAndFusePattern<MatmulOp>,
-               LinalgTileAndFusePattern<ConvOp>>(
+               LinalgTileAndFusePattern<Conv2DOp>>(
       context, dependenceGraph,
       LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType),
       LinalgFusionOptions().setIndicesToFuse({2}),


        


More information about the Mlir-commits mailing list