[Mlir-commits] [mlir] e2e83f4 - [mlir][sparse] support coiteration over sparse tensor slices

Peiming Liu llvmlistbot at llvm.org
Wed Feb 15 15:52:27 PST 2023


Author: Peiming Liu
Date: 2023-02-15T23:52:22Z
New Revision: e2e83f4c8f1dc36d5025841996be821f04a19953

URL: https://github.com/llvm/llvm-project/commit/e2e83f4c8f1dc36d5025841996be821f04a19953
DIFF: https://github.com/llvm/llvm-project/commit/e2e83f4c8f1dc36d5025841996be821f04a19953.diff

LOG: [mlir][sparse] support coiteration over sparse tensor slices

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D140736

Added: 
    mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_matmul_slice.mlir

Modified: 
    mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
    mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
    mlir/test/Dialect/SparseTensor/sparse_1d.mlir
    mlir/test/Dialect/SparseTensor/sparse_2d.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
index 7f94b3aae8fb6..72d341925cdad 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
@@ -87,6 +87,30 @@ static std::pair<Value, Value> fromSliceCoord(OpBuilder &builder, Location loc,
   return std::make_pair(v, rem);
 }
 
+static std::pair<Value, Value>
+genSliceLegitPredicate(OpBuilder &builder, Location loc, Value coord,
+                       SparseTensorEncodingAttr enc, unsigned lvl) {
+  std::pair<Value, Value> trans = fromSliceCoord(builder, loc, coord, enc, lvl);
+  // First, coord >= offset (TODO: seems unsigned >= 0 won't be folded, skip
+  // the check if the offset is zero).
+  auto geOffset =
+      builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::uge, coord,
+                                    getSliceOffset(builder, loc, enc, lvl));
+  // Second, coord_in_slice < length
+  auto ltLength =
+      builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, trans.first,
+                                    getSliceSize(builder, loc, enc, lvl));
+
+  // Third, rem == 0; confirmed that (a % 1) will be folded to 0
+  auto fitStride =
+      builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, trans.second,
+                                    constantIndex(builder, loc, 0));
+
+  auto pred = builder.create<arith::AndIOp>(loc, geOffset, ltLength);
+  pred = builder.create<arith::AndIOp>(loc, pred, fitStride);
+  return {trans.first, pred};
+}
+
 //===----------------------------------------------------------------------===//
 // Sparse tensor loop emitter class implementations
 //===----------------------------------------------------------------------===//
@@ -353,31 +377,14 @@ Operation *LoopEmitter::enterLoopOverTensorAtDim(
   if (isSparseSlices[tid] && isSparseInput) {
     // For sparse level slices, we need to filter out invalid coordinates that
     // are not included in the slice.
-    std::pair<Value, Value> trans = fromSliceCoord(builder, loc, c, enc, dim);
     SmallVector<Type> types;
     for (Value red : reduc)
       types.push_back(red.getType());
 
-    // First, coord >= offset (TODO: seems unsigned >= 0 won't be folded, skip
-    // the check if the offset is zero).
-    auto geOff =
-        builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::uge, c,
-                                      getSliceOffset(builder, loc, enc, dim));
-    // Second, coords < length
-    auto ltLen = builder.create<arith::CmpIOp>(
-        loc, arith::CmpIPredicate::ult, trans.first,
-        getSliceSize(builder, loc, enc, dim));
-
-    // Third, rem == 0; confirmed that (a % 1) will be folded to 0
-    auto fitStride = builder.create<arith::CmpIOp>(
-        loc, arith::CmpIPredicate::eq, trans.second,
-        constantIndex(builder, loc, 0));
-
-    auto pred = builder.create<arith::AndIOp>(loc, geOff, ltLen);
-    pred = builder.create<arith::AndIOp>(loc, pred, fitStride);
+    auto [trans, pred] = genSliceLegitPredicate(builder, loc, c, enc, dim);
     bool hasReduc = !types.empty();
-    scf::IfOp ifOp =
-        builder.create<scf::IfOp>(loc, types, pred, /*else*/ hasReduc);
+    scf::IfOp ifOp = builder.create<scf::IfOp>(loc, types, pred,
+                                               /*else*/ hasReduc);
     if (hasReduc) {
       // scf.for (a) -> v
       //  %s = scf.if (a) -> v
@@ -392,7 +399,7 @@ Operation *LoopEmitter::enterLoopOverTensorAtDim(
     }
     // Set the insertion point to matched branch.
     builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
-    c = trans.first;
+    c = trans;
   }
 
   assert(c);
@@ -400,7 +407,7 @@ Operation *LoopEmitter::enterLoopOverTensorAtDim(
   // NOTE: we can also prepare for next dim here in advance
   // Push the loop into stack
   loopStack.emplace_back(ArrayRef<size_t>(tid), ArrayRef<size_t>(dim), loop,
-                         coord[tid][dim], loopTag);
+                         builder.getInsertionBlock(), coord[tid][dim], loopTag);
   // Emit extra locals.
   emitExtraLocalsForTensorsAtDenseDims(builder, loc, tids, dims);
 
@@ -470,7 +477,7 @@ Operation *LoopEmitter::enterFilterLoopOverTensorAtDim(
   // NOTE: we can also prepare for next dim here in advance
   // Push the loop into stack
   loopStack.emplace_back(ArrayRef<size_t>(tid), ArrayRef<size_t>(dim), forOp,
-                         coord[tid][dim], nullptr);
+                         builder.getInsertionBlock(), coord[tid][dim], nullptr);
   return forOp;
 }
 
@@ -536,7 +543,9 @@ Operation *LoopEmitter::enterCoIterationOverTensorsAtDims(
 
   // Generates while body.
   builder.setInsertionPointToStart(&whileOp.getAfter().front());
-  Value min;
+
+  SmallVector<std::pair<Value, unsigned>> slicesPreds;
+  unsigned i = 0;
   for (auto [tid, dim] : llvm::zip(tids, dims)) {
     // Prepares for next level.
     if (isCompressedDLT(dimTypes[tid][dim]) ||
@@ -545,26 +554,73 @@ Operation *LoopEmitter::enterCoIterationOverTensorsAtDims(
       Value s = pidxs[tid][dim];
       Value load = genIndexLoad(builder, loc, ptr, s);
       coord[tid][dim] = load;
-      if (!needsUniv) {
+      if (isSparseSlices[tid]) {
+        auto enc = getSparseTensorEncoding(tensors[tid].getType());
+        auto [trans, pred] =
+            genSliceLegitPredicate(builder, loc, load, enc, dim);
+        slicesPreds.emplace_back(pred, i);
+        // Updates to the relative coordinate to the slice.
+        coord[tid][dim] = trans;
+      }
+      i++;
+    }
+  }
+
+  if (!slicesPreds.empty()) {
+    // Skips invalid loop iteration when slice coordinate is inapplicable.
+    SmallVector<Value> yields(after->getArguments());
+    // Generates a list of if statments
+    //  pidx = in_slice ? pidx : pidx + 1
+    // TODO: instead of always picking pidx + 1, we should set pidx = high to
+    // break to loop the coordinates is larger than the slice size.
+    for (auto [pred, idx] : slicesPreds) {
+      Value nextPidx = builder.create<arith::AddIOp>(
+          loc, yields[idx], constantIndex(builder, loc, 1));
+      yields[idx] =
+          builder.create<arith::SelectOp>(loc, pred, yields[idx], nextPidx);
+    }
+
+    Value pred = slicesPreds.front().first;
+    for (int i = 1, e = slicesPreds.size(); i < e; i++) {
+      pred = builder.create<arith::AndIOp>(loc, pred, slicesPreds[i].first);
+    }
+    auto ifOp = builder.create<scf::IfOp>(loc, types, pred, /*else*/ true);
+    ifOp->setAttr(getLoopEmitterLoopAttrName(),
+                  StringAttr::get(builder.getContext(), "slice"));
+    builder.create<scf::YieldOp>(loc, ifOp->getResults());
+    assert(types.size() == yields.size());
+    // If not all slices are legit
+    builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
+    builder.create<scf::YieldOp>(loc, yields);
+
+    // If all slices are legit, start the user generated code.
+    builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
+  }
+
+  Value min;
+  // Finds the minimum coordinate
+  if (!needsUniv) {
+    for (auto [tid, dim] : llvm::zip(tids, dims)) {
+      if (isCompressedDLT(dimTypes[tid][dim]) ||
+          isSingletonDLT(dimTypes[tid][dim])) {
         if (min) {
           Value cmp = builder.create<arith::CmpIOp>(
-              loc, arith::CmpIPredicate::ult, load, min);
-          min = builder.create<arith::SelectOp>(loc, cmp, load, min);
+              loc, arith::CmpIPredicate::ult, coord[tid][dim], min);
+          min = builder.create<arith::SelectOp>(loc, cmp, coord[tid][dim], min);
         } else {
-          min = load;
+          min = coord[tid][dim];
         }
       }
     }
-  }
-
-  if (needsUniv) {
+  } else {
     assert(!min);
     // Otherwise, universal index is the minimal pidx.
     min = after->getArguments().back();
   }
 
   // Sets up the loop stack.
-  loopStack.emplace_back(tids, dims, whileOp, min, loopTag);
+  loopStack.emplace_back(tids, dims, whileOp, builder.getInsertionBlock(), min,
+                         loopTag);
   assert(loopStack.size() == loopSeqStack.size());
 
   // Emits extra locals
@@ -642,6 +698,7 @@ void LoopEmitter::emitExtraLocalsForTensorsAtDenseDims(OpBuilder &builder,
 void LoopEmitter::exitForLoop(RewriterBase &rewriter, Location loc,
                               MutableArrayRef<Value> reduc) {
   LoopLevelInfo &loopInfo = loopStack.back();
+  rewriter.setInsertionPointToEnd(loopInfo.userCodeBlock);
   auto &dims = loopStack.back().dims;
   auto &tids = loopStack.back().tids;
   auto forOp = llvm::dyn_cast<scf::ForOp>(loopInfo.loop);
@@ -722,12 +779,12 @@ void LoopEmitter::exitForLoop(RewriterBase &rewriter, Location loc,
 
 void LoopEmitter::exitCoIterationLoop(OpBuilder &builder, Location loc,
                                       MutableArrayRef<Value> reduc) {
-  auto whileOp = llvm::cast<scf::WhileOp>(loopStack.back().loop);
-  auto &dims = loopStack.back().dims;
-  auto &tids = loopStack.back().tids;
-  Value iv = loopStack.back().iv;
-  // Generation while loop induction at the end.
-  builder.setInsertionPointToEnd(&whileOp.getAfter().front());
+  const LoopLevelInfo &loopInfo = loopStack.back();
+  auto whileOp = llvm::cast<scf::WhileOp>(loopInfo.loop);
+  builder.setInsertionPointToEnd(loopInfo.userCodeBlock);
+  auto &dims = loopInfo.dims;
+  auto &tids = loopInfo.tids;
+  Value iv = loopInfo.iv;
   // Finalize the induction. Note that the induction could be performed
   // in the individual if-branches to avoid re-evaluating the conditions.
   // However, that would result in a rather elaborate forest of yield

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
index 832281a86359e..419d7039d1e6a 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
@@ -170,8 +170,8 @@ class LoopEmitter {
 private:
   struct LoopLevelInfo {
     LoopLevelInfo(ArrayRef<size_t> tids, ArrayRef<size_t> dims, Operation *loop,
-                  Value iv, StringAttr loopTag)
-        : tids(tids), dims(dims), loop(loop), iv(iv) {
+                  Block *userBlock, Value iv, StringAttr loopTag)
+        : tids(tids), dims(dims), loop(loop), userCodeBlock(userBlock), iv(iv) {
       // Attached a special tag to loop emitter generated loop.
       if (loopTag)
         loop->setAttr(LoopEmitter::getLoopEmitterLoopAttrName(), loopTag);
@@ -181,8 +181,9 @@ class LoopEmitter {
     const llvm::SmallVector<size_t> tids;
     // The corresponding dims for the tensors
     const llvm::SmallVector<size_t> dims;
-    const Operation *loop; // the loop operation
-    const Value iv;        // the induction variable for the loop
+    const Operation *loop;      // the loop operation
+    Block *const userCodeBlock; // the block holding users' generated code.
+    const Value iv;             // the induction variable for the loop
   };
 
   /// Linearizes address for dense dimension (i.e., p = (i * d0) + j).

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 7f88c6a4d9844..094806bddeaa6 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -1066,6 +1066,11 @@ static void finalizeWhileOp(CodegenEnv &env, OpBuilder &builder, unsigned idx,
   if (env.isReduc() || env.isExpand() || env.getInsertionChain()) {
     while (auto ifOp = dyn_cast_or_null<scf::IfOp>(
                builder.getInsertionBlock()->getParentOp())) {
+      // Break on IfOp for slicing filtering.
+      if (ifOp->getAttr(LoopEmitter::getLoopEmitterLoopAttrName()) ==
+          StringAttr::get(ifOp->getContext(), "slice"))
+        break;
+
       unsigned y = 0;
       SmallVector<Value> yields;
       if (env.isReduc()) {

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_1d.mlir b/mlir/test/Dialect/SparseTensor/sparse_1d.mlir
index 35f8b09fa784b..0d5d554f9f326 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_1d.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_1d.mlir
@@ -1300,11 +1300,11 @@ func.func @four_tensors_op(%arga: tensor<?xf64>,
 // CHECK:             scf.condition(%[[VAL_33]]) %[[VAL_25]], %[[VAL_26]], %[[VAL_27]], %[[VAL_28]] : index, index, index, f64
 // CHECK:           } do {
 // CHECK:           ^bb0(%[[VAL_34:.*]]: index, %[[VAL_35:.*]]: index, %[[VAL_36:.*]]: index, %[[VAL_37:.*]]: f64):
-// CHECK:             %[[VAL_38:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_34]]] : memref<?xindex>
-// CHECK:             %[[VAL_39:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_35]]] : memref<?xindex>
+// CHECK-DAG:         %[[VAL_38:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_34]]] : memref<?xindex>
+// CHECK-DAG:         %[[VAL_39:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_35]]] : memref<?xindex>
+// CHECK-DAG:         %[[VAL_42:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_36]]] : memref<?xindex>
 // CHECK:             %[[VAL_40:.*]] = arith.cmpi ult, %[[VAL_39]], %[[VAL_38]] : index
 // CHECK:             %[[VAL_41:.*]] = arith.select %[[VAL_40]], %[[VAL_39]], %[[VAL_38]] : index
-// CHECK:             %[[VAL_42:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_36]]] : memref<?xindex>
 // CHECK:             %[[VAL_43:.*]] = arith.cmpi ult, %[[VAL_42]], %[[VAL_41]] : index
 // CHECK:             %[[VAL_44:.*]] = arith.select %[[VAL_43]], %[[VAL_42]], %[[VAL_41]] : index
 // CHECK:             %[[VAL_45:.*]] = arith.cmpi eq, %[[VAL_38]], %[[VAL_44]] : index

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_2d.mlir b/mlir/test/Dialect/SparseTensor/sparse_2d.mlir
index cc5ffa1b9f159..cd4b0cce8fafb 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_2d.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_2d.mlir
@@ -1128,11 +1128,11 @@ func.func @sampled_dense_dense(%args: tensor<?x?xf32, #Tss>,
 // CHECK:                 scf.condition(%[[VAL_56]]) %[[VAL_48]], %[[VAL_49]], %[[VAL_50]], %[[VAL_51]] : index, index, index, f32
 // CHECK:               } do {
 // CHECK:               ^bb0(%[[VAL_57:.*]]: index, %[[VAL_58:.*]]: index, %[[VAL_59:.*]]: index, %[[VAL_60:.*]]: f32):
-// CHECK:                 %[[VAL_61:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_57]]] : memref<?xindex>
-// CHECK:                 %[[VAL_62:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_58]]] : memref<?xindex>
+// CHECK-DAG:             %[[VAL_61:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_57]]] : memref<?xindex>
+// CHECK-DAG:             %[[VAL_62:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_58]]] : memref<?xindex>
+// CHECK-DAG:             %[[VAL_65:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_59]]] : memref<?xindex>
 // CHECK:                 %[[VAL_63:.*]] = arith.cmpi ult, %[[VAL_62]], %[[VAL_61]] : index
 // CHECK:                 %[[VAL_64:.*]] = arith.select %[[VAL_63]], %[[VAL_62]], %[[VAL_61]] : index
-// CHECK:                 %[[VAL_65:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_59]]] : memref<?xindex>
 // CHECK:                 %[[VAL_66:.*]] = arith.cmpi ult, %[[VAL_65]], %[[VAL_64]] : index
 // CHECK:                 %[[VAL_67:.*]] = arith.select %[[VAL_66]], %[[VAL_65]], %[[VAL_64]] : index
 // CHECK:                 %[[VAL_68:.*]] = arith.cmpi eq, %[[VAL_61]], %[[VAL_67]] : index

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_matmul_slice.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_matmul_slice.mlir
new file mode 100644
index 0000000000000..8f77b3d6a16bb
--- /dev/null
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_matmul_slice.mlir
@@ -0,0 +1,193 @@
+// DEFINE: %{option} = enable-runtime-library=false
+// DEFINE: %{command} = mlir-opt %s --sparse-compiler=%{option} | \
+// DEFINE: mlir-cpu-runner \
+// DEFINE:  -e entry -entry-point-result=void  \
+// DEFINE:  -shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext,%mlir_lib_dir/libmlir_runner_utils%shlibext | \
+// DEFINE: FileCheck %s
+//
+// RUN: %{command}
+//
+
+// TODO: support lib path.
+
+#DCSR = #sparse_tensor.encoding<{
+  dimLevelType = [ "compressed", "compressed" ]
+}>
+
+#DCSR_SLICE = #sparse_tensor.encoding<{
+  dimLevelType = [ "compressed", "compressed" ],
+  slice = [ (0, 4, 1), (0, 8, 1) ]
+}>
+
+#CSR = #sparse_tensor.encoding<{
+  dimLevelType = [ "dense", "compressed" ]
+}>
+
+#CSR_SLICE = #sparse_tensor.encoding<{
+  dimLevelType = [ "dense", "compressed" ],
+  slice = [ (0, 4, 1), (0, 8, 1) ]
+}>
+
+#CSR_SLICE_1 = #sparse_tensor.encoding<{
+  dimLevelType = [ "dense", "compressed" ],
+  slice = [ (0, 4, 2), (0, 4, 1) ]
+}>
+
+#DCSR_SLICE_1 = #sparse_tensor.encoding<{
+  dimLevelType = [ "compressed", "compressed" ],
+  slice = [ (0, 4, 2), (1, 4, 1) ]
+}>
+
+module {
+  func.func private @printMemrefF64(%ptr : tensor<*xf64>)
+  func.func private @printMemref1dF64(%ptr : memref<?xf64>) attributes { llvm.emit_c_interface }
+
+
+  //
+  // Computes C = A x B with one matrix CSR sparse slices and the other DSCR sparse slice.
+  //
+  func.func @matmul1(%A: tensor<4x4xf64, #CSR_SLICE_1>,
+                     %B: tensor<4x4xf64, #DCSR_SLICE_1>) -> tensor<4x4xf64, #CSR> {
+    %C = bufferization.alloc_tensor() : tensor<4x4xf64, #CSR>
+    %D = linalg.matmul
+      ins(%A, %B: tensor<4x4xf64, #CSR_SLICE_1>, tensor<4x4xf64, #DCSR_SLICE_1>)
+         outs(%C: tensor<4x4xf64, #CSR>) -> tensor<4x4xf64, #CSR>
+    return %D: tensor<4x4xf64, #CSR>
+  }
+
+  //
+  // Computes C = A x B with one matrix CSR sparse slice and the other CSR sparse tensor.
+  //
+  func.func @matmul2(%A: tensor<4x8xf64, #CSR_SLICE>,
+                     %B: tensor<8x4xf64, #CSR>) -> tensor<4x4xf64, #CSR> {
+    %C = bufferization.alloc_tensor() : tensor<4x4xf64, #CSR>
+    %D = linalg.matmul
+      ins(%A, %B: tensor<4x8xf64, #CSR_SLICE>, tensor<8x4xf64, #CSR>)
+         outs(%C: tensor<4x4xf64, #CSR>) -> tensor<4x4xf64, #CSR>
+    return %D: tensor<4x4xf64, #CSR>
+  }
+
+  //
+  // Computes C = A x B with one matrix DCSR sparse slice and the other DCSR sparse tensor.
+  //
+  func.func @matmul3(%A: tensor<4x8xf64, #DCSR_SLICE>,
+                     %B: tensor<8x4xf64, #DCSR>) -> tensor<4x4xf64, #DCSR> {
+    %C = bufferization.alloc_tensor() : tensor<4x4xf64, #DCSR>
+    %D = linalg.matmul
+      ins(%A, %B: tensor<4x8xf64, #DCSR_SLICE>, tensor<8x4xf64, #DCSR>)
+         outs(%C: tensor<4x4xf64, #DCSR>) -> tensor<4x4xf64, #DCSR>
+    return %D: tensor<4x4xf64, #DCSR>
+  }
+
+  //
+  // Main driver.
+  //
+  func.func @entry() {
+    %c0 = arith.constant 0 : index
+    %f0 = arith.constant 0.0 : f64
+
+    %sa = arith.constant dense<[
+        [ 0.0, 2.1, 0.0, 0.0, 0.0, 6.1, 0.0, 0.0 ],
+        [ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 ],
+        [ 0.0, 2.3, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 ],
+        [ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0 ],
+        [ 0.0, 2.1, 0.0, 0.0, 0.0, 6.1, 0.0, 0.0 ],
+        [ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 ],
+        [ 0.0, 2.3, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 ],
+        [ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0 ]
+    ]> : tensor<8x8xf64>
+    %sb = arith.constant dense<[
+        [ 0.0, 0.0, 0.0, 1.0 ],
+        [ 0.0, 0.0, 2.0, 0.0 ],
+        [ 0.0, 3.0, 0.0, 0.0 ],
+        [ 4.0, 0.0, 0.0, 0.0 ],
+        [ 0.0, 0.0, 0.0, 0.0 ],
+        [ 0.0, 5.0, 0.0, 0.0 ],
+        [ 0.0, 0.0, 6.0, 0.0 ],
+        [ 0.0, 0.0, 7.0, 8.0 ]
+    ]> : tensor<8x4xf64>
+    %zero = arith.constant dense<0.0> : tensor<4x4xf64>
+
+    // Convert all these matrices to sparse format.
+    %tmp = sparse_tensor.convert %sa : tensor<8x8xf64> to tensor<8x8xf64, #DCSR>
+    %a = tensor.extract_slice %tmp[0, 0][4, 8][1, 1] : tensor<8x8xf64, #DCSR> to tensor<4x8xf64, #DCSR_SLICE>
+    %b = sparse_tensor.convert %sb : tensor<8x4xf64> to tensor<8x4xf64, #DCSR>
+
+    %2 = call @matmul3(%a, %b)
+       : (tensor<4x8xf64, #DCSR_SLICE>,
+          tensor<8x4xf64, #DCSR>) -> tensor<4x4xf64, #DCSR>
+
+    // DCSR test
+    //
+    // CHECK:       [0,   30.5,   4.2,   0],
+    // CHECK-NEXT:  [0,   0,   0,   0],
+    // CHECK-NEXT:  [0,   0,   4.6,   0],
+    // CHECK-NEXT:  [0,   0,   7,   8]
+    //
+    %c2 = sparse_tensor.convert %2 : tensor<4x4xf64, #DCSR> to tensor<4x4xf64>
+    %c2u = tensor.cast %c2 : tensor<4x4xf64> to tensor<*xf64>
+    call @printMemrefF64(%c2u) : (tensor<*xf64>) -> ()
+
+    %t1 = sparse_tensor.convert %sa : tensor<8x8xf64> to tensor<8x8xf64, #CSR>
+    %a1 = tensor.extract_slice %t1[0, 0][4, 8][1, 1] : tensor<8x8xf64, #CSR> to tensor<4x8xf64, #CSR_SLICE>
+    %b1 = sparse_tensor.convert %sb : tensor<8x4xf64> to tensor<8x4xf64, #CSR>
+    %3 = call @matmul2(%a1, %b1)
+       : (tensor<4x8xf64, #CSR_SLICE>,
+          tensor<8x4xf64, #CSR>) -> tensor<4x4xf64, #CSR>
+
+    // CSR test
+    //
+    // CHECK:       [0,   30.5,   4.2,   0],
+    // CHECK-NEXT:  [0,   0,   0,   0],
+    // CHECK-NEXT:  [0,   0,   4.6,   0],
+    // CHECK-NEXT:  [0,   0,   7,   8]
+    //
+    %c3 = sparse_tensor.convert %3 : tensor<4x4xf64, #CSR> to tensor<4x4xf64>
+    %c3u = tensor.cast %c3 : tensor<4x4xf64> to tensor<*xf64>
+    call @printMemrefF64(%c3u) : (tensor<*xf64>) -> ()
+
+    // slice x slice
+    //
+    // CHECK:      [2.3,   0,   0,   0],
+    // CHECK-NEXT: [6.9,   0,   0,   0],
+    // CHECK-NEXT: [0,   0,   0,   0],
+    // CHECK-NEXT: [12.6,   0,   0,   0]]
+    //
+    %s1 = tensor.extract_slice %tmp[0, 1][4, 4][2, 1] : tensor<8x8xf64, #DCSR> to tensor<4x4xf64, #DCSR_SLICE_1>
+    %s2 = tensor.extract_slice %b1[0, 0][4, 4][2, 1] : tensor<8x4xf64, #CSR> to tensor<4x4xf64, #CSR_SLICE_1>
+    %4 = call @matmul1(%s2, %s1)
+       : (tensor<4x4xf64, #CSR_SLICE_1>,
+          tensor<4x4xf64, #DCSR_SLICE_1>) -> tensor<4x4xf64, #CSR>
+
+    %c4 = sparse_tensor.convert %4 : tensor<4x4xf64, #CSR> to tensor<4x4xf64>
+    %c4u = tensor.cast %c4 : tensor<4x4xf64> to tensor<*xf64>
+    call @printMemrefF64(%c4u) : (tensor<*xf64>) -> ()
+
+    // sparse slices should generate the same result as dense slices
+    //
+    // CHECK:      [2.3,   0,   0,   0],
+    // CHECK-NEXT: [6.9,   0,   0,   0],
+    // CHECK-NEXT: [0,   0,   0,   0],
+    // CHECK-NEXT: [12.6,   0,   0,   0]]
+    //
+    %ds1 = tensor.extract_slice %sa[0, 1][4, 4][2, 1] : tensor<8x8xf64> to tensor<4x4xf64>
+    %ds2 = tensor.extract_slice %sb[0, 0][4, 4][2, 1] : tensor<8x4xf64> to tensor<4x4xf64>
+
+    %d = bufferization.alloc_tensor() copy(%zero) : tensor<4x4xf64>
+    %r = linalg.matmul ins(%ds2, %ds1: tensor<4x4xf64>, tensor<4x4xf64>)
+                       outs(%d: tensor<4x4xf64>) -> tensor<4x4xf64>
+    %du = tensor.cast %r : tensor<4x4xf64> to tensor<*xf64>
+    call @printMemrefF64(%du) : (tensor<*xf64>) -> ()
+
+    // Releases resources.
+    bufferization.dealloc_tensor %b1 : tensor<8x4xf64, #CSR>
+    bufferization.dealloc_tensor %t1 : tensor<8x8xf64, #CSR>
+    bufferization.dealloc_tensor %b  : tensor<8x4xf64, #DCSR>
+    bufferization.dealloc_tensor %tmp: tensor<8x8xf64, #DCSR>
+    bufferization.dealloc_tensor %4  : tensor<4x4xf64, #CSR>
+    bufferization.dealloc_tensor %3  : tensor<4x4xf64, #CSR>
+    bufferization.dealloc_tensor %2  : tensor<4x4xf64, #DCSR>
+
+    return
+  }
+}


        


More information about the Mlir-commits mailing list