[Mlir-commits] [mlir] e2e83f4 - [mlir][sparse] support coiteration over sparse tensor slices
Peiming Liu
llvmlistbot at llvm.org
Wed Feb 15 15:52:27 PST 2023
Author: Peiming Liu
Date: 2023-02-15T23:52:22Z
New Revision: e2e83f4c8f1dc36d5025841996be821f04a19953
URL: https://github.com/llvm/llvm-project/commit/e2e83f4c8f1dc36d5025841996be821f04a19953
DIFF: https://github.com/llvm/llvm-project/commit/e2e83f4c8f1dc36d5025841996be821f04a19953.diff
LOG: [mlir][sparse] support coiteration over sparse tensor slices
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D140736
Added:
mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_matmul_slice.mlir
Modified:
mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
mlir/test/Dialect/SparseTensor/sparse_1d.mlir
mlir/test/Dialect/SparseTensor/sparse_2d.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
index 7f94b3aae8fb6..72d341925cdad 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
@@ -87,6 +87,30 @@ static std::pair<Value, Value> fromSliceCoord(OpBuilder &builder, Location loc,
return std::make_pair(v, rem);
}
+static std::pair<Value, Value>
+genSliceLegitPredicate(OpBuilder &builder, Location loc, Value coord,
+ SparseTensorEncodingAttr enc, unsigned lvl) {
+ std::pair<Value, Value> trans = fromSliceCoord(builder, loc, coord, enc, lvl);
+ // First, coord >= offset (TODO: seems unsigned >= 0 won't be folded, skip
+ // the check if the offset is zero).
+ auto geOffset =
+ builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::uge, coord,
+ getSliceOffset(builder, loc, enc, lvl));
+ // Second, coord_in_slice < length
+ auto ltLength =
+ builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, trans.first,
+ getSliceSize(builder, loc, enc, lvl));
+
+ // Third, rem == 0; confirmed that (a % 1) will be folded to 0
+ auto fitStride =
+ builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, trans.second,
+ constantIndex(builder, loc, 0));
+
+ auto pred = builder.create<arith::AndIOp>(loc, geOffset, ltLength);
+ pred = builder.create<arith::AndIOp>(loc, pred, fitStride);
+ return {trans.first, pred};
+}
+
//===----------------------------------------------------------------------===//
// Sparse tensor loop emitter class implementations
//===----------------------------------------------------------------------===//
@@ -353,31 +377,14 @@ Operation *LoopEmitter::enterLoopOverTensorAtDim(
if (isSparseSlices[tid] && isSparseInput) {
// For sparse level slices, we need to filter out invalid coordinates that
// are not included in the slice.
- std::pair<Value, Value> trans = fromSliceCoord(builder, loc, c, enc, dim);
SmallVector<Type> types;
for (Value red : reduc)
types.push_back(red.getType());
- // First, coord >= offset (TODO: seems unsigned >= 0 won't be folded, skip
- // the check if the offset is zero).
- auto geOff =
- builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::uge, c,
- getSliceOffset(builder, loc, enc, dim));
- // Second, coords < length
- auto ltLen = builder.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::ult, trans.first,
- getSliceSize(builder, loc, enc, dim));
-
- // Third, rem == 0; confirmed that (a % 1) will be folded to 0
- auto fitStride = builder.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::eq, trans.second,
- constantIndex(builder, loc, 0));
-
- auto pred = builder.create<arith::AndIOp>(loc, geOff, ltLen);
- pred = builder.create<arith::AndIOp>(loc, pred, fitStride);
+ auto [trans, pred] = genSliceLegitPredicate(builder, loc, c, enc, dim);
bool hasReduc = !types.empty();
- scf::IfOp ifOp =
- builder.create<scf::IfOp>(loc, types, pred, /*else*/ hasReduc);
+ scf::IfOp ifOp = builder.create<scf::IfOp>(loc, types, pred,
+ /*else*/ hasReduc);
if (hasReduc) {
// scf.for (a) -> v
// %s = scf.if (a) -> v
@@ -392,7 +399,7 @@ Operation *LoopEmitter::enterLoopOverTensorAtDim(
}
// Set the insertion point to matched branch.
builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
- c = trans.first;
+ c = trans;
}
assert(c);
@@ -400,7 +407,7 @@ Operation *LoopEmitter::enterLoopOverTensorAtDim(
// NOTE: we can also prepare for next dim here in advance
// Push the loop into stack
loopStack.emplace_back(ArrayRef<size_t>(tid), ArrayRef<size_t>(dim), loop,
- coord[tid][dim], loopTag);
+ builder.getInsertionBlock(), coord[tid][dim], loopTag);
// Emit extra locals.
emitExtraLocalsForTensorsAtDenseDims(builder, loc, tids, dims);
@@ -470,7 +477,7 @@ Operation *LoopEmitter::enterFilterLoopOverTensorAtDim(
// NOTE: we can also prepare for next dim here in advance
// Push the loop into stack
loopStack.emplace_back(ArrayRef<size_t>(tid), ArrayRef<size_t>(dim), forOp,
- coord[tid][dim], nullptr);
+ builder.getInsertionBlock(), coord[tid][dim], nullptr);
return forOp;
}
@@ -536,7 +543,9 @@ Operation *LoopEmitter::enterCoIterationOverTensorsAtDims(
// Generates while body.
builder.setInsertionPointToStart(&whileOp.getAfter().front());
- Value min;
+
+ SmallVector<std::pair<Value, unsigned>> slicesPreds;
+ unsigned i = 0;
for (auto [tid, dim] : llvm::zip(tids, dims)) {
// Prepares for next level.
if (isCompressedDLT(dimTypes[tid][dim]) ||
@@ -545,26 +554,73 @@ Operation *LoopEmitter::enterCoIterationOverTensorsAtDims(
Value s = pidxs[tid][dim];
Value load = genIndexLoad(builder, loc, ptr, s);
coord[tid][dim] = load;
- if (!needsUniv) {
+ if (isSparseSlices[tid]) {
+ auto enc = getSparseTensorEncoding(tensors[tid].getType());
+ auto [trans, pred] =
+ genSliceLegitPredicate(builder, loc, load, enc, dim);
+ slicesPreds.emplace_back(pred, i);
+ // Updates to the relative coordinate to the slice.
+ coord[tid][dim] = trans;
+ }
+ i++;
+ }
+ }
+
+ if (!slicesPreds.empty()) {
+ // Skips invalid loop iteration when slice coordinate is inapplicable.
+ SmallVector<Value> yields(after->getArguments());
+ // Generates a list of if statments
+ // pidx = in_slice ? pidx : pidx + 1
+ // TODO: instead of always picking pidx + 1, we should set pidx = high to
+ // break to loop the coordinates is larger than the slice size.
+ for (auto [pred, idx] : slicesPreds) {
+ Value nextPidx = builder.create<arith::AddIOp>(
+ loc, yields[idx], constantIndex(builder, loc, 1));
+ yields[idx] =
+ builder.create<arith::SelectOp>(loc, pred, yields[idx], nextPidx);
+ }
+
+ Value pred = slicesPreds.front().first;
+ for (int i = 1, e = slicesPreds.size(); i < e; i++) {
+ pred = builder.create<arith::AndIOp>(loc, pred, slicesPreds[i].first);
+ }
+ auto ifOp = builder.create<scf::IfOp>(loc, types, pred, /*else*/ true);
+ ifOp->setAttr(getLoopEmitterLoopAttrName(),
+ StringAttr::get(builder.getContext(), "slice"));
+ builder.create<scf::YieldOp>(loc, ifOp->getResults());
+ assert(types.size() == yields.size());
+ // If not all slices are legit
+ builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
+ builder.create<scf::YieldOp>(loc, yields);
+
+ // If all slices are legit, start the user generated code.
+ builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
+ }
+
+ Value min;
+ // Finds the minimum coordinate
+ if (!needsUniv) {
+ for (auto [tid, dim] : llvm::zip(tids, dims)) {
+ if (isCompressedDLT(dimTypes[tid][dim]) ||
+ isSingletonDLT(dimTypes[tid][dim])) {
if (min) {
Value cmp = builder.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::ult, load, min);
- min = builder.create<arith::SelectOp>(loc, cmp, load, min);
+ loc, arith::CmpIPredicate::ult, coord[tid][dim], min);
+ min = builder.create<arith::SelectOp>(loc, cmp, coord[tid][dim], min);
} else {
- min = load;
+ min = coord[tid][dim];
}
}
}
- }
-
- if (needsUniv) {
+ } else {
assert(!min);
// Otherwise, universal index is the minimal pidx.
min = after->getArguments().back();
}
// Sets up the loop stack.
- loopStack.emplace_back(tids, dims, whileOp, min, loopTag);
+ loopStack.emplace_back(tids, dims, whileOp, builder.getInsertionBlock(), min,
+ loopTag);
assert(loopStack.size() == loopSeqStack.size());
// Emits extra locals
@@ -642,6 +698,7 @@ void LoopEmitter::emitExtraLocalsForTensorsAtDenseDims(OpBuilder &builder,
void LoopEmitter::exitForLoop(RewriterBase &rewriter, Location loc,
MutableArrayRef<Value> reduc) {
LoopLevelInfo &loopInfo = loopStack.back();
+ rewriter.setInsertionPointToEnd(loopInfo.userCodeBlock);
auto &dims = loopStack.back().dims;
auto &tids = loopStack.back().tids;
auto forOp = llvm::dyn_cast<scf::ForOp>(loopInfo.loop);
@@ -722,12 +779,12 @@ void LoopEmitter::exitForLoop(RewriterBase &rewriter, Location loc,
void LoopEmitter::exitCoIterationLoop(OpBuilder &builder, Location loc,
MutableArrayRef<Value> reduc) {
- auto whileOp = llvm::cast<scf::WhileOp>(loopStack.back().loop);
- auto &dims = loopStack.back().dims;
- auto &tids = loopStack.back().tids;
- Value iv = loopStack.back().iv;
- // Generation while loop induction at the end.
- builder.setInsertionPointToEnd(&whileOp.getAfter().front());
+ const LoopLevelInfo &loopInfo = loopStack.back();
+ auto whileOp = llvm::cast<scf::WhileOp>(loopInfo.loop);
+ builder.setInsertionPointToEnd(loopInfo.userCodeBlock);
+ auto &dims = loopInfo.dims;
+ auto &tids = loopInfo.tids;
+ Value iv = loopInfo.iv;
// Finalize the induction. Note that the induction could be performed
// in the individual if-branches to avoid re-evaluating the conditions.
// However, that would result in a rather elaborate forest of yield
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
index 832281a86359e..419d7039d1e6a 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
@@ -170,8 +170,8 @@ class LoopEmitter {
private:
struct LoopLevelInfo {
LoopLevelInfo(ArrayRef<size_t> tids, ArrayRef<size_t> dims, Operation *loop,
- Value iv, StringAttr loopTag)
- : tids(tids), dims(dims), loop(loop), iv(iv) {
+ Block *userBlock, Value iv, StringAttr loopTag)
+ : tids(tids), dims(dims), loop(loop), userCodeBlock(userBlock), iv(iv) {
// Attached a special tag to loop emitter generated loop.
if (loopTag)
loop->setAttr(LoopEmitter::getLoopEmitterLoopAttrName(), loopTag);
@@ -181,8 +181,9 @@ class LoopEmitter {
const llvm::SmallVector<size_t> tids;
// The corresponding dims for the tensors
const llvm::SmallVector<size_t> dims;
- const Operation *loop; // the loop operation
- const Value iv; // the induction variable for the loop
+ const Operation *loop; // the loop operation
+ Block *const userCodeBlock; // the block holding users' generated code.
+ const Value iv; // the induction variable for the loop
};
/// Linearizes address for dense dimension (i.e., p = (i * d0) + j).
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 7f88c6a4d9844..094806bddeaa6 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -1066,6 +1066,11 @@ static void finalizeWhileOp(CodegenEnv &env, OpBuilder &builder, unsigned idx,
if (env.isReduc() || env.isExpand() || env.getInsertionChain()) {
while (auto ifOp = dyn_cast_or_null<scf::IfOp>(
builder.getInsertionBlock()->getParentOp())) {
+ // Break on IfOp for slicing filtering.
+ if (ifOp->getAttr(LoopEmitter::getLoopEmitterLoopAttrName()) ==
+ StringAttr::get(ifOp->getContext(), "slice"))
+ break;
+
unsigned y = 0;
SmallVector<Value> yields;
if (env.isReduc()) {
diff --git a/mlir/test/Dialect/SparseTensor/sparse_1d.mlir b/mlir/test/Dialect/SparseTensor/sparse_1d.mlir
index 35f8b09fa784b..0d5d554f9f326 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_1d.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_1d.mlir
@@ -1300,11 +1300,11 @@ func.func @four_tensors_op(%arga: tensor<?xf64>,
// CHECK: scf.condition(%[[VAL_33]]) %[[VAL_25]], %[[VAL_26]], %[[VAL_27]], %[[VAL_28]] : index, index, index, f64
// CHECK: } do {
// CHECK: ^bb0(%[[VAL_34:.*]]: index, %[[VAL_35:.*]]: index, %[[VAL_36:.*]]: index, %[[VAL_37:.*]]: f64):
-// CHECK: %[[VAL_38:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_34]]] : memref<?xindex>
-// CHECK: %[[VAL_39:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_35]]] : memref<?xindex>
+// CHECK-DAG: %[[VAL_38:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_34]]] : memref<?xindex>
+// CHECK-DAG: %[[VAL_39:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_35]]] : memref<?xindex>
+// CHECK-DAG: %[[VAL_42:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_36]]] : memref<?xindex>
// CHECK: %[[VAL_40:.*]] = arith.cmpi ult, %[[VAL_39]], %[[VAL_38]] : index
// CHECK: %[[VAL_41:.*]] = arith.select %[[VAL_40]], %[[VAL_39]], %[[VAL_38]] : index
-// CHECK: %[[VAL_42:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_36]]] : memref<?xindex>
// CHECK: %[[VAL_43:.*]] = arith.cmpi ult, %[[VAL_42]], %[[VAL_41]] : index
// CHECK: %[[VAL_44:.*]] = arith.select %[[VAL_43]], %[[VAL_42]], %[[VAL_41]] : index
// CHECK: %[[VAL_45:.*]] = arith.cmpi eq, %[[VAL_38]], %[[VAL_44]] : index
diff --git a/mlir/test/Dialect/SparseTensor/sparse_2d.mlir b/mlir/test/Dialect/SparseTensor/sparse_2d.mlir
index cc5ffa1b9f159..cd4b0cce8fafb 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_2d.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_2d.mlir
@@ -1128,11 +1128,11 @@ func.func @sampled_dense_dense(%args: tensor<?x?xf32, #Tss>,
// CHECK: scf.condition(%[[VAL_56]]) %[[VAL_48]], %[[VAL_49]], %[[VAL_50]], %[[VAL_51]] : index, index, index, f32
// CHECK: } do {
// CHECK: ^bb0(%[[VAL_57:.*]]: index, %[[VAL_58:.*]]: index, %[[VAL_59:.*]]: index, %[[VAL_60:.*]]: f32):
-// CHECK: %[[VAL_61:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_57]]] : memref<?xindex>
-// CHECK: %[[VAL_62:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_58]]] : memref<?xindex>
+// CHECK-DAG: %[[VAL_61:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_57]]] : memref<?xindex>
+// CHECK-DAG: %[[VAL_62:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_58]]] : memref<?xindex>
+// CHECK-DAG: %[[VAL_65:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_59]]] : memref<?xindex>
// CHECK: %[[VAL_63:.*]] = arith.cmpi ult, %[[VAL_62]], %[[VAL_61]] : index
// CHECK: %[[VAL_64:.*]] = arith.select %[[VAL_63]], %[[VAL_62]], %[[VAL_61]] : index
-// CHECK: %[[VAL_65:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_59]]] : memref<?xindex>
// CHECK: %[[VAL_66:.*]] = arith.cmpi ult, %[[VAL_65]], %[[VAL_64]] : index
// CHECK: %[[VAL_67:.*]] = arith.select %[[VAL_66]], %[[VAL_65]], %[[VAL_64]] : index
// CHECK: %[[VAL_68:.*]] = arith.cmpi eq, %[[VAL_61]], %[[VAL_67]] : index
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_matmul_slice.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_matmul_slice.mlir
new file mode 100644
index 0000000000000..8f77b3d6a16bb
--- /dev/null
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_matmul_slice.mlir
@@ -0,0 +1,193 @@
+// DEFINE: %{option} = enable-runtime-library=false
+// DEFINE: %{command} = mlir-opt %s --sparse-compiler=%{option} | \
+// DEFINE: mlir-cpu-runner \
+// DEFINE: -e entry -entry-point-result=void \
+// DEFINE: -shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext,%mlir_lib_dir/libmlir_runner_utils%shlibext | \
+// DEFINE: FileCheck %s
+//
+// RUN: %{command}
+//
+
+// TODO: support lib path.
+
+#DCSR = #sparse_tensor.encoding<{
+ dimLevelType = [ "compressed", "compressed" ]
+}>
+
+#DCSR_SLICE = #sparse_tensor.encoding<{
+ dimLevelType = [ "compressed", "compressed" ],
+ slice = [ (0, 4, 1), (0, 8, 1) ]
+}>
+
+#CSR = #sparse_tensor.encoding<{
+ dimLevelType = [ "dense", "compressed" ]
+}>
+
+#CSR_SLICE = #sparse_tensor.encoding<{
+ dimLevelType = [ "dense", "compressed" ],
+ slice = [ (0, 4, 1), (0, 8, 1) ]
+}>
+
+#CSR_SLICE_1 = #sparse_tensor.encoding<{
+ dimLevelType = [ "dense", "compressed" ],
+ slice = [ (0, 4, 2), (0, 4, 1) ]
+}>
+
+#DCSR_SLICE_1 = #sparse_tensor.encoding<{
+ dimLevelType = [ "compressed", "compressed" ],
+ slice = [ (0, 4, 2), (1, 4, 1) ]
+}>
+
+module {
+ func.func private @printMemrefF64(%ptr : tensor<*xf64>)
+ func.func private @printMemref1dF64(%ptr : memref<?xf64>) attributes { llvm.emit_c_interface }
+
+
+ //
+ // Computes C = A x B with one matrix CSR sparse slices and the other DSCR sparse slice.
+ //
+ func.func @matmul1(%A: tensor<4x4xf64, #CSR_SLICE_1>,
+ %B: tensor<4x4xf64, #DCSR_SLICE_1>) -> tensor<4x4xf64, #CSR> {
+ %C = bufferization.alloc_tensor() : tensor<4x4xf64, #CSR>
+ %D = linalg.matmul
+ ins(%A, %B: tensor<4x4xf64, #CSR_SLICE_1>, tensor<4x4xf64, #DCSR_SLICE_1>)
+ outs(%C: tensor<4x4xf64, #CSR>) -> tensor<4x4xf64, #CSR>
+ return %D: tensor<4x4xf64, #CSR>
+ }
+
+ //
+ // Computes C = A x B with one matrix CSR sparse slice and the other CSR sparse tensor.
+ //
+ func.func @matmul2(%A: tensor<4x8xf64, #CSR_SLICE>,
+ %B: tensor<8x4xf64, #CSR>) -> tensor<4x4xf64, #CSR> {
+ %C = bufferization.alloc_tensor() : tensor<4x4xf64, #CSR>
+ %D = linalg.matmul
+ ins(%A, %B: tensor<4x8xf64, #CSR_SLICE>, tensor<8x4xf64, #CSR>)
+ outs(%C: tensor<4x4xf64, #CSR>) -> tensor<4x4xf64, #CSR>
+ return %D: tensor<4x4xf64, #CSR>
+ }
+
+ //
+ // Computes C = A x B with one matrix DCSR sparse slice and the other DCSR sparse tensor.
+ //
+ func.func @matmul3(%A: tensor<4x8xf64, #DCSR_SLICE>,
+ %B: tensor<8x4xf64, #DCSR>) -> tensor<4x4xf64, #DCSR> {
+ %C = bufferization.alloc_tensor() : tensor<4x4xf64, #DCSR>
+ %D = linalg.matmul
+ ins(%A, %B: tensor<4x8xf64, #DCSR_SLICE>, tensor<8x4xf64, #DCSR>)
+ outs(%C: tensor<4x4xf64, #DCSR>) -> tensor<4x4xf64, #DCSR>
+ return %D: tensor<4x4xf64, #DCSR>
+ }
+
+ //
+ // Main driver.
+ //
+ func.func @entry() {
+ %c0 = arith.constant 0 : index
+ %f0 = arith.constant 0.0 : f64
+
+ %sa = arith.constant dense<[
+ [ 0.0, 2.1, 0.0, 0.0, 0.0, 6.1, 0.0, 0.0 ],
+ [ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 ],
+ [ 0.0, 2.3, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 ],
+ [ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0 ],
+ [ 0.0, 2.1, 0.0, 0.0, 0.0, 6.1, 0.0, 0.0 ],
+ [ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 ],
+ [ 0.0, 2.3, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 ],
+ [ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0 ]
+ ]> : tensor<8x8xf64>
+ %sb = arith.constant dense<[
+ [ 0.0, 0.0, 0.0, 1.0 ],
+ [ 0.0, 0.0, 2.0, 0.0 ],
+ [ 0.0, 3.0, 0.0, 0.0 ],
+ [ 4.0, 0.0, 0.0, 0.0 ],
+ [ 0.0, 0.0, 0.0, 0.0 ],
+ [ 0.0, 5.0, 0.0, 0.0 ],
+ [ 0.0, 0.0, 6.0, 0.0 ],
+ [ 0.0, 0.0, 7.0, 8.0 ]
+ ]> : tensor<8x4xf64>
+ %zero = arith.constant dense<0.0> : tensor<4x4xf64>
+
+ // Convert all these matrices to sparse format.
+ %tmp = sparse_tensor.convert %sa : tensor<8x8xf64> to tensor<8x8xf64, #DCSR>
+ %a = tensor.extract_slice %tmp[0, 0][4, 8][1, 1] : tensor<8x8xf64, #DCSR> to tensor<4x8xf64, #DCSR_SLICE>
+ %b = sparse_tensor.convert %sb : tensor<8x4xf64> to tensor<8x4xf64, #DCSR>
+
+ %2 = call @matmul3(%a, %b)
+ : (tensor<4x8xf64, #DCSR_SLICE>,
+ tensor<8x4xf64, #DCSR>) -> tensor<4x4xf64, #DCSR>
+
+ // DCSR test
+ //
+ // CHECK: [0, 30.5, 4.2, 0],
+ // CHECK-NEXT: [0, 0, 0, 0],
+ // CHECK-NEXT: [0, 0, 4.6, 0],
+ // CHECK-NEXT: [0, 0, 7, 8]
+ //
+ %c2 = sparse_tensor.convert %2 : tensor<4x4xf64, #DCSR> to tensor<4x4xf64>
+ %c2u = tensor.cast %c2 : tensor<4x4xf64> to tensor<*xf64>
+ call @printMemrefF64(%c2u) : (tensor<*xf64>) -> ()
+
+ %t1 = sparse_tensor.convert %sa : tensor<8x8xf64> to tensor<8x8xf64, #CSR>
+ %a1 = tensor.extract_slice %t1[0, 0][4, 8][1, 1] : tensor<8x8xf64, #CSR> to tensor<4x8xf64, #CSR_SLICE>
+ %b1 = sparse_tensor.convert %sb : tensor<8x4xf64> to tensor<8x4xf64, #CSR>
+ %3 = call @matmul2(%a1, %b1)
+ : (tensor<4x8xf64, #CSR_SLICE>,
+ tensor<8x4xf64, #CSR>) -> tensor<4x4xf64, #CSR>
+
+ // CSR test
+ //
+ // CHECK: [0, 30.5, 4.2, 0],
+ // CHECK-NEXT: [0, 0, 0, 0],
+ // CHECK-NEXT: [0, 0, 4.6, 0],
+ // CHECK-NEXT: [0, 0, 7, 8]
+ //
+ %c3 = sparse_tensor.convert %3 : tensor<4x4xf64, #CSR> to tensor<4x4xf64>
+ %c3u = tensor.cast %c3 : tensor<4x4xf64> to tensor<*xf64>
+ call @printMemrefF64(%c3u) : (tensor<*xf64>) -> ()
+
+ // slice x slice
+ //
+ // CHECK: [2.3, 0, 0, 0],
+ // CHECK-NEXT: [6.9, 0, 0, 0],
+ // CHECK-NEXT: [0, 0, 0, 0],
+ // CHECK-NEXT: [12.6, 0, 0, 0]]
+ //
+ %s1 = tensor.extract_slice %tmp[0, 1][4, 4][2, 1] : tensor<8x8xf64, #DCSR> to tensor<4x4xf64, #DCSR_SLICE_1>
+ %s2 = tensor.extract_slice %b1[0, 0][4, 4][2, 1] : tensor<8x4xf64, #CSR> to tensor<4x4xf64, #CSR_SLICE_1>
+ %4 = call @matmul1(%s2, %s1)
+ : (tensor<4x4xf64, #CSR_SLICE_1>,
+ tensor<4x4xf64, #DCSR_SLICE_1>) -> tensor<4x4xf64, #CSR>
+
+ %c4 = sparse_tensor.convert %4 : tensor<4x4xf64, #CSR> to tensor<4x4xf64>
+ %c4u = tensor.cast %c4 : tensor<4x4xf64> to tensor<*xf64>
+ call @printMemrefF64(%c4u) : (tensor<*xf64>) -> ()
+
+ // sparse slices should generate the same result as dense slices
+ //
+ // CHECK: [2.3, 0, 0, 0],
+ // CHECK-NEXT: [6.9, 0, 0, 0],
+ // CHECK-NEXT: [0, 0, 0, 0],
+ // CHECK-NEXT: [12.6, 0, 0, 0]]
+ //
+ %ds1 = tensor.extract_slice %sa[0, 1][4, 4][2, 1] : tensor<8x8xf64> to tensor<4x4xf64>
+ %ds2 = tensor.extract_slice %sb[0, 0][4, 4][2, 1] : tensor<8x4xf64> to tensor<4x4xf64>
+
+ %d = bufferization.alloc_tensor() copy(%zero) : tensor<4x4xf64>
+ %r = linalg.matmul ins(%ds2, %ds1: tensor<4x4xf64>, tensor<4x4xf64>)
+ outs(%d: tensor<4x4xf64>) -> tensor<4x4xf64>
+ %du = tensor.cast %r : tensor<4x4xf64> to tensor<*xf64>
+ call @printMemrefF64(%du) : (tensor<*xf64>) -> ()
+
+ // Releases resources.
+ bufferization.dealloc_tensor %b1 : tensor<8x4xf64, #CSR>
+ bufferization.dealloc_tensor %t1 : tensor<8x8xf64, #CSR>
+ bufferization.dealloc_tensor %b : tensor<8x4xf64, #DCSR>
+ bufferization.dealloc_tensor %tmp: tensor<8x8xf64, #DCSR>
+ bufferization.dealloc_tensor %4 : tensor<4x4xf64, #CSR>
+ bufferization.dealloc_tensor %3 : tensor<4x4xf64, #CSR>
+ bufferization.dealloc_tensor %2 : tensor<4x4xf64, #DCSR>
+
+ return
+ }
+}
More information about the Mlir-commits
mailing list