[Mlir-commits] [mlir] [mlir][sparse] use a consistent order between [dis]assembleOp and sto… (PR #84079)
Aart Bik
llvmlistbot at llvm.org
Tue Mar 5 14:45:08 PST 2024
================
@@ -73,34 +67,27 @@ static void convVals(OpBuilder &builder, Location loc, TypeRange types,
// Convert the external representation of the values array.
auto rtp = cast<RankedTensorType>(type);
const SparseTensorType stt(rtp);
- auto shape = stt.getBatchLvlShape();
- shape.push_back(ShapedType::kDynamic);
SmallVector<Value> inputs;
SmallVector<Type> retTypes;
SmallVector<Type> cntTypes;
- // Collect the external representation of the values array for
- // input or the outgoing sparse tensor for output.
- inputs.push_back(fromVals[idx++]);
- if (!isIn) {
- inputs.push_back(extraVals[extra++]);
- retTypes.push_back(RankedTensorType::get(shape, stt.getElementType()));
- cntTypes.push_back(builder.getIndexType()); // nnz
- }
+ if (!isIn)
+ inputs.push_back(fromVals[idx++]); // The sparse tensor to disassemble
// Collect the external representations of the pos/crd arrays.
foreachFieldAndTypeInSparseTensor(stt, [&, isIn](Type t, FieldIndex,
SparseTensorFieldKind kind,
Level, LevelType) {
if (kind == SparseTensorFieldKind::CrdMemRef ||
- kind == SparseTensorFieldKind::PosMemRef) {
+ kind == SparseTensorFieldKind::PosMemRef ||
+ kind == SparseTensorFieldKind::ValMemRef) {
if (isIn) {
inputs.push_back(fromVals[idx++]);
} else {
ShapedType st = t.cast<ShapedType>();
auto rtp = RankedTensorType::get(st.getShape(), st.getElementType());
inputs.push_back(extraVals[extra++]);
retTypes.push_back(rtp);
- cntTypes.push_back(rtp.getElementType());
+ cntTypes.push_back(builder.getIndexType());
----------------
aartbik wrote:
is this correct? is it always an index type or did we use the same type as metadata?
https://github.com/llvm/llvm-project/pull/84079
More information about the Mlir-commits
mailing list