[Mlir-commits] [mlir] 67b37f5 - [mlir] Conv ops vectorization pass

Jakub Lichman llvmlistbot at llvm.org
Tue Sep 8 01:48:21 PDT 2020


Author: Jakub Lichman
Date: 2020-09-08T08:47:42Z
New Revision: 67b37f571cc27d5684125f694d719b114ad72a18

URL: https://github.com/llvm/llvm-project/commit/67b37f571cc27d5684125f694d719b114ad72a18
DIFF: https://github.com/llvm/llvm-project/commit/67b37f571cc27d5684125f694d719b114ad72a18.diff

LOG: [mlir] Conv ops vectorization pass

In this commit a new way of convolution ops lowering is introduced.
The conv op vectorization pass lowers linalg convolution ops
into vector contractions. This lowering is possible when conv op
is first tiled by 1 along specific dimensions which transforms
it into dot product between input and kernel subview memory buffers.
This pass converts such conv op into vector contraction and does
all necessary vector transfers that make it work.

Differential Revision: https://reviews.llvm.org/D86619

Added: 
    mlir/test/Conversion/LinalgToVector/linalg-to-vector.mlir
    mlir/test/lib/Transforms/TestConvVectorization.cpp

Modified: 
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
    mlir/test/lib/Transforms/CMakeLists.txt
    mlir/tools/mlir-opt/mlir-opt.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index f438b6587c8b..ce3b5fd2fd24 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -30,6 +30,10 @@ struct TiledLinalgOp {
   SmallVector<Operation *, 8> loops;
 };
 
+/// Populates patterns for vectorization of all ConvN-D ops.
+void populateConvVectorizationPatterns(MLIRContext *context,
+                                       OwningRewritePatternList &patterns);
+
 /// Performs standalone tiling of a single LinalgOp by `tileSizes`.
 /// and permute the loop nest according to `interchangeVector`
 /// The permutation is expressed as a list of integers that specify
@@ -531,6 +535,53 @@ struct AffineMinSCFCanonicalizationPattern
                                 PatternRewriter &rewriter) const override;
 };
 
+/// Converts Convolution op into vector contraction.
+///
+/// Conversion expects ConvOp to have dimensions marked in the *mask* as
+/// false of size 1. This ensures that the ConvOp can be lowered to vector
+/// contraction of dimensions marked in the *mask* as true.
+///
+/// A good example is ConvNHWCOp which is 2D Conv op with channels as the last
+/// dimension. For this op we contract last 3 dimensions.
+/// The initial op definition looks like this:
+/// ```
+/// linalg.conv_2d_nhwc  %arg0, %arg1, %arg2 :
+///   (memref<1x3x3x3xf32>, memref<1x3x3x3xf32>, memref<?x?x?x?xf32>)
+/// ```
+/// This op can be expressed as a dot product between %arg0 (input) and
+/// %arg1 (kernel) which is written into first entry of %arg2 (output). This is
+/// the ConvOp this pass expects and converts into:
+/// ```
+/// #map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+/// #map1 = affine_map<(d0, d1, d2) -> ()>
+/// .....
+/// %0 = vector.transfer_read %arg0[%c0, %c0, %c0, %c0], %c0_f32
+///   : memref<1x3x3x3xf32>, vector<3x3x3xf32>
+/// %1 = vector.transfer_read %arg1[%c0, %c0, %c0, %c0], %c0_f32
+///   : memref<1x3x3x3xf32>, vector<3x3x3xf32>
+/// %2 = vector.contract {indexing_maps = [#map0, #map0, #map1],
+///   iterator_types = ["reduction", "reduction", "reduction"]} %0, %1,
+///   %c0_f32 : vector<3x3x3xf32>, vector<3x3x3xf32> into f32
+/// store %2, %arg2[%c0, %c0, %c0, %c0] : memref<?x?x?x?xf32>
+/// ```
+/// where first 2 operations read input and kernel memory buffers into vectors.
+/// Subsequently, they are contracted together and the result is written to
+/// the first entry of the output buffer.
+template <typename ConvOp, int N>
+struct ConvOpVectorization : public OpRewritePattern<ConvOp> {
+  using OpRewritePattern<ConvOp>::OpRewritePattern;
+  SmallVector<bool, 4> mask;
+
+  ConvOpVectorization(MLIRContext *context, SmallVector<bool, 4> msk)
+      : OpRewritePattern<ConvOp>(context) {
+    assert(msk.size() == N && "Mask size does not match rank");
+    this->mask = msk;
+  }
+
+  LogicalResult matchAndRewrite(ConvOp minOp,
+                                PatternRewriter &rewriter) const override;
+};
+
 //===----------------------------------------------------------------------===//
 // Support for staged pattern application.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index ada89f1c82b5..cd36c753b6f6 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -367,3 +367,98 @@ LogicalResult LinalgCopyVTWForwardingPattern::matchAndRewrite(
 
   return success();
 }
+
+template <class ConvOp, int N>
+LogicalResult ConvOpVectorization<ConvOp, N>::matchAndRewrite(
+    ConvOp op, PatternRewriter &rewriter) const {
+  const uint dimSize = 3;
+  Location loc = op.getLoc();
+  MLIRContext *context = op.getContext();
+  edsc::ScopedContext scope(rewriter, loc);
+
+  ShapedType inShapeType = op.getInputShapedType(0);
+  ShapedType kShapeType = op.getInputShapedType(1);
+
+  ArrayRef<int64_t> inShape = inShapeType.getShape();
+  ArrayRef<int64_t> kShape = kShapeType.getShape();
+
+  if (!inShapeType.hasStaticShape() || !kShapeType.hasStaticShape())
+    return failure();
+
+  SmallVector<AffineExpr, 4> mapping;
+  // Fail to apply when the size of not vectorized dimension is not 1 or
+  // when the size of vectorized dimension is not dimSize.
+  for (unsigned i = 0; i < N; i++) {
+    if (!mask[i] && (inShape[i] != 1 || kShape[i] != 1))
+      return failure();
+    if (mask[i] && (inShape[i] != dimSize || kShape[i] != dimSize))
+      return failure();
+
+    if (mask[i])
+      mapping.push_back(getAffineDimExpr(i, context));
+  }
+
+  Value input = op.getInput(0);
+  Value kernel = op.getInput(1);
+  Value output = op.getOutputBuffer(0);
+
+  uint rank = inShapeType.getRank();
+  uint numDims = mapping.size();
+  Type elemType = inShapeType.getElementType();
+
+  auto map = AffineMap::get(rank, 0, mapping, context);
+  SmallVector<Value, 4> zeros(rank, std_constant_index(0));
+  auto vecType =
+      VectorType::get(SmallVector<int64_t, 4>(numDims, dimSize), elemType);
+
+  auto inputVec = vector_transfer_read(vecType, input, zeros, map);
+  auto kernelVec = vector_transfer_read(vecType, kernel, zeros, map);
+
+  auto acc = std_constant(elemType, rewriter.getZeroAttr(elemType));
+
+  std::array<AffineMap, 3> indexingMaps{
+      AffineMap::getMultiDimIdentityMap(numDims, context),
+      AffineMap::getMultiDimIdentityMap(numDims, context),
+      AffineMap::get(numDims, 0, {}, context)};
+
+  std::vector<StringRef> iteratorTypes(numDims, "reduction");
+
+  auto result = rewriter.create<vector::ContractionOp>(
+      loc, inputVec, kernelVec, acc,
+      rewriter.getAffineMapArrayAttr(indexingMaps),
+      rewriter.getStrArrayAttr(iteratorTypes));
+
+  rewriter.create<StoreOp>(loc, result, output, ValueRange(zeros));
+  rewriter.eraseOp(op);
+  return success();
+}
+
+void mlir::linalg::populateConvVectorizationPatterns(
+    MLIRContext *context, OwningRewritePatternList &patterns) {
+  patterns.insert<ConvOpVectorization<linalg::ConvWOp, 1>>(
+      context, SmallVector<bool, 4>{true});
+
+  patterns.insert<ConvOpVectorization<linalg::ConvNWCOp, 3>>(
+      context, SmallVector<bool, 4>{false, true, true});
+
+  patterns.insert<ConvOpVectorization<linalg::ConvNCWOp, 3>>(
+      context, SmallVector<bool, 4>{false, true, true});
+
+  patterns.insert<ConvOpVectorization<linalg::ConvHWOp, 2>>(
+      context, SmallVector<bool, 4>{true, true});
+
+  patterns.insert<ConvOpVectorization<linalg::ConvNHWCOp, 4>>(
+      context, SmallVector<bool, 4>{false, true, true, true});
+
+  patterns.insert<ConvOpVectorization<linalg::ConvNCHWOp, 4>>(
+      context, SmallVector<bool, 4>{false, true, true, true});
+
+  patterns.insert<ConvOpVectorization<linalg::ConvDHWOp, 3>>(
+      context, SmallVector<bool, 4>{true, true, true});
+
+  patterns.insert<ConvOpVectorization<linalg::ConvNDHWCOp, 5>>(
+      context, SmallVector<bool, 4>{false, true, true, true, true});
+
+  patterns.insert<ConvOpVectorization<linalg::ConvNCDHWOp, 5>>(
+      context, SmallVector<bool, 4>{false, true, true, true, true});
+}

diff  --git a/mlir/test/Conversion/LinalgToVector/linalg-to-vector.mlir b/mlir/test/Conversion/LinalgToVector/linalg-to-vector.mlir
new file mode 100644
index 000000000000..487718301d00
--- /dev/null
+++ b/mlir/test/Conversion/LinalgToVector/linalg-to-vector.mlir
@@ -0,0 +1,167 @@
+// RUN: mlir-opt %s -test-conv-vectorization --cse | FileCheck %s
+
+// CHECK-DAG:  #[[$map0:.*]] = affine_map<(d0) -> (d0)>
+// CHECK-DAG:  #[[$map1:.*]] = affine_map<(d0) -> ()>
+// CHECK-DAG:  #[[$map2:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)>
+// CHECK-DAG:  #[[$map3:.*]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-DAG:  #[[$map4:.*]] = affine_map<(d0, d1) -> ()>
+// CHECK-DAG:  #[[$map5:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)>
+// CHECK-DAG:  #[[$map6:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK-DAG:  #[[$map7:.*]] = affine_map<(d0, d1, d2) -> ()>
+// CHECK-DAG:  #[[$map8:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d1, d2, d3, d4)>
+// CHECK-DAG:  #[[$map9:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK-DAG:  #[[$map10:.*]] = affine_map<(d0, d1, d2, d3) -> ()>
+
+func @conv_1d(%arg0: memref<3xf32>, %arg1: memref<3xf32>, %arg2: memref<?xf32>) {
+  linalg.conv_1d %arg0, %arg1, %arg2 : (memref<3xf32>, memref<3xf32>, memref<?xf32>)
+  return
+}
+
+// CHECK-LABEL: @conv_1d
+//  CHECK-SAME: %[[arg0:[a-zA-Z0-9]+]]: memref<3xf32>
+//  CHECK-SAME: %[[arg1:[a-zA-Z0-9]+]]: memref<3xf32>
+//  CHECK-SAME: %[[arg2:[a-zA-Z0-9]+]]: memref<?xf32
+//       CHECK:   %[[c0:.*]] = constant 0 : index
+//       CHECK:   %[[cst:.*]] = constant 0.000000e+00 : f32
+//       CHECK:   %[[v0:.*]] = vector.transfer_read %[[arg0]][%[[c0]]], %[[cst]] : memref<3xf32>, vector<3xf32>
+//       CHECK:   %[[v1:.*]] = vector.transfer_read %[[arg1]][%[[c0]]], %[[cst]] : memref<3xf32>, vector<3xf32>
+//       CHECK:   %[[v2:.*]] = vector.contract {indexing_maps = [#[[$map0]], #[[$map0]], #[[$map1]]], iterator_types = ["reduction"]} %[[v0]], %[[v1]], %[[cst]] : vector<3xf32>, vector<3xf32> into f32
+//       CHECK:   store %[[v2]], %[[arg2]][%[[c0]]] : memref<?xf32>
+//       CHECK:   return
+
+func @conv_1d_ncw(%arg0: memref<1x3x3xf32>, %arg1: memref<1x3x3xf32>, %arg2: memref<?x?x?xf32>) {
+  linalg.conv_1d_ncw %arg0, %arg1, %arg2 : (memref<1x3x3xf32>, memref<1x3x3xf32>, memref<?x?x?xf32>)
+  return
+}
+
+// CHECK-LABEL: @conv_1d_ncw
+//  CHECK-SAME: %[[arg0:[a-zA-Z0-9]+]]: memref<1x3x3xf32>
+//  CHECK-SAME: %[[arg1:[a-zA-Z0-9]+]]: memref<1x3x3xf32>
+//  CHECK-SAME: %[[arg2:[a-zA-Z0-9]+]]: memref<?x?x?xf32
+//       CHECK:   %[[c0:.*]] = constant 0 : index
+//       CHECK:   %[[cst:.*]] = constant 0.000000e+00 : f32
+//       CHECK:   %[[v0:.*]] = vector.transfer_read %[[arg0]][%[[c0]], %[[c0]], %[[c0]]], %[[cst]] : memref<1x3x3xf32>, vector<3x3xf32>
+//       CHECK:   %[[v1:.*]] = vector.transfer_read %[[arg1]][%[[c0]], %[[c0]], %[[c0]]], %[[cst]] : memref<1x3x3xf32>, vector<3x3xf32>
+//       CHECK:   %[[v2:.*]] = vector.contract {indexing_maps = [#[[$map3]], #[[$map3]], #[[$map4]]], iterator_types = ["reduction", "reduction"]} %[[v0]], %[[v1]], %[[cst]] : vector<3x3xf32>, vector<3x3xf32> into f32
+//       CHECK:   store %[[v2]], %[[arg2]][%[[c0]], %[[c0]], %[[c0]]] : memref<?x?x?xf32>
+//       CHECK:   return
+
+
+func @conv_1d_nwc(%arg0: memref<1x3x3xf32>, %arg1: memref<1x3x3xf32>, %arg2: memref<?x?x?xf32>) {
+  linalg.conv_1d_nwc %arg0, %arg1, %arg2 : (memref<1x3x3xf32>, memref<1x3x3xf32>, memref<?x?x?xf32>)
+  return
+}
+
+// CHECK-LABEL: @conv_1d_nwc
+//  CHECK-SAME: %[[arg0:[a-zA-Z0-9]+]]: memref<1x3x3xf32>
+//  CHECK-SAME: %[[arg1:[a-zA-Z0-9]+]]: memref<1x3x3xf32>
+//  CHECK-SAME: %[[arg2:[a-zA-Z0-9]+]]: memref<?x?x?xf32
+//       CHECK:   %[[c0:.*]] = constant 0 : index
+//       CHECK:   %[[cst:.*]] = constant 0.000000e+00 : f32
+//       CHECK:   %[[v0:.*]] = vector.transfer_read %[[arg0]][%[[c0]], %[[c0]], %[[c0]]], %[[cst]] : memref<1x3x3xf32>, vector<3x3xf32>
+//       CHECK:   %[[v1:.*]] = vector.transfer_read %[[arg1]][%[[c0]], %[[c0]], %[[c0]]], %[[cst]] : memref<1x3x3xf32>, vector<3x3xf32>
+//       CHECK:   %[[v2:.*]] = vector.contract {indexing_maps = [#[[$map3]], #[[$map3]], #[[$map4]]], iterator_types = ["reduction", "reduction"]} %[[v0]], %[[v1]], %[[cst]] : vector<3x3xf32>, vector<3x3xf32> into f32
+//       CHECK:   store %[[v2]], %[[arg2]][%[[c0]], %[[c0]], %[[c0]]] : memref<?x?x?xf32>
+//       CHECK:   return
+
+func @conv_2d(%arg0: memref<3x3xf32>, %arg1: memref<3x3xf32>, %arg2: memref<?x?xf32>) {
+  linalg.conv_2d %arg0, %arg1, %arg2 : (memref<3x3xf32>, memref<3x3xf32>, memref<?x?xf32>)
+  return
+}
+
+// CHECK-LABEL: @conv_2d
+//  CHECK-SAME: %[[arg0:[a-zA-Z0-9]+]]: memref<3x3xf32>
+//  CHECK-SAME: %[[arg1:[a-zA-Z0-9]+]]: memref<3x3xf32>
+//  CHECK-SAME: %[[arg2:[a-zA-Z0-9]+]]: memref<?x?xf32
+//       CHECK:   %[[c0:.*]] = constant 0 : index
+//       CHECK:   %[[cst:.*]] = constant 0.000000e+00 : f32
+//       CHECK:   %[[v0:.*]] = vector.transfer_read %[[arg0]][%[[c0]], %[[c0]]], %[[cst]] : memref<3x3xf32>, vector<3x3xf32>
+//       CHECK:   %[[v1:.*]] = vector.transfer_read %[[arg1]][%[[c0]], %[[c0]]], %[[cst]] : memref<3x3xf32>, vector<3x3xf32>
+//       CHECK:   %[[v2:.*]] = vector.contract {indexing_maps = [#[[$map3]], #[[$map3]], #[[$map4]]], iterator_types = ["reduction", "reduction"]} %[[v0]], %[[v1]], %[[cst]] : vector<3x3xf32>, vector<3x3xf32> into f32
+//       CHECK:   store %[[v2]], %[[arg2]][%[[c0]], %[[c0]]] : memref<?x?xf32>
+//       CHECK:   return
+
+func @conv_2d_nchw(%arg0: memref<1x3x3x3xf32>, %arg1: memref<1x3x3x3xf32>, %arg2: memref<?x?x?x?xf32>) {
+  linalg.conv_2d_nchw %arg0, %arg1, %arg2 : (memref<1x3x3x3xf32>, memref<1x3x3x3xf32>, memref<?x?x?x?xf32>)
+  return
+}
+
+// CHECK-LABEL: @conv_2d_nchw
+//  CHECK-SAME: %[[arg0:[a-zA-Z0-9]+]]: memref<1x3x3x3xf32>
+//  CHECK-SAME: %[[arg1:[a-zA-Z0-9]+]]: memref<1x3x3x3xf32>
+//  CHECK-SAME: %[[arg2:[a-zA-Z0-9]+]]: memref<?x?x?x?xf32
+//       CHECK:   %[[c0:.*]] = constant 0 : index
+//       CHECK:   %[[cst:.*]] = constant 0.000000e+00 : f32
+//       CHECK:   %[[v0:.*]] = vector.transfer_read %[[arg0]][%[[c0]], %[[c0]], %[[c0]], %[[c0]]], %[[cst]] : memref<1x3x3x3xf32>, vector<3x3x3xf32>
+//       CHECK:   %[[v1:.*]] = vector.transfer_read %[[arg1]][%[[c0]], %[[c0]], %[[c0]], %[[c0]]], %[[cst]] : memref<1x3x3x3xf32>, vector<3x3x3xf32>
+//       CHECK:   %[[v2:.*]] = vector.contract {indexing_maps = [#[[$map6]], #[[$map6]], #[[$map7]]], iterator_types = ["reduction", "reduction", "reduction"]} %[[v0]], %[[v1]], %[[cst]] : vector<3x3x3xf32>, vector<3x3x3xf32> into f32
+//       CHECK:   store %[[v2]], %[[arg2]][%[[c0]], %[[c0]], %[[c0]], %[[c0]]] : memref<?x?x?x?xf32>
+//       CHECK:   return
+
+func @conv_2d_nhwc(%arg0: memref<1x3x3x3xf32>, %arg1: memref<1x3x3x3xf32>, %arg2: memref<?x?x?x?xf32>) {
+  linalg.conv_2d_nhwc %arg0, %arg1, %arg2 : (memref<1x3x3x3xf32>, memref<1x3x3x3xf32>, memref<?x?x?x?xf32>)
+  return
+}
+
+// CHECK-LABEL: @conv_2d_nhwc
+//  CHECK-SAME: %[[arg0:[a-zA-Z0-9]+]]: memref<1x3x3x3xf32>
+//  CHECK-SAME: %[[arg1:[a-zA-Z0-9]+]]: memref<1x3x3x3xf32>
+//  CHECK-SAME: %[[arg2:[a-zA-Z0-9]+]]: memref<?x?x?x?xf32
+//       CHECK:   %[[c0:.*]] = constant 0 : index
+//       CHECK:   %[[cst:.*]] = constant 0.000000e+00 : f32
+//       CHECK:   %[[v0:.*]] = vector.transfer_read %[[arg0]][%[[c0]], %[[c0]], %[[c0]], %[[c0]]], %[[cst]] : memref<1x3x3x3xf32>, vector<3x3x3xf32>
+//       CHECK:   %[[v1:.*]] = vector.transfer_read %[[arg1]][%[[c0]], %[[c0]], %[[c0]], %[[c0]]], %[[cst]] : memref<1x3x3x3xf32>, vector<3x3x3xf32>
+//       CHECK:   %[[v2:.*]] = vector.contract {indexing_maps = [#[[$map6]], #[[$map6]], #[[$map7]]], iterator_types = ["reduction", "reduction", "reduction"]} %[[v0]], %[[v1]], %[[cst]] : vector<3x3x3xf32>, vector<3x3x3xf32> into f32
+//       CHECK:   store %[[v2]], %[[arg2]][%[[c0]], %[[c0]], %[[c0]], %[[c0]]] : memref<?x?x?x?xf32>
+//       CHECK:   return
+
+func @conv_3d(%arg0: memref<3x3x3xf32>, %arg1: memref<3x3x3xf32>, %arg2: memref<?x?x?xf32>) {
+  linalg.conv_3d %arg0, %arg1, %arg2 : (memref<3x3x3xf32>, memref<3x3x3xf32>, memref<?x?x?xf32>)
+  return
+}
+
+// CHECK-LABEL: @conv_3d
+//  CHECK-SAME: %[[arg0:[a-zA-Z0-9]+]]: memref<3x3x3xf32>
+//  CHECK-SAME: %[[arg1:[a-zA-Z0-9]+]]: memref<3x3x3xf32>
+//  CHECK-SAME: %[[arg2:[a-zA-Z0-9]+]]: memref<?x?x?xf32
+//       CHECK:   %[[c0:.*]] = constant 0 : index
+//       CHECK:   %[[cst:.*]] = constant 0.000000e+00 : f32
+//       CHECK:   %[[v0:.*]] = vector.transfer_read %[[arg0]][%[[c0]], %[[c0]], %[[c0]]], %[[cst]] : memref<3x3x3xf32>, vector<3x3x3xf32>
+//       CHECK:   %[[v1:.*]] = vector.transfer_read %[[arg1]][%[[c0]], %[[c0]], %[[c0]]], %[[cst]] : memref<3x3x3xf32>, vector<3x3x3xf32>
+//       CHECK:   %[[v2:.*]] = vector.contract {indexing_maps = [#[[$map6]], #[[$map6]], #[[$map7]]], iterator_types = ["reduction", "reduction", "reduction"]} %[[v0]], %[[v1]], %[[cst]] : vector<3x3x3xf32>, vector<3x3x3xf32> into f32
+//       CHECK:   store %[[v2]], %[[arg2]][%[[c0]], %[[c0]], %[[c0]]] : memref<?x?x?xf32>
+//       CHECK:   return
+
+func @conv_3d_ncdhw(%arg0: memref<1x3x3x3x3xf32>, %arg1: memref<1x3x3x3x3xf32>, %arg2: memref<?x?x?x?x?xf32>) {
+  linalg.conv_3d_ncdhw %arg0, %arg1, %arg2 : (memref<1x3x3x3x3xf32>, memref<1x3x3x3x3xf32>, memref<?x?x?x?x?xf32>)
+  return
+}
+
+// CHECK-LABEL: @conv_3d_ncdhw
+//  CHECK-SAME: %[[arg0:[a-zA-Z0-9]+]]: memref<1x3x3x3x3xf32>
+//  CHECK-SAME: %[[arg1:[a-zA-Z0-9]+]]: memref<1x3x3x3x3xf32>
+//  CHECK-SAME: %[[arg2:[a-zA-Z0-9]+]]: memref<?x?x?x?x?xf32
+//       CHECK:   %[[c0:.*]] = constant 0 : index
+//       CHECK:   %[[cst:.*]] = constant 0.000000e+00 : f32
+//       CHECK:   %[[v0:.*]] = vector.transfer_read %[[arg0]][%[[c0]], %[[c0]], %[[c0]], %[[c0]], %[[c0]]], %[[cst]] : memref<1x3x3x3x3xf32>, vector<3x3x3x3xf32>
+//       CHECK:   %[[v1:.*]] = vector.transfer_read %[[arg1]][%[[c0]], %[[c0]], %[[c0]], %[[c0]], %[[c0]]], %[[cst]] : memref<1x3x3x3x3xf32>, vector<3x3x3x3xf32>
+//       CHECK:   %[[v2:.*]] = vector.contract {indexing_maps = [#[[$map9]], #[[$map9]], #[[$map10]]], iterator_types = ["reduction", "reduction", "reduction", "reduction"]} %[[v0]], %[[v1]], %[[cst]] : vector<3x3x3x3xf32>, vector<3x3x3x3xf32> into f32
+//       CHECK:   store %[[v2]], %[[arg2]][%[[c0]], %[[c0]], %[[c0]], %[[c0]], %[[c0]]] : memref<?x?x?x?x?xf32>
+//       CHECK:   return
+
+func @conv_3d_ndhwc(%arg0: memref<1x3x3x3x3xf32>, %arg1: memref<1x3x3x3x3xf32>, %arg2: memref<?x?x?x?x?xf32>) {
+  linalg.conv_3d_ndhwc %arg0, %arg1, %arg2 : (memref<1x3x3x3x3xf32>, memref<1x3x3x3x3xf32>, memref<?x?x?x?x?xf32>)
+  return
+}
+
+// CHECK-LABEL: @conv_3d_ndhwc
+//  CHECK-SAME: %[[arg0:[a-zA-Z0-9]+]]: memref<1x3x3x3x3xf32>
+//  CHECK-SAME: %[[arg1:[a-zA-Z0-9]+]]: memref<1x3x3x3x3xf32>
+//  CHECK-SAME: %[[arg2:[a-zA-Z0-9]+]]: memref<?x?x?x?x?xf32
+//       CHECK:   %[[c0:.*]] = constant 0 : index
+//       CHECK:   %[[cst:.*]] = constant 0.000000e+00 : f32
+//       CHECK:   %[[v0:.*]] = vector.transfer_read %[[arg0]][%[[c0]], %[[c0]], %[[c0]], %[[c0]], %[[c0]]], %[[cst]] : memref<1x3x3x3x3xf32>, vector<3x3x3x3xf32>
+//       CHECK:   %[[v1:.*]] = vector.transfer_read %[[arg1]][%[[c0]], %[[c0]], %[[c0]], %[[c0]], %[[c0]]], %[[cst]] : memref<1x3x3x3x3xf32>, vector<3x3x3x3xf32>
+//       CHECK:   %[[v2:.*]] = vector.contract {indexing_maps = [#[[$map9]], #[[$map9]], #[[$map10]]], iterator_types = ["reduction", "reduction", "reduction", "reduction"]} %[[v0]], %[[v1]], %[[cst]] : vector<3x3x3x3xf32>, vector<3x3x3x3xf32> into f32
+//       CHECK:   store %[[v2]], %[[arg2]][%[[c0]], %[[c0]], %[[c0]], %[[c0]], %[[c0]]] : memref<?x?x?x?x?xf32>
+//       CHECK:   return

diff  --git a/mlir/test/lib/Transforms/CMakeLists.txt b/mlir/test/lib/Transforms/CMakeLists.txt
index de894467d63d..3ac1e7c55235 100644
--- a/mlir/test/lib/Transforms/CMakeLists.txt
+++ b/mlir/test/lib/Transforms/CMakeLists.txt
@@ -5,6 +5,7 @@ add_mlir_library(MLIRTestTransforms
   TestExpandTanh.cpp
   TestCallGraph.cpp
   TestConstantFold.cpp
+  TestConvVectorization.cpp
   TestConvertCallOp.cpp
   TestConvertGPUKernelToCubin.cpp
   TestConvertGPUKernelToHsaco.cpp

diff  --git a/mlir/test/lib/Transforms/TestConvVectorization.cpp b/mlir/test/lib/Transforms/TestConvVectorization.cpp
new file mode 100644
index 000000000000..37e509cbbbe1
--- /dev/null
+++ b/mlir/test/lib/Transforms/TestConvVectorization.cpp
@@ -0,0 +1,51 @@
+//===- TestConvVectorization.cpp - Linalg to Vector dialect conversion ----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+using namespace mlir;
+
+namespace {
+/// A pass converting MLIR Linalg ops into Vector ops.
+class TestConvVectorization
+    : public PassWrapper<TestConvVectorization, OperationPass<ModuleOp>> {
+  void runOnOperation() override;
+
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<vector::VectorDialect>();
+    registry.insert<linalg::LinalgDialect>();
+    registry.insert<StandardOpsDialect>();
+  }
+};
+} // namespace
+
+void TestConvVectorization::runOnOperation() {
+  MLIRContext *context = &getContext();
+  ModuleOp module = getOperation();
+
+  ConversionTarget target(*context);
+  target.addLegalDialect<AffineDialect, scf::SCFDialect, StandardOpsDialect,
+                         vector::VectorDialect>();
+  target.addLegalOp<ModuleOp, FuncOp, ModuleTerminatorOp, ReturnOp>();
+  target.addLegalOp<linalg::FillOp, linalg::YieldOp>();
+
+  OwningRewritePatternList patterns;
+  linalg::populateConvVectorizationPatterns(context, patterns);
+
+  if (failed(applyPartialConversion(module, target, patterns)))
+    return signalPassFailure();
+}
+
+namespace mlir {
+void registerTestConvVectorization() {
+  PassRegistration<TestConvVectorization> testTransformPatternsPass(
+      "test-conv-vectorization", "Test vectorization of convolutions");
+}
+} // namespace mlir

diff  --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 34e03a5f9920..437b5f4b6f1a 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -45,6 +45,7 @@ void registerTestAllReduceLoweringPass();
 void registerTestBufferPlacementPreparationPass();
 void registerTestCallGraphPass();
 void registerTestConstantFold();
+void registerTestConvVectorization();
 void registerTestConvertGPUKernelToCubinPass();
 void registerTestConvertGPUKernelToHsacoPass();
 void registerTestDominancePass();
@@ -93,6 +94,7 @@ void registerTestPasses() {
   registerTestAffineLoopUnswitchingPass();
   registerTestLoopPermutationPass();
   registerTestCallGraphPass();
+  registerTestConvVectorization();
   registerTestConstantFold();
 #if MLIR_CUDA_CONVERSIONS_ENABLED
   registerTestConvertGPUKernelToCubinPass();


        


More information about the Mlir-commits mailing list