[Mlir-commits] [mlir] 0352690 - [mlir][sparse] make foreach operation support sparse tensor slices.
Peiming Liu
llvmlistbot at llvm.org
Wed Feb 8 10:58:41 PST 2023
Author: Peiming Liu
Date: 2023-02-08T18:58:35Z
New Revision: 03526904217568a95fb7369c397edc9d26975789
URL: https://github.com/llvm/llvm-project/commit/03526904217568a95fb7369c397edc9d26975789
DIFF: https://github.com/llvm/llvm-project/commit/03526904217568a95fb7369c397edc9d26975789.diff
LOG: [mlir][sparse] make foreach operation support sparse tensor slices.
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D140713
Added:
mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_foreach_slices.mlir
Modified:
mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index f61e1f52ce8f9..fda7a5c481395 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -534,7 +534,10 @@ getNormalizedEncodingForSpecifier(SparseTensorEncodingAttr enc) {
enc.getContext(), dlts,
AffineMap(), // dimOrdering (irrelavant to storage speicifer)
AffineMap(), // highLvlOrdering (irrelavant to storage specifer)
- enc.getPointerBitWidth(), enc.getIndexBitWidth());
+ enc.getPointerBitWidth(), enc.getIndexBitWidth(),
+ // FIXME: we should keep the slice information, for now it is okay as only
+ // constant can be used for slice
+ ArrayRef<SparseTensorDimSliceAttr>{} /*enc.getDimSlices()*/);
}
StorageSpecifierType
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
index 88981fccaf403..df19b61dbede1 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
@@ -42,6 +42,50 @@ static Value genIndexLoad(OpBuilder &builder, Location loc, Value ptr,
return load;
}
+// TODO: Support dynamic sized slice.
+static Value getSliceOffset(OpBuilder &builder, Location loc,
+ SparseTensorEncodingAttr enc, unsigned lvl) {
+ return constantIndex(builder, loc, *enc.getStaticLvlSliceOffset(lvl));
+}
+
+static Value getSliceSize(OpBuilder &builder, Location loc,
+ SparseTensorEncodingAttr enc, unsigned lvl) {
+ return constantIndex(builder, loc, *enc.getStaticLvlSliceSize(lvl));
+}
+
+static Value getSliceStride(OpBuilder &builder, Location loc,
+ SparseTensorEncodingAttr enc, unsigned lvl) {
+ return constantIndex(builder, loc, *enc.getStaticLvlSliceStride(lvl));
+}
+
+// Converts a coordinate relative to the slice to the coordinate relative
+// to the underlying tensor.
+static Value toSliceCoord(OpBuilder &builder, Location loc, Value v,
+ SparseTensorEncodingAttr enc, unsigned lvl) {
+
+ Value stride = getSliceStride(builder, loc, enc, lvl);
+ Value offset = getSliceOffset(builder, loc, enc, lvl);
+ // iv = iv * stride + offset
+ v = builder.create<arith::MulIOp>(loc, v, stride);
+ v = builder.create<arith::AddIOp>(loc, v, offset);
+ return v;
+}
+
+// Converts a coordinate relative to the underlying tensor to the coordinate
+// relative to the slice, returns a extra reminder value
+static std::pair<Value, Value> fromSliceCoord(OpBuilder &builder, Location loc,
+ Value v,
+ SparseTensorEncodingAttr enc,
+ unsigned lvl) {
+ Value stride = getSliceStride(builder, loc, enc, lvl);
+ Value offset = getSliceOffset(builder, loc, enc, lvl);
+ // iv = (iv - offset) / stride
+ v = builder.create<arith::SubIOp>(loc, v, offset);
+ Value rem = builder.create<arith::RemUIOp>(loc, v, stride);
+ v = builder.create<arith::DivUIOp>(loc, v, stride);
+ return std::make_pair(v, rem);
+}
+
//===----------------------------------------------------------------------===//
// Sparse tensor loop emitter class implementations
//===----------------------------------------------------------------------===//
@@ -50,6 +94,10 @@ Value LoopEmitter::genAddress(OpBuilder &builder, Location loc, size_t tid,
size_t dim, Value iv) {
Value p = dim == 0 ? constantIndex(builder, loc, 0) : pidxs[tid][dim - 1];
Value mul = builder.create<arith::MulIOp>(loc, highs[tid][dim], p);
+ if (isSparseSlices[tid]) {
+ auto enc = getSparseTensorEncoding(tensors[tid].getType());
+ iv = toSliceCoord(builder, loc, iv, enc, dim);
+ }
Value add = builder.create<arith::AddIOp>(loc, mul, iv);
return add;
}
@@ -67,6 +115,7 @@ void LoopEmitter::initialize(ValueRange tensors, StringAttr loopTag,
this->hasOutput = hasOutput;
this->isSparseOut = isSparseOut;
this->tensors.assign(tensors.begin(), tensors.end());
+ this->isSparseSlices.assign(tensors.size(), false);
this->dimTypes.assign(tensors.size(), std::vector<DimLevelType>());
this->pidxs.assign(tensors.size(), std::vector<Value>());
this->coord.assign(tensors.size(), std::vector<Value>());
@@ -87,10 +136,11 @@ void LoopEmitter::initialize(ValueRange tensors, StringAttr loopTag,
auto enc = getSparseTensorEncoding(rtp);
// We always treat sparse output tensor as dense so that we always iterate
// it based on dim size.
- if (enc && !(isOutputTensor(tid) && isSparseOut))
+ if (enc && !(isOutputTensor(tid) && isSparseOut)) {
+ isSparseSlices[tid] = enc.isSlice();
for (auto dimTp : enc.getDimLevelType())
dimTypes[tid].push_back(dimTp);
- else
+ } else
dimTypes[tid].assign(rank, DimLevelType::Dense);
// Initialize using empty value.
@@ -218,7 +268,6 @@ Operation *LoopEmitter::enterLoopOverTensorAtDim(
ArrayRef<size_t> dims, MutableArrayRef<Value> reduc, bool isParallel) {
// TODO: support multiple return on parallel for?
assert(!isParallel || reduc.size() <= 1);
-
bool isSparseInput = false;
size_t tid = tids.front(), dim = dims.front();
for (auto [t, d] : llvm::zip(tids, dims)) {
@@ -239,10 +288,13 @@ Operation *LoopEmitter::enterLoopOverTensorAtDim(
isSparseInput = isSparseInput || isSparse;
}
+ auto enc = getSparseTensorEncoding(tensors[tid].getType());
+ // TODO: support dynamic slices.
Value step = constantIndex(builder, loc, 1);
Value lo = isSparseInput ? pidxs[tid][dim] // current offset
- : loopSeqStack.back(); // univeral tid
+ : loopSeqStack.back(); // universal index
Value hi = highs[tid][dim];
+
Operation *loop = nullptr;
Value iv;
if (isParallel) {
@@ -275,15 +327,64 @@ Operation *LoopEmitter::enterLoopOverTensorAtDim(
}
assert(loop && iv);
+ Value c;
if (isSparseInput) {
pidxs[tid][dim] = iv;
// Generating a load on the indices array yields the coordinate.
Value ptr = idxBuffer[tid][dim];
- coord[tid][dim] = genIndexLoad(builder, loc, ptr, iv);
+ c = genIndexLoad(builder, loc, ptr, iv);
} else {
// Dense tensor, the coordinates is the inducation variable.
- coord[tid][dim] = iv;
+ c = iv;
}
+
+ if (isSparseSlices[tid] && isSparseInput) {
+ // For sparse level slices, we need to filter out invalid coordinates that
+ // are not included in the slice.
+ std::pair<Value, Value> trans = fromSliceCoord(builder, loc, c, enc, dim);
+ SmallVector<Type> types;
+ for (Value red : reduc)
+ types.push_back(red.getType());
+
+ // First, coord >= offset (TODO: seems unsigned >= 0 won't be folded, skip
+ // the check if the offset is zero).
+ auto geOff =
+ builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::uge, c,
+ getSliceOffset(builder, loc, enc, dim));
+ // Second, coords < length
+ auto ltLen = builder.create<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::ult, trans.first,
+ getSliceSize(builder, loc, enc, dim));
+
+ // Third, rem == 0; confirmed that (a % 1) will be folded to 0
+ auto fitStride = builder.create<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::eq, trans.second,
+ constantIndex(builder, loc, 0));
+
+ auto pred = builder.create<arith::AndIOp>(loc, geOff, ltLen);
+ pred = builder.create<arith::AndIOp>(loc, pred, fitStride);
+ bool hasReduc = !types.empty();
+ scf::IfOp ifOp =
+ builder.create<scf::IfOp>(loc, types, pred, /*else*/ hasReduc);
+ if (hasReduc) {
+ // scf.for (a) -> v
+ // %s = scf.if (a) -> v
+ // user-generated code.
+ // else
+ // yield a
+ // yield %s
+ builder.create<scf::YieldOp>(loc, ifOp.getResults());
+ builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
+ // On mismatch.
+ builder.create<scf::YieldOp>(loc, reduc);
+ }
+ // Set the insertion point to matched branch.
+ builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
+ c = trans.first;
+ }
+
+ assert(c);
+ coord[tid][dim] = c;
// NOTE: we can also prepare for next dim here in advance
// Push the loop into stack
loopStack.emplace_back(ArrayRef<size_t>(tid), ArrayRef<size_t>(dim), loop,
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
index a1db60c211957..832281a86359e 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
@@ -259,22 +259,25 @@ class LoopEmitter {
std::vector<std::vector<Value>> idxBuffer; // to_indices
std::vector<Value> valBuffer; // to_value
- // Loop Stack, stores the information of all the nested loops that are
- // alive.
+ /// Whether the sparse input is a slice.
+ std::vector<bool> isSparseSlices;
+
+ /// Loop Stack, stores the information of all the nested loops that are
+ /// alive.
std::vector<LoopLevelInfo> loopStack;
- // Loop Sequence Stack, stores the unversial index for the current loop
- // sequence.
+ /// Loop Sequence Stack, stores the unversial index for the current loop
+ /// sequence.
std::vector<Value> loopSeqStack;
- // Maps AffineDimExpr to the index of the loop in loopStack.
- // TODO: We should probably use a callback function here to make it more
- // general.
+ /// Maps AffineDimExpr to the index of the loop in loopStack.
+ /// TODO: We should probably use a callback function here to make it more
+ /// general.
std::vector<unsigned> sparsiferLoopLvlMap;
- // TODO: not yet used, it should track the current level for each tensor
- // to help eliminate `dim` paramters from above APIs.
- // std::vector<size_t> curLv;
+ /// TODO: not yet used, it should track the current level for each tensor
+ /// to help eliminate `dim` paramters from above APIs.
+ /// std::vector<size_t> curLv;
};
} // namespace sparse_tensor
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index 11348e0703e39..b2541c38a30b1 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -1010,6 +1010,40 @@ class SparseConvertConverter : public OpConversionPattern<ConvertOp> {
}
};
+class SparseExtractSliceCoverter
+ : public OpConversionPattern<tensor::ExtractSliceOp> {
+public:
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(tensor::ExtractSliceOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto srcEnc = getSparseTensorEncoding(op.getSourceType());
+ auto dstEnc = getSparseTensorEncoding(op.getResult().getType());
+ if (!srcEnc && !dstEnc)
+ return failure();
+
+ // TODO: We should check these in ExtractSliceOp::verify.
+ assert(srcEnc && dstEnc && dstEnc.isSlice());
+ assert(srcEnc.getDimLevelType() == dstEnc.getDimLevelType());
+ assert(srcEnc.getDimOrdering() == dstEnc.getDimOrdering());
+ assert(srcEnc.getHigherOrdering() == dstEnc.getHigherOrdering());
+ assert(srcEnc.getPointerBitWidth() == dstEnc.getPointerBitWidth());
+ assert(srcEnc.getIndexBitWidth() == dstEnc.getIndexBitWidth());
+
+ // TODO: support dynamic slices.
+ for (int i = 0, e = op.getSourceType().getRank(); i < e; i++) {
+ assert(op.getStaticStrides()[i] == dstEnc.getStaticDimSliceStride(i));
+ assert(op.getStaticOffsets()[i] == dstEnc.getStaticDimSliceOffset(i));
+ assert(op.getStaticSizes()[i] == dstEnc.getStaticDimSliceSize(i));
+ }
+
+ // TODO: create a new specifer for slices (need to encode slice metadata).
+ // It does not matter now because only constant offset/stride are allowed.
+ rewriter.replaceOp(op, adaptor.getSource());
+ return success();
+ }
+};
+
/// Sparse codegen rule for number of entries operator.
class SparseNumberOfEntriesConverter
: public OpConversionPattern<NumberOfEntriesOp> {
@@ -1133,13 +1167,13 @@ void mlir::populateSparseTensorCodegenPatterns(
bool enableBufferInitialization) {
patterns.add<SparsePackOpConverter, SparseReturnConverter,
SparseCallConverter, SparseDimOpConverter, SparseCastConverter,
- SparseTensorDeallocConverter, SparseTensorLoadConverter,
- SparseExpandConverter, SparseCompressConverter,
- SparseInsertConverter, SparseToPointersConverter,
- SparseToIndicesConverter, SparseToIndicesBufferConverter,
- SparseToValuesConverter, SparseConvertConverter,
- SparseNumberOfEntriesConverter>(typeConverter,
- patterns.getContext());
+ SparseTensorDeallocConverter, SparseExtractSliceCoverter,
+ SparseTensorLoadConverter, SparseExpandConverter,
+ SparseCompressConverter, SparseInsertConverter,
+ SparseToPointersConverter, SparseToIndicesConverter,
+ SparseToIndicesBufferConverter, SparseToValuesConverter,
+ SparseConvertConverter, SparseNumberOfEntriesConverter>(
+ typeConverter, patterns.getContext());
patterns.add<SparseTensorAllocConverter>(typeConverter, patterns.getContext(),
enableBufferInitialization);
}
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_foreach_slices.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_foreach_slices.mlir
new file mode 100644
index 0000000000000..560c64fa29a41
--- /dev/null
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_foreach_slices.mlir
@@ -0,0 +1,94 @@
+// DEFINE: %{option} = enable-runtime-library=false
+// DEFINE: %{command} = mlir-opt %s --sparse-compiler=%{option} | \
+// DEFINE: mlir-cpu-runner \
+// DEFINE: -e entry -entry-point-result=void \
+// DEFINE: -shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext | \
+// DEFINE: FileCheck %s
+//
+// RUN: %{command}
+//
+
+// TODO: support slices on lib path
+#CSR = #sparse_tensor.encoding<{
+ dimLevelType = [ "dense", "compressed" ]
+}>
+
+#CSR_SLICE = #sparse_tensor.encoding<{
+ dimLevelType = [ "dense", "compressed" ],
+ slice = [ (1, 4, 1), (1, 4, 2) ]
+}>
+
+module {
+ func.func @foreach_print_non_slice(%A: tensor<4x4xf64, #CSR>) {
+ sparse_tensor.foreach in %A : tensor<4x4xf64, #CSR> do {
+ ^bb0(%1: index, %2: index, %v: f64) :
+ vector.print %1: index
+ vector.print %2: index
+ vector.print %v: f64
+ }
+ return
+ }
+
+ func.func @foreach_print_slice(%A: tensor<4x4xf64, #CSR_SLICE>) {
+ sparse_tensor.foreach in %A : tensor<4x4xf64, #CSR_SLICE> do {
+ ^bb0(%1: index, %2: index, %v: f64) :
+ vector.print %1: index
+ vector.print %2: index
+ vector.print %v: f64
+ }
+ return
+ }
+
+ func.func @entry() {
+ %c0 = arith.constant 0 : index
+ %sa = arith.constant dense<[
+ [ 0.0, 2.1, 0.0, 0.0, 0.0, 6.1, 0.0, 0.0 ],
+ [ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 ],
+ [ 0.0, 2.3, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 ],
+ [ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0 ],
+ [ 0.0, 0.0, 0.1, 0.0, 0.0, 2.1, 0.0, 0.0 ],
+ [ 0.0, 0.0, 0.0, 0.0, 3.1, 0.0, 0.0, 0.0 ],
+ [ 0.0, 2.3, 0.0, 0.0, 0.0, 0.0, 3.3, 0.0 ],
+ [ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0 ]
+ ]> : tensor<8x8xf64>
+
+ %tmp = sparse_tensor.convert %sa : tensor<8x8xf64> to tensor<8x8xf64, #CSR>
+ %a = tensor.extract_slice %tmp[1, 1][4, 4][1, 2] : tensor<8x8xf64, #CSR> to
+ tensor<4x4xf64, #CSR_SLICE>
+ // Foreach on sparse tensor slices directly
+ //
+ // CHECK: 1
+ // CHECK-NEXT: 0
+ // CHECK-NEXT: 2.3
+ // CHECK-NEXT: 2
+ // CHECK-NEXT: 3
+ // CHECK-NEXT: 1
+ // CHECK-NEXT: 3
+ // CHECK-NEXT: 2
+ // CHECK-NEXT: 2.1
+ //
+ call @foreach_print_slice(%a) : (tensor<4x4xf64, #CSR_SLICE>) -> ()
+
+ // FIXME: investigate why a tensor copy is inserted for this slice
+// %dense = tensor.extract_slice %sa[1, 1][4, 4][1, 2] : tensor<8x8xf64> to
+// tensor<4x4xf64>
+// %b = sparse_tensor.convert %dense : tensor<4x4xf64> to tensor<4x4xf64, #CSR>
+// // Foreach on sparse tensor instead of slice they should yield the same result.
+// //
+// // C_HECK-NEXT: 1
+// // C_HECK-NEXT: 0
+// // C_HECK-NEXT: 2.3
+// // C_HECK-NEXT: 2
+// // C_HECK-NEXT: 3
+// // C_HECK-NEXT: 1
+// // C_HECK-NEXT: 3
+// // C_HECK-NEXT: 2
+// // C_HECK-NEXT: 2.1
+// //
+// call @foreach_print_non_slice(%b) : (tensor<4x4xf64, #CSR>) -> ()
+// bufferization.dealloc_tensor %b : tensor<4x4xf64, #CSR>
+
+ bufferization.dealloc_tensor %tmp : tensor<8x8xf64, #CSR>
+ return
+ }
+}
More information about the Mlir-commits
mailing list