[Mlir-commits] [mlir] [mlir] Add direct vectorization lowering for `tensor.pack` ops (PR #78660)

Diego Caballero llvmlistbot at llvm.org
Fri Jan 19 14:56:12 PST 2024


================
@@ -1393,6 +1400,182 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
   return success();
 }
 
+/// Given a tensor::PackOp, return the permutation from the "tiled"
+/// shape to the "packed" shape, defined as the following:
+/// The "packed" shape is the same as the `dest` shape of the pack op.
+/// The "tiled" shape is a permutation of the `dest` shape such that
+/// each outer dimension is in the original `source` order, and the
+/// inner_tile dimensions immediately follow their corresponding outer
+/// dimension.
+/// i.e. for the following tensor.pack:
+/// ```mlir
+/// %pack = tensor.pack %0 padding_value(%1)
+///   outer_dims_perm = [0, 2, 1]
+///   inner_dims_pos = [2, 1]
+///   inner_tiles = [16, 2]
+///   into %2 : tensor<32x8x16> -> tensor<32x1x4x16x2>
+/// ```
+/// The "packed" shape is `32x1x4x16x2`
+/// The "tiled" shape is `32x(4x2)x(1x16)`
+static SmallVector<int64_t>
+getTiledShapeToPackedShapePerm(tensor::PackOp packOp) {
+  auto innerTiles = packOp.getInnerTiles();
+  int64_t srcRank = packOp.getSourceRank();
+  auto innerDimsPos = packOp.getInnerDimsPos();
+  if (innerDimsPos.empty())
+    innerDimsPos = to_vector(llvm::seq<int64_t>(innerTiles.size()));
+  auto outerDimsPerm = packOp.getOuterDimsPerm();
+  if (outerDimsPerm.empty())
+    outerDimsPerm = to_vector(llvm::seq<int64_t>(srcRank));
+  auto packedIdxToTiledIdx = [&](int64_t idx) -> int64_t {
+    int64_t srcIdx;
+    if (idx >= srcRank)
+      srcIdx = innerDimsPos[idx - srcRank];
+    else
+      srcIdx = outerDimsPerm[idx];
+    int64_t tiledIdx = srcIdx;
+    for (int64_t pos : innerDimsPos)
+      if (pos < srcIdx)
+        tiledIdx++;
+    if (idx >= srcRank)
+      tiledIdx++;
+    return tiledIdx;
+  };
+  SmallVector<int64_t> perm;
+  for (size_t i = 0; i < packOp.getDestRank(); i++)
+    perm.push_back(packedIdxToTiledIdx(i));
+  return perm;
+}
+
+/// Given a tensor::PackOp, return the "tiled" `dest` shape as described
+/// above in `getTiledShapeToPackedShapePerm`.
+static SmallVector<int64_t> getTiledPackShape(tensor::PackOp packOp,
+                                              ArrayRef<int64_t> destShape) {
+  auto perm = getTiledShapeToPackedShapePerm(packOp);
+  return applyPermutation(destShape, invertPermutationVector(perm));
+}
+
+/// Create a masked TransferReadOp from `source` with shape `readShape`.
+static vector::MaskOp createMaskedTransferRead(OpBuilder &builder, Location loc,
+                                               Value source,
+                                               ArrayRef<int64_t> readShape,
+                                               Value padValue) {
+  auto maskType = VectorType::get(readShape, builder.getI1Type());
+  auto vectorType = VectorType::get(readShape, padValue.getType());
+  SmallVector<OpFoldResult> mixedSourceDims =
+      tensor::getMixedSizes(builder, loc, source);
+  Value mask =
+      builder.create<vector::CreateMaskOp>(loc, maskType, mixedSourceDims);
+  auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
+  int64_t readRank = readShape.size();
+  auto transferReadOp = builder.create<vector::TransferReadOp>(
+      loc,
+      /*vectorType=*/vectorType,
+      /*source=*/source,
+      /*indices=*/SmallVector<Value>(readRank, zero),
+      /*padding=*/padValue,
+      /*inBounds=*/SmallVector<bool>(readRank, true));
+  return cast<vector::MaskOp>(
+      mlir::vector::maskOperation(builder, transferReadOp, mask));
+}
+
+/// Given an input, the mixed destSizes, and the vector sizes for vectorization,
+/// create an empty destination tensor and create a TransferWriteOp from the
+/// input to the empty tensor. If the destination shape is not the same as the
+/// inputVectorSizes for the first rank(inputVectorSizes) dims, then create a
+/// mask for the write.
+static Operation *createWriteOrMaskedWrite(OpBuilder &builder, Location loc,
+                                           Value input,
+                                           SmallVector<OpFoldResult> destSizes,
+                                           ArrayRef<int64_t> inputVectorSizes) {
+  auto inputType = cast<VectorType>(input.getType());
+  Value dest = builder.create<tensor::EmptyOp>(loc, destSizes,
+                                               inputType.getElementType());
+  int64_t rank = cast<ShapedType>(dest.getType()).getRank();
+  auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
+  Operation *write = builder.create<vector::TransferWriteOp>(
+      loc,
+      /*vector=*/input,
+      /*source=*/dest,
+      /*indices=*/SmallVector<Value>(rank, zero),
+      /*inBounds=*/SmallVector<bool>(rank, true));
+  auto destShape = cast<ShapedType>(dest.getType()).getShape();
+  bool needMaskForWrite =
+      llvm::any_of(llvm::zip(inputVectorSizes, destShape),
+                   [](auto it) { return std::get<0>(it) != std::get<1>(it); });
+  if (needMaskForWrite) {
+    SmallVector<int64_t> writeMaskShape;
+    writeMaskShape.append(inputVectorSizes.begin(), inputVectorSizes.end());
+    writeMaskShape.append(destShape.begin() + inputVectorSizes.size(),
+                          destShape.end());
+    auto writeMaskType = VectorType::get(writeMaskShape, builder.getI1Type());
+    Value maskForWrite =
+        builder.create<vector::CreateMaskOp>(loc, writeMaskType, destSizes);
+    write = mlir::vector::maskOperation(builder, write, maskForWrite);
+  }
+  return write;
+}
+
+/// Vectorize tensor::PackOp with (1) static innerTiles and (2) constant
+/// padding value into
+/// transfer_write_in_bounds(
+///     transpose(
+///         shape_cast(
+///             transfer_read_masked(pack_source, pad_value))))
+static LogicalResult
+vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
+                        ArrayRef<int64_t> inputVectorSizes,
+                        SmallVectorImpl<Value> &newResults) {
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(packOp);
+
+  Location loc = packOp.getLoc();
+  auto padValue = packOp.getPaddingValue();
+  if (!padValue) {
+    padValue = rewriter.create<arith::ConstantOp>(
+        loc, rewriter.getZeroAttr(packOp.getSourceType().getElementType()));
+  }
+  ReifiedRankedShapedTypeDims reifiedReturnShapes;
+  LogicalResult status =
+      cast<ReifyRankedShapedTypeOpInterface>(packOp.getOperation())
+          .reifyResultShapes(rewriter, reifiedReturnShapes);
+  (void)status; // prevent unused variable warning on non-assert builds
+  assert(succeeded(status) && "failed to reify result shapes");
+
+  // Create masked TransferReadOp
----------------
dcaballe wrote:

super nit: period at the end of comments per coding standard (here and below)

https://github.com/llvm/llvm-project/pull/78660


More information about the Mlir-commits mailing list