[Mlir-commits] [mlir] [mlir][sparse] add a sparse_tensor.print operation (PR #83321)

Peiming Liu llvmlistbot at llvm.org
Wed Feb 28 12:02:25 PST 2024


================
@@ -598,6 +600,96 @@ struct GenSemiRingReduction : public OpRewritePattern<GenericOp> {
   }
 };
 
+/// Sparse rewriting rule for the print operator. This operation is mainly used
+/// for debugging and testing. As such, it lowers to the vector.print operation
+/// which only require very light-weight runtime support.
+struct PrintRewriter : public OpRewritePattern<PrintOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(PrintOp op,
+                                PatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    auto tensor = op.getTensor();
+    auto stt = getSparseTensorType(tensor);
+    // Header with NSE.
+    auto nse = rewriter.create<NumberOfEntriesOp>(loc, tensor);
+    rewriter.create<vector::PrintOp>(
+        loc, rewriter.getStringAttr("---- Sparse Tensor ----\nnse = "));
+    rewriter.create<vector::PrintOp>(loc, nse);
+    // Use the "codegen" foreach loop construct to iterate over
+    // all typical sparse tensor components for printing.
+    foreachFieldAndTypeInSparseTensor(stt, [&rewriter, &loc,
+                                            &tensor](Type tp, FieldIndex,
+                                                     SparseTensorFieldKind kind,
+                                                     Level l, LevelType) {
+      switch (kind) {
+      case SparseTensorFieldKind::StorageSpec: {
+        break;
+      }
+      case SparseTensorFieldKind::PosMemRef: {
+        auto lvl = constantIndex(rewriter, loc, l);
+        rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("pos["));
+        rewriter.create<vector::PrintOp>(
+            loc, lvl, vector::PrintPunctuation::NoPunctuation);
+        rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("] : "));
+        auto poss = rewriter.create<ToPositionsOp>(loc, tp, tensor, l);
+        printContents(rewriter, loc, tp, poss);
+        break;
+      }
+      case SparseTensorFieldKind::CrdMemRef: {
+        auto lvl = constantIndex(rewriter, loc, l);
+        rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("crd["));
+        rewriter.create<vector::PrintOp>(
+            loc, lvl, vector::PrintPunctuation::NoPunctuation);
+        rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("] : "));
+        auto crds = rewriter.create<ToCoordinatesOp>(loc, tp, tensor, l);
+        printContents(rewriter, loc, tp, crds);
+        break;
+      }
+      case SparseTensorFieldKind::ValMemRef: {
+        rewriter.create<vector::PrintOp>(loc,
+                                         rewriter.getStringAttr("values : "));
+        auto vals = rewriter.create<ToValuesOp>(loc, tp, tensor);
+        printContents(rewriter, loc, tp, vals);
+        break;
+      }
+      }
+      return true;
+    });
+    rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("----\n"));
+    rewriter.eraseOp(op);
+    return success();
+  }
+
+private:
+  // Helper to print contents of a single memref. Note that for the "push_back"
+  // vectors, this prints the full capacity, not just the size. This is done
+  // on purpose, so that clients see how much storage has been allocated in
+  // total. Contents of the extra capacity in the buffer may be uninitialized
+  // (unless the flag enable-buffer-initialization is set to true).
+  //
+  // Generates code to print:
+  //    ( a0, a1, ... )
+  static void printContents(PatternRewriter &rewriter, Location loc, Type tp,
+                            Value vec) {
+    // Open bracket.
+    rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Open);
+    // For loop over elements.
+    auto zero = constantIndex(rewriter, loc, 0);
+    auto size = rewriter.create<memref::DimOp>(loc, vec, zero);
+    auto step = constantIndex(rewriter, loc, 1);
+    auto forOp = rewriter.create<scf::ForOp>(loc, zero, size, step);
----------------
PeimingLiu wrote:

It might need a nested loop to handle batch levels :)

https://github.com/llvm/llvm-project/pull/83321


More information about the Mlir-commits mailing list