[Mlir-commits] [mlir] 6db397a - [mlir][sparse] support dynamic sparse tensor slices.
Peiming Liu
llvmlistbot at llvm.org
Fri Mar 10 15:12:46 PST 2023
Author: Peiming Liu
Date: 2023-03-10T23:12:41Z
New Revision: 6db397a8d49c7c0f76c5aa556f5bffb1eb1fb13b
URL: https://github.com/llvm/llvm-project/commit/6db397a8d49c7c0f76c5aa556f5bffb1eb1fb13b
DIFF: https://github.com/llvm/llvm-project/commit/6db397a8d49c7c0f76c5aa556f5bffb1eb1fb13b.diff
LOG: [mlir][sparse] support dynamic sparse tensor slices.
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D141532
Added:
mlir/test/Dialect/SparseTensor/sparse_extract_slice.mlir
Modified:
mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h
mlir/test/Dialect/SparseTensor/invalid.mlir
mlir/test/Dialect/SparseTensor/sparse_foreach.mlir
mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_foreach_slices.mlir
mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_matmul_slice.mlir
utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 64112222f912a..3bf11189d3805 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -570,7 +570,7 @@ Level mlir::sparse_tensor::toStoredDim(RankedTensorType type, Dimension d) {
/// We normalized sparse tensor encoding attribute by always using
/// ordered/unique DLT such that "compressed-nu-no" and "compressed-nu" (as well
/// as other variants) lead to the same storage specifier type, and stripping
-/// irrelevant fields that does not alter the sparse tensor memory layout.
+/// irrelevant fields that do not alter the sparse tensor memory layout.
static SparseTensorEncodingAttr
getNormalizedEncodingForSpecifier(SparseTensorEncodingAttr enc) {
SmallVector<DimLevelType> dlts;
@@ -582,13 +582,10 @@ getNormalizedEncodingForSpecifier(SparseTensorEncodingAttr enc) {
AffineMap(), // dimOrdering (irrelavant to storage speicifer)
AffineMap(), // highLvlOrdering (irrelavant to storage specifer)
// Always use `index` for memSize and lvlSize instead of reusing
- // `getPosWidth`/`getCrdWidth`.
- // It allows us to reuse the same SSA value for
diff erent bitwidth,
- // It also avoids casting between index/integer (returned by DimOp)
- 0, 0,
- // FIXME: we should keep the slice information, for now it is okay as only
- // constant can be used for slice
- ArrayRef<SparseTensorDimSliceAttr>{} /*enc.getDimSlices()*/);
+ // `getPosWidth` and `getCrdWidth`. It allows us to reuse the same SSA
+ // value for
diff erent bitwidth, it also avoids casting between index and
+ // integer (returned by DimOp)
+ 0, 0, enc.getDimSlices());
}
StorageSpecifierType
@@ -620,11 +617,10 @@ static LogicalResult verifySparsifierGetterSetter(
const auto enc = md.getType().getEncoding();
const Level lvlRank = enc.getLvlRank();
- // TODO:
- // if (mdKind == StorageSpecifierKind::DimOffset ||
- // mdKind == StorageSpecifierKind::DimStride)
- // if (!enc.isSlice())
- // return op->emitError("requested slice data on non-slice tensor");
+ if (mdKind == StorageSpecifierKind::DimOffset ||
+ mdKind == StorageSpecifierKind::DimStride)
+ if (!enc.isSlice())
+ return op->emitError("requested slice data on non-slice tensor");
if (mdKind != StorageSpecifierKind::ValMemSize) {
if (!lvl)
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
index 40836f4d77781..bdd6020d9d0ac 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
@@ -694,3 +694,23 @@ Value sparse_tensor::genValMemSize(OpBuilder &builder, Location loc,
Value tensor) {
return getDescriptorFromTensorTuple(tensor).getValMemSize(builder, loc);
}
+
+Value sparse_tensor::createOrFoldSliceOffsetOp(OpBuilder &builder, Location loc,
+ Value tensor, Dimension dim) {
+ auto enc = getSparseTensorEncoding(tensor.getType());
+ assert(enc && enc.isSlice());
+ std::optional<unsigned> offset = enc.getStaticDimSliceOffset(dim);
+ if (offset.has_value())
+ return constantIndex(builder, loc, *offset);
+ return builder.create<ToSliceOffsetOp>(loc, tensor, APInt(64, dim));
+}
+
+Value sparse_tensor::createOrFoldSliceStrideOp(OpBuilder &builder, Location loc,
+ Value tensor, Dimension dim) {
+ auto enc = getSparseTensorEncoding(tensor.getType());
+ assert(enc && enc.isSlice());
+ std::optional<unsigned> stride = enc.getStaticDimSliceStride(dim);
+ if (stride.has_value())
+ return constantIndex(builder, loc, *stride);
+ return builder.create<ToSliceStrideOp>(loc, tensor, APInt(64, dim));
+}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
index 6fa30935e781f..6d6351cce6dba 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
@@ -364,6 +364,15 @@ Value genToValues(OpBuilder &builder, Location loc, Value tensor);
/// Generates code to retrieve the values size for the sparse tensor.
Value genValMemSize(OpBuilder &builder, Location loc, Value tensor);
+/// Generates code to retrieve the slice offset for the sparse tensor slice,
+/// return a constant if the offset is statically known.
+Value createOrFoldSliceOffsetOp(OpBuilder &builder, Location loc, Value tensor,
+ Dimension dim);
+
+/// Generates code to retrieve the slice slice for the sparse tensor slice,
+/// return a constant if the offset is statically known.
+Value createOrFoldSliceStrideOp(OpBuilder &builder, Location loc, Value tensor,
+ Dimension dim);
} // namespace sparse_tensor
} // namespace mlir
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
index a8474a1a65dc2..f48520b2286b8 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
@@ -43,29 +43,25 @@ static Value genIndexLoad(OpBuilder &builder, Location loc, Value mem,
return load;
}
-// TODO: Support dynamic sized slice.
-static Value getSliceOffset(OpBuilder &builder, Location loc,
- SparseTensorEncodingAttr enc, unsigned lvl) {
- return constantIndex(builder, loc, *enc.getStaticLvlSliceOffset(lvl));
+static Value genSliceOffset(OpBuilder &builder, Location loc, Value tensor,
+ unsigned lvl) {
+ auto enc = getSparseTensorEncoding(tensor.getType());
+ // FIXME: `toOrigDim` is deprecated
+ return createOrFoldSliceOffsetOp(builder, loc, tensor, toOrigDim(enc, lvl));
}
-static Value getSliceSize(OpBuilder &builder, Location loc,
- SparseTensorEncodingAttr enc, unsigned lvl) {
- return constantIndex(builder, loc, *enc.getStaticLvlSliceSize(lvl));
-}
-
-static Value getSliceStride(OpBuilder &builder, Location loc,
- SparseTensorEncodingAttr enc, unsigned lvl) {
- return constantIndex(builder, loc, *enc.getStaticLvlSliceStride(lvl));
+static Value genSliceStride(OpBuilder &builder, Location loc, Value tensor,
+ unsigned lvl) {
+ auto enc = getSparseTensorEncoding(tensor.getType());
+ // FIXME: `toOrigDim` is deprecated
+ return createOrFoldSliceStrideOp(builder, loc, tensor, toOrigDim(enc, lvl));
}
// Converts a coordinate relative to the slice to the coordinate relative
// to the underlying tensor.
static Value toSliceCoord(OpBuilder &builder, Location loc, Value v,
- SparseTensorEncodingAttr enc, unsigned lvl) {
-
- Value stride = getSliceStride(builder, loc, enc, lvl);
- Value offset = getSliceOffset(builder, loc, enc, lvl);
+ Value offset, Value stride, Value tensor,
+ unsigned lvl) {
// iv = iv * stride + offset
v = builder.create<arith::MulIOp>(loc, v, stride);
v = builder.create<arith::AddIOp>(loc, v, offset);
@@ -75,40 +71,58 @@ static Value toSliceCoord(OpBuilder &builder, Location loc, Value v,
// Converts a coordinate relative to the underlying tensor to the coordinate
// relative to the slice, returns a extra reminder value
static std::pair<Value, Value> fromSliceCrd(OpBuilder &builder, Location loc,
- Value v,
- SparseTensorEncodingAttr enc,
+ Value iv, Value offset,
+ Value stride, Value tensor,
unsigned lvl) {
- Value stride = getSliceStride(builder, loc, enc, lvl);
- Value offset = getSliceOffset(builder, loc, enc, lvl);
// iv = (iv - offset) / stride
- v = builder.create<arith::SubIOp>(loc, v, offset);
- Value rem = builder.create<arith::RemUIOp>(loc, v, stride);
- v = builder.create<arith::DivUIOp>(loc, v, stride);
- return std::make_pair(v, rem);
+ iv = builder.create<arith::SubIOp>(loc, iv, offset);
+ Value rem = builder.create<arith::RemUIOp>(loc, iv, stride);
+ iv = builder.create<arith::DivUIOp>(loc, iv, stride);
+ return std::make_pair(iv, rem);
}
-static std::pair<Value, Value>
-genSliceLegitPredicate(OpBuilder &builder, Location loc, Value crd,
- SparseTensorEncodingAttr enc, unsigned lvl) {
- std::pair<Value, Value> trans = fromSliceCrd(builder, loc, crd, enc, lvl);
- // First, crd >= offset (TODO: seems unsigned >= 0 won't be folded, skip
- // the check if the offset is zero).
- auto geOffset =
- builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::uge, crd,
- getSliceOffset(builder, loc, enc, lvl));
+std::pair<Value, Value>
+LoopEmitter::genSliceLegitPredicate(OpBuilder &builder, Location loc, Value crd,
+ unsigned tid, unsigned lvl) {
+ assert(isSparseSlices[tid]);
+ Value slice = tensors[tid];
+ Value offset = sliceOffsets[tid][lvl];
+ Value stride = sliceStrides[tid][lvl];
+ auto enc = getSparseTensorEncoding(slice.getType());
+
+ std::pair<Value, Value> transformedCrd =
+ fromSliceCrd(builder, loc, crd, offset, stride, slice, lvl);
+
+ SmallVector<Value, 3> conds; // at most 3 conditions
+
+ // First, coord >= offset (skip the check if offset is known to be 0).
+ if (auto staticOffset = enc.getStaticLvlSliceOffset(lvl);
+ !(staticOffset.has_value() && *staticOffset == 0)) {
+ auto geOffset = builder.create<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::uge, crd, offset);
+ conds.push_back(geOffset);
+ }
+
// Second, coord_in_slice < length
- auto ltLength =
- builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, trans.first,
- getSliceSize(builder, loc, enc, lvl));
-
- // Third, rem == 0; confirmed that (a % 1) will be folded to 0
- auto fitStride =
- builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, trans.second,
- constantIndex(builder, loc, 0));
-
- auto pred = builder.create<arith::AndIOp>(loc, geOffset, ltLength);
- pred = builder.create<arith::AndIOp>(loc, pred, fitStride);
- return {trans.first, pred};
+ auto ltLength = builder.create<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::ult, transformedCrd.first, lvlSizes[tid][lvl]);
+ conds.push_back(ltLength);
+
+ // Third, rem == 0 (skip the check if stride is known to be 1).
+ if (auto staticStride = enc.getStaticLvlSliceStride(lvl);
+ !(staticStride.has_value() && *staticStride == 1)) {
+ auto fitStride = builder.create<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::eq, transformedCrd.second,
+ constantIndex(builder, loc, 0));
+ conds.push_back(fitStride);
+ }
+
+ // Must meet all condition to be a valid coordinate in slice.
+ auto pred = conds.front();
+ for (auto cond : ValueRange(conds).drop_front())
+ pred = builder.create<arith::AndIOp>(loc, pred, cond);
+
+ return {transformedCrd.first, pred};
}
//===----------------------------------------------------------------------===//
@@ -119,10 +133,9 @@ Value LoopEmitter::genAddress(OpBuilder &builder, Location loc, size_t tid,
size_t dim, Value iv) {
Value p = dim == 0 ? constantIndex(builder, loc, 0) : pidxs[tid][dim - 1];
Value mul = builder.create<arith::MulIOp>(loc, highs[tid][dim], p);
- if (isSparseSlices[tid]) {
- auto enc = getSparseTensorEncoding(tensors[tid].getType());
- iv = toSliceCoord(builder, loc, iv, enc, dim);
- }
+ if (isSparseSlices[tid])
+ iv = toSliceCoord(builder, loc, iv, sliceOffsets[tid][dim],
+ sliceStrides[tid][dim], tensors[tid], dim);
Value add = builder.create<arith::AddIOp>(loc, mul, iv);
return add;
}
@@ -204,6 +217,8 @@ void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
this->isSparseOut = isSparseOut;
this->tensors.assign(ts.begin(), ts.end());
this->isSparseSlices.assign(tensors.size(), false);
+ this->sliceOffsets.assign(tensors.size(), std::vector<Value>());
+ this->sliceStrides.assign(tensors.size(), std::vector<Value>());
this->dimTypes.assign(tensors.size(), std::vector<DimLevelType>());
this->pidxs.assign(tensors.size(), std::vector<Value>());
this->segHi.assign(tensors.size(), std::vector<Value>());
@@ -246,6 +261,8 @@ void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
dimTypes[tid].assign(rank, DimLevelType::Dense);
// Initialize using empty value.
+ sliceOffsets[tid].assign(rank, Value());
+ sliceStrides[tid].assign(rank, Value());
pidxs[tid].assign(rank, Value());
segHi[tid].assign(rank, Value());
coord[tid].assign(rank, Value());
@@ -300,11 +317,17 @@ void LoopEmitter::initializeLoopEmit(OpBuilder &builder, Location loc,
assert(isDenseDLT(dlt));
}
- // Find upper bound in current dimension.
// FIXME: `toOrigDim` is deprecated
- const Dimension d = toOrigDim(enc, l);
- lvlSizes[t][l] = highs[t][l] =
- mlir::linalg::createOrFoldDimOp(builder, loc, tensor, d);
+ // Since we do not have HigherOrdering now, we can always rely on the 1:1
+ // mapping from level to dimension to retrieve the level size.
+ Value lvlSz = mlir::linalg::createOrFoldDimOp(builder, loc, tensor,
+ toOrigDim(enc, l));
+ // Find upper bound in current dimension.
+ highs[t][l] = lvlSizes[t][l] = lvlSz;
+ if (isSparseSlices[t]) {
+ sliceOffsets[t][l] = genSliceOffset(builder, loc, tensors[t], l);
+ sliceStrides[t][l] = genSliceStride(builder, loc, tensors[t], l);
+ }
}
// Perform the required bufferization. Dense inputs materialize
@@ -405,7 +428,6 @@ Operation *LoopEmitter::enterLoopOverTensorAtDim(
isSparseInput = isSparseInput || isSparse;
}
- auto enc = getSparseTensorEncoding(tensors[tid].getType());
const auto reassoc = getCollapseReassociation(tid, dim);
// TODO: support dynamic slices.
// Uses the first dimension here to build the loop bound (which is also the
@@ -468,7 +490,7 @@ Operation *LoopEmitter::enterLoopOverTensorAtDim(
for (Value red : reduc)
types.push_back(red.getType());
- auto [trans, pred] = genSliceLegitPredicate(builder, loc, crd, enc, dim);
+ auto [trans, pred] = genSliceLegitPredicate(builder, loc, crd, tid, dim);
bool hasReduc = !types.empty();
scf::IfOp ifOp = builder.create<scf::IfOp>(loc, types, pred,
/*else*/ hasReduc);
@@ -660,11 +682,8 @@ Operation *LoopEmitter::enterCoIterationOverTensorsAtDims(
isSingletonDLT(dimTypes[tid][dim])) {
coord[tid][dim] = genSparseCrd(builder, loc, tid, dim);
if (isSparseSlices[tid]) {
- Value load =
- genIndexLoad(builder, loc, crdBuffer[tid][dim], pidxs[tid][dim]);
- auto enc = getSparseTensorEncoding(tensors[tid].getType());
auto [trans, pred] =
- genSliceLegitPredicate(builder, loc, load, enc, dim);
+ genSliceLegitPredicate(builder, loc, coord[tid][dim], tid, dim);
slicesPreds.emplace_back(pred, i);
// Updates to the relative coordinate to the slice.
coord[tid][dim] = trans;
@@ -679,7 +698,7 @@ Operation *LoopEmitter::enterCoIterationOverTensorsAtDims(
// Generates a list of if statments
// pidx = in_slice ? pidx : pidx + 1
// TODO: instead of always picking pidx + 1, we should set pidx = high to
- // break to loop the coordinates is larger than the slice size.
+ // break to loop if the coordinates is larger than the slice size.
for (auto [pred, idx] : slicesPreds) {
Value nextPidx = builder.create<arith::AddIOp>(
loc, yields[idx], constantIndex(builder, loc, 1));
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
index 1f6ee150dc4b4..8bc5da077c5c0 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
@@ -202,6 +202,13 @@ class LoopEmitter {
Value genSparseCrd(OpBuilder &builder, Location loc, size_t tid,
size_t dstLvl);
+ /// Generates a predicate to determine whether the tranformed coordinates are
+ /// in the given slice.
+ /// Returns std::pair<Transformed coordinates, Predicate>
+ std::pair<Value, Value> genSliceLegitPredicate(OpBuilder &builder,
+ Location loc, Value crd,
+ unsigned tid, unsigned lvl);
+
bool isOutputTensor(size_t tid) {
return hasOutput && tid == tensors.size() - 1;
}
@@ -278,6 +285,9 @@ class LoopEmitter {
/// Whether the sparse input is a slice.
std::vector<bool> isSparseSlices;
+ /// Values related to slices.
+ std::vector<std::vector<Value>> sliceOffsets;
+ std::vector<std::vector<Value>> sliceStrides;
/// Loop Stack, stores the information of all the nested loops that are
/// alive.
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp
index f3a6adbf0eceb..0c68c4db4fe95 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp
@@ -130,17 +130,18 @@ Value SpecifierStructBuilder::getInitValue(OpBuilder &builder, Location loc,
/// Builds IR extracting the pos-th offset from the descriptor.
Value SpecifierStructBuilder::dimOffset(OpBuilder &builder, Location loc,
Dimension dim) const {
- return builder.create<LLVM::ExtractValueOp>(
- loc, value,
- ArrayRef<int64_t>({kDimOffsetPosInSpecifier, static_cast<int64_t>(dim)}));
+ return extractField(
+ builder, loc,
+ ArrayRef<int64_t>{kDimOffsetPosInSpecifier, static_cast<int64_t>(dim)});
}
/// Builds IR inserting the pos-th offset into the descriptor.
void SpecifierStructBuilder::setDimOffset(OpBuilder &builder, Location loc,
Dimension dim, Value size) {
- value = builder.create<LLVM::InsertValueOp>(
- loc, value, size,
- ArrayRef<int64_t>({kDimOffsetPosInSpecifier, static_cast<int64_t>(dim)}));
+ insertField(
+ builder, loc,
+ ArrayRef<int64_t>{kDimOffsetPosInSpecifier, static_cast<int64_t>(dim)},
+ size);
}
/// Builds IR extracting the `lvl`-th level-size from the descriptor.
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index 80f299692af1b..71c78d9061a9b 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -18,6 +18,9 @@
#include "CodegenUtils.h"
#include "SparseTensorStorageLayout.h"
+#include "llvm/Support/FormatVariadic.h"
+
+#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
@@ -28,7 +31,6 @@
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Transforms/DialectConversion.h"
-#include "llvm/Support/FormatVariadic.h"
#include <optional>
@@ -697,6 +699,23 @@ class SparseDimOpConverter : public OpConversionPattern<tensor::DimOp> {
}
};
+template <typename Op, StorageSpecifierKind kind>
+class SparseSliceGetterOpConverter : public OpConversionPattern<Op> {
+public:
+ using OpConversionPattern<Op>::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(Op op, typename Op::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ // Simply lowers to specifer.get <field> operation.
+ auto desc = getDescriptorFromTensorTuple(adaptor.getSlice());
+ auto v = desc.getSpecifierField(rewriter, op.getLoc(), kind,
+ op.getDim().getZExtValue());
+
+ rewriter.replaceOp(op, v);
+ return success();
+ }
+};
+
/// Sparse codegen rule for trivial tensor casts.
class SparseCastConverter : public OpConversionPattern<tensor::CastOp> {
public:
@@ -1099,13 +1118,15 @@ class SparseConvertConverter : public OpConversionPattern<ConvertOp> {
}
};
-class SparseExtractSliceCoverter
+class SparseExtractSliceConverter
: public OpConversionPattern<tensor::ExtractSliceOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(tensor::ExtractSliceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ MLIRContext *ctx = op.getContext();
auto srcEnc = getSparseTensorEncoding(op.getSourceType());
auto dstEnc = getSparseTensorEncoding(op.getResult().getType());
if (!srcEnc && !dstEnc)
@@ -1119,16 +1140,43 @@ class SparseExtractSliceCoverter
assert(srcEnc.getPosWidth() == dstEnc.getPosWidth());
assert(srcEnc.getCrdWidth() == dstEnc.getCrdWidth());
- // TODO: support dynamic slices.
- for (int i = 0, e = op.getSourceType().getRank(); i < e; i++) {
- assert(op.getStaticStrides()[i] == dstEnc.getStaticDimSliceStride(i));
- assert(op.getStaticOffsets()[i] == dstEnc.getStaticDimSliceOffset(i));
- assert(op.getStaticSizes()[i] == dstEnc.getStaticDimSliceSize(i));
+ SmallVector<Value> fields;
+ auto desc = getMutDescriptorFromTensorTuple(adaptor.getSource(), fields);
+
+ auto newSpec = rewriter.create<StorageSpecifierInitOp>(
+ loc, StorageSpecifierType::get(ctx, dstEnc), desc.getSpecifier());
+ desc.setSpecifier(newSpec);
+
+ // Fills in slice information.
+ for (const auto &it : llvm::enumerate(llvm::zip(
+ op.getMixedOffsets(), op.getMixedSizes(), op.getMixedStrides()))) {
+ Dimension dim = it.index();
+ auto [offset, size, stride] = it.value();
+
+ Value offsetV = getValueOrCreateConstantIndexOp(rewriter, loc, offset);
+ Value sizeV = getValueOrCreateConstantIndexOp(rewriter, loc, size);
+ Value strideV = getValueOrCreateConstantIndexOp(rewriter, loc, stride);
+ // TODO: We could probably only set dynamic value here. But it would
+ // requires us to fill the hole when casting a static slice to dynamic
+ // slice.
+ desc.setSpecifierField(rewriter, loc, StorageSpecifierKind::DimOffset,
+ dim, offsetV);
+
+ // FIXME: we need to distinguish level sizes and dimension size for slices
+ // here. Maybe we should store slice level sizes in a
diff erent array
+ // instead of reusing it.
+ assert(srcEnc.hasIdDimOrdering());
+ desc.setSpecifierField(rewriter, loc, StorageSpecifierKind::LvlSize, dim,
+ sizeV);
+ desc.setSpecifierField(rewriter, loc, StorageSpecifierKind::DimStride,
+ dim, strideV);
}
- // TODO: create a new specifer for slices (need to encode slice metadata).
- // It does not matter now because only constant offset/stride are allowed.
- rewriter.replaceOp(op, adaptor.getSource());
+ // NOTE: we can not generate tuples directly from descriptor here, as the
+ // descriptor is holding the original type, yet we want the slice type
+ // here (they shared every memref but with an updated specifier).
+ rewriter.replaceOp(op, genTuple(rewriter, loc, op.getResult().getType(),
+ desc.getFields()));
return success();
}
};
@@ -1449,13 +1497,18 @@ void mlir::populateSparseTensorCodegenPatterns(
patterns.add<SparsePackOpConverter, SparseUnpackOpConverter,
SparseReturnConverter, SparseCallConverter, SparseDimOpConverter,
SparseCastConverter, SparseTensorDeallocConverter,
- SparseExtractSliceCoverter, SparseTensorLoadConverter,
+ SparseExtractSliceConverter, SparseTensorLoadConverter,
SparseExpandConverter, SparseCompressConverter,
- SparseInsertConverter, SparseToPositionsConverter,
- SparseToCoordinatesConverter, SparseToCoordinatesBufferConverter,
- SparseToValuesConverter, SparseConvertConverter,
- SparseNewOpConverter, SparseNumberOfEntriesConverter>(
- typeConverter, patterns.getContext());
+ SparseInsertConverter,
+ SparseSliceGetterOpConverter<ToSliceOffsetOp,
+ StorageSpecifierKind::DimOffset>,
+ SparseSliceGetterOpConverter<ToSliceStrideOp,
+ StorageSpecifierKind::DimStride>,
+ SparseToPositionsConverter, SparseToCoordinatesConverter,
+ SparseToCoordinatesBufferConverter, SparseToValuesConverter,
+ SparseConvertConverter, SparseNewOpConverter,
+ SparseNumberOfEntriesConverter>(typeConverter,
+ patterns.getContext());
patterns.add<SparseTensorAllocConverter>(typeConverter, patterns.getContext(),
enableBufferInitialization);
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h
index 69cc3af3a5bdd..788ad28ee4221 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h
@@ -403,6 +403,8 @@ class MutSparseTensorDescriptor
fields[fidx] = v;
}
+ void setSpecifier(Value newSpec) { fields.back() = newSpec; }
+
void setSpecifierField(OpBuilder &builder, Location loc,
StorageSpecifierKind kind, std::optional<Level> lvl,
Value v) {
diff --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir
index caf994cf8c192..6c0f13ae1b564 100644
--- a/mlir/test/Dialect/SparseTensor/invalid.mlir
+++ b/mlir/test/Dialect/SparseTensor/invalid.mlir
@@ -259,16 +259,16 @@ func.func @sparse_get_md(%arg0: !sparse_tensor.storage_specifier<#SparseVector>)
return %0 : index
}
-//// -----
-//
-//#SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}>
-//
-//func.func @sparse_get_md(%arg0: !sparse_tensor.storage_specifier<#SparseVector>) -> i64 {
-// // _e_xpected-error at +1 {{requested slice data on non-slice tensor}}
-// %0 = sparse_tensor.storage_specifier.get %arg0 dim_offset at 0
-// : !sparse_tensor.storage_specifier<#SparseVector> to i64
-// return %0 : i64
-//}
+// -----
+
+#SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}>
+
+func.func @sparse_get_md(%arg0: !sparse_tensor.storage_specifier<#SparseVector>) -> i64 {
+ // expected-error at +1 {{requested slice data on non-slice tensor}}
+ %0 = sparse_tensor.storage_specifier.get %arg0 dim_offset at 0
+ : !sparse_tensor.storage_specifier<#SparseVector>
+ return %0 : index
+}
// -----
diff --git a/mlir/test/Dialect/SparseTensor/sparse_extract_slice.mlir b/mlir/test/Dialect/SparseTensor/sparse_extract_slice.mlir
new file mode 100644
index 0000000000000..745b0a8f376d5
--- /dev/null
+++ b/mlir/test/Dialect/SparseTensor/sparse_extract_slice.mlir
@@ -0,0 +1,33 @@
+// RUN: mlir-opt %s --sparse-tensor-codegen --cse | FileCheck %s
+
+#CSR = #sparse_tensor.encoding<{
+ dimLevelType = [ "dense", "compressed" ]
+}>
+
+#CSR_SLICE = #sparse_tensor.encoding<{
+ dimLevelType = [ "dense", "compressed" ],
+ slice = [ (0, 4, 1), (0, 8, 1) ]
+}>
+
+// CHECK-LABEL: func.func @sparse_slice(
+// CHECK-SAME: %[[VAL_0:.*0]]: memref<?xindex>,
+// CHECK-SAME: %[[VAL_1:.*1]]: memref<?xindex>,
+// CHECK-SAME: %[[VAL_2:.*2]]: memref<?xf64>,
+// CHECK-SAME: %[[VAL_3:.*3]]: !sparse_tensor.storage_specifier<#sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ] }>>)
+// CHECK: %[[VAL_4:.*]] = sparse_tensor.storage_specifier.init with %[[VAL_3]]
+// CHECK: %[[VAL_5:.*]] = arith.constant 0 : index
+// CHECK: %[[VAL_6:.*]] = arith.constant 4 : index
+// CHECK: %[[VAL_7:.*]] = arith.constant 1 : index
+// CHECK: %[[VAL_8:.*]] = sparse_tensor.storage_specifier.set %[[VAL_4]] dim_offset at 0 with %[[VAL_5]]
+// CHECK: %[[VAL_9:.*]] = sparse_tensor.storage_specifier.set %[[VAL_8]] lvl_sz at 0 with %[[VAL_6]]
+// CHECK: %[[VAL_10:.*]] = sparse_tensor.storage_specifier.set %[[VAL_9]] dim_stride at 0 with %[[VAL_7]]
+// CHECK: %[[VAL_11:.*]] = arith.constant 8 : index
+// CHECK: %[[VAL_12:.*]] = sparse_tensor.storage_specifier.set %[[VAL_10]] dim_offset at 1 with %[[VAL_5]]
+// CHECK: %[[VAL_13:.*]] = sparse_tensor.storage_specifier.set %[[VAL_12]] lvl_sz at 1 with %[[VAL_11]]
+// CHECK: %[[VAL_14:.*]] = sparse_tensor.storage_specifier.set %[[VAL_13]] dim_stride at 1 with %[[VAL_7]]
+// CHECK: return %[[VAL_0]], %[[VAL_1]], %[[VAL_2]], %[[VAL_14]]
+func.func @sparse_slice(%t1 : tensor<8x8xf64, #CSR>) -> tensor<4x8xf64, #CSR_SLICE> {
+ %a1 = tensor.extract_slice %t1[0, 0][4, 8][1, 1] : tensor<8x8xf64, #CSR> to
+ tensor<4x8xf64, #CSR_SLICE>
+ return %a1 : tensor<4x8xf64, #CSR_SLICE>
+}
diff --git a/mlir/test/Dialect/SparseTensor/sparse_foreach.mlir b/mlir/test/Dialect/SparseTensor/sparse_foreach.mlir
index ad203abba0886..8d72e9455c187 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_foreach.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_foreach.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --post-sparsification-rewrite="enable-runtime-library=false enable-foreach=true" | FileCheck %s
+// RUN: mlir-opt %s --post-sparsification-rewrite="enable-runtime-library=false enable-foreach=true" --canonicalize | FileCheck %s
// CHECK-LABEL: func.func @sparse_foreach_constant
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
@@ -27,3 +27,115 @@ func.func @sparse_foreach_constant() -> () {
}
return
}
+
+#CSR_SLICE = #sparse_tensor.encoding<{
+ dimLevelType = [ "compressed", "compressed" ],
+ slice = [ (0, 4, 1), (2, 4, 1) ]
+}>
+
+#CSR_SLICE_DYN = #sparse_tensor.encoding<{
+ dimLevelType = [ "compressed", "compressed" ],
+ slice = [ (?, ?, ?), (?, ?, ?) ]
+}>
+
+
+// CHECK-LABEL: func.func @foreach_print_slice_dyn(
+// CHECK-SAME: %[[VAL_0:.*]]: tensor<?x?xf64,
+// CHECK-DAG: %[[VAL_1:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[VAL_3:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<?x?xf64,
+// CHECK-DAG: %[[VAL_4:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<?x?xf64,
+// CHECK-DAG: %[[VAL_5:.*]] = tensor.dim %[[VAL_0]], %[[VAL_1]] : tensor<?x?xf64,
+// CHECK-DAG: %[[VAL_6:.*]] = sparse_tensor.slice.offset %[[VAL_0]] at 0 : tensor<?x?xf64,
+// CHECK-DAG: %[[VAL_7:.*]] = sparse_tensor.slice.stride %[[VAL_0]] at 0 : tensor<?x?xf64,
+// CHECK-DAG: %[[VAL_8:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 1 : index} : tensor<?x?xf64,
+// CHECK-DAG: %[[VAL_9:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 1 : index} : tensor<?x?xf64,
+// CHECK-DAG: %[[VAL_10:.*]] = tensor.dim %[[VAL_0]], %[[VAL_2]] : tensor<?x?xf64,
+// CHECK-DAG: %[[VAL_11:.*]] = sparse_tensor.slice.offset %[[VAL_0]] at 1 : tensor<?x?xf64,
+// CHECK-DAG: %[[VAL_12:.*]] = sparse_tensor.slice.stride %[[VAL_0]] at 1 : tensor<?x?xf64,
+// CHECK-DAG: %[[VAL_13:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<?x?xf64,
+// CHECK: %[[VAL_14:.*]] = memref.load %[[VAL_3]]{{\[}}%[[VAL_1]]] : memref<?xindex>
+// CHECK: %[[VAL_15:.*]] = memref.load %[[VAL_3]]{{\[}}%[[VAL_2]]] : memref<?xindex>
+// CHECK: scf.for %[[VAL_16:.*]] = %[[VAL_14]] to %[[VAL_15]] step %[[VAL_2]] {
+// CHECK: %[[VAL_17:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_16]]] : memref<?xindex>
+// CHECK: %[[VAL_18:.*]] = arith.subi %[[VAL_17]], %[[VAL_6]] : index
+// CHECK: %[[VAL_19:.*]] = arith.remui %[[VAL_18]], %[[VAL_7]] : index
+// CHECK: %[[VAL_20:.*]] = arith.divui %[[VAL_18]], %[[VAL_7]] : index
+// CHECK: %[[VAL_21:.*]] = arith.cmpi uge, %[[VAL_17]], %[[VAL_6]] : index
+// CHECK: %[[VAL_22:.*]] = arith.cmpi ult, %[[VAL_20]], %[[VAL_5]] : index
+// CHECK: %[[VAL_23:.*]] = arith.cmpi eq, %[[VAL_19]], %[[VAL_1]] : index
+// CHECK: %[[VAL_24:.*]] = arith.andi %[[VAL_21]], %[[VAL_22]] : i1
+// CHECK: %[[VAL_25:.*]] = arith.andi %[[VAL_24]], %[[VAL_23]] : i1
+// CHECK: scf.if %[[VAL_25]] {
+// CHECK: %[[VAL_26:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_16]]] : memref<?xindex>
+// CHECK: %[[VAL_27:.*]] = arith.addi %[[VAL_16]], %[[VAL_2]] : index
+// CHECK: %[[VAL_28:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_27]]] : memref<?xindex>
+// CHECK: scf.for %[[VAL_29:.*]] = %[[VAL_26]] to %[[VAL_28]] step %[[VAL_2]] {
+// CHECK: %[[VAL_30:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_29]]] : memref<?xindex>
+// CHECK: %[[VAL_31:.*]] = arith.subi %[[VAL_30]], %[[VAL_11]] : index
+// CHECK: %[[VAL_32:.*]] = arith.remui %[[VAL_31]], %[[VAL_12]] : index
+// CHECK: %[[VAL_33:.*]] = arith.divui %[[VAL_31]], %[[VAL_12]] : index
+// CHECK: %[[VAL_34:.*]] = arith.cmpi uge, %[[VAL_30]], %[[VAL_11]] : index
+// CHECK: %[[VAL_35:.*]] = arith.cmpi ult, %[[VAL_33]], %[[VAL_10]] : index
+// CHECK: %[[VAL_36:.*]] = arith.cmpi eq, %[[VAL_32]], %[[VAL_1]] : index
+// CHECK: %[[VAL_37:.*]] = arith.andi %[[VAL_34]], %[[VAL_35]] : i1
+// CHECK: %[[VAL_38:.*]] = arith.andi %[[VAL_37]], %[[VAL_36]] : i1
+// CHECK: scf.if %[[VAL_38]] {
+// CHECK: %[[VAL_39:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_29]]] : memref<?xf64>
+// CHECK: "test.use"(%[[VAL_39]]) : (f64) -> ()
+// CHECK: }
+// CHECK: }
+// CHECK: }
+// CHECK: }
+// CHECK: return
+//
+func.func @foreach_print_slice_dyn(%A: tensor<?x?xf64, #CSR_SLICE_DYN>) {
+ sparse_tensor.foreach in %A : tensor<?x?xf64, #CSR_SLICE_DYN> do {
+ ^bb0(%1: index, %2: index, %v: f64) :
+ "test.use" (%v) : (f64) -> ()
+ }
+ return
+}
+
+// CHECK-LABEL: func.func @foreach_print_slice(
+// CHECK-SAME: %[[VAL_0:.*]]: tensor<4x4xf64,
+// CHECK-DAG: %[[VAL_1:.*]] = arith.constant 4 : index
+// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[VAL_5:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<4x4xf64,
+// CHECK-DAG: %[[VAL_6:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<4x4xf64,
+// CHECK-DAG: %[[VAL_7:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 1 : index} : tensor<4x4xf64,
+// CHECK-DAG: %[[VAL_8:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 1 : index} : tensor<4x4xf64,
+// CHECK-DAG: %[[VAL_9:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<4x4xf64,
+// CHECK-DAG: %[[VAL_10:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_3]]] : memref<?xindex>
+// CHECK: %[[VAL_11:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_4]]] : memref<?xindex>
+// CHECK: scf.for %[[VAL_12:.*]] = %[[VAL_10]] to %[[VAL_11]] step %[[VAL_4]] {
+// CHECK: %[[VAL_13:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_12]]] : memref<?xindex>
+// CHECK: %[[VAL_14:.*]] = arith.cmpi ult, %[[VAL_13]], %[[VAL_1]] : index
+// CHECK: scf.if %[[VAL_14]] {
+// CHECK: %[[VAL_15:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_12]]] : memref<?xindex>
+// CHECK: %[[VAL_16:.*]] = arith.addi %[[VAL_12]], %[[VAL_4]] : index
+// CHECK: %[[VAL_17:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_16]]] : memref<?xindex>
+// CHECK: scf.for %[[VAL_18:.*]] = %[[VAL_15]] to %[[VAL_17]] step %[[VAL_4]] {
+// CHECK: %[[VAL_19:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_18]]] : memref<?xindex>
+// CHECK: %[[VAL_20:.*]] = arith.subi %[[VAL_19]], %[[VAL_2]] : index
+// CHECK: %[[VAL_21:.*]] = arith.cmpi uge, %[[VAL_19]], %[[VAL_2]] : index
+// CHECK: %[[VAL_22:.*]] = arith.cmpi ult, %[[VAL_20]], %[[VAL_1]] : index
+// CHECK: %[[VAL_23:.*]] = arith.andi %[[VAL_21]], %[[VAL_22]] : i1
+// CHECK: scf.if %[[VAL_23]] {
+// CHECK: %[[VAL_24:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_18]]] : memref<?xf64>
+// CHECK: "test.use"(%[[VAL_24]]) : (f64) -> ()
+// CHECK: }
+// CHECK: }
+// CHECK: }
+// CHECK: }
+// CHECK: return
+//
+func.func @foreach_print_slice(%A: tensor<4x4xf64, #CSR_SLICE>) {
+ sparse_tensor.foreach in %A : tensor<4x4xf64, #CSR_SLICE> do {
+ ^bb0(%1: index, %2: index, %v: f64) :
+ "test.use" (%v) : (f64) -> ()
+ }
+ return
+}
\ No newline at end of file
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_foreach_slices.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_foreach_slices.mlir
index 548818ccb02db..a3e94263a40fc 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_foreach_slices.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_foreach_slices.mlir
@@ -2,7 +2,7 @@
// DEFINE: %{command} = mlir-opt %s --sparse-compiler=%{option} | \
// DEFINE: mlir-cpu-runner \
// DEFINE: -e entry -entry-point-result=void \
-// DEFINE: -shared-libs=%mlir_c_runner_utils | \
+// DEFINE: -shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext | \
// DEFINE: FileCheck %s
//
// RUN: %{command}
@@ -18,6 +18,12 @@
slice = [ (1, 4, 1), (1, 4, 2) ]
}>
+#CSR_SLICE_DYN = #sparse_tensor.encoding<{
+ dimLevelType = [ "dense", "compressed" ],
+ slice = [ (?, ?, ?), (?, ?, ?) ]
+}>
+
+
module {
func.func @foreach_print_non_slice(%A: tensor<4x4xf64, #CSR>) {
sparse_tensor.foreach in %A : tensor<4x4xf64, #CSR> do {
@@ -39,8 +45,22 @@ module {
return
}
+ func.func @foreach_print_slice_dyn(%A: tensor<?x?xf64, #CSR_SLICE_DYN>) {
+ sparse_tensor.foreach in %A : tensor<?x?xf64, #CSR_SLICE_DYN> do {
+ ^bb0(%1: index, %2: index, %v: f64) :
+ vector.print %1: index
+ vector.print %2: index
+ vector.print %v: f64
+ }
+ return
+ }
+
func.func @entry() {
%c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %c4 = arith.constant 4 : index
+
%sa = arith.constant dense<[
[ 0.0, 2.1, 0.0, 0.0, 0.0, 6.1, 0.0, 0.0 ],
[ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 ],
@@ -52,6 +72,7 @@ module {
[ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0 ]
]> : tensor<8x8xf64>
+
%tmp = sparse_tensor.convert %sa : tensor<8x8xf64> to tensor<8x8xf64, #CSR>
%a = tensor.extract_slice %tmp[1, 1][4, 4][1, 2] : tensor<8x8xf64, #CSR> to
tensor<4x4xf64, #CSR_SLICE>
@@ -72,7 +93,7 @@ module {
%dense = tensor.extract_slice %sa[1, 1][4, 4][1, 2] : tensor<8x8xf64> to
tensor<4x4xf64>
%b = sparse_tensor.convert %dense : tensor<4x4xf64> to tensor<4x4xf64, #CSR>
- // Foreach on sparse tensor instead of slice should yield the same result.
+ // Foreach on sparse tensor instead of slice they should yield the same result.
//
// CHECK-NEXT: 1
// CHECK-NEXT: 0
@@ -86,8 +107,28 @@ module {
//
call @foreach_print_non_slice(%b) : (tensor<4x4xf64, #CSR>) -> ()
- bufferization.dealloc_tensor %b : tensor<4x4xf64, #CSR>
+ // The same slice, but with dynamic encoding.
+ // TODO: Investigates why reusing the same %tmp above would cause bufferization
+ // errors.
+ %tmp1 = sparse_tensor.convert %sa : tensor<8x8xf64> to tensor<8x8xf64, #CSR>
+ %a_dyn = tensor.extract_slice %tmp1[%c1, %c1][%c4, %c4][%c1, %c2] :
+ tensor<8x8xf64, #CSR> to tensor<?x?xf64, #CSR_SLICE_DYN>
+ //
+ // CHECK-NEXT: 1
+ // CHECK-NEXT: 0
+ // CHECK-NEXT: 2.3
+ // CHECK-NEXT: 2
+ // CHECK-NEXT: 3
+ // CHECK-NEXT: 1
+ // CHECK-NEXT: 3
+ // CHECK-NEXT: 2
+ // CHECK-NEXT: 2.1
+ //
+ call @foreach_print_slice_dyn(%a_dyn) : (tensor<?x?xf64, #CSR_SLICE_DYN>) -> ()
+
bufferization.dealloc_tensor %tmp : tensor<8x8xf64, #CSR>
+ bufferization.dealloc_tensor %tmp1 : tensor<8x8xf64, #CSR>
+ bufferization.dealloc_tensor %b : tensor<4x4xf64, #CSR>
return
}
}
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_matmul_slice.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_matmul_slice.mlir
index 8f77b3d6a16bb..ffa6ad8d7618b 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_matmul_slice.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_matmul_slice.mlir
@@ -38,10 +38,32 @@
slice = [ (0, 4, 2), (1, 4, 1) ]
}>
+#CSR_SLICE_dyn = #sparse_tensor.encoding<{
+ dimLevelType = [ "dense", "compressed" ],
+ slice = [ (?, 4, ?), (?, 4, ?) ]
+}>
+
+#DCSR_SLICE_dyn = #sparse_tensor.encoding<{
+ dimLevelType = [ "compressed", "compressed" ],
+ slice = [ (?, 4, ?), (?, 4, ?) ]
+}>
+
+
module {
func.func private @printMemrefF64(%ptr : tensor<*xf64>)
func.func private @printMemref1dF64(%ptr : memref<?xf64>) attributes { llvm.emit_c_interface }
+ //
+ // Computes C = A x B with all matrices dynamic sparse slice (SpMSpM) in CSR and DCSR
+ //
+ func.func @matmul_dyn(%A: tensor<4x4xf64, #CSR_SLICE_dyn>,
+ %B: tensor<4x4xf64, #DCSR_SLICE_dyn>) -> tensor<4x4xf64, #CSR> {
+ %C = bufferization.alloc_tensor() : tensor<4x4xf64, #CSR>
+ %D = linalg.matmul
+ ins(%A, %B: tensor<4x4xf64, #CSR_SLICE_dyn>, tensor<4x4xf64, #DCSR_SLICE_dyn>)
+ outs(%C: tensor<4x4xf64, #CSR>) -> tensor<4x4xf64, #CSR>
+ return %D: tensor<4x4xf64, #CSR>
+ }
//
// Computes C = A x B with one matrix CSR sparse slices and the other DSCR sparse slice.
@@ -83,7 +105,9 @@ module {
// Main driver.
//
func.func @entry() {
- %c0 = arith.constant 0 : index
+ %c_0 = arith.constant 0 : index
+ %c_1 = arith.constant 1 : index
+ %c_2 = arith.constant 2 : index
%f0 = arith.constant 0.0 : f64
%sa = arith.constant dense<[
@@ -158,11 +182,27 @@ module {
%4 = call @matmul1(%s2, %s1)
: (tensor<4x4xf64, #CSR_SLICE_1>,
tensor<4x4xf64, #DCSR_SLICE_1>) -> tensor<4x4xf64, #CSR>
-
%c4 = sparse_tensor.convert %4 : tensor<4x4xf64, #CSR> to tensor<4x4xf64>
%c4u = tensor.cast %c4 : tensor<4x4xf64> to tensor<*xf64>
call @printMemrefF64(%c4u) : (tensor<*xf64>) -> ()
+ // slice x slice (same as above, but with dynamic stride information)
+ //
+ // CHECK: [2.3, 0, 0, 0],
+ // CHECK-NEXT: [6.9, 0, 0, 0],
+ // CHECK-NEXT: [0, 0, 0, 0],
+ // CHECK-NEXT: [12.6, 0, 0, 0]]
+ //
+ %s1_dyn = tensor.extract_slice %tmp[%c_0, %c_1][4, 4][%c_2, %c_1] : tensor<8x8xf64, #DCSR> to tensor<4x4xf64, #DCSR_SLICE_dyn>
+ %s2_dyn = tensor.extract_slice %b1[%c_0, %c_0][4, 4][%c_2, %c_1] : tensor<8x4xf64, #CSR> to tensor<4x4xf64, #CSR_SLICE_dyn>
+ %dyn_4 = call @matmul_dyn(%s2_dyn, %s1_dyn)
+ : (tensor<4x4xf64, #CSR_SLICE_dyn>,
+ tensor<4x4xf64, #DCSR_SLICE_dyn>) -> tensor<4x4xf64, #CSR>
+
+ %c4_dyn = sparse_tensor.convert %dyn_4 : tensor<4x4xf64, #CSR> to tensor<4x4xf64>
+ %c4u_dyn = tensor.cast %c4_dyn : tensor<4x4xf64> to tensor<*xf64>
+ call @printMemrefF64(%c4u_dyn) : (tensor<*xf64>) -> ()
+
// sparse slices should generate the same result as dense slices
//
// CHECK: [2.3, 0, 0, 0],
@@ -179,7 +219,7 @@ module {
%du = tensor.cast %r : tensor<4x4xf64> to tensor<*xf64>
call @printMemrefF64(%du) : (tensor<*xf64>) -> ()
- // Releases resources.
+ // Releases resources (we do not need to deallocate slices).
bufferization.dealloc_tensor %b1 : tensor<8x4xf64, #CSR>
bufferization.dealloc_tensor %t1 : tensor<8x8xf64, #CSR>
bufferization.dealloc_tensor %b : tensor<8x4xf64, #DCSR>
@@ -187,6 +227,7 @@ module {
bufferization.dealloc_tensor %4 : tensor<4x4xf64, #CSR>
bufferization.dealloc_tensor %3 : tensor<4x4xf64, #CSR>
bufferization.dealloc_tensor %2 : tensor<4x4xf64, #DCSR>
+ bufferization.dealloc_tensor %dyn_4 : tensor<4x4xf64, #CSR>
return
}
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 3f98d8247bdcc..a3f58b06a7606 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -2223,6 +2223,7 @@ cc_library(
deps = [
":AffineDialect",
":ArithDialect",
+ ":ArithUtils",
":BufferizationDialect",
":BufferizationTransforms",
":ComplexDialect",
More information about the Mlir-commits
mailing list