[Mlir-commits] [mlir] 392e16c - [mlir][Linalg] NFC - Cleanup conv1d generators

Nicolas Vasilache llvmlistbot at llvm.org
Mon Jan 17 09:39:34 PST 2022


Author: Nicolas Vasilache
Date: 2022-01-17T17:39:19Z
New Revision: 392e16c27ffce3c88d6d14050327bb7409756900

URL: https://github.com/llvm/llvm-project/commit/392e16c27ffce3c88d6d14050327bb7409756900
DIFF: https://github.com/llvm/llvm-project/commit/392e16c27ffce3c88d6d14050327bb7409756900.diff

LOG: [mlir][Linalg] NFC - Cleanup conv1d generators

Differential Revision: https://reviews.llvm.org/D117330

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
    mlir/test/Dialect/Linalg/vectorize-convolution.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 3daf243ce472..db3e7197a0aa 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1287,6 +1287,22 @@ LogicalResult LinalgCopyVTWForwardingPattern::matchAndRewrite(
 //===----------------------------------------------------------------------===//
 // Convolution vectorization patterns
 //===----------------------------------------------------------------------===//
+
+template <int N>
+static void bindShapeDims(ShapedType shapedType) {}
+
+template <int N, typename IntTy, typename... IntTy2>
+static void bindShapeDims(ShapedType shapedType, IntTy &val, IntTy2 &...vals) {
+  val = shapedType.getShape()[N];
+  bindShapeDims<N + 1, IntTy2 &...>(shapedType, vals...);
+}
+
+/// Bind a pack of int& to the leading dimensions of shapedType.getShape().
+template <typename... IntTy>
+static void bindShapeDims(ShapedType shapedType, IntTy &...vals) {
+  bindShapeDims<0>(shapedType, vals...);
+}
+
 namespace {
 /// Generate a vector implementation for either:
 /// ```
@@ -1354,11 +1370,11 @@ struct Conv1DNwcGenerator : public StructuredGenerator<LinalgOp> {
     if (!valid)
       return failure();
 
-    int nSize = lhsShapedType.getShape()[0];
-    int wSize = resShapedType.getShape()[1];
-    int cSize = lhsShapedType.getShape()[2];
-    int kwSize = rhsShapedType.getShape()[0];
-    int fSize = rhsShapedType.getShape()[2];
+    int64_t nSize, wSize, cSize, kwSize, fSize;
+    // kernel{kw, c, f}
+    bindShapeDims(rhsShapedType, kwSize, cSize, fSize);
+    // out{n, w, f}
+    bindShapeDims(resShapedType, nSize, wSize);
 
     vector::TransferWriteOp write;
     Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
@@ -1398,31 +1414,29 @@ struct Conv1DNwcGenerator : public StructuredGenerator<LinalgOp> {
     //===------------------------------------------------------------------===//
     // Unroll along kw and read slices of lhs and rhs.
     SmallVector<Value> lhsVals, rhsVals, resVals;
+    // Extract lhs slice of size {n, wSizeStep, c} @ [0, sw * w + dw * kw, 0].
     for (int64_t kw = 0; kw < kwSize; ++kw) {
-      // Extract rhs slice of size {c, f} @ [kw].
-      rhsVals.push_back(builder.create<vector::ExtractOp>(
-          loc, rhs, /*offsets=*/ArrayRef<int64_t>{kw}));
-
       for (int64_t w = 0; w < wSize; w += wSizeStep) {
-        // Extract lhs slice of size {n, wSizeStep, c}
-        //   @ [0, sw * w + dw * kw, 0].
         lhsVals.push_back(builder.create<vector::ExtractStridedSliceOp>(
             loc, lhs,
             /*offsets=*/ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0},
             /*sizes=*/ArrayRef<int64_t>{nSize, wSizeStep, cSize},
             /*strides=*/ArrayRef<int64_t>{1, 1, 1}));
-
-        // This does not depend on kw.
-        if (kw == 0) {
-          // Extract res slice: {n, wSizeStep, f} @ [0, w, 0].
-          resVals.push_back(builder.create<vector::ExtractStridedSliceOp>(
-              loc, res,
-              /*offsets=*/ArrayRef<int64_t>{0, w, 0},
-              /*sizes=*/ArrayRef<int64_t>{nSize, wSizeStep, fSize},
-              /*strides=*/ArrayRef<int64_t>{1, 1, 1}));
-        }
       }
     }
+    // Extract rhs slice of size {c, f} @ [kw].
+    for (int64_t kw = 0; kw < kwSize; ++kw) {
+      rhsVals.push_back(builder.create<vector::ExtractOp>(
+          loc, rhs, /*offsets=*/ArrayRef<int64_t>{kw}));
+    }
+    // Extract res slice: {n, wSizeStep, f} @ [0, w, 0].
+    for (int64_t w = 0; w < wSize; w += wSizeStep) {
+      resVals.push_back(builder.create<vector::ExtractStridedSliceOp>(
+          loc, res,
+          /*offsets=*/ArrayRef<int64_t>{0, w, 0},
+          /*sizes=*/ArrayRef<int64_t>{nSize, wSizeStep, fSize},
+          /*strides=*/ArrayRef<int64_t>{1, 1, 1}));
+    }
 
     auto linearIndex = [&](int64_t kw, int64_t w) {
       return kw * (wSize / wSizeStep) + w;
@@ -1476,14 +1490,15 @@ struct Conv1DNwcGenerator : public StructuredGenerator<LinalgOp> {
   /// kw is always unrolled.
   /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
   /// > 1.
-  FailureOr<Operation *> dilatedConv() {
+  FailureOr<Operation *> depthwiseConv() {
     if (!valid)
       return failure();
 
-    int nSize = lhsShapedType.getShape()[0];
-    int wSize = resShapedType.getShape()[1];
-    int cSize = lhsShapedType.getShape()[2];
-    int kwSize = rhsShapedType.getShape()[0];
+    int64_t nSize, wSize, cSize, kwSize;
+    // kernel{kw, c}
+    bindShapeDims(rhsShapedType, kwSize, cSize);
+    // out{n, w, c}
+    bindShapeDims(resShapedType, nSize, wSize);
 
     vector::TransferWriteOp write;
     Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
@@ -1522,31 +1537,30 @@ struct Conv1DNwcGenerator : public StructuredGenerator<LinalgOp> {
     //===------------------------------------------------------------------===//
     // Unroll along kw and read slices of lhs and rhs.
     SmallVector<Value> lhsVals, rhsVals, resVals;
+    // Extract lhs slice of size {n, wSizeStep, c}
+    //   @ [0, sw * w + dw * kw, 0].
     for (int64_t kw = 0; kw < kwSize; ++kw) {
-      // Extract rhs slice of size {c} @ [kw].
-      rhsVals.push_back(builder.create<vector::ExtractOp>(
-          loc, rhs, /*offsets=*/ArrayRef<int64_t>{kw}));
-
       for (int64_t w = 0; w < wSize; w += wSizeStep) {
-        // Extract lhs slice of size {n, wSizeStep, c}
-        //   @ [0, sw * w + dw * kw, 0].
         lhsVals.push_back(builder.create<vector::ExtractStridedSliceOp>(
             loc, lhs,
             /*offsets=*/ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0},
             /*sizes=*/ArrayRef<int64_t>{nSize, wSizeStep, cSize},
             /*strides=*/ArrayRef<int64_t>{1, 1, 1}));
-
-        // This does not depend on kw.
-        if (kw == 0) {
-          // Extract res slice: {n, wSizeStep, c} @ [0, w, 0].
-          resVals.push_back(builder.create<vector::ExtractStridedSliceOp>(
-              loc, res,
-              /*offsets=*/ArrayRef<int64_t>{0, w, 0},
-              /*sizes=*/ArrayRef<int64_t>{nSize, wSizeStep, cSize},
-              /*strides=*/ArrayRef<int64_t>{1, 1, 1}));
-        }
       }
     }
+    // Extract rhs slice of size {c} @ [kw].
+    for (int64_t kw = 0; kw < kwSize; ++kw) {
+      rhsVals.push_back(builder.create<vector::ExtractOp>(
+          loc, rhs, /*offsets=*/ArrayRef<int64_t>{kw}));
+    }
+    // Extract res slice: {n, wSizeStep, c} @ [0, w, 0].
+    for (int64_t w = 0; w < wSize; w += wSizeStep) {
+      resVals.push_back(builder.create<vector::ExtractStridedSliceOp>(
+          loc, res,
+          /*offsets=*/ArrayRef<int64_t>{0, w, 0},
+          /*sizes=*/ArrayRef<int64_t>{nSize, wSizeStep, cSize},
+          /*strides=*/ArrayRef<int64_t>{1, 1, 1}));
+    }
 
     auto linearIndex = [&](int64_t kw, int64_t w) {
       return kw * (wSize / wSizeStep) + w;
@@ -1555,7 +1569,7 @@ struct Conv1DNwcGenerator : public StructuredGenerator<LinalgOp> {
     // Compute contraction: O{n, w, c} += I{n, sw * w + dw * kw, c} * F{c}
     for (int64_t kw = 0; kw < kwSize; ++kw) {
       for (int64_t w = 0; w < wSize; w += wSizeStep) {
-        resVals[w] = dilatedConv1dSliceAsFma(
+        resVals[w] = depthwiseConv1dSliceAsFma(
             builder, loc, lhsVals[linearIndex(kw, w)], rhsVals[kw], resVals[w]);
       }
     }
@@ -1580,8 +1594,8 @@ struct Conv1DNwcGenerator : public StructuredGenerator<LinalgOp> {
   }
 
   /// Lower lhs{n, w, c} * rhs{c} -> res{n, w, c} to fma.
-  Value dilatedConv1dSliceAsFma(OpBuilder &b, Location loc, Value lhs,
-                                Value rhs, Value res) {
+  Value depthwiseConv1dSliceAsFma(OpBuilder &b, Location loc, Value lhs,
+                                  Value rhs, Value res) {
     Value bcast = builder.create<vector::BroadcastOp>(loc, res.getType(), rhs);
     return b.create<vector::FMAOp>(loc, lhs, bcast, res);
   }
@@ -1614,7 +1628,7 @@ struct Conv1DNwcGenerator : public StructuredGenerator<LinalgOp> {
     if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
                 /*rhsIndex*/ {kw, c},
                 /*resIndex*/ {n, w, c}}))
-      return dilatedConv();
+      return depthwiseConv();
     return failure();
   }
 

diff  --git a/mlir/test/Dialect/Linalg/vectorize-convolution.mlir b/mlir/test/Dialect/Linalg/vectorize-convolution.mlir
index e545e7f191aa..3168e7f67d51 100644
--- a/mlir/test/Dialect/Linalg/vectorize-convolution.mlir
+++ b/mlir/test/Dialect/Linalg/vectorize-convolution.mlir
@@ -23,15 +23,15 @@ func @conv1d_nwc_4x2x8_memref(%input: memref<4x6x3xf32>, %filter: memref<1x3x8xf
 //  CHECK-DAG:  %[[V_FILTER_R:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]
 //  CHECK-DAG:  %[[V_OUTPUT_R:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]
 
-//      CHECK:    %[[V_FILTER:.+]] = vector.extract %[[V_FILTER_R]][0] : vector<1x3x8xf32>
-/// w == 0, kw == 0
 //      CHECK:   %[[V_INPUT_0:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
 // CHECK-SAME:     {offsets = [0, 0, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x4x3xf32> to vector<4x1x3xf32>
-//      CHECK:  %[[V_OUTPUT_0:.+]] = vector.extract_strided_slice %[[V_OUTPUT_R]]
-// CHECK-SAME:     {offsets = [0, 0, 0], sizes = [4, 1, 8], strides = [1, 1, 1]} : vector<4x2x8xf32> to vector<4x1x8xf32>
-/// w == 1, kw == 0
 //      CHECK:   %[[V_INPUT_1:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
 // CHECK-SAME:     {offsets = [0, 3, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x4x3xf32> to vector<4x1x3xf32>
+
+//      CHECK:    %[[V_FILTER:.+]] = vector.extract %[[V_FILTER_R]][0] : vector<1x3x8xf32>
+
+//      CHECK:  %[[V_OUTPUT_0:.+]] = vector.extract_strided_slice %[[V_OUTPUT_R]]
+// CHECK-SAME:     {offsets = [0, 0, 0], sizes = [4, 1, 8], strides = [1, 1, 1]} : vector<4x2x8xf32> to vector<4x1x8xf32>
 //      CHECK:  %[[V_OUTPUT_1:.+]] = vector.extract_strided_slice %[[V_OUTPUT_R]]
 // CHECK-SAME:     {offsets = [0, 1, 0], sizes = [4, 1, 8], strides = [1, 1, 1]} : vector<4x2x8xf32> to vector<4x1x8xf32>
 
@@ -84,27 +84,23 @@ func @conv1d_nwc_4x2x8_memref(%input: memref<4x6x3xf32>, %filter: memref<2x3x8xf
 //  CHECK-DAG:   %[[V_FILTER_R:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]
 //  CHECK-DAG:   %[[V_OUTPUT_R:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]
 
-
-/// w == 0, kw == 0
-//      CHECK:  %[[V_FILTER_0:.+]] = vector.extract %[[V_FILTER_R]][0] : vector<2x3x8xf32>
 //      CHECK:   %[[V_INPUT_0:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
 // CHECK-SAME:     {offsets = [0, 0, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x6x3xf32> to vector<4x1x3xf32>
-//      CHECK:  %[[V_OUTPUT_0:.+]] = vector.extract_strided_slice %[[V_OUTPUT_R]]
-// CHECK-SAME:     {offsets = [0, 0, 0], sizes = [4, 1, 8], strides = [1, 1, 1]} : vector<4x2x8xf32> to vector<4x1x8xf32>
-/// w == 1, kw == 0
 //      CHECK:   %[[V_INPUT_1:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
 // CHECK-SAME:     {offsets = [0, 3, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x6x3xf32> to vector<4x1x3xf32>
-//      CHECK:  %[[V_OUTPUT_1:.+]] = vector.extract_strided_slice %[[V_OUTPUT_R]]
-// CHECK-SAME:     {offsets = [0, 1, 0], sizes = [4, 1, 8], strides = [1, 1, 1]} : vector<4x2x8xf32> to vector<4x1x8xf32>
-
-/// w == 0, kw == 1
-//      CHECK:  %[[V_FILTER_1:.+]] = vector.extract %[[V_FILTER_R]][1] : vector<2x3x8xf32>
 //      CHECK:   %[[V_INPUT_2:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
 // CHECK-SAME:     {offsets = [0, 2, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x6x3xf32> to vector<4x1x3xf32>
-/// w == 1, kw == 0
 //      CHECK:   %[[V_INPUT_3:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
 // CHECK-SAME:     {offsets = [0, 5, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x6x3xf32> to vector<4x1x3xf32>
 
+//      CHECK:  %[[V_FILTER_0:.+]] = vector.extract %[[V_FILTER_R]][0] : vector<2x3x8xf32>
+//      CHECK:  %[[V_FILTER_1:.+]] = vector.extract %[[V_FILTER_R]][1] : vector<2x3x8xf32>
+
+//      CHECK:  %[[V_OUTPUT_0:.+]] = vector.extract_strided_slice %[[V_OUTPUT_R]]
+// CHECK-SAME:     {offsets = [0, 0, 0], sizes = [4, 1, 8], strides = [1, 1, 1]} : vector<4x2x8xf32> to vector<4x1x8xf32>
+//      CHECK:  %[[V_OUTPUT_1:.+]] = vector.extract_strided_slice %[[V_OUTPUT_R]]
+// CHECK-SAME:     {offsets = [0, 1, 0], sizes = [4, 1, 8], strides = [1, 1, 1]} : vector<4x2x8xf32> to vector<4x1x8xf32>
+
 /// w == 0, kw == 0
 //      CHECK:   %[[CONTRACT_0:.+]] = vector.contract {
 // CHECK-SAME:       indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
@@ -165,15 +161,14 @@ func @conv1d_nwc_4x2x8_memref(%input: memref<4x6x3xf32>, %filter: memref<2x3x8xf
 //  CHECK-DAG:  %[[V_FILTER_R:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]
 //  CHECK-DAG:  %[[V_OUTPUT_R:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]
 
-/// w == 0, kw == 0
-//      CHECK:  %[[V_FILTER_0:.+]] = vector.extract %[[V_FILTER_R]][0] : vector<2x3x8xf32>
 //      CHECK:   %[[V_INPUT_0:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
 // CHECK-SAME:     {offsets = [0, 0, 0], sizes = [4, 2, 3], strides = [1, 1, 1]} : vector<4x4x3xf32> to vector<4x2x3xf32>
-/// w == 0, kw == 1
-//      CHECK:  %[[V_FILTER_1:.+]] = vector.extract %[[V_FILTER_R]][1] : vector<2x3x8xf32>
 //      CHECK:   %[[V_INPUT_1:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
 // CHECK-SAME:     {offsets = [0, 2, 0], sizes = [4, 2, 3], strides = [1, 1, 1]} : vector<4x4x3xf32> to vector<4x2x3xf32>
 
+//      CHECK:  %[[V_FILTER_0:.+]] = vector.extract %[[V_FILTER_R]][0] : vector<2x3x8xf32>
+//      CHECK:  %[[V_FILTER_1:.+]] = vector.extract %[[V_FILTER_R]][1] : vector<2x3x8xf32>
+
 /// w == 0, kw == 0
 //      CHECK:   %[[CONTRACT_0:.+]] = vector.contract {
 // CHECK-SAME:       indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
@@ -211,15 +206,14 @@ func @depthwise_conv1d_nwc_wc_3x5x4_memref(%input: memref<3x5x4xf32>, %filter: m
 //      CHECK:  %[[V_FILTER_R:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]]]
 //      CHECK:  %[[V_OUTPUT_R:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
 
-/// w == 0, kw == 0
-//      CHECK:  %[[V_FILTER_0:.+]] = vector.extract %[[V_FILTER_R]][0] : vector<2x4xf32>
 //      CHECK:   %[[V_INPUT_0:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
 // CHECK-SAME:     {offsets = [0, 0, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x4x4xf32> to vector<3x2x4xf32>
-/// w == 0, kw == 1
-//      CHECK:  %[[V_FILTER_1:.+]] = vector.extract %[[V_FILTER_R]][1] : vector<2x4xf32>
 //      CHECK:   %[[V_INPUT_1:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
 // CHECK-SAME:     {offsets = [0, 2, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x4x4xf32> to vector<3x2x4xf32>
 
+//      CHECK:  %[[V_FILTER_0:.+]] = vector.extract %[[V_FILTER_R]][0] : vector<2x4xf32>
+//      CHECK:  %[[V_FILTER_1:.+]] = vector.extract %[[V_FILTER_R]][1] : vector<2x4xf32>
+
 /// w == 0, kw = 0
 //      CHECK:  %[[B_FILTER_0:.*]] = vector.broadcast %[[V_FILTER_0]] : vector<4xf32> to vector<3x2x4xf32>
 //      CHECK:  %[[FMA_0:.*]] = vector.fma %[[V_INPUT_0]], %[[B_FILTER_0]], %[[V_OUTPUT_R]] : vector<3x2x4xf32>


        


More information about the Mlir-commits mailing list