[Mlir-commits] [mlir] [mlir][sparse] add a sparse_tensor.print operation (PR #83321)
Aart Bik
llvmlistbot at llvm.org
Wed Feb 28 12:11:29 PST 2024
================
@@ -598,6 +600,96 @@ struct GenSemiRingReduction : public OpRewritePattern<GenericOp> {
}
};
+/// Sparse rewriting rule for the print operator. This operation is mainly used
+/// for debugging and testing. As such, it lowers to the vector.print operation
+/// which only require very light-weight runtime support.
+struct PrintRewriter : public OpRewritePattern<PrintOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(PrintOp op,
+ PatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ auto tensor = op.getTensor();
+ auto stt = getSparseTensorType(tensor);
+ // Header with NSE.
+ auto nse = rewriter.create<NumberOfEntriesOp>(loc, tensor);
+ rewriter.create<vector::PrintOp>(
+ loc, rewriter.getStringAttr("---- Sparse Tensor ----\nnse = "));
+ rewriter.create<vector::PrintOp>(loc, nse);
+ // Use the "codegen" foreach loop construct to iterate over
+ // all typical sparse tensor components for printing.
+ foreachFieldAndTypeInSparseTensor(stt, [&rewriter, &loc,
+ &tensor](Type tp, FieldIndex,
+ SparseTensorFieldKind kind,
+ Level l, LevelType) {
+ switch (kind) {
+ case SparseTensorFieldKind::StorageSpec: {
+ break;
+ }
+ case SparseTensorFieldKind::PosMemRef: {
+ auto lvl = constantIndex(rewriter, loc, l);
+ rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("pos["));
+ rewriter.create<vector::PrintOp>(
+ loc, lvl, vector::PrintPunctuation::NoPunctuation);
+ rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("] : "));
+ auto poss = rewriter.create<ToPositionsOp>(loc, tp, tensor, l);
----------------
aartbik wrote:
Agreed, I went for consistency rather than beauty here. Perhaps the compromise is to remove the s from all, let me do that.
https://github.com/llvm/llvm-project/pull/83321
More information about the Mlir-commits
mailing list