[Mlir-commits] [mlir] 6bb7d24 - [mlir][Linalg] Add a first vectorization pattern for conv1d in NWCxWCF format.

Nicolas Vasilache llvmlistbot at llvm.org
Wed Oct 20 07:03:02 PDT 2021


Author: Nicolas Vasilache
Date: 2021-10-20T13:54:18Z
New Revision: 6bb7d2474fe4d3a68e2d1efefaa0bc8a244737bb

URL: https://github.com/llvm/llvm-project/commit/6bb7d2474fe4d3a68e2d1efefaa0bc8a244737bb
DIFF: https://github.com/llvm/llvm-project/commit/6bb7d2474fe4d3a68e2d1efefaa0bc8a244737bb.diff

LOG: [mlir][Linalg] Add a first vectorization pattern for conv1d in NWCxWCF format.

This revision uses the newly refactored StructuredGenerator to create a simple vectorization for conv1d_nwc_wcf.

Note that the pattern is not specific to the op and is technically not even specific to the ConvolutionOpInterface (modulo minor details related to dilations and strides).

The overall design follows the same ideas as the lowering of vector::ContractionOp -> vector::OuterProduct: it seeks to be minimally complex, composable and extensible while avoiding inference analysis. Instead, we metaprogram the maps/indexings we expect and we match against them.

This is just a first stab and still needs to be evaluated for performance.
Other tradeoffs are possible that should be explored.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D111894

Added: 
    mlir/test/Dialect/Linalg/vectorize-convolution.mlir

Modified: 
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
    mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
    mlir/lib/Dialect/Vector/VectorTransforms.cpp
    mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 73c36be72404..f9b82b371904 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -41,11 +41,15 @@ bool skipUnitDimReshape(const OpResult &producer, OpOperand &consumer);
 //===----------------------------------------------------------------------===//
 using LinalgLoops = SmallVector<Operation *, 4>;
 
-/// Populates patterns for vectorization of all ConvN-D ops.
+/// [DEPRECATED] Populates patterns for vectorization of all ConvN-D ops.
 void populateConvVectorizationPatterns(
     MLIRContext *context, SmallVectorImpl<RewritePatternSet> &patterns,
     ArrayRef<int64_t> tileSizes);
 
+/// Populates patterns for vectorizing convolution ops.
+void populateConvolutionVectorizationPatterns(RewritePatternSet &patterns,
+                                              PatternBenefit benefit = 1);
+
 /// Populate patterns that convert `ElementwiseMappable` ops to linalg
 /// parallel loops.
 void populateElementwiseToLinalgConversionPatterns(RewritePatternSet &patterns);

diff  --git a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
index e318bd5cf3e1..bdd5909faebd 100644
--- a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
@@ -25,7 +25,7 @@
 
 namespace mlir {
 
-class PatternRewriter;
+class OpBuilder;
 
 /// Tests whether the given maps describe a row major matmul. The test is
 /// permutation-invariant. Note that this only checks the affine maps from an
@@ -161,8 +161,8 @@ class StructuredGenerator {
     Win() : IteratorType(getWindowIteratorTypeName()) {}
   };
 
-  StructuredGenerator(PatternRewriter &rewriter, StructuredOpInterface op)
-      : rewriter(rewriter), ctx(op.getContext()), loc(op.getLoc()),
+  StructuredGenerator(OpBuilder &builder, StructuredOpInterface op)
+      : builder(builder), ctx(op.getContext()), loc(op.getLoc()),
         iterators(op.iterator_types()), maps(op.getIndexingMaps()), op(op) {}
 
   bool iters(ArrayRef<IteratorType> its) {
@@ -181,7 +181,7 @@ class StructuredGenerator {
   }
 
 protected:
-  PatternRewriter &rewriter;
+  OpBuilder &builder;
   MLIRContext *ctx;
   Location loc;
   ArrayAttr iterators;

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index f2641e20cdf0..1b05cd6c378d 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -20,6 +20,7 @@
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
 #include "mlir/Dialect/Vector/VectorOps.h"
+#include "mlir/Dialect/Vector/VectorTransforms.h"
 #include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/PatternMatch.h"
@@ -44,6 +45,12 @@ using llvm::dbgs;
 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
 #define LDBG(X) LLVM_DEBUG(DBGS() << X)
 
+// Forward declarations.
+static LogicalResult vectorizeContraction(OpBuilder &b, LinalgOp linalgOp,
+                                          SmallVectorImpl<Value> &newResults);
+static FailureOr<Operation *>
+vectorizeConvolution(OpBuilder &b, ConvolutionOpInterface convOp);
+
 /// Return the unique instance of OpType in `block` if it is indeed unique.
 /// Return null if none or more than 1 instances exist.
 template <typename OpType>
@@ -147,7 +154,7 @@ static Operation *matchLinalgReduction(OpOperand *outputOperand) {
   auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
   unsigned outputPos =
       outputOperand->getOperandNumber() - linalgOp.getNumInputs();
-  // Only single combiner operatios are supported for now.
+  // Only single combiner operations are supported for now.
   SmallVector<Operation *, 4> combinerOps;
   if (!matchReduction(linalgOp.getRegionOutputArgs(), outputPos, combinerOps) ||
       combinerOps.size() != 1)
@@ -575,6 +582,11 @@ LogicalResult vectorizeAsLinalgGeneric(
   return success();
 }
 
+/// Helper function to vectorize a `linalgOp` with contraction semantics in a
+/// generic fashion.
+/// This helper is needed atm because the truly generic implementation requires
+/// good vector.multi_reduce folding patterns that are currently NYI.
+// TODO: drop reliance on a specific pattern.
 static LogicalResult vectorizeContraction(OpBuilder &b, LinalgOp linalgOp,
                                           SmallVectorImpl<Value> &newResults) {
   assert(isaContractionOpInterface(linalgOp) &&
@@ -664,6 +676,11 @@ LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) {
     return success();
   if (isaContractionOpInterface(linalgOp))
     return success();
+  // TODO: isaConvolutionOpInterface that can also infer from generic features.
+  // But we will still need stride/dilation attributes that will be annoying to
+  // reverse-engineer...
+  if (isa<ConvolutionOpInterface>(op))
+    return success();
   // TODO: the common vector shape is equal to the static loop sizes only when
   // all indexing maps are projected permutations. For convs and stencils the
   // logic will need to evolve.
@@ -688,6 +705,18 @@ mlir::linalg::vectorizeLinalgOp(OpBuilder &b, Operation *op,
   if (isaContractionOpInterface(linalgOp))
     return vectorizeContraction(b, linalgOp, newResults);
 
+  // TODO: isaConvolutionOpInterface that can also infer from generic features.
+  // But we will still need stride/dilation attributes that will be annoying to
+  // reverse-engineer...
+  if (auto convOp = dyn_cast<ConvolutionOpInterface>(op)) {
+    FailureOr<Operation *> resultOrFail = vectorizeConvolution(b, convOp);
+    if (failed(resultOrFail))
+      return failure();
+    Operation *newOp = *resultOrFail;
+    llvm::append_range(newResults, newOp->getResults());
+    return success();
+  }
+
   LDBG(""
        << "Vectorize linalg op as a generic by broadcasting to "
           "maximal common shape: "
@@ -1421,3 +1450,188 @@ LogicalResult LinalgCopyVTWForwardingPattern::matchAndRewrite(
 
   return success();
 }
+
+//===----------------------------------------------------------------------===//
+// Convolution vectorization patterns
+//===----------------------------------------------------------------------===//
+namespace {
+/// Generate a vector implementation for:
+/// ```
+///   Op def: (     n,     w,     c,    kw,    f  )
+///    Iters: ({Par(), Par(), Par(), Red(), Red()})
+///   Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
+/// ```
+/// w and kw are unrolled.
+/// TODO: do not unroll w (resp. kw) when the strideW ( resp. dilationW) is > 1.
+struct Conv1D_NWC_WCF_Generator : public StructuredGenerator<LinalgOp> {
+  Conv1D_NWC_WCF_Generator(OpBuilder &builder, LinalgOp linalgOp, int strideW,
+                           int dilationW)
+      : StructuredGenerator<LinalgOp>(builder, linalgOp), valid(false),
+        strideW(strideW), dilationW(dilationW) {
+    // Determine whether `linalgOp` can be generated with this generator
+    if (linalgOp.getNumInputs() != 2 || linalgOp.getNumOutputs() != 1)
+      return;
+    lhsShaped = linalgOp.inputs()[0];
+    rhsShaped = linalgOp.inputs()[1];
+    resShaped = linalgOp.outputs()[0];
+    lhsShapedType = lhsShaped.getType().dyn_cast<ShapedType>();
+    rhsShapedType = rhsShaped.getType().dyn_cast<ShapedType>();
+    resShapedType = resShaped.getType().dyn_cast<ShapedType>();
+    if (!lhsShapedType || !rhsShapedType || !resShapedType)
+      return;
+    if (lhsShapedType.getRank() != 3 || rhsShapedType.getRank() != 3 ||
+        resShapedType.getRank() != 3)
+      return;
+
+    // Check for reduction `add` preceded by `mul`.
+    Operation *reduceOp = matchLinalgReduction(linalgOp.getOutputOperand(0));
+    if (!reduceOp)
+      return;
+    llvm::Optional<vector::CombiningKind> maybeKind;
+    maybeKind = getKindForOp(reduceOp);
+    if (!maybeKind || *maybeKind != vector::CombiningKind::ADD)
+      return;
+    maybeKind = getKindForOp(&(linalgOp->getRegion(0).front().front()));
+    if (!maybeKind || *maybeKind != vector::CombiningKind::MUL)
+      return;
+
+    // The op is now known to be valid.
+    valid = true;
+  }
+
+  /// Generate a vector implementation for:
+  /// ```
+  ///   Op def: (     n,     w,     c,    kw,    f  )
+  ///    Iters: ({Par(), Par(), Par(), Red(), Red()})
+  ///   Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
+  /// ```
+  /// w and kw are unrolled.
+  /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is > 1.
+  FailureOr<Operation *> conv() {
+    if (!valid)
+      return failure();
+
+    int nSize = lhsShapedType.getShape()[0];
+    int wSize = resShapedType.getShape()[1];
+    int cSize = lhsShapedType.getShape()[2];
+    int kwSize = rhsShapedType.getShape()[0];
+    int fSize = rhsShapedType.getShape()[2];
+
+    vector::TransferWriteOp write;
+    Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
+
+    // Unroll along kw and read slices of lhs and rhs.
+    // Alternatively we could preload both 3-d slices and extract smaller slices
+    // iteratively without touching memory. But this will quickly spill.
+    for (int64_t kw = 0; kw < kwSize; ++kw) {
+      // Read rhs slice of size {1, c, f} @ [kw, 0, 0].
+      Value kwVal = builder.create<arith::ConstantIndexOp>(loc, kw);
+      VectorType rhsType =
+          VectorType::get({1, cSize, fSize}, rhsShapedType.getElementType());
+      Value rhs = builder.create<vector::TransferReadOp>(
+          loc, rhsType, rhsShaped, ValueRange{kwVal, zero, zero});
+
+      for (int64_t w = 0; w < wSize; ++w) {
+        // Read lhs slice of size {n, 1, c} @ [0, sw * w + dw * kw, 0].
+        Value lhsStridedIdx = builder.create<arith::ConstantIndexOp>(
+            loc, strideW * w + dilationW * kw);
+        VectorType lhsType =
+            VectorType::get({nSize, 1, cSize}, lhsShapedType.getElementType());
+        Value lhs = builder.create<vector::TransferReadOp>(
+            loc, lhsType, lhsShaped, ValueRange{zero, lhsStridedIdx, zero});
+
+        // Read res slice: {n, 1, f} @ [0, w, 0].
+        Value wVal = builder.create<arith::ConstantIndexOp>(loc, w);
+        VectorType resType =
+            VectorType::get({nSize, 1, fSize}, resShapedType.getElementType());
+        // When operating on tensors, reading from the updated value is required
+        // for vector.transfer_read/write hoisting to function as expected.
+        Value res = builder.create<vector::TransferReadOp>(
+            loc, resType, resShaped, ValueRange{zero, wVal, zero});
+
+        // Compute contraction: I{n, 1, c} * F{1, c, f} -> O{n, 1, f}
+        StringRef par = Par().strRef, red = Red().strRef;
+        AffineExpr n, one, f, c;
+        bindDims(ctx, n, one, f, c);
+        // clang-format off
+        res = builder.create<vector::ContractionOp>(
+          loc, lhs, rhs, res,
+          /*indexingMaps=*/MapList{{n, one, c}, {one, c, f}, {n, one, f}},
+          /*iteratorTypes=*/ArrayRef<StringRef>{par, par, par, red});
+        // clang-format on
+
+        // Write back res slice: {n, 1, f} @ [0, w, 0].
+        write = builder.create<vector::TransferWriteOp>(
+            loc, res, resShaped, ValueRange{zero, wVal, zero});
+        if (write.getNumResults() == 1)
+          resShaped = write->getResult(0);
+      }
+    }
+
+    return write.getOperation();
+  }
+
+  /// Entry point that transposes into the common form:
+  ///   {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
+  FailureOr<Operation *> generateConv() {
+    AffineExpr n, w, f, kw, c;
+    bindDims(ctx, n, w, f, kw, c);
+
+    if (!iters({Par(), Par(), Par(), Red(), Red()}))
+      return failure();
+
+    // No transposition needed.
+    if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
+                /*rhsIndex*/ {kw, c, f},
+                /*resIndex*/ {n, w, f}}))
+      return conv();
+    return failure();
+  }
+
+private:
+  bool valid;
+  int strideW, dilationW;
+  Value lhsShaped, rhsShaped, resShaped;
+  ShapedType lhsShapedType, rhsShapedType, resShapedType;
+};
+} // namespace
+
+/// Helper function to vectorize a `linalgOp` with convolution semantics.
+// TODO: extend the generic vectorization to support windows and drop this.
+static FailureOr<Operation *>
+vectorizeConvolution(OpBuilder &b, ConvolutionOpInterface convOp) {
+  // TODO: these are legitimately part of ConvolutionOpInterface.
+  auto strides = convOp->getAttrOfType<DenseIntElementsAttr>("strides");
+  auto dilations = convOp->getAttrOfType<DenseIntElementsAttr>("dilations");
+  auto stride = strides ? *strides.getValues<uint64_t>().begin() : 1;
+  auto dilation = dilations ? *dilations.getValues<uint64_t>().begin() : 1;
+  LinalgOp linalgOp = cast<LinalgOp>(convOp.getOperation());
+  Conv1D_NWC_WCF_Generator e(b, linalgOp, stride, dilation);
+  return e.generateConv();
+}
+
+struct VectorizeConvolution
+    : public OpInterfaceRewritePattern<ConvolutionOpInterface> {
+  using OpInterfaceRewritePattern::OpInterfaceRewritePattern;
+
+  LogicalResult matchAndRewrite(ConvolutionOpInterface convOp,
+                                PatternRewriter &rewriter) const override {
+    FailureOr<Operation *> resultOrFail =
+        vectorizeConvolution(rewriter, convOp);
+    if (failed(resultOrFail))
+      return failure();
+    Operation *newOp = *resultOrFail;
+    if (newOp->getNumResults() == 0) {
+      rewriter.eraseOp(convOp.getOperation());
+      return success();
+    }
+    assert(newOp->getNumResults() == 1 && "expected single result");
+    rewriter.replaceOp(convOp.getOperation(), newOp->getResult(0));
+    return success();
+  }
+};
+
+void mlir::linalg::populateConvolutionVectorizationPatterns(
+    RewritePatternSet &patterns, PatternBenefit benefit) {
+  patterns.add<VectorizeConvolution>(patterns.getContext(), benefit);
+}

diff  --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index d1b9835452b7..7f73e367356b 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -1257,36 +1257,34 @@ struct Red : public IteratorType {
 struct UnrolledOuterProductGenerator
     : public StructuredGenerator<vector::ContractionOp> {
 
-  UnrolledOuterProductGenerator(PatternRewriter &rewriter,
-                                vector::ContractionOp op)
-      : StructuredGenerator<vector::ContractionOp>(rewriter, op),
+  UnrolledOuterProductGenerator(OpBuilder &builder, vector::ContractionOp op)
+      : StructuredGenerator<vector::ContractionOp>(builder, op),
         kind(op.kind()), lhs(op.lhs()), rhs(op.rhs()), res(op.acc()),
         lhsType(op.getLhsType()) {}
 
   Value t(Value v) {
     static constexpr std::array<int64_t, 2> perm = {1, 0};
-    return rewriter.create<vector::TransposeOp>(loc, v, perm);
+    return builder.create<vector::TransposeOp>(loc, v, perm);
   }
 
-  LogicalResult outer_prod(Value lhs, Value rhs, Value res, int reductionSize) {
+  Value outer_prod(Value lhs, Value rhs, Value res, int reductionSize) {
     assert(reductionSize > 0);
     for (int64_t k = 0; k < reductionSize; ++k) {
-      Value a = rewriter.create<vector::ExtractOp>(loc, lhs, k);
-      Value b = rewriter.create<vector::ExtractOp>(loc, rhs, k);
-      res = rewriter.create<vector::OuterProductOp>(loc, res.getType(), a, b,
-                                                    res, kind);
+      Value a = builder.create<vector::ExtractOp>(loc, lhs, k);
+      Value b = builder.create<vector::ExtractOp>(loc, rhs, k);
+      res = builder.create<vector::OuterProductOp>(loc, res.getType(), a, b,
+                                                   res, kind);
     }
-    rewriter.replaceOp(op, res);
-    return success();
+    return res;
   }
 
   /// Two outer parallel, one inner reduction (matmat flavor).
-  LogicalResult matmat() {
+  FailureOr<Value> matmat() {
     if (!iters({Par(), Par(), Red()}))
       return failure();
     // Set up the parallel/reduction structure in the right form.
     AffineExpr m, n, k;
-    bindDims(rewriter.getContext(), m, n, k);
+    bindDims(builder.getContext(), m, n, k);
     // Classical row-major matmul:  Just permute the lhs.
     if (layout({{m, k}, {k, n}, {m, n}}))
       return outer_prod(t(lhs), rhs, res, lhsType.getDimSize(1));
@@ -1318,11 +1316,11 @@ struct UnrolledOuterProductGenerator
   }
 
   /// One outer parallel, one inner reduction (matvec flavor)
-  LogicalResult matvec() {
+  FailureOr<Value> matvec() {
     if (!iters({Par(), Red()}))
       return failure();
     AffineExpr m, k;
-    bindDims(rewriter.getContext(), m, k);
+    bindDims(builder.getContext(), m, k);
 
     // Case mat-vec: transpose.
     if (layout({{m, k}, {k}, {m}}))
@@ -1342,11 +1340,11 @@ struct UnrolledOuterProductGenerator
   //
   // One outer reduction, one inner parallel (tmatvec flavor)
   //
-  LogicalResult tmatvec() {
+  FailureOr<Value> tmatvec() {
     if (!iters({Red(), Par()}))
       return failure();
     AffineExpr k, m;
-    bindDims(rewriter.getContext(), k, m);
+    bindDims(builder.getContext(), k, m);
 
     // Case mat-vec: transpose.
     if (layout({{m, k}, {k}, {m}}))
@@ -1399,12 +1397,21 @@ LogicalResult ContractionOpToOuterProductOpLowering::matchAndRewrite(
     return failure();
 
   UnrolledOuterProductGenerator e(rewriter, op);
-  if (succeeded(e.matmat()))
+  FailureOr<Value> matmatRes = e.matmat();
+  if (succeeded(matmatRes)) {
+    rewriter.replaceOp(op, *matmatRes);
     return success();
-  if (succeeded(e.matvec()))
+  }
+  FailureOr<Value> matvecRes = e.matvec();
+  if (succeeded(matvecRes)) {
+    rewriter.replaceOp(op, *matvecRes);
     return success();
-  if (succeeded(e.tmatvec()))
+  }
+  FailureOr<Value> tmatvecRes = e.tmatvec();
+  if (succeeded(tmatvecRes)) {
+    rewriter.replaceOp(op, *tmatvecRes);
     return success();
+  }
 
   return failure();
 }

diff  --git a/mlir/test/Dialect/Linalg/vectorize-convolution.mlir b/mlir/test/Dialect/Linalg/vectorize-convolution.mlir
new file mode 100644
index 000000000000..a7c6f47cda7f
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/vectorize-convolution.mlir
@@ -0,0 +1,108 @@
+// RUN: mlir-opt -split-input-file -test-linalg-transform-patterns=test-linalg-to-vector-patterns %s | FileCheck %s
+
+func @conv1d_nwc_4x2x8_memref(%input: memref<4x6x3xf32>, %filter: memref<1x3x8xf32>, %output: memref<4x2x8xf32>) {
+  linalg.conv_1d_nwc_wcf
+    {dilations = dense<1> : tensor<1xi64>, strides = dense<3> : tensor<1xi64>}
+    ins(%input, %filter : memref<4x6x3xf32>, memref<1x3x8xf32>)
+    outs(%output : memref<4x2x8xf32>)
+  return
+}
+
+// CHECK: #[[INPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+// CHECK: #[[FILTER_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d3, d2)>
+// CHECK: #[[OUTPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+
+//      CHECK: func @conv1d_nwc_4x2x8_memref
+// CHECK-SAME: (%[[INPUT:.+]]: memref<4x6x3xf32>, %[[FILTER:.+]]: memref<1x3x8xf32>, %[[OUTPUT:.+]]: memref<4x2x8xf32>)
+
+//  CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//  CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
+//  CHECK-DAG:   %[[C3:.+]] = arith.constant 3 : index
+//  CHECK-DAG:   %[[F0:.+]] = arith.constant 0.000000e+00 : f32
+
+/// w == 0, kw == 0
+//      CHECK:   %[[V_FILTER:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]
+//      CHECK:   %[[V_INPUT0:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]
+//      CHECK:   %[[V_OUTPUT_0:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]
+//      CHECK:   %[[CONTRACT0:.+]] = vector.contract {
+// CHECK-SAME:       indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
+// CHECK-SAME:       iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
+// CHECK-SAME:     %[[V_INPUT0]], %[[V_FILTER]], %[[V_OUTPUT_0]]
+// CHECK-SAME:     : vector<4x1x3xf32>, vector<1x3x8xf32> into vector<4x1x8xf32>
+//      CHECK:   vector.transfer_write %[[CONTRACT0]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
+
+/// w == 1, kw == 0
+//      CHECK:   %[[V_INPUT3:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C3]], %[[C0]]], %[[F0]]
+//      CHECK:   %[[V_OUTPUT_1:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C1]], %[[C0]]], %[[F0]]
+//      CHECK:   %[[CONTRACT1:.+]] = vector.contract {
+// CHECK-SAME:       indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
+// CHECK-SAME:       iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
+// CHECK-SAME:     %[[V_INPUT3]], %[[V_FILTER]], %[[V_OUTPUT_1]]
+// CHECK-SAME:     : vector<4x1x3xf32>, vector<1x3x8xf32> into vector<4x1x8xf32>
+//      CHECK:   vector.transfer_write %[[CONTRACT1]], %[[OUTPUT]][%[[C0]], %[[C1]], %[[C0]]]
+
+// -----
+
+func @conv1d_nwc_4x2x8_memref(%input: memref<4x6x3xf32>, %filter: memref<2x3x8xf32>, %output: memref<4x2x8xf32>) {
+  linalg.conv_1d_nwc_wcf
+    {dilations = dense<2> : tensor<1xi64>, strides = dense<3> : tensor<1xi64>}
+    ins(%input, %filter : memref<4x6x3xf32>, memref<2x3x8xf32>)
+    outs(%output : memref<4x2x8xf32>)
+  return
+}
+
+// CHECK: #[[INPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+// CHECK: #[[FILTER_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d3, d2)>
+// CHECK: #[[OUTPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+
+//      CHECK: func @conv1d_nwc_4x2x8_memref
+// CHECK-SAME: (%[[INPUT:.+]]: memref<4x6x3xf32>, %[[FILTER:.+]]: memref<2x3x8xf32>, %[[OUTPUT:.+]]: memref<4x2x8xf32>)
+
+//  CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//  CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
+//  CHECK-DAG:   %[[C2:.+]] = arith.constant 2 : index
+//  CHECK-DAG:   %[[C3:.+]] = arith.constant 3 : index
+//  CHECK-DAG:   %[[C5:.+]] = arith.constant 5 : index
+//  CHECK-DAG:   %[[F0:.+]] = arith.constant 0.000000e+00 : f32
+
+/// w == 0, kw == 0
+//      CHECK:   %[[V_FILTER_A:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]
+//      CHECK:   %[[V_INPUT0_A:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]
+//      CHECK:   %[[V_OUTPUT_0_A:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]
+//      CHECK:   %[[CONTRACT0_A:.+]] = vector.contract {
+// CHECK-SAME:       indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
+// CHECK-SAME:       iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
+// CHECK-SAME:     %[[V_INPUT0_A]], %[[V_FILTER_A]], %[[V_OUTPUT_0_A]]
+// CHECK-SAME:     : vector<4x1x3xf32>, vector<1x3x8xf32> into vector<4x1x8xf32>
+//      CHECK:   vector.transfer_write %[[CONTRACT0_A]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
+
+/// w == 0, kw == 1
+//      CHECK:   %[[V_INPUT3_A:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C3]], %[[C0]]], %[[F0]]
+//      CHECK:   %[[V_OUTPUT_1_A:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C1]], %[[C0]]], %[[F0]]
+//      CHECK:   %[[CONTRACT1_A:.+]] = vector.contract {
+// CHECK-SAME:       indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
+// CHECK-SAME:       iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
+// CHECK-SAME:     %[[V_INPUT3_A]], %[[V_FILTER_A]], %[[V_OUTPUT_1_A]]
+// CHECK-SAME:     : vector<4x1x3xf32>, vector<1x3x8xf32> into vector<4x1x8xf32>
+//      CHECK:   vector.transfer_write %[[CONTRACT1_A]], %[[OUTPUT]][%[[C0]], %[[C1]], %[[C0]]]
+
+/// w == 1, kw == 0
+//      CHECK:   %[[V_FILTER_B:.+]]   = vector.transfer_read %[[FILTER]][%[[C1]], %[[C0]], %[[C0]]], %[[F0]]
+//      CHECK:   %[[V_INPUT0_B:.+]]   = vector.transfer_read  %[[INPUT]][%[[C0]], %[[C2]], %[[C0]]], %[[F0]]
+//      CHECK:   %[[V_OUTPUT_0_B:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]
+//      CHECK:   %[[CONTRACT0_B:.+]] = vector.contract {
+// CHECK-SAME:       indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
+// CHECK-SAME:       iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
+// CHECK-SAME:     %[[V_INPUT0_B]], %[[V_FILTER_B]], %[[V_OUTPUT_0_B]]
+// CHECK-SAME:     : vector<4x1x3xf32>, vector<1x3x8xf32> into vector<4x1x8xf32>
+//      CHECK:   vector.transfer_write %[[CONTRACT0_B]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
+
+/// w == 1, kw == 1
+//      CHECK:     %[[V_INPUT3_B:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C5]], %[[C0]]], %[[F0]]
+//      CHECK:   %[[V_OUTPUT_1_B:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C1]], %[[C0]]], %[[F0]]
+//      CHECK:   %[[CONTRACT1_B:.+]] = vector.contract {
+// CHECK-SAME:       indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
+// CHECK-SAME:       iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
+// CHECK-SAME:     %[[V_INPUT3_B]], %[[V_FILTER_B]], %[[V_OUTPUT_1_B]]
+// CHECK-SAME:     : vector<4x1x3xf32>, vector<1x3x8xf32> into vector<4x1x8xf32>
+//      CHECK:   vector.transfer_write %[[CONTRACT1_B]], %[[OUTPUT]][%[[C0]], %[[C1]], %[[C0]]]

diff  --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
index 3e336f341535..96b761c2cfc4 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
@@ -553,6 +553,7 @@ static void applyLinalgToVectorPatterns(FuncOp funcOp) {
       LinalgTransformationFilter()
           .addOpFilter<ContractionOpInterface, FillOp, CopyOp, GenericOp>());
   populatePadTensorOpVectorizationPatterns(patterns);
+  populateConvolutionVectorizationPatterns(patterns);
   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
 }
 


        


More information about the Mlir-commits mailing list