[Mlir-commits] [mlir] [mlir][Vectorizer] Added support to Vectorize tensor.unpack (PR #76087)

Balaji V. Iyer. llvmlistbot at llvm.org
Thu Feb 15 14:20:53 PST 2024


================
@@ -1559,6 +1571,90 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
   return success();
 }
 
+/// Vectorize a `tensor::UnPackOp` without OuterDimsPerms to these 4 Ops:
+///   Vector::TransferReadOp - Reads a vector from the source tensor
+///   vector::TransposeOp - Transpose the Source tensor
+///   ShapeCastOp - Reshape the data based on the target.
+///   vector::TransferWriteOp. - Write the result vector back to the destination
+///   tensor
+static LogicalResult vectorizeAsUnpackOp(RewriterBase &rewriter,
+                                         tensor::UnPackOp unpackOp,
+                                         ArrayRef<int64_t> inputVectorSizes,
+                                         SmallVectorImpl<Value> &newResults) {
+
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(unpackOp);
+
+  RankedTensorType unpackTensorType = unpackOp.getSourceType();
+
+  SmallVector<int64_t> readMaskShape(unpackTensorType.getShape());
+  llvm::ArrayRef<int64_t> innerDimPos = unpackOp.getInnerDimsPos();
+  llvm::ArrayRef<int64_t> innerTiles = unpackOp.getStaticInnerTiles();
+  for (unsigned int i = 0; i < inputVectorSizes.size(); i++) {
+    readMaskShape[i] = inputVectorSizes[i];
+  }
+  for (auto [index, size] : enumerate(innerTiles)) {
+    readMaskShape[innerDimPos[index]] =
+        llvm::divideCeil(readMaskShape[innerDimPos[index]], size);
+  }
+
+  // ReadMask is the size of tensor used to read and apply mask. It is
+  // set like this. Let's say the vectorSize (VS) array is size 'N' and
+  // the sourceShape(SS) is 'M' where M >= N
+  // Thus:
+  // ReadMaskShape = [VS[0], ..., VS[N-1], SS[N], ..., SS[M-1]]
+  ReifiedRankedShapedTypeDims reifiedRetShapes;
+  LogicalResult status =
+      cast<ReifyRankedShapedTypeOpInterface>(unpackOp.getOperation())
+          .reifyResultShapes(rewriter, reifiedRetShapes);
+  if (status.failed()) {
+    LDBG("Unable to reify result shapes of " << unpackOp);
+    return failure();
+  }
+  Location loc = unpackOp->getLoc();
+
+  // Read result, mask if necessary.
+  Value readResult = createReadOrMaskedRead(
+      rewriter, loc, unpackOp.getSource(),
+      llvm::ArrayRef<int64_t>(readMaskShape.begin(), readMaskShape.end()),
+      nullptr);
+
+  PackingMetadata packMetadata;
+  SmallVector<int64_t> lastDimToInsertPosPerm = invertPermutationVector(
+      tensor::getPackUnPackInverseDestPerm(unpackOp, packMetadata));
+  ShapedType maskedOpShapedType = cast<ShapedType>(readResult.getType());
+  SmallVector<int64_t> stripMineShape(maskedOpShapedType.getShape());
+  mlir::Type stripMineElemType = maskedOpShapedType.getElementType();
+  applyPermutationToVector(stripMineShape, lastDimToInsertPosPerm);
+  RankedTensorType stripMineTensorType =
+      RankedTensorType::Builder(stripMineShape, stripMineElemType, {})
+          .setShape(stripMineShape);
----------------
bviyer wrote:

Fixed

https://github.com/llvm/llvm-project/pull/76087


More information about the Mlir-commits mailing list