[Mlir-commits] [mlir] [mlir][sparse] add a sparse_tensor.print operation (PR #83321)
Peiming Liu
llvmlistbot at llvm.org
Wed Feb 28 12:02:25 PST 2024
================
@@ -598,6 +600,96 @@ struct GenSemiRingReduction : public OpRewritePattern<GenericOp> {
}
};
+/// Sparse rewriting rule for the print operator. This operation is mainly used
+/// for debugging and testing. As such, it lowers to the vector.print operation
+/// which only require very light-weight runtime support.
+struct PrintRewriter : public OpRewritePattern<PrintOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(PrintOp op,
+ PatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ auto tensor = op.getTensor();
+ auto stt = getSparseTensorType(tensor);
+ // Header with NSE.
+ auto nse = rewriter.create<NumberOfEntriesOp>(loc, tensor);
+ rewriter.create<vector::PrintOp>(
+ loc, rewriter.getStringAttr("---- Sparse Tensor ----\nnse = "));
+ rewriter.create<vector::PrintOp>(loc, nse);
+ // Use the "codegen" foreach loop construct to iterate over
+ // all typical sparse tensor components for printing.
+ foreachFieldAndTypeInSparseTensor(stt, [&rewriter, &loc,
+ &tensor](Type tp, FieldIndex,
+ SparseTensorFieldKind kind,
+ Level l, LevelType) {
+ switch (kind) {
+ case SparseTensorFieldKind::StorageSpec: {
+ break;
+ }
+ case SparseTensorFieldKind::PosMemRef: {
+ auto lvl = constantIndex(rewriter, loc, l);
+ rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("pos["));
+ rewriter.create<vector::PrintOp>(
+ loc, lvl, vector::PrintPunctuation::NoPunctuation);
+ rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("] : "));
+ auto poss = rewriter.create<ToPositionsOp>(loc, tp, tensor, l);
+ printContents(rewriter, loc, tp, poss);
+ break;
+ }
+ case SparseTensorFieldKind::CrdMemRef: {
+ auto lvl = constantIndex(rewriter, loc, l);
+ rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("crd["));
+ rewriter.create<vector::PrintOp>(
+ loc, lvl, vector::PrintPunctuation::NoPunctuation);
+ rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("] : "));
+ auto crds = rewriter.create<ToCoordinatesOp>(loc, tp, tensor, l);
+ printContents(rewriter, loc, tp, crds);
+ break;
+ }
+ case SparseTensorFieldKind::ValMemRef: {
+ rewriter.create<vector::PrintOp>(loc,
+ rewriter.getStringAttr("values : "));
+ auto vals = rewriter.create<ToValuesOp>(loc, tp, tensor);
+ printContents(rewriter, loc, tp, vals);
+ break;
+ }
+ }
+ return true;
+ });
+ rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("----\n"));
+ rewriter.eraseOp(op);
+ return success();
+ }
+
+private:
+ // Helper to print contents of a single memref. Note that for the "push_back"
+ // vectors, this prints the full capacity, not just the size. This is done
+ // on purpose, so that clients see how much storage has been allocated in
+ // total. Contents of the extra capacity in the buffer may be uninitialized
+ // (unless the flag enable-buffer-initialization is set to true).
+ //
+ // Generates code to print:
+ // ( a0, a1, ... )
+ static void printContents(PatternRewriter &rewriter, Location loc, Type tp,
+ Value vec) {
+ // Open bracket.
+ rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Open);
+ // For loop over elements.
+ auto zero = constantIndex(rewriter, loc, 0);
+ auto size = rewriter.create<memref::DimOp>(loc, vec, zero);
+ auto step = constantIndex(rewriter, loc, 1);
+ auto forOp = rewriter.create<scf::ForOp>(loc, zero, size, step);
----------------
PeimingLiu wrote:
It might need a nested loop to handle batch levels :)
https://github.com/llvm/llvm-project/pull/83321
More information about the Mlir-commits
mailing list