[Mlir-commits] [mlir] 52028c1 - [mlir][sparse] Generate AOS subviews on-demand.
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Jan 11 08:57:07 PST 2023
Author: bixia1
Date: 2023-01-11T08:57:01-08:00
New Revision: 52028c1a48af87f3f56ca51fdfc13c8b89010302
URL: https://github.com/llvm/llvm-project/commit/52028c1a48af87f3f56ca51fdfc13c8b89010302
DIFF: https://github.com/llvm/llvm-project/commit/52028c1a48af87f3f56ca51fdfc13c8b89010302.diff
LOG: [mlir][sparse] Generate AOS subviews on-demand.
Previously, we generate AOS subviews for indices buffers when constructing an
immutable sparse tensor descriptor. We now only generate such subviews when
getIdxMemRefOrView is requested.
Reviewed By: Peiming
Differential Revision: https://reviews.llvm.org/D141325
Added:
Modified:
mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp
index a34e66a2918bb..eaa4b420bbcd3 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp
@@ -138,7 +138,7 @@ class SpecifierGetterSetterOpConverter : public OpConversionPattern<SourceOp> {
op.getDim().value().getZExtValue());
} else {
auto enc = op.getSpecifier().getType().getEncoding();
- StorageLayout<true> layout(enc);
+ StorageLayout layout(enc);
Optional<unsigned> dim = std::nullopt;
if (op.getDim())
dim = op.getDim().value().getZExtValue();
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index a680ddf06d88d..38a7e0e0610fb 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -295,11 +295,10 @@ static Value genCompressed(OpBuilder &builder, Location loc,
unsigned idxIndex;
unsigned idxStride;
std::tie(idxIndex, idxStride) = desc.getIdxMemRefIndexAndStride(d);
- unsigned ptrIndex = desc.getPtrMemRefIndex(d);
Value one = constantIndex(builder, loc, 1);
Value pp1 = builder.create<arith::AddIOp>(loc, pos, one);
- Value plo = genLoad(builder, loc, desc.getMemRefField(ptrIndex), pos);
- Value phi = genLoad(builder, loc, desc.getMemRefField(ptrIndex), pp1);
+ Value plo = genLoad(builder, loc, desc.getPtrMemRef(d), pos);
+ Value phi = genLoad(builder, loc, desc.getPtrMemRef(d), pp1);
Value msz = desc.getIdxMemSize(builder, loc, d);
Value idxStrideC;
if (idxStride > 1) {
@@ -325,7 +324,7 @@ static Value genCompressed(OpBuilder &builder, Location loc,
builder.create<scf::YieldOp>(loc, eq);
builder.setInsertionPointToStart(&ifOp1.getElseRegion().front());
if (d > 0)
- genStore(builder, loc, msz, desc.getMemRefField(ptrIndex), pos);
+ genStore(builder, loc, msz, desc.getPtrMemRef(d), pos);
builder.create<scf::YieldOp>(loc, constantI1(builder, loc, false));
builder.setInsertionPointAfter(ifOp1);
Value p = ifOp1.getResult(0);
@@ -352,7 +351,7 @@ static Value genCompressed(OpBuilder &builder, Location loc,
// If !present (changes fields, update next).
builder.setInsertionPointToStart(&ifOp2.getElseRegion().front());
Value mszp1 = builder.create<arith::AddIOp>(loc, msz, one);
- genStore(builder, loc, mszp1, desc.getMemRefField(ptrIndex), pp1);
+ genStore(builder, loc, mszp1, desc.getPtrMemRef(d), pp1);
createPushback(builder, loc, desc, SparseTensorFieldKind::IdxMemRef, d,
indices[d]);
// Prepare the next dimension "as needed".
@@ -638,10 +637,8 @@ class SparseDimOpConverter : public OpConversionPattern<tensor::DimOp> {
if (!index || !getSparseTensorEncoding(adaptor.getSource().getType()))
return failure();
- Location loc = op.getLoc();
- auto desc =
- getDescriptorFromTensorTuple(rewriter, loc, adaptor.getSource());
- auto sz = sizeFromTensorAtDim(rewriter, loc, desc, *index);
+ auto desc = getDescriptorFromTensorTuple(adaptor.getSource());
+ auto sz = sizeFromTensorAtDim(rewriter, op.getLoc(), desc, *index);
if (!sz)
return failure();
@@ -756,8 +753,7 @@ class SparseExpandConverter : public OpConversionPattern<ExpandOp> {
if (!getSparseTensorEncoding(op.getTensor().getType()))
return failure();
Location loc = op->getLoc();
- auto desc =
- getDescriptorFromTensorTuple(rewriter, loc, adaptor.getTensor());
+ auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
RankedTensorType srcType =
op.getTensor().getType().cast<RankedTensorType>();
Type eltType = srcType.getElementType();
@@ -900,8 +896,7 @@ class SparseToPointersConverter : public OpConversionPattern<ToPointersOp> {
// Replace the requested pointer access with corresponding field.
// The cast_op is inserted by type converter to intermix 1:N type
// conversion.
- auto desc = getDescriptorFromTensorTuple(rewriter, op.getLoc(),
- adaptor.getTensor());
+ auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
uint64_t dim = op.getDimension().getZExtValue();
rewriter.replaceOp(op, desc.getPtrMemRef(dim));
return success();
@@ -919,17 +914,17 @@ class SparseToIndicesConverter : public OpConversionPattern<ToIndicesOp> {
// Replace the requested pointer access with corresponding field.
// The cast_op is inserted by type converter to intermix 1:N type
// conversion.
- auto desc = getDescriptorFromTensorTuple(rewriter, op.getLoc(),
- adaptor.getTensor());
+ Location loc = op.getLoc();
+ auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
uint64_t dim = op.getDimension().getZExtValue();
- Value field = desc.getIdxMemRef(dim);
+ Value field = desc.getIdxMemRefOrView(rewriter, loc, dim);
// Insert a cast to bridge the actual type to the user expected type. If the
// actual type and the user expected type aren't compatible, the compiler or
// the runtime will issue an error.
Type resType = op.getResult().getType();
if (resType != field.getType())
- field = rewriter.create<memref::CastOp>(op.getLoc(), resType, field);
+ field = rewriter.create<memref::CastOp>(loc, resType, field);
rewriter.replaceOp(op, field);
return success();
@@ -967,8 +962,7 @@ class SparseToValuesConverter : public OpConversionPattern<ToValuesOp> {
// Replace the requested pointer access with corresponding field.
// The cast_op is inserted by type converter to intermix 1:N type
// conversion.
- auto desc = getDescriptorFromTensorTuple(rewriter, op.getLoc(),
- adaptor.getTensor());
+ auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
rewriter.replaceOp(op, desc.getValMemRef());
return success();
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.cpp
index b0eb72e6fd668..e24a38d3947db 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.cpp
@@ -109,41 +109,24 @@ void SparseTensorSpecifier::setSpecifierField(OpBuilder &builder, Location loc,
// SparseTensorDescriptor methods.
//===----------------------------------------------------------------------===//
-sparse_tensor::SparseTensorDescriptor::SparseTensorDescriptor(
- OpBuilder &builder, Location loc, Type tp, ValueArrayRef buffers)
- : SparseTensorDescriptorImpl<false>(tp), expandedFields() {
- SparseTensorEncodingAttr enc = getSparseTensorEncoding(tp);
- unsigned rank = enc.getDimLevelType().size();
+Value sparse_tensor::SparseTensorDescriptor::getIdxMemRefOrView(
+ OpBuilder &builder, Location loc, unsigned idxDim) const {
+ auto enc = getSparseTensorEncoding(rType);
unsigned cooStart = getCOOStart(enc);
- if (cooStart < rank) {
- ValueRange beforeFields = buffers.drop_back(3);
- expandedFields.append(beforeFields.begin(), beforeFields.end());
- Value buffer = buffers[buffers.size() - 3];
-
+ unsigned idx = idxDim >= cooStart ? cooStart : idxDim;
+ Value buffer = getMemRefField(SparseTensorFieldKind::IdxMemRef, idx);
+ if (idxDim >= cooStart) {
+ unsigned rank = enc.getDimLevelType().size();
Value stride = constantIndex(builder, loc, rank - cooStart);
- SmallVector<Value> buffersArray(buffers.begin(), buffers.end());
- MutSparseTensorDescriptor mutDesc(tp, buffersArray);
- // Calculate subbuffer size as memSizes[idx] / (stride).
- Value subBufferSize = mutDesc.getIdxMemSize(builder, loc, cooStart);
- subBufferSize = builder.create<arith::DivUIOp>(loc, subBufferSize, stride);
-
- // Create views of the linear idx buffer for the COO indices.
- for (unsigned i = cooStart; i < rank; i++) {
- Value subBuffer = builder.create<memref::SubViewOp>(
- loc, buffer,
- /*offset=*/ValueRange{constantIndex(builder, loc, i - cooStart)},
- /*size=*/ValueRange{subBufferSize},
- /*step=*/ValueRange{stride});
- expandedFields.push_back(subBuffer);
- }
- expandedFields.push_back(buffers[buffers.size() - 2]); // The Values memref.
- expandedFields.push_back(buffers.back()); // The specifier.
- fields = expandedFields;
- } else {
- fields = buffers;
+ Value size = getIdxMemSize(builder, loc, cooStart);
+ size = builder.create<arith::DivUIOp>(loc, size, stride);
+ buffer = builder.create<memref::SubViewOp>(
+ loc, buffer,
+ /*offset=*/ValueRange{constantIndex(builder, loc, idxDim - cooStart)},
+ /*size=*/ValueRange{size},
+ /*step=*/ValueRange{stride});
}
-
- sanityCheck();
+ return buffer;
}
//===----------------------------------------------------------------------===//
@@ -156,8 +139,7 @@ void sparse_tensor::foreachFieldInSparseTensor(
const SparseTensorEncodingAttr enc,
llvm::function_ref<bool(unsigned, SparseTensorFieldKind, unsigned,
DimLevelType)>
- callback,
- bool isBuffer) {
+ callback) {
assert(enc);
#define RETURN_ON_FALSE(idx, kind, dim, dlt) \
@@ -165,11 +147,13 @@ void sparse_tensor::foreachFieldInSparseTensor(
return;
unsigned rank = enc.getDimLevelType().size();
- unsigned cooStart = isBuffer ? getCOOStart(enc) : rank;
+ unsigned end = getCOOStart(enc);
+ if (end != rank)
+ end += 1;
static_assert(kDataFieldStartingIdx == 0);
unsigned fieldIdx = kDataFieldStartingIdx;
// Per-dimension storage.
- for (unsigned r = 0; r < rank; r++) {
+ for (unsigned r = 0; r < end; r++) {
// Dimension level types apply in order to the reordered dimension.
// As a result, the compound type can be constructed directly in the given
// order.
@@ -178,8 +162,7 @@ void sparse_tensor::foreachFieldInSparseTensor(
RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::PtrMemRef, r, dlt);
RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::IdxMemRef, r, dlt);
} else if (isSingletonDLT(dlt)) {
- if (r < cooStart)
- RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::IdxMemRef, r, dlt);
+ RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::IdxMemRef, r, dlt);
} else {
assert(isDenseDLT(dlt)); // no fields
}
@@ -231,38 +214,32 @@ void sparse_tensor::foreachFieldAndTypeInSparseTensor(
return callback(valMemType, fieldIdx, fieldKind, dim, dlt);
};
llvm_unreachable("unrecognized field kind");
- },
- /*isBuffer=*/true);
+ });
}
-unsigned sparse_tensor::getNumFieldsFromEncoding(SparseTensorEncodingAttr enc,
- bool isBuffer) {
+unsigned sparse_tensor::getNumFieldsFromEncoding(SparseTensorEncodingAttr enc) {
unsigned numFields = 0;
- foreachFieldInSparseTensor(
- enc,
- [&numFields](unsigned, SparseTensorFieldKind, unsigned,
- DimLevelType) -> bool {
- numFields++;
- return true;
- },
- isBuffer);
+ foreachFieldInSparseTensor(enc,
+ [&numFields](unsigned, SparseTensorFieldKind,
+ unsigned, DimLevelType) -> bool {
+ numFields++;
+ return true;
+ });
return numFields;
}
unsigned
sparse_tensor::getNumDataFieldsFromEncoding(SparseTensorEncodingAttr enc) {
unsigned numFields = 0; // one value memref
- foreachFieldInSparseTensor(
- enc,
- [&numFields](unsigned fidx, SparseTensorFieldKind, unsigned,
- DimLevelType) -> bool {
- if (fidx >= kDataFieldStartingIdx)
- numFields++;
- return true;
- },
- /*isBuffer=*/true);
+ foreachFieldInSparseTensor(enc,
+ [&numFields](unsigned fidx, SparseTensorFieldKind,
+ unsigned, DimLevelType) -> bool {
+ if (fidx >= kDataFieldStartingIdx)
+ numFields++;
+ return true;
+ });
numFields -= 1; // the last field is MetaData field
- assert(numFields == getNumFieldsFromEncoding(enc, /*isBuffer=*/true) -
- kDataFieldStartingIdx - 1);
+ assert(numFields ==
+ getNumFieldsFromEncoding(enc) - kDataFieldStartingIdx - 1);
return numFields;
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h
index 8d25ba2160e44..9ca7149056ddd 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h
@@ -77,8 +77,7 @@ void foreachFieldInSparseTensor(
llvm::function_ref<bool(unsigned /*fieldIdx*/,
SparseTensorFieldKind /*fieldKind*/,
unsigned /*dim (if applicable)*/,
- DimLevelType /*DLT (if applicable)*/)>,
- bool isBuffer = false);
+ DimLevelType /*DLT (if applicable)*/)>);
/// Same as above, except that it also builds the Type for the corresponding
/// field.
@@ -90,7 +89,7 @@ void foreachFieldAndTypeInSparseTensor(
DimLevelType /*DLT (if applicable)*/)>);
/// Gets the total number of fields for the given sparse tensor encoding.
-unsigned getNumFieldsFromEncoding(SparseTensorEncodingAttr enc, bool isBuffer);
+unsigned getNumFieldsFromEncoding(SparseTensorEncodingAttr enc);
/// Gets the total number of data fields (index arrays, pointer arrays, and a
/// value array) for the given sparse tensor encoding.
@@ -107,12 +106,7 @@ inline SparseTensorFieldKind toFieldKind(StorageSpecifierKind kind) {
}
/// Provides methods to access fields of a sparse tensor with the given
-/// encoding. When isBuffer is true, the fields are the actual buffers of the
-/// sparse tensor storage. In particular, when a linear buffer is used to
-/// store the COO data as an array-of-structures, the fields include the
-/// linear buffer (isBuffer=true) or includes the subviews of the buffer for the
-/// indices (isBuffer=false).
-template <bool isBuffer>
+/// encoding.
class StorageLayout {
public:
explicit StorageLayout(SparseTensorEncodingAttr enc) : enc(enc) {}
@@ -132,7 +126,7 @@ class StorageLayout {
}
static unsigned getNumFieldsFromEncoding(SparseTensorEncodingAttr enc) {
- return sparse_tensor::getNumFieldsFromEncoding(enc, isBuffer);
+ return sparse_tensor::getNumFieldsFromEncoding(enc);
}
static void foreachFieldInSparseTensor(
@@ -140,7 +134,7 @@ class StorageLayout {
llvm::function_ref<bool(unsigned, SparseTensorFieldKind, unsigned,
DimLevelType)>
callback) {
- return sparse_tensor::foreachFieldInSparseTensor(enc, callback, isBuffer);
+ return sparse_tensor::foreachFieldInSparseTensor(enc, callback);
}
std::pair<unsigned, unsigned>
@@ -148,7 +142,7 @@ class StorageLayout {
std::optional<unsigned> dim) const {
unsigned fieldIdx = -1u;
unsigned stride = 1;
- if (isBuffer && kind == SparseTensorFieldKind::IdxMemRef) {
+ if (kind == SparseTensorFieldKind::IdxMemRef) {
assert(dim.has_value());
unsigned cooStart = getCOOStart(enc);
unsigned rank = enc.getDimLevelType().size();
@@ -222,18 +216,11 @@ class SparseTensorDescriptorImpl {
using ValueArrayRef = typename std::conditional<mut, SmallVectorImpl<Value> &,
ValueRange>::type;
- SparseTensorDescriptorImpl(Type tp)
- : rType(tp.cast<RankedTensorType>()), fields() {}
-
SparseTensorDescriptorImpl(Type tp, ValueArrayRef fields)
: rType(tp.cast<RankedTensorType>()), fields(fields) {
- sanityCheck();
- }
-
- void sanityCheck() {
- assert(getSparseTensorEncoding(rType) &&
- StorageLayout<mut>::getNumFieldsFromEncoding(
- getSparseTensorEncoding(rType)) == fields.size());
+ assert(getSparseTensorEncoding(tp) &&
+ getNumFieldsFromEncoding(getSparseTensorEncoding(tp)) ==
+ fields.size());
// We should make sure the class is trivially copyable (and should be small
// enough) such that we can pass it by value.
static_assert(
@@ -244,22 +231,10 @@ class SparseTensorDescriptorImpl {
unsigned getMemRefFieldIndex(SparseTensorFieldKind kind,
std::optional<unsigned> dim) const {
// Delegates to storage layout.
- StorageLayout<mut> layout(getSparseTensorEncoding(rType));
+ StorageLayout layout(getSparseTensorEncoding(rType));
return layout.getMemRefFieldIndex(kind, dim);
}
- unsigned getPtrMemRefIndex(unsigned ptrDim) const {
- return getMemRefFieldIndex(SparseTensorFieldKind::PtrMemRef, ptrDim);
- }
-
- unsigned getIdxMemRefIndex(unsigned idxDim) const {
- return getMemRefFieldIndex(SparseTensorFieldKind::IdxMemRef, idxDim);
- }
-
- unsigned getValMemRefIndex() const {
- return getMemRefFieldIndex(SparseTensorFieldKind::ValMemRef, std::nullopt);
- }
-
unsigned getNumFields() const { return fields.size(); }
///
@@ -281,10 +256,6 @@ class SparseTensorDescriptorImpl {
return getMemRefField(SparseTensorFieldKind::PtrMemRef, ptrDim);
}
- Value getIdxMemRef(unsigned idxDim) const {
- return getMemRefField(SparseTensorFieldKind::IdxMemRef, idxDim);
- }
-
Value getValMemRef() const {
return getMemRefField(SparseTensorFieldKind::ValMemRef, std::nullopt);
}
@@ -299,15 +270,19 @@ class SparseTensorDescriptorImpl {
return fields[fidx];
}
- Value getField(unsigned fidx) const {
- assert(fidx < fields.size());
- return fields[fidx];
+ Value getPtrMemSize(OpBuilder &builder, Location loc, unsigned dim) const {
+ return getSpecifierField(builder, loc, StorageSpecifierKind::PtrMemSize,
+ dim);
}
- ValueRange getMemRefFields() const {
- ValueRange ret = fields;
- // Drop the last metadata fields.
- return ret.slice(0, fields.size() - 1);
+ Value getIdxMemSize(OpBuilder &builder, Location loc, unsigned dim) const {
+ return getSpecifierField(builder, loc, StorageSpecifierKind::IdxMemSize,
+ dim);
+ }
+
+ Value getValMemSize(OpBuilder &builder, Location loc) const {
+ return getSpecifierField(builder, loc, StorageSpecifierKind::ValMemSize,
+ std::nullopt);
}
Type getMemRefElementType(SparseTensorFieldKind kind,
@@ -331,23 +306,15 @@ class MutSparseTensorDescriptor : public SparseTensorDescriptorImpl<true> {
MutSparseTensorDescriptor(Type tp, ValueArrayRef buffers)
: SparseTensorDescriptorImpl<true>(tp, buffers) {}
- ///
- /// Getters: get the value for required field.
- ///
-
- Value getPtrMemSize(OpBuilder &builder, Location loc, unsigned dim) const {
- return getSpecifierField(builder, loc, StorageSpecifierKind::PtrMemSize,
- dim);
- }
-
- Value getIdxMemSize(OpBuilder &builder, Location loc, unsigned dim) const {
- return getSpecifierField(builder, loc, StorageSpecifierKind::IdxMemSize,
- dim);
+ Value getField(unsigned fidx) const {
+ assert(fidx < fields.size());
+ return fields[fidx];
}
- Value getValMemSize(OpBuilder &builder, Location loc) const {
- return getSpecifierField(builder, loc, StorageSpecifierKind::ValMemSize,
- std::nullopt);
+ ValueRange getMemRefFields() const {
+ ValueRange ret = fields;
+ // Drop the last metadata fields.
+ return ret.slice(0, fields.size() - 1);
}
///
@@ -384,7 +351,7 @@ class MutSparseTensorDescriptor : public SparseTensorDescriptorImpl<true> {
std::pair<unsigned, unsigned>
getIdxMemRefIndexAndStride(unsigned idxDim) const {
- StorageLayout<true> layout(getSparseTensorEncoding(rType));
+ StorageLayout layout(getSparseTensorEncoding(rType));
return layout.getFieldIndexAndStride(SparseTensorFieldKind::IdxMemRef,
idxDim);
}
@@ -393,19 +360,17 @@ class MutSparseTensorDescriptor : public SparseTensorDescriptorImpl<true> {
auto enc = getSparseTensorEncoding(rType);
unsigned cooStart = getCOOStart(enc);
assert(cooStart < enc.getDimLevelType().size());
- return getIdxMemRef(cooStart);
+ return getMemRefField(SparseTensorFieldKind::IdxMemRef, cooStart);
}
};
class SparseTensorDescriptor : public SparseTensorDescriptorImpl<false> {
public:
- SparseTensorDescriptor(OpBuilder &builder, Location loc, Type tp,
- ValueArrayRef buffers);
+ SparseTensorDescriptor(Type tp, ValueArrayRef buffers)
+ : SparseTensorDescriptorImpl<false>(tp, buffers) {}
-private:
- // Store the fields passed to SparseTensorDescriptorImpl when the tensor has
- // a COO region.
- SmallVector<Value> expandedFields;
+ Value getIdxMemRefOrView(OpBuilder &builder, Location loc,
+ unsigned idxDim) const;
};
/// Returns the "tuple" value of the adapted tensor.
@@ -425,11 +390,9 @@ inline Value genTuple(OpBuilder &builder, Location loc,
return genTuple(builder, loc, desc.getTensorType(), desc.getFields());
}
-inline SparseTensorDescriptor
-getDescriptorFromTensorTuple(OpBuilder &builder, Location loc, Value tensor) {
+inline SparseTensorDescriptor getDescriptorFromTensorTuple(Value tensor) {
auto tuple = getTuple(tensor);
- return SparseTensorDescriptor(builder, loc, tuple.getResultTypes()[0],
- tuple.getInputs());
+ return SparseTensorDescriptor(tuple.getResultTypes()[0], tuple.getInputs());
}
inline MutSparseTensorDescriptor
More information about the Mlir-commits
mailing list