[Mlir-commits] [mlir] [mlir][tensor] Rewrite tensor.pack as a constant (PR #93954)
Han-Chung Wang
llvmlistbot at llvm.org
Fri May 31 10:54:21 PDT 2024
================
@@ -45,9 +48,159 @@ struct GenerateToConstant : public OpRewritePattern<GenerateOp> {
}
};
+/// Rewrite tensor.pack with arith.constant if the pack is writing
+/// to an empty tensor and the destination shape is static.
+struct PackToConstant : OpRewritePattern<tensor::PackOp> {
+ using OpRewritePattern<tensor::PackOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tensor::PackOp packOp,
+ PatternRewriter &rewriter) const override {
+ auto constOp = packOp.getSource().getDefiningOp<arith::ConstantOp>();
+ if (!constOp)
+ return failure();
+ // Must be a dense constant.
+ auto denseAttr = dyn_cast<DenseElementsAttr>(constOp.getValue());
+ if (!denseAttr)
+ return failure();
+
+ // Bail out if the pack is used as a writing operation i.e.,
+ // the destination is not a tensor.empty.
+ if (!packOp.getDest().getDefiningOp<tensor::EmptyOp>())
+ return rewriter.notifyMatchFailure(packOp,
+ "expects empty tensor destination");
+ // Pack destination must have static shape.
+ if (!packOp.getDestType().hasStaticShape())
+ return rewriter.notifyMatchFailure(
+ packOp, "expects destination with static shape");
+
+ // Pack with padding is not supported currently.
+ // TODO: Insert padding values as a part of rewrite.
+ if (packOp.getPaddingValue())
+ return rewriter.notifyMatchFailure(packOp, "expects no padding value");
+
+ OpBuilder::InsertionGuard guard(rewriter);
+
+ // If it is a splat constant, rewrite the pack directly.
+ if (denseAttr.isSplat()) {
+ DenseElementsAttr packedDenseShape =
+ denseAttr.reshape(packOp.getDestType());
+ rewriter.setInsertionPoint(constOp);
+ rewriter.replaceOpWithNewOp<arith::ConstantOp>(packOp, packedDenseShape);
+
+ return success();
+ }
----------------
hanhanW wrote:
This case is already covered in folders.
https://github.com/llvm/llvm-project/blob/07bd43945789e3fc8f57d21484a7f683d17166f3/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp#L4278-L4287
https://github.com/llvm/llvm-project/pull/93954
More information about the Mlir-commits
mailing list