[Mlir-commits] [mlir] [mlir][tensor] Rewrite tensor.pack as a constant (PR #93954)
Han-Chung Wang
llvmlistbot at llvm.org
Fri May 31 10:54:21 PDT 2024
================
@@ -45,9 +48,159 @@ struct GenerateToConstant : public OpRewritePattern<GenerateOp> {
}
};
+/// Rewrite tensor.pack with arith.constant if the pack is writing
+/// to an empty tensor and the destination shape is static.
+struct PackToConstant : OpRewritePattern<tensor::PackOp> {
+ using OpRewritePattern<tensor::PackOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tensor::PackOp packOp,
+ PatternRewriter &rewriter) const override {
+ auto constOp = packOp.getSource().getDefiningOp<arith::ConstantOp>();
+ if (!constOp)
+ return failure();
+ // Must be a dense constant.
+ auto denseAttr = dyn_cast<DenseElementsAttr>(constOp.getValue());
+ if (!denseAttr)
+ return failure();
+
+ // Bail out if the pack is used as a writing operation i.e.,
+ // the destination is not a tensor.empty.
+ if (!packOp.getDest().getDefiningOp<tensor::EmptyOp>())
+ return rewriter.notifyMatchFailure(packOp,
+ "expects empty tensor destination");
+ // Pack destination must have static shape.
+ if (!packOp.getDestType().hasStaticShape())
+ return rewriter.notifyMatchFailure(
+ packOp, "expects destination with static shape");
+
+ // Pack with padding is not supported currently.
+ // TODO: Insert padding values as a part of rewrite.
+ if (packOp.getPaddingValue())
+ return rewriter.notifyMatchFailure(packOp, "expects no padding value");
+
+ OpBuilder::InsertionGuard guard(rewriter);
+
+ // If it is a splat constant, rewrite the pack directly.
+ if (denseAttr.isSplat()) {
+ DenseElementsAttr packedDenseShape =
+ denseAttr.reshape(packOp.getDestType());
+ rewriter.setInsertionPoint(constOp);
+ rewriter.replaceOpWithNewOp<arith::ConstantOp>(packOp, packedDenseShape);
+
+ return success();
+ }
+
+ // Constant contains non-splat dense values.
+ // Move the data into a new packed buffer. Each value is placed into its new
+ // position as defined by the pack operation.
+ ArrayRef<char> srcRawData = denseAttr.getRawData();
+ SmallVector<char> destRawData(srcRawData.size());
+
+ int64_t numberOfElements = denseAttr.getNumElements();
+ SmallVector<int64_t> strides =
+ computeStrides(packOp.getDestType().getShape());
+
+ // Parallelize raw data movement to speedup large constant packing.
+ parallelFor(
+ packOp.getContext(), 0, numberOfElements,
+ [&](size_t destLinearizedIdx) {
+ // Step 1: De-linearize destination index.
+ // f(lin) = tmp[A][B][C]
+ SmallVector<int64_t> destIndices =
+ delinearize(destLinearizedIdx, strides);
+
+ // Step 2: Arrange the indexes based on the packing information.
+ // Compute inverse of outerDimsPerm to bring the loops into the
+ // canonical form tmp[A][B][a][b].
+ if (!packOp.getOuterDimsPerm().empty()) {
+ SmallVector<int64_t> inversePermutation =
+ invertPermutationVector(packOp.getOuterDimsPerm());
+ SmallVector<int64_t> tileLoops;
+ for (int64_t i = 0; i < packOp.getSourceType().getRank(); i++)
+ tileLoops.push_back(destIndices[i]);
+ applyPermutationToVector(tileLoops, inversePermutation);
+
+ SmallVector<int64_t> pointLoops;
+ for (size_t i = packOp.getSourceType().getRank();
+ i < destIndices.size(); i++) {
+ pointLoops.push_back(destIndices[i]);
+ }
+
+ destIndices = tileLoops;
+ destIndices.append(pointLoops.begin(), pointLoops.end());
+ }
+ assert(destIndices.size() ==
+ static_cast<size_t>(packOp.getDestType().getRank()));
----------------
hanhanW wrote:
nit: use `getDestRank()` method.
https://github.com/llvm/llvm-project/pull/93954
More information about the Mlir-commits
mailing list