[Mlir-commits] [mlir] [mlir][Vectorizer] Added support to Vectorize tensor.unpack (PR #76087)

Han-Chung Wang llvmlistbot at llvm.org
Tue Feb 6 21:15:48 PST 2024


================
@@ -1393,6 +1394,130 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
   return success();
 }
 
+/// Vectorize an `tensor::UnPackOp` without OuterDimsPerms to these 4 Ops:
+///   Vector::TransferReadOp - Reads the Vector Array of Source data
+///   vector::TransposeOp - Transpose the Source
+///   ShapeCastOp - Reshapes the data based on the target.
+///   vector::TransferWriteOp. - Write the result vector back.
+static LogicalResult vectorizeAsUnpackOp(RewriterBase &rewriter,
+                                         tensor::UnPackOp unpackOp,
+                                         ArrayRef<int64_t> inputVectorSizes,
+                                         SmallVectorImpl<Value> &newResults) {
+  // Handling this case requires a bit more change. Right now
+  // just the required attributes are handled.
+  if (!unpackOp.getOuterDimsPerm().empty()) {
+    LDBG("outer dimensions perms NYI for: " << unpackOp);
+    return failure();
+  }
+
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(unpackOp);
+
+  RankedTensorType unpackTensorType = unpackOp.getSourceType();
+  llvm::SmallVector<int64_t> readMaskShape(unpackTensorType.getShape());
+  for (unsigned int ii = 0; ii < inputVectorSizes.size(); ii++) {
+    readMaskShape[ii] = inputVectorSizes[ii];
+  }
+
+  // ReadMask is the size of tensor used to read and apply mask. It is
+  // set like this. Let's say the vectorSize (VS) array is size 'N' and
+  // the sourceShape(SS) is 'M' where M >= N
+  // Thus:
+  // ReadMaskShape = [VS[0], ..., VS[N-1], SS[N], ..., SS[M-1]]
+  auto vectorType =
+      VectorType::get(readMaskShape, unpackTensorType.getElementType());
+  ReifiedRankedShapedTypeDims reifiedRetShapes;
+  LogicalResult status =
+      cast<ReifyRankedShapedTypeOpInterface>(unpackOp.getOperation())
+          .reifyResultShapes(rewriter, reifiedRetShapes);
+  if (status.failed()) {
+    LDBG("Unable to reify result shapes of " << unpackOp);
+    return failure();
+  }
+  int64_t unpackRank = unpackTensorType.getRank();
+  arith::ConstantIndexOp zeroOp =
+      rewriter.create<arith::ConstantIndexOp>(unpackOp->getLoc(), 0);
+
+  vector::TransferReadOp readOp = rewriter.create<vector::TransferReadOp>(
+      unpackOp.getLoc(), vectorType, unpackOp.getSource(),
+      SmallVector<Value>(unpackRank, zeroOp),
+      rewriter.getMultiDimIdentityMap(unpackRank));
+
+  auto readMaskType = VectorType::get(readMaskShape, rewriter.getI1Type());
+  Value mask = rewriter.create<vector::CreateMaskOp>(
+      unpackOp.getLoc(), readMaskType,
+      tensor::getMixedSizes(rewriter, unpackOp.getLoc(), unpackOp.getSource()));
+  vector::MaskOp maskedOp =
+      cast<vector::MaskOp>(mlir::vector::maskOperation(rewriter, readOp, mask));
+
+  int64_t numPackedDim = unpackOp.getInnerDimsPos().size();
+  llvm::SmallVector<int64_t> lastDims = llvm::to_vector(
+      llvm::seq<int64_t>(unpackRank - numPackedDim, unpackRank));
+  PackingMetadata packMetadata =
+      computePackingMetadata(unpackRank, unpackOp.getInnerDimsPos());
+  SmallVector<int64_t> lastDimToInsertPosPerm = computePermutationVector(
+      unpackRank, lastDims, packMetadata.insertPositions);
+  ShapedType maskedOpShapedType =
+      cast<ShapedType>(maskedOp.getResult(0).getType());
+  SmallVector<int64_t> stripMineShape(maskedOpShapedType.getShape());
+  mlir::Type stripMineElemType = maskedOpShapedType.getElementType();
+  applyPermutationToVector(stripMineShape, lastDimToInsertPosPerm);
+
+  RankedTensorType stripMineTensorType =
+      RankedTensorType::Builder(stripMineShape, stripMineElemType, {})
+          .setShape(stripMineShape);
+
+  // Collapse the tensor to the size required by result.
+  RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
+      stripMineTensorType, packMetadata.reassociations);
+  auto vecCollapsedType =
+      VectorType::get(collapsedType.getShape(), collapsedType.getElementType());
----------------
hanhanW wrote:

please move them to right before where it is used. The comment seems to be outdated? You are shape_cast on vectors, not tensors.

https://github.com/llvm/llvm-project/pull/76087


More information about the Mlir-commits mailing list