[Mlir-commits] [mlir] [mlir][vector] add tensor.concat, bitcast, expand_shape, collapse_shape vectorization support (PR #97297)

Oleksandr Alex Zinenko llvmlistbot at llvm.org
Tue Jul 9 02:11:44 PDT 2024


================
@@ -1718,6 +1718,209 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
   return success();
 }
 
+/// Vectorize a `tensor::expandshape` to these 3 Ops:
+///   Vector::TransferReadOp - Reads a vector from the source tensor
+///   ShapeCastOp - Reshape the data based on the target.
+///   vector::TransferWriteOp. - Write the result vector back to the destination
+///   tensor
+static LogicalResult lowerTensorReshape(RewriterBase &rewriter,
+                                        Operation *inputOp,
+                                        ArrayRef<int64_t> inputVectorSizes,
+                                        SmallVectorImpl<Value> &newResults) {
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(inputOp);
+  auto src = inputOp->getOperand(0);
+  auto srcType = mlir::dyn_cast<ShapedType>(src.getType());
+  auto result = inputOp->getResults()[0];
+  auto resultType = mlir::dyn_cast<ShapedType>(result.getType());
+  ArrayRef<int64_t> resultShape = resultType.getShape();
+  ArrayRef<int64_t> srcShape = srcType.getShape();
+  Location loc = inputOp->getLoc();
+
+  llvm::SmallVector<int64_t> srcVectorizedShape;
+  llvm::SmallDenseMap<int64_t, int64_t> shapeScales;
+
+  auto getVectorizeShape = [&](ArrayRef<int64_t> &retShape,
+                               ArrayRef<int64_t> &inputShape) {
+    bool isResultShapeBigger = srcType.getRank() < resultType.getRank();
+
+    int64_t cur = 1, resultIdx = 0;
+    for (auto [srcIdx, ss] : llvm::enumerate(inputShape)) {
+      cur *= ss;
+      if (!isResultShapeBigger) {
+        // collapse
+        srcVectorizedShape.emplace_back(ss);
+        if (cur == retShape[resultIdx]) {
+          if (shapeScales.count(resultIdx)) {
+            srcVectorizedShape.back() *= shapeScales[resultIdx];
+          }
+          cur = 1;
+          resultIdx++;
+        }
+      } else {
+        // expand
+        if (cur == retShape[resultIdx]) {
+          srcVectorizedShape.emplace_back(cur);
+          if (shapeScales.count(srcIdx)) {
+            srcVectorizedShape.back() *= shapeScales[srcIdx];
+          }
+          cur = 1;
+          resultIdx++;
+        }
+      }
+    }
+  };
+  if (!inputVectorSizes.empty()) {
+    for (auto [idx, vs] : llvm::enumerate(inputVectorSizes)) {
+      if (vs != resultShape[idx])
+        shapeScales[idx] = vs / resultShape[idx];
+    }
+
+    bool isResultShapeBigger = srcType.getRank() < resultType.getRank();
+    if (!isResultShapeBigger) {
+      getVectorizeShape(resultShape, srcShape);
+    } else {
+      getVectorizeShape(srcShape, resultShape);
+    }
+  } else {
+    srcVectorizedShape.assign(srcShape.begin(), srcShape.end());
+  }
+  // read
+  auto padValue = rewriter.create<arith::ConstantOp>(
+      loc, rewriter.getZeroAttr(srcType.getElementType()));
+  Value readResult = vector::createReadOrMaskedRead(
+      rewriter, loc, src,
+      inputVectorSizes.empty() ? srcType.getShape() : srcVectorizedShape,
+      padValue, false);
+
+  auto shapeCastType =
+      VectorType::get(inputVectorSizes.empty() ? resultShape : inputVectorSizes,
+                      resultType.getElementType());
+  vector::ShapeCastOp shapeCastOp =
+      rewriter.create<vector::ShapeCastOp>(loc, shapeCastType, readResult);
+
+  // write
+  SmallVector<OpFoldResult> destSizes;
+  for (auto size : resultShape) {
+    destSizes.emplace_back(rewriter.getIndexAttr(size));
+  }
----------------
ftynse wrote:

Reserve space before appending in a loop. Or better, use a proper combinator like `map_to_vector`.

https://github.com/llvm/llvm-project/pull/97297


More information about the Mlir-commits mailing list