[Mlir-commits] [mlir] [mlir] Add direct vectorization lowering for `tensor.pack` ops (PR #78660)

Han-Chung Wang llvmlistbot at llvm.org
Mon Jan 22 00:02:10 PST 2024


================
@@ -1393,6 +1400,182 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
   return success();
 }
 
+/// Given a tensor::PackOp, return the permutation from the "tiled"
+/// shape to the "packed" shape, defined as the following:
+/// The "packed" shape is the same as the `dest` shape of the pack op.
+/// The "tiled" shape is a permutation of the `dest` shape such that
+/// each outer dimension is in the original `source` order, and the
+/// inner_tile dimensions immediately follow their corresponding outer
+/// dimension.
+/// i.e. for the following tensor.pack:
+/// ```mlir
+/// %pack = tensor.pack %0 padding_value(%1)
+///   outer_dims_perm = [0, 2, 1]
+///   inner_dims_pos = [2, 1]
+///   inner_tiles = [16, 2]
+///   into %2 : tensor<32x8x16> -> tensor<32x1x4x16x2>
+/// ```
+/// The "packed" shape is `32x1x4x16x2`
+/// The "tiled" shape is `32x(4x2)x(1x16)`
+static SmallVector<int64_t>
+getTiledShapeToPackedShapePerm(tensor::PackOp packOp) {
+  auto innerTiles = packOp.getInnerTiles();
+  int64_t srcRank = packOp.getSourceRank();
+  auto innerDimsPos = packOp.getInnerDimsPos();
+  if (innerDimsPos.empty())
+    innerDimsPos = to_vector(llvm::seq<int64_t>(innerTiles.size()));
+  auto outerDimsPerm = packOp.getOuterDimsPerm();
+  if (outerDimsPerm.empty())
+    outerDimsPerm = to_vector(llvm::seq<int64_t>(srcRank));
+  auto packedIdxToTiledIdx = [&](int64_t idx) -> int64_t {
+    int64_t srcIdx;
+    if (idx >= srcRank)
+      srcIdx = innerDimsPos[idx - srcRank];
+    else
+      srcIdx = outerDimsPerm[idx];
+    int64_t tiledIdx = srcIdx;
+    for (int64_t pos : innerDimsPos)
+      if (pos < srcIdx)
+        tiledIdx++;
+    if (idx >= srcRank)
+      tiledIdx++;
+    return tiledIdx;
+  };
+  SmallVector<int64_t> perm;
+  for (size_t i = 0; i < packOp.getDestRank(); i++)
+    perm.push_back(packedIdxToTiledIdx(i));
+  return perm;
+}
+
+/// Given a tensor::PackOp, return the "tiled" `dest` shape as described
+/// above in `getTiledShapeToPackedShapePerm`.
+static SmallVector<int64_t> getTiledPackShape(tensor::PackOp packOp,
+                                              ArrayRef<int64_t> destShape) {
+  auto perm = getTiledShapeToPackedShapePerm(packOp);
+  return applyPermutation(destShape, invertPermutationVector(perm));
+}
+
+/// Create a masked TransferReadOp from `source` with shape `readShape`.
+static vector::MaskOp createMaskedTransferRead(OpBuilder &builder, Location loc,
+                                               Value source,
+                                               ArrayRef<int64_t> readShape,
+                                               Value padValue) {
+  auto maskType = VectorType::get(readShape, builder.getI1Type());
+  auto vectorType = VectorType::get(readShape, padValue.getType());
+  SmallVector<OpFoldResult> mixedSourceDims =
+      tensor::getMixedSizes(builder, loc, source);
+  Value mask =
+      builder.create<vector::CreateMaskOp>(loc, maskType, mixedSourceDims);
+  auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
+  int64_t readRank = readShape.size();
+  auto transferReadOp = builder.create<vector::TransferReadOp>(
+      loc,
+      /*vectorType=*/vectorType,
+      /*source=*/source,
+      /*indices=*/SmallVector<Value>(readRank, zero),
+      /*padding=*/padValue,
+      /*inBounds=*/SmallVector<bool>(readRank, true));
----------------
hanhanW wrote:

This part is not clear to me. My understanding is that we don't need the padValue and all the access are in bounds, if we have a mask. E.g.,

https://github.com/llvm/llvm-project/blob/aa4547fcc8eeb9bf4f3cf48cc926f62544e58767/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp#L1331-L1342

The padValue is needed only when masking is not involved. @dcaballe can you provide some guidance? How do we handle the padding value if the op is masked? Should we bail out if the padding value is not constant zero?

(I thought I dropped a similar comment in the first round or review, but it seems not...)

https://github.com/llvm/llvm-project/pull/78660


More information about the Mlir-commits mailing list