[Mlir-commits] [mlir] [mlir][vector] add tensor.concat, bitcast, expand_shape, collapse_shape vectorization support (PR #97297)

Oleksandr Alex Zinenko llvmlistbot at llvm.org
Tue Jul 9 02:11:44 PDT 2024


================
@@ -1718,6 +1718,209 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
   return success();
 }
 
+/// Vectorize a `tensor::expandshape` to these 3 Ops:
+///   Vector::TransferReadOp - Reads a vector from the source tensor
+///   ShapeCastOp - Reshape the data based on the target.
+///   vector::TransferWriteOp. - Write the result vector back to the destination
+///   tensor
+static LogicalResult lowerTensorReshape(RewriterBase &rewriter,
+                                        Operation *inputOp,
+                                        ArrayRef<int64_t> inputVectorSizes,
+                                        SmallVectorImpl<Value> &newResults) {
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(inputOp);
+  auto src = inputOp->getOperand(0);
+  auto srcType = mlir::dyn_cast<ShapedType>(src.getType());
+  auto result = inputOp->getResults()[0];
+  auto resultType = mlir::dyn_cast<ShapedType>(result.getType());
+  ArrayRef<int64_t> resultShape = resultType.getShape();
+  ArrayRef<int64_t> srcShape = srcType.getShape();
+  Location loc = inputOp->getLoc();
+
+  llvm::SmallVector<int64_t> srcVectorizedShape;
+  llvm::SmallDenseMap<int64_t, int64_t> shapeScales;
+
+  auto getVectorizeShape = [&](ArrayRef<int64_t> &retShape,
+                               ArrayRef<int64_t> &inputShape) {
+    bool isResultShapeBigger = srcType.getRank() < resultType.getRank();
+
+    int64_t cur = 1, resultIdx = 0;
+    for (auto [srcIdx, ss] : llvm::enumerate(inputShape)) {
+      cur *= ss;
+      if (!isResultShapeBigger) {
+        // collapse
+        srcVectorizedShape.emplace_back(ss);
+        if (cur == retShape[resultIdx]) {
+          if (shapeScales.count(resultIdx)) {
+            srcVectorizedShape.back() *= shapeScales[resultIdx];
+          }
+          cur = 1;
+          resultIdx++;
+        }
----------------
ftynse wrote:

https://llvm.org/docs/CodingStandards.html#use-early-exits-and-continue-to-simplify-code

https://github.com/llvm/llvm-project/pull/97297


More information about the Mlir-commits mailing list