[Mlir-commits] [mlir] 550288c - [mlir][sparse] Add new utility class to help generates loop structures over sparse tensors; Implement foreach operator.

Peiming Liu llvmlistbot at llvm.org
Fri Sep 30 14:42:51 PDT 2022


Author: Peiming Liu
Date: 2022-09-30T21:42:42Z
New Revision: 550288cbc3967333886f87677ac8df93b928d32b

URL: https://github.com/llvm/llvm-project/commit/550288cbc3967333886f87677ac8df93b928d32b
DIFF: https://github.com/llvm/llvm-project/commit/550288cbc3967333886f87677ac8df93b928d32b.diff

LOG: [mlir][sparse] Add new utility class to help generates loop structures over sparse tensors; Implement foreach operator.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D134782

Added: 
    mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_codegen_foreach.mlir

Modified: 
    mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
index e2d8e621655f3..89ae924cf8b27 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
@@ -8,12 +8,237 @@
 
 #include "CodegenUtils.h"
 
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Utils/Utils.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/Matchers.h"
 #include "mlir/IR/Types.h"
 #include "mlir/IR/Value.h"
 
 using namespace mlir;
 using namespace mlir::sparse_tensor;
 
+/// Generates a pointer/index load from the sparse storage scheme. Narrower
+/// data types need to be zero extended before casting the value into the
+/// index type used for looping and indexing.
+static Value genIndexLoad(OpBuilder &builder, Location loc, Value ptr,
+                          Value s) {
+  // For the scalar case, we simply zero extend narrower indices into 64-bit
+  // values before casting to index without a performance penalty. Here too,
+  // however, indices that already are 64-bit, in theory, cannot express the
+  // full range as explained above.
+  Value load = builder.create<memref::LoadOp>(loc, ptr, s);
+  if (!load.getType().isa<IndexType>()) {
+    if (load.getType().getIntOrFloatBitWidth() < 64)
+      load = builder.create<arith::ExtUIOp>(loc, builder.getI64Type(), load);
+    load =
+        builder.create<arith::IndexCastOp>(loc, builder.getIndexType(), load);
+  }
+  return load;
+}
+
+//===----------------------------------------------------------------------===//
+// Sparse tensor loop emitter class implementations
+//===----------------------------------------------------------------------===//
+
+SparseTensorLoopEmitter::SparseTensorLoopEmitter(ValueRange tensors)
+    : tensors(tensors.begin(), tensors.end()), dims(tensors.size()),
+      pidxs(tensors.size()), coord(tensors.size()), highs(tensors.size()),
+      sizes(tensors.size()), ptrBuffer(tensors.size()),
+      idxBuffer(tensors.size()), valBuffer(tensors.size()), loopStack(),
+      curLv(tensors.size(), 0) {
+  for (size_t i = 0, e = tensors.size(); i < e; i++) {
+    auto t = tensors[i];
+    auto rtp = t.getType().cast<RankedTensorType>();
+    auto rank = static_cast<size_t>(rtp.getRank());
+    auto enc = getSparseTensorEncoding(rtp);
+    if (enc)
+      for (auto dimTp : enc.getDimLevelType())
+        dims[i].push_back(dimTp);
+    else
+      dims[i].assign(rank, SparseTensorEncodingAttr::DimLevelType::Dense);
+
+    // Initialize using empty value.
+    pidxs[i].assign(rank, Value());
+    coord[i].assign(rank, Value());
+    highs[i].assign(rank, Value());
+    sizes[i].assign(rank, Value());
+    ptrBuffer[i].assign(rank, Value());
+    idxBuffer[i].assign(rank, Value());
+  }
+}
+
+void SparseTensorLoopEmitter::initializeLoopEmit(OpBuilder &builder,
+                                                 Location loc) {
+  // For every tensor, find lower and upper bound on dimensions, set the
+  // same bounds on loop indices, and obtain dense or sparse buffer(s).
+  // TODO: Provides ability to generate loop on output buffer (with undef
+  // dim level in Merger in GenericOp Sparsification).
+  for (size_t t = 0, e = tensors.size(); t < e; t++) {
+    auto tensor = tensors[t];
+    auto rtp = tensor.getType().cast<RankedTensorType>();
+    auto rank = rtp.getRank();
+    auto shape = rtp.getShape();
+    auto enc = getSparseTensorEncoding(rtp);
+    auto dynShape = {ShapedType::kDynamicSize};
+    // Scan all dimensions of current tensor.
+    for (int64_t d = 0; d < rank; d++) {
+      // This should be called only once at beginning.
+      assert(!ptrBuffer[t][d] && !idxBuffer[t][d] && !sizes[t][d] &&
+             !highs[t][d]);
+      // Handle sparse storage schemes.
+      if (isCompressedDim(dims[t][d])) {
+        auto ptrTp =
+            MemRefType::get(dynShape, getPointerOverheadType(builder, enc));
+        auto indTp =
+            MemRefType::get(dynShape, getIndexOverheadType(builder, enc));
+        auto dim = builder.getIndexAttr(d);
+        // Generate sparse primitives to obtains pointer and indices.
+        ptrBuffer[t][d] = builder.create<ToPointersOp>(loc, ptrTp, tensor, dim);
+        idxBuffer[t][d] = builder.create<ToIndicesOp>(loc, indTp, tensor, dim);
+      } else if (isSingletonDim(dims[t][d])) {
+        llvm_unreachable("TODO: not implemented yet");
+      }
+
+      // Find upper bound in current dimension.
+      unsigned p = toOrigDim(enc, d);
+      Value up = mlir::linalg::createOrFoldDimOp(builder, loc, tensor, p);
+      sizes[t][d] = highs[t][d] = up;
+    }
+    // Perform the required bufferization. Dense inputs materialize
+    // from the input tensors. Dense outputs need special handling.
+    // Sparse inputs use sparse primitives to obtain the values.
+    Type elementType = rtp.getElementType();
+
+    if (!enc) {
+      // Non-annotated dense tensors.
+      auto denseTp = MemRefType::get(shape, elementType);
+      // This is not the output tensor
+      valBuffer[t] =
+          builder.create<bufferization::ToMemrefOp>(loc, denseTp, tensor);
+    } else {
+      // Annotated sparse tensors.
+      auto dynShape = {ShapedType::kDynamicSize};
+      auto sparseTp = MemRefType::get(dynShape, elementType);
+      valBuffer[t] = builder.create<ToValuesOp>(loc, sparseTp, tensor);
+    }
+    // Prepare to enter the first dim for all (input) tensors
+    prepareLoopOverTensorAtDim(builder, loc, t, 0);
+  }
+}
+
+Operation *SparseTensorLoopEmitter::enterLoopOverTensorAtDim(
+    OpBuilder &builder, Location loc, size_t tid, size_t dim,
+    ArrayRef<Value> reduc) {
+  assert(dims[tid].size() > dim);
+  // We can not re-enter the same level.
+  assert(!coord[tid][dim]);
+  Value step = constantIndex(builder, loc, 1);
+  bool isCompressed = isCompressedDim(dims[tid][dim]);
+  assert(isDenseDim(dims[tid][dim]) || isCompressedDim(dims[tid][dim]));
+
+  Value lo = isCompressed ? pidxs[tid][dim] : constantIndex(builder, loc, 0);
+  Value hi = highs[tid][dim];
+
+  // TODO: support reduction.
+  if (!reduc.empty())
+    llvm_unreachable("TODO: not implemented yet");
+
+  scf::ForOp forOp = builder.create<scf::ForOp>(loc, lo, hi, step, reduc);
+  builder.setInsertionPointToStart(forOp.getBody());
+  Value iv = forOp.getInductionVar();
+  Operation *loop = forOp;
+
+  assert(iv);
+  if (isCompressed) {
+    pidxs[tid][dim] = iv;
+    // Generating a load on the indices array yields the coordinate.
+    Value ptr = idxBuffer[tid][dim];
+    // TODO: generates load for vector value.
+    coord[tid][dim] = genIndexLoad(builder, loc, ptr, iv);
+  } else {
+    // Dense tensor, the coordinates is the inducation variable.
+    coord[tid][dim] = iv;
+    // generate pidx for dense dim (pidx = i * sz + j)
+    // TODO: handle vector loop.
+    Value p = dim == 0 ? constantIndex(builder, loc, 0) : pidxs[tid][dim - 1];
+    Value mul = builder.create<arith::MulIOp>(loc, sizes[tid][dim], p);
+    Value add = builder.create<arith::AddIOp>(loc, mul, iv);
+    pidxs[tid][dim] = add;
+  }
+
+  // Prepares for next dim if this is not currently the innermost dimension.
+  if (dim != dims[tid].size() - 1)
+    prepareLoopOverTensorAtDim(builder, loc, tid, dim + 1);
+
+  loopStack.push_back(LoopLevelInfo({tid}, {dim}, coord[tid][dim]));
+  return loop;
+}
+
+void SparseTensorLoopEmitter::enterCoiterationOverTensorsAtDims(
+    OpBuilder &builder, Location loc, ArrayRef<size_t> ts,
+    ArrayRef<size_t> ds) {
+  llvm_unreachable("TODO: unimplemented");
+}
+
+bool SparseTensorLoopEmitter::prepareLoopOverTensorAtDim(OpBuilder &builder,
+                                                         Location loc,
+                                                         size_t tid,
+                                                         size_t dim) {
+  // TODO: generate loop iteration on output tensor based on the shape
+  // instead of pointer/indices arrays.
+  assert(dims[tid].size() > dim);
+
+  if (isDenseDim(dims[tid][dim]))
+    return false;
+
+  // Either the first dimension, or the previous dimension has been set.
+  assert(dim == 0 || pidxs[tid][dim - 1]);
+  if (isCompressedDim(dims[tid][dim])) {
+    Value ptr = ptrBuffer[tid][dim];
+    Value c1 = constantIndex(builder, loc, 1);
+    Value pLo = dim == 0 ? constantIndex(builder, loc, 0) : pidxs[tid][dim - 1];
+    Value pHi = builder.create<arith::AddIOp>(loc, pLo, c1);
+
+    pidxs[tid][dim] = genIndexLoad(builder, loc, ptr, pLo);
+    highs[tid][dim] = genIndexLoad(builder, loc, ptr, pHi);
+
+    return true;
+  }
+
+  if (isSingletonDim(dims[tid][dim]))
+    llvm_unreachable("TODO: not implemented yet");
+
+  llvm_unreachable("Unrecognizable dimesion type!");
+}
+
+Value SparseTensorLoopEmitter::emitExtraLocalsForTensorsAtDims(
+    OpBuilder &builder, Location loc, size_t tid, size_t dim) {
+  llvm_unreachable("TODO: not implemented yet");
+}
+
+void SparseTensorLoopEmitter::exitCurrentLoop() {
+  // Clean up the values, it would help use to discover potential bug at a
+  // earlier stage (instead of silently using a wrong value).
+  LoopLevelInfo &loopInfo = loopStack.back();
+  assert(loopInfo.tensors.size() == loopInfo.dims.size());
+  for (auto info : llvm::zip(loopInfo.tensors, loopInfo.dims)) {
+    auto tid = std::get<0>(info);
+    auto dim = std::get<1>(info);
+    assert(pidxs[tid][dim] && coord[tid][dim] && highs[tid][dim]);
+    // Reset to null.
+    pidxs[tid][dim] = Value();
+    coord[tid][dim] = Value();
+    if (!isDenseDim(dims[tid][dim]))
+      // Dense dimension, high is fixed.
+      highs[tid][dim] = Value();
+  }
+  loopStack.pop_back();
+}
+
 //===----------------------------------------------------------------------===//
 // ExecutionEngine/SparseTensorUtils helper functions.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
index 72ae4dea54d75..63456a8bcc2cb 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
@@ -28,9 +28,128 @@ class Value;
 namespace sparse_tensor {
 
 //===----------------------------------------------------------------------===//
-// ExecutionEngine/SparseTensorUtils helper functions.
+// SparseTensorLoopEmiter class, manages sparse tensors and helps to generate
+// loop structure to (co-iterate) sparse tensors.
+//
+// An example usage:
+// To generate following loops over T1<?x?> and T2<?x?>
+//
+// for i in T1[0] {
+//   for j : T2[0] {
+//     for k : T1[1] {}
+//     for k : T2[1] {}
+//   }
+// }
+//
+// One can use
+//
+// SparseTensorLoopEmiter loopEmiter({T1, T1});
+// loopEmiter.initializeLoopEmit();
+// loopEmiter.enterLoopOverTensorAtDim(T1, 0);
+// loopEmiter.enterLoopOverTensorAtDim(T2, 0);
+// loopEmiter.enterLoopOverTensorAtDim(T1, 1);
+// loopEmiter.exitCurrentLoop();
+// loopEmiter.enterLoopOverTensorAtDim(T2, 1);
+// for 0 -> 3:
+//    loopEmiter.exitCurrentLoop();
 //===----------------------------------------------------------------------===//
 
+// TODO: Sparsification should also rely on this class to generate loops.
+class SparseTensorLoopEmitter {
+public:
+  /// Constructor: take an array of tensors inputs, on which the generated loops
+  /// will iterate on. The index of the tensor in the array is also the
+  /// tensor id (tid) used in related functions.
+  explicit SparseTensorLoopEmitter(ValueRange tensors);
+
+  ///
+  /// Core functions.
+  ///
+
+  /// Starts a loop emitting session:
+  /// 1. Generates all the buffers needed to iterate tensors.
+  /// 2. Generates the lo/hi bounds to iterate tensors[0].
+  void initializeLoopEmit(OpBuilder &builder, Location loc);
+
+  // TODO: Gets rid of `dim` in the argument list? Track the dimension we
+  // are currently at internally. Then it would be enterNextDimForTensor.
+
+  /// Emits loop over tensor[dim], it assumes that loops between
+  /// tensor[0...dim - 1] have already been generated.
+  /// It also prepares to enter tensor[dim + 1].
+  Operation *enterLoopOverTensorAtDim(OpBuilder &builder, Location loc,
+                                      size_t tid, size_t dim,
+                                      ArrayRef<Value> reduc = {});
+
+  /// Emits a coiteration loop over a set of tensors.
+  // TODO: not yet implemented
+  void enterCoiterationOverTensorsAtDims(OpBuilder &builder, Location loc,
+                                         ArrayRef<size_t> ts,
+                                         ArrayRef<size_t> ds);
+
+  /// Emits extra locals, since the locals might not be in simplified lattices
+  /// point used to generate the loops, but are still required to generates
+  /// expressions.
+  Value emitExtraLocalsForTensorsAtDims(OpBuilder &builder, Location loc,
+                                        size_t tid, size_t dim);
+
+  void exitCurrentLoop();
+
+  /// Return the array of coordinate for all the loop generated till now.
+  void getCoordinateArray(SmallVectorImpl<Value> &coords) {
+    for (auto &l : loopStack)
+      coords.push_back(l.idx);
+  }
+
+  ///
+  /// Getters.
+  ///
+
+  Value getTensorValueBuffer(size_t tid) { return valBuffer[tid]; }
+  Value getLastLevelTensorPointerIndex(size_t tid) {
+    return pidxs[tid].back();
+  };
+
+private:
+  struct LoopLevelInfo {
+    LoopLevelInfo(ArrayRef<size_t> ts, ArrayRef<size_t> ds, Value idx)
+        : tensors(ts), dims(ds), idx(idx) {}
+    llvm::SmallVector<size_t, 4> tensors;
+    llvm::SmallVector<size_t, 4> dims;
+    Value idx;
+  };
+
+  /// Return false if tid[dim] is a dense dimension that does not need to be
+  /// prepared (to be used by sparsification for needUniv).
+  bool prepareLoopOverTensorAtDim(OpBuilder &builder, Location loc, size_t tid,
+                                  size_t dim);
+
+  /// Input (TODO: and output) tensors.
+  std::vector<Value> tensors;
+  /// The dim type array for each tensor.
+  std::vector<std::vector<SparseTensorEncodingAttr::DimLevelType>> dims;
+  /// Sparse iteration information (by tensor and dim). These arrays
+  /// are updated to remain current within the current loop.
+  std::vector<std::vector<Value>> pidxs;
+  std::vector<std::vector<Value>> coord;
+  std::vector<std::vector<Value>> highs;
+  /// Universal dense indices and upper bounds (by index). The sizes array is
+  /// set once with the inferred dimension sizes.
+  std::vector<std::vector<Value>> sizes;
+  std::vector<std::vector<Value>> ptrBuffer; // to_pointers
+  std::vector<std::vector<Value>> idxBuffer; // to_indices
+  std::vector<Value> valBuffer;              // to_value
+
+  std::vector<LoopLevelInfo> loopStack;
+  // TODO: not yet used, it should track the current level for each tensor
+  // to help eliminate `dim` paramters from above APIs.
+  std::vector<size_t> curLv;
+};
+
+//===----------------------------------------------------------------------===//
+// ExecutionEngine/SparseTensorUtils helper functions.
+//===----------------------------------------------------------------------===//
+//
 /// Converts an overhead storage bitwidth to its internal type-encoding.
 OverheadType overheadTypeEncoding(unsigned width);
 

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 18dd53947bab2..f0c7f599e2853 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -15,6 +15,7 @@
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
@@ -279,7 +280,8 @@ struct ReshapeRewriter : public OpRewritePattern<ReshapeOp> {
       auto convert = rewriter.create<ConvertOp>(loc, denseTp, op.getSrc());
       op->setOperand(0, convert);
       return success();
-    } else if (encDst) {
+    }
+    if (encDst) {
       RankedTensorType rtp =
           op.getResult().getType().template cast<RankedTensorType>();
       auto denseTp =
@@ -294,6 +296,60 @@ struct ReshapeRewriter : public OpRewritePattern<ReshapeOp> {
   }
 };
 
+/// Sparse rewriting rule for the foreach operator.
+struct ForeachRewriter : public OpRewritePattern<ForeachOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ForeachOp op,
+                                PatternRewriter &rewriter) const override {
+
+    auto loc = op.getLoc();
+    Value input = op.getTensor();
+    auto rtp = input.getType().cast<RankedTensorType>();
+    int64_t rank = rtp.getRank();
+    auto enc = getSparseTensorEncoding(rtp);
+
+    // 1. Generates loop for the sparse input.
+    SparseTensorLoopEmitter loopEmitter(ValueRange{input});
+    loopEmitter.initializeLoopEmit(rewriter, loc);
+    for (int64_t i = 0; i < rank; i++)
+      loopEmitter.enterLoopOverTensorAtDim(rewriter, loc, 0, i);
+
+    Value vals = loopEmitter.getTensorValueBuffer(0);
+    Value idx = loopEmitter.getLastLevelTensorPointerIndex(0);
+    Value val = rewriter.create<memref::LoadOp>(op.getLoc(), vals, idx);
+
+    SmallVector<Value, 4> coords;
+    coords.reserve(rank);
+    loopEmitter.getCoordinateArray(coords);
+
+    for (int64_t i = 0; i < rank; i++)
+      loopEmitter.exitCurrentLoop();
+
+    // 2. Inline the block in the foreach operator.
+    Block::iterator inlinePos = rewriter.getInsertionPoint();
+    Block *srcBlock = op.getBody();
+    // Remove sparse_tensor.yield.
+    rewriter.eraseOp(srcBlock->getTerminator());
+
+    SmallVector<Value, 4> args;
+    // Remap coordinates.
+    for (int64_t i = 0; i < rank; i++) {
+      Value actual = coords[toOrigDim(enc, i)];
+      args.push_back(actual);
+    }
+    // Remap value.
+    args.push_back(val);
+
+    // Inline body.
+    rewriter.mergeBlockBefore(srcBlock, &*inlinePos, args);
+    // delete the foreach operator.
+    rewriter.eraseOp(op);
+    return success();
+  }
+};
+
 } // namespace
 
 //===---------------------------------------------------------------------===//
@@ -301,9 +357,10 @@ struct ReshapeRewriter : public OpRewritePattern<ReshapeOp> {
 //===---------------------------------------------------------------------===//
 
 void mlir::populateSparseTensorRewriting(RewritePatternSet &patterns,
-                                         bool /*enableRT*/) {
+                                         bool enableRT) {
   patterns.add<FoldInvariantYield, FuseSparseMultiplyOverAdd,
                ReshapeRewriter<tensor::ExpandShapeOp>,
-               ReshapeRewriter<tensor::CollapseShapeOp>>(patterns.getContext());
+               ReshapeRewriter<tensor::CollapseShapeOp>, ForeachRewriter>(
+      patterns.getContext());
   // TODO: If RT not enabled, rewrite concatenate ops, etc here.
 }

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_codegen_foreach.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_codegen_foreach.mlir
new file mode 100644
index 0000000000000..1a56565787ef5
--- /dev/null
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_codegen_foreach.mlir
@@ -0,0 +1,116 @@
+// RUN: mlir-opt %s --sparse-compiler | \
+// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
+// RUN:  -shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext | \
+// RUN: FileCheck %s
+
+#Row = #sparse_tensor.encoding<{
+  dimLevelType = [ "compressed", "dense" ]
+}>
+
+#CSR = #sparse_tensor.encoding<{
+  dimLevelType = [ "dense", "compressed" ]
+}>
+
+#DCSC = #sparse_tensor.encoding<{
+  dimLevelType = [ "compressed", "compressed" ],
+  dimOrdering = affine_map<(i,j) -> (j,i)>
+}>
+
+module {
+
+  /// uses foreach operator to print coords and values.
+  func.func @foreach_print_1(%arg0: tensor<2x2xf64, #Row>) {
+    sparse_tensor.foreach in %arg0 : tensor<2x2xf64, #Row> do {
+      ^bb0(%1: index, %2: index, %v: f64) :
+        vector.print %1: index
+        vector.print %2: index
+        vector.print %v: f64
+     }
+     return
+  }
+
+  func.func @foreach_print_2(%arg0: tensor<2x2xf64, #CSR>) {
+    sparse_tensor.foreach in %arg0 : tensor<2x2xf64, #CSR> do {
+      ^bb0(%1: index, %2: index, %v: f64) :
+        vector.print %1: index
+        vector.print %2: index
+        vector.print %v: f64
+     }
+     return
+  }
+
+  func.func @foreach_print_3(%arg0: tensor<2x2xf64, #DCSC>) {
+    sparse_tensor.foreach in %arg0 : tensor<2x2xf64, #DCSC> do {
+      ^bb0(%1: index, %2: index, %v: f64) :
+        vector.print %1: index
+        vector.print %2: index
+        vector.print %v: f64
+     }
+     return
+  }
+
+  //
+  // Main driver.
+  //
+  func.func @entry() {
+    //
+    // Initialize a 3-dim dense tensor.
+    //
+    %src = arith.constant dense<
+       [[  1.0,  2.0],
+        [  5.0,  6.0]]
+    > : tensor<2x2xf64>
+
+    //
+    // Convert dense tensor directly to various sparse tensors.
+    //
+    %s1 = sparse_tensor.convert %src : tensor<2x2xf64> to tensor<2x2xf64, #Row>
+    %s2 = sparse_tensor.convert %src : tensor<2x2xf64> to tensor<2x2xf64, #CSR>
+    %s3 = sparse_tensor.convert %src : tensor<2x2xf64> to tensor<2x2xf64, #DCSC>
+    // CHECK: 0
+    // CHECK-NEXT: 0
+    // CHECK-NEXT: 1
+    // CHECK-NEXT: 0
+    // CHECK-NEXT: 1
+    // CHECK-NEXT: 2
+    // CHECK-NEXT: 1
+    // CHECK-NEXT: 0
+    // CHECK-NEXT: 5
+    // CHECK-NEXT: 1
+    // CHECK-NEXT: 1
+    // CHECK-NEXT: 6
+    call @foreach_print_1(%s1) : (tensor<2x2xf64, #Row>) -> ()
+    // CHECK-NEXT: 0
+    // CHECK-NEXT: 0
+    // CHECK-NEXT: 1
+    // CHECK-NEXT: 0
+    // CHECK-NEXT: 1
+    // CHECK-NEXT: 2
+    // CHECK-NEXT: 1
+    // CHECK-NEXT: 0
+    // CHECK-NEXT: 5
+    // CHECK-NEXT: 1
+    // CHECK-NEXT: 1
+    // CHECK-NEXT: 6
+    call @foreach_print_2(%s2) : (tensor<2x2xf64, #CSR>) -> ()
+    // CHECK-NEXT: 0
+    // CHECK-NEXT: 0
+    // CHECK-NEXT: 1
+    // CHECK-NEXT: 1
+    // CHECK-NEXT: 0
+    // CHECK-NEXT: 5
+    // CHECK-NEXT: 0
+    // CHECK-NEXT: 1
+    // CHECK-NEXT: 2
+    // CHECK-NEXT: 1
+    // CHECK-NEXT: 1
+    // CHECK-NEXT: 6
+    call @foreach_print_3(%s3) : (tensor<2x2xf64, #DCSC>) -> ()
+    
+    bufferization.dealloc_tensor %s1 : tensor<2x2xf64, #Row>
+    bufferization.dealloc_tensor %s2 : tensor<2x2xf64, #CSR>
+    bufferization.dealloc_tensor %s3 : tensor<2x2xf64, #DCSC>
+
+    return
+  }
+}


        


More information about the Mlir-commits mailing list