[Mlir-commits] [mlir] [mlir][sparse] introduce MapRef, unify conversion/codegen for reader (PR #68360)
Peiming Liu
llvmlistbot at llvm.org
Thu Oct 5 16:30:22 PDT 2023
================
@@ -729,3 +729,92 @@ Value sparse_tensor::createOrFoldSliceStrideOp(OpBuilder &builder, Location loc,
return constantIndex(builder, loc, *stride);
return builder.create<ToSliceStrideOp>(loc, tensor, APInt(64, dim));
}
+
+void sparse_tensor::fillDimShape(OpBuilder &builder, Location loc,
+ SparseTensorType stt,
+ SmallVectorImpl<Value> &out) {
+ out.clear();
+ out.reserve(stt.getDimRank());
+ for (const DynSize sh : stt.getDimShape()) {
+ const auto s = ShapedType::isDynamic(sh) ? 0 : sh;
+ out.push_back(constantIndex(builder, loc, s));
+ }
+}
+
+Value sparse_tensor::genReader(OpBuilder &builder, Location loc,
+ SparseTensorType stt, Value tensor,
+ /*out*/ SmallVectorImpl<Value> &dimShapesValues,
+ /*out*/ Value &dimSizesBuffer) {
+ // Construct the dimShapes buffer. The buffer contains the static size
+ // per dimension, or otherwise a zero for a dynamic size.
+ fillDimShape(builder, loc, stt, dimShapesValues);
+ Value dimShapesBuffer = allocaBuffer(builder, loc, dimShapesValues);
+ // Create the `CheckedSparseTensorReader`. This reader performs a
+ // consistency check on the static sizes, but accepts any size
+ // of each dimension with a dynamic size.
+ Type opaqueTp = getOpaquePointerType(builder);
+ Type eltTp = stt.getElementType();
+ Value valTp = constantPrimaryTypeEncoding(builder, loc, eltTp);
+ Value reader =
+ createFuncCall(builder, loc, "createCheckedSparseTensorReader", opaqueTp,
+ {tensor, dimShapesBuffer, valTp}, EmitCInterface::On)
+ .getResult(0);
+ // For static shapes, the shape buffer can be used right away. For dynamic
+ // shapes, use the information from the reader to construct a buffer that
+ // supplies the actual size for each dynamic dimension.
+ dimSizesBuffer = dimShapesBuffer;
+ if (stt.hasDynamicDimShape()) {
+ Type indexTp = builder.getIndexType();
+ auto memTp = MemRefType::get({ShapedType::kDynamic}, indexTp);
+ dimSizesBuffer =
+ createFuncCall(builder, loc, "getSparseTensorReaderDimSizes", memTp,
+ reader, EmitCInterface::On)
+ .getResult(0);
+ }
+ return reader;
+}
+
+Value sparse_tensor::genReaderBuffers(OpBuilder &builder, Location loc,
+ SparseTensorType stt,
+ SmallVectorImpl<Value> &dimShapesValues,
----------------
PeimingLiu wrote:
Can it be a `const ArrayRef`?
https://github.com/llvm/llvm-project/pull/68360
More information about the Mlir-commits
mailing list