[Mlir-commits] [mlir] [mlir][sparse] add a sparse_tensor.print operation (PR #83321)

Yinying Li llvmlistbot at llvm.org
Wed Feb 28 11:48:08 PST 2024


================
@@ -598,6 +600,96 @@ struct GenSemiRingReduction : public OpRewritePattern<GenericOp> {
   }
 };
 
+/// Sparse rewriting rule for the print operator. This operation is mainly used
+/// for debugging and testing. As such, it lowers to the vector.print operation
+/// which only require very light-weight runtime support.
+struct PrintRewriter : public OpRewritePattern<PrintOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(PrintOp op,
+                                PatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    auto tensor = op.getTensor();
+    auto stt = getSparseTensorType(tensor);
+    // Header with NSE.
+    auto nse = rewriter.create<NumberOfEntriesOp>(loc, tensor);
+    rewriter.create<vector::PrintOp>(
+        loc, rewriter.getStringAttr("---- Sparse Tensor ----\nnse = "));
+    rewriter.create<vector::PrintOp>(loc, nse);
+    // Use the "codegen" foreach loop construct to iterate over
+    // all typical sparse tensor components for printing.
+    foreachFieldAndTypeInSparseTensor(stt, [&rewriter, &loc,
+                                            &tensor](Type tp, FieldIndex,
+                                                     SparseTensorFieldKind kind,
+                                                     Level l, LevelType) {
+      switch (kind) {
+      case SparseTensorFieldKind::StorageSpec: {
+        break;
+      }
+      case SparseTensorFieldKind::PosMemRef: {
+        auto lvl = constantIndex(rewriter, loc, l);
+        rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("pos["));
+        rewriter.create<vector::PrintOp>(
+            loc, lvl, vector::PrintPunctuation::NoPunctuation);
+        rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("] : "));
+        auto poss = rewriter.create<ToPositionsOp>(loc, tp, tensor, l);
----------------
yinying-lisa-li wrote:

I personally prefer pos as the variable name. Poss took me a second to grasp. But I know it's consistent with other variables names like vals, so it's up to you. :)

https://github.com/llvm/llvm-project/pull/83321


More information about the Mlir-commits mailing list