[Mlir-commits] [mlir] [mlir][Vectorizer] Added support to Vectorize tensor.unpack (PR #76087)

Diego Caballero llvmlistbot at llvm.org
Fri Feb 9 09:02:13 PST 2024


================
@@ -1393,6 +1394,130 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
   return success();
 }
 
+/// Vectorize an `tensor::UnPackOp` without OuterDimsPerms to these 4 Ops:
+///   Vector::TransferReadOp - Reads the Vector Array of Source data
+///   vector::TransposeOp - Transpose the Source
+///   ShapeCastOp - Reshapes the data based on the target.
+///   vector::TransferWriteOp. - Write the result vector back.
+static LogicalResult vectorizeAsUnpackOp(RewriterBase &rewriter,
+                                         tensor::UnPackOp unpackOp,
+                                         ArrayRef<int64_t> inputVectorSizes,
+                                         SmallVectorImpl<Value> &newResults) {
+  // Handling this case requires a bit more change. Right now
+  // just the required attributes are handled.
+  if (!unpackOp.getOuterDimsPerm().empty()) {
+    LDBG("outer dimensions perms NYI for: " << unpackOp);
+    return failure();
+  }
+
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(unpackOp);
+
+  RankedTensorType unpackTensorType = unpackOp.getSourceType();
+  llvm::SmallVector<int64_t> readMaskShape(unpackTensorType.getShape());
+  for (unsigned int ii = 0; ii < inputVectorSizes.size(); ii++) {
+    readMaskShape[ii] = inputVectorSizes[ii];
+  }
+
+  // ReadMask is the size of tensor used to read and apply mask. It is
+  // set like this. Let's say the vectorSize (VS) array is size 'N' and
+  // the sourceShape(SS) is 'M' where M >= N
+  // Thus:
+  // ReadMaskShape = [VS[0], ..., VS[N-1], SS[N], ..., SS[M-1]]
+  auto vectorType =
+      VectorType::get(readMaskShape, unpackTensorType.getElementType());
+  ReifiedRankedShapedTypeDims reifiedRetShapes;
+  LogicalResult status =
+      cast<ReifyRankedShapedTypeOpInterface>(unpackOp.getOperation())
+          .reifyResultShapes(rewriter, reifiedRetShapes);
+  if (status.failed()) {
+    LDBG("Unable to reify result shapes of " << unpackOp);
+    return failure();
+  }
+  int64_t unpackRank = unpackTensorType.getRank();
+  arith::ConstantIndexOp zeroOp =
+      rewriter.create<arith::ConstantIndexOp>(unpackOp->getLoc(), 0);
+
+  vector::TransferReadOp readOp = rewriter.create<vector::TransferReadOp>(
+      unpackOp.getLoc(), vectorType, unpackOp.getSource(),
+      SmallVector<Value>(unpackRank, zeroOp),
+      rewriter.getMultiDimIdentityMap(unpackRank));
+
+  auto readMaskType = VectorType::get(readMaskShape, rewriter.getI1Type());
+  Value mask = rewriter.create<vector::CreateMaskOp>(
+      unpackOp.getLoc(), readMaskType,
----------------
dcaballe wrote:

Still seeing `unpackOp.getLoc()` in multiple places

https://github.com/llvm/llvm-project/pull/76087


More information about the Mlir-commits mailing list