[Mlir-commits] [mlir] [mlir][tensor] Rewrite tensor.pack as a constant (PR #93954)
Han-Chung Wang
llvmlistbot at llvm.org
Fri May 31 10:54:22 PDT 2024
================
@@ -45,9 +48,159 @@ struct GenerateToConstant : public OpRewritePattern<GenerateOp> {
}
};
+/// Rewrite tensor.pack with arith.constant if the pack is writing
+/// to an empty tensor and the destination shape is static.
+struct PackToConstant : OpRewritePattern<tensor::PackOp> {
+ using OpRewritePattern<tensor::PackOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tensor::PackOp packOp,
+ PatternRewriter &rewriter) const override {
+ auto constOp = packOp.getSource().getDefiningOp<arith::ConstantOp>();
+ if (!constOp)
+ return failure();
+ // Must be a dense constant.
+ auto denseAttr = dyn_cast<DenseElementsAttr>(constOp.getValue());
+ if (!denseAttr)
+ return failure();
+
+ // Bail out if the pack is used as a writing operation i.e.,
+ // the destination is not a tensor.empty.
+ if (!packOp.getDest().getDefiningOp<tensor::EmptyOp>())
+ return rewriter.notifyMatchFailure(packOp,
+ "expects empty tensor destination");
+ // Pack destination must have static shape.
+ if (!packOp.getDestType().hasStaticShape())
+ return rewriter.notifyMatchFailure(
+ packOp, "expects destination with static shape");
+
+ // Pack with padding is not supported currently.
+ // TODO: Insert padding values as a part of rewrite.
+ if (packOp.getPaddingValue())
+ return rewriter.notifyMatchFailure(packOp, "expects no padding value");
----------------
hanhanW wrote:
nit: perhaps say that it is NIY (not implemented yet) in the failure message.
https://github.com/llvm/llvm-project/pull/93954
More information about the Mlir-commits
mailing list