[Mlir-commits] [mlir] [mlir][Vectorizer] Added support to Vectorize tensor.unpack (PR #76087)

Han-Chung Wang llvmlistbot at llvm.org
Tue Feb 6 21:15:49 PST 2024


================
@@ -1393,6 +1394,130 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
   return success();
 }
 
+/// Vectorize an `tensor::UnPackOp` without OuterDimsPerms to these 4 Ops:
+///   Vector::TransferReadOp - Reads the Vector Array of Source data
+///   vector::TransposeOp - Transpose the Source
+///   ShapeCastOp - Reshapes the data based on the target.
+///   vector::TransferWriteOp. - Write the result vector back.
+static LogicalResult vectorizeAsUnpackOp(RewriterBase &rewriter,
+                                         tensor::UnPackOp unpackOp,
+                                         ArrayRef<int64_t> inputVectorSizes,
+                                         SmallVectorImpl<Value> &newResults) {
+  // Handling this case requires a bit more change. Right now
+  // just the required attributes are handled.
+  if (!unpackOp.getOuterDimsPerm().empty()) {
+    LDBG("outer dimensions perms NYI for: " << unpackOp);
+    return failure();
+  }
+
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(unpackOp);
+
+  RankedTensorType unpackTensorType = unpackOp.getSourceType();
+  llvm::SmallVector<int64_t> readMaskShape(unpackTensorType.getShape());
+  for (unsigned int ii = 0; ii < inputVectorSizes.size(); ii++) {
+    readMaskShape[ii] = inputVectorSizes[ii];
+  }
+
+  // ReadMask is the size of tensor used to read and apply mask. It is
+  // set like this. Let's say the vectorSize (VS) array is size 'N' and
+  // the sourceShape(SS) is 'M' where M >= N
+  // Thus:
+  // ReadMaskShape = [VS[0], ..., VS[N-1], SS[N], ..., SS[M-1]]
+  auto vectorType =
+      VectorType::get(readMaskShape, unpackTensorType.getElementType());
+  ReifiedRankedShapedTypeDims reifiedRetShapes;
+  LogicalResult status =
+      cast<ReifyRankedShapedTypeOpInterface>(unpackOp.getOperation())
+          .reifyResultShapes(rewriter, reifiedRetShapes);
+  if (status.failed()) {
+    LDBG("Unable to reify result shapes of " << unpackOp);
+    return failure();
+  }
+  int64_t unpackRank = unpackTensorType.getRank();
+  arith::ConstantIndexOp zeroOp =
+      rewriter.create<arith::ConstantIndexOp>(unpackOp->getLoc(), 0);
+
+  vector::TransferReadOp readOp = rewriter.create<vector::TransferReadOp>(
+      unpackOp.getLoc(), vectorType, unpackOp.getSource(),
+      SmallVector<Value>(unpackRank, zeroOp),
+      rewriter.getMultiDimIdentityMap(unpackRank));
+
+  auto readMaskType = VectorType::get(readMaskShape, rewriter.getI1Type());
+  Value mask = rewriter.create<vector::CreateMaskOp>(
+      unpackOp.getLoc(), readMaskType,
+      tensor::getMixedSizes(rewriter, unpackOp.getLoc(), unpackOp.getSource()));
+  vector::MaskOp maskedOp =
+      cast<vector::MaskOp>(mlir::vector::maskOperation(rewriter, readOp, mask));
+
+  int64_t numPackedDim = unpackOp.getInnerDimsPos().size();
+  llvm::SmallVector<int64_t> lastDims = llvm::to_vector(
+      llvm::seq<int64_t>(unpackRank - numPackedDim, unpackRank));
+  PackingMetadata packMetadata =
+      computePackingMetadata(unpackRank, unpackOp.getInnerDimsPos());
+  SmallVector<int64_t> lastDimToInsertPosPerm = computePermutationVector(
+      unpackRank, lastDims, packMetadata.insertPositions);
+  ShapedType maskedOpShapedType =
+      cast<ShapedType>(maskedOp.getResult(0).getType());
+  SmallVector<int64_t> stripMineShape(maskedOpShapedType.getShape());
+  mlir::Type stripMineElemType = maskedOpShapedType.getElementType();
+  applyPermutationToVector(stripMineShape, lastDimToInsertPosPerm);
+
+  RankedTensorType stripMineTensorType =
+      RankedTensorType::Builder(stripMineShape, stripMineElemType, {})
+          .setShape(stripMineShape);
+
+  // Collapse the tensor to the size required by result.
+  RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
+      stripMineTensorType, packMetadata.reassociations);
+  auto vecCollapsedType =
+      VectorType::get(collapsedType.getShape(), collapsedType.getElementType());
+
+  // Transpose the appropriate rows to match output.
+  vector::TransposeOp transposeOp = rewriter.create<vector::TransposeOp>(
+      unpackOp.getLoc(), maskedOp.getResult(0), lastDimToInsertPosPerm);
+
+  vector::ShapeCastOp shapeCastOp = rewriter.create<vector::ShapeCastOp>(
+      unpackOp.getLoc(), vecCollapsedType, transposeOp->getResult(0));
+  tensor::EmptyOp emptyOp =
+      rewriter.create<tensor::EmptyOp>(unpackOp.getLoc(), reifiedRetShapes[0],
+                                       unpackTensorType.getElementType());
+
+  int64_t destRank = cast<ShapedType>(emptyOp.getType()).getRank();
+  Operation *writeOp = rewriter.create<vector::TransferWriteOp>(
+      unpackOp.getLoc(), shapeCastOp->getResult(0), emptyOp,
+      SmallVector<Value>(destRank, zeroOp), SmallVector<bool>(destRank, true));
+  auto resultShape = unpackOp.getResult().getType().getShape();
+
+  // If the shape of the result doesn't match the inputVectorSizes, a mask
+  // is necessary.
+  bool needMaskForWrite =
+      llvm::any_of(llvm::zip_equal(inputVectorSizes, resultShape),
+                   [](auto it) { return std::get<0>(it) != std::get<1>(it); });
+  mlir::OpResult result = writeOp->getResult(0);
+  if (needMaskForWrite) {
+    SmallVector<int64_t> writeMaskShape(inputVectorSizes);
+    llvm::ArrayRef<int64_t> innerDimPos = unpackOp.getInnerDimsPos();
+    llvm::ArrayRef<int64_t> innerTiles = unpackOp.getStaticInnerTiles();
+    for (auto [index, size] : enumerate(innerTiles)) {
+      writeMaskShape[innerDimPos[index]] *= size;
+    }
+    // WriteMaskShape is computed using the vectorSizes, inner Dim Position and
+    // innerTiles.
+    // WriteMaskShape (WMS) initialized to [inputVectorSizes]
+    // for-each index, value in inner-Tiles vector:
+    //      WMS[innerDimPos[index]] = WMS[innerDimPos[index]] * value
+    auto writeMaskType = VectorType::get(writeMaskShape, rewriter.getI1Type());
+    Value writeMask = rewriter.create<vector::CreateMaskOp>(
+        unpackOp.getLoc(), writeMaskType, reifiedRetShapes[0]);
+    Operation *writeOpWithMask =
+        mlir::vector::maskOperation(rewriter, writeOp, writeMask);
+    result = writeOpWithMask->getResult(0);
+  }
----------------
hanhanW wrote:

same here, can we try to refactor/generalize it a bit more? Like what Max has done in his PR.

https://github.com/llvm/llvm-project/pull/76087


More information about the Mlir-commits mailing list