[Mlir-commits] [mlir] [mlir][linalg][conv] Flatten the channel dimension when vectorizing (PR #71918)

Nicolas Vasilache llvmlistbot at llvm.org
Mon Nov 13 05:35:53 PST 2023


================
@@ -2936,6 +2938,176 @@ struct Conv1DGenerator
         .getOperation();
   }
 
+  /// Generate a vector implementation for ("flatten channel dim"):
+  /// ```
+  ///   Op def: (     n,     w,     c,    kw)
+  ///    Iters: ({Par(), Par(), Par(), Red()})
+  ///   Layout: {{n, 1 * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
+  /// ```
+  /// c of the input/output is collapsed with w. kw is always unrolled and
+  /// broadcast to match w.
+  ///
+  /// TODO: Add support for non-unit stride/dilation
+  /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
+  /// > 1.
+  FailureOr<Operation *> depthwiseConvFlatten() {
+    if (!valid)
+      return rewriter.notifyMatchFailure(op, "unvectorizable depthwise conv");
+
+    int64_t nSize, iSize, wSize, cSize, kwSize;
+    // kernel{kw, c}
+    bindShapeDims(rhsShapedType, kwSize, cSize);
+    // out{n, w, c}
+    bindShapeDims(resShapedType, nSize, wSize);
+    // in{n, w, c}
+    bindShapeDims(lhsShapedType, nSize, iSize);
+
+    vector::TransferWriteOp write;
+    Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+
+    if (strideW == 1)
+      return rewriter.notifyMatchFailure(
+          op, "Non-unit strides are not supported yet");
+    if (dilationW == 1)
+      return rewriter.notifyMatchFailure(
+          op, "Non-unit dilations are not supported yet");
+
+    Type lhsEltType = lhsShapedType.getElementType();
+    Type rhsEltType = rhsShapedType.getElementType();
+    Type resEltType = resShapedType.getElementType();
+    VectorType rhsType = VectorType::get({kwSize, cSize}, rhsEltType);
+    VectorType lhsType = VectorType::get(
+        {nSize,
+         // iw = (ow * sw + kw *  dw - 1) * c
+         //   (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
+         (((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) - 1) *
+             cSize},
+        lhsEltType);
+
+    VectorType resType = VectorType::get({nSize, wSize * cSize}, resEltType);
+
+    Value res, lhs, lhsFlat, resFlat;
+    // Read rhs slice of size {kw, c} @ [0, 0].
+    Value rhs = rewriter.create<vector::TransferReadOp>(loc, rhsType, rhsShaped,
+                                                        ValueRange{zero, zero});
+
+    SmallVector<ReassociationIndices> reassociation = {{0}, {1, 2}};
+
+    // Flatten w and c dimensions
+    auto lhsTypeCollapsed = VectorType::get({nSize, iSize * cSize}, lhsEltType);
+    auto linalgOp = dyn_cast<LinalgOp>(op);
+    lhsFlat =
+        linalgOp.hasTensorSemantics()
+            ? (Value)rewriter.create<tensor::CollapseShapeOp>(
+                  loc,
+                  RankedTensorType::get(lhsTypeCollapsed.getShape(),
+                                        lhsEltType),
+                  lhsShaped, reassociation)
+            : (Value)rewriter.create<memref::CollapseShapeOp>(
+                  loc, MemRefType::get(lhsTypeCollapsed.getShape(), lhsEltType),
+                  lhsShaped, reassociation);
+    resFlat =
+        linalgOp.hasTensorSemantics()
+            ? (Value)rewriter.create<tensor::CollapseShapeOp>(
+                  loc, RankedTensorType::get(resType.getShape(), resEltType),
+                  resShaped, reassociation)
+            : (Value)rewriter.create<memref::CollapseShapeOp>(
+                  loc, MemRefType::get(resType.getShape(), resEltType),
+                  resShaped, reassociation);
+
+    // Read lhs slice of size {n, (w * wSize + kw * dilationW) * c} @ [0,
+    // 0].
+    lhs = rewriter.create<vector::TransferReadOp>(loc, lhsType, lhsFlat,
+                                                  ValueRange{zero, zero});
+    // Read res slice of size {n, w * c} @ [0, 0].
+    res = rewriter.create<vector::TransferReadOp>(loc, resType, resFlat,
+                                                  ValueRange{zero, zero});
+
+    //===------------------------------------------------------------------===//
+    // Begin vector-only rewrite part
+    //===------------------------------------------------------------------===//
+    // Unroll along kw and read slices of lhs and rhs.
+    SmallVector<Value> lhsVals, rhsVals, resVals;
+    // Extract lhs slice of size {n, wSizeStep * c}
+    //   @ [0, (sw * w + dw * kw) * cSize].
+    for (int64_t kw = 0; kw < kwSize; ++kw) {
+      for (int64_t w = 0; w < wSize; w += wSize) {
+        lhsVals.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
+            loc, lhs,
+            /*offsets=*/
+            ArrayRef<int64_t>{0, (w * wSize + kw * dilationW) * cSize},
+            /*sizes=*/ArrayRef<int64_t>{nSize, wSize * cSize},
+            /*strides=*/ArrayRef<int64_t>{1, 1}));
+      }
+    }
+    // Extract rhs slice of size {c} @ [kw].
+    for (int64_t kw = 0; kw < kwSize; ++kw) {
+      rhsVals.push_back(rewriter.create<vector::ExtractOp>(
+          loc, rhs, /*offsets=*/ArrayRef<int64_t>{kw}));
+    }
+
+    // Extract res slice
+    // Flattened case:  {n, wSizeStep * c} @ [0, w].
+    for (int64_t w = 0; w < wSize; w += wSize) {
+      resVals.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
+          loc, res,
+          /*offsets=*/ArrayRef<int64_t>{0, w * cSize},
+          /*sizes=*/ArrayRef<int64_t>{nSize, wSize * cSize},
+          /*strides=*/ArrayRef<int64_t>{1, 1}));
+    }
+
+    auto linearIndex = [&](int64_t kw, int64_t w) {
+      return kw * (wSize / wSize) + w;
----------------
nicolasvasilache wrote:

this looks wrong

https://github.com/llvm/llvm-project/pull/71918


More information about the Mlir-commits mailing list