[Mlir-commits] [mlir] f5ce99a - [mlir][sparse] Factoring out NewCallParams class in SparseTensorConversion.cpp
wren romano
llvmlistbot at llvm.org
Tue Nov 8 17:20:02 PST 2022
Author: wren romano
Date: 2022-11-08T17:19:54-08:00
New Revision: f5ce99afa72fd74d57a3b9fba658f48626b3aef5
URL: https://github.com/llvm/llvm-project/commit/f5ce99afa72fd74d57a3b9fba658f48626b3aef5
DIFF: https://github.com/llvm/llvm-project/commit/f5ce99afa72fd74d57a3b9fba658f48626b3aef5.diff
LOG: [mlir][sparse] Factoring out NewCallParams class in SparseTensorConversion.cpp
The new class helps encapsulate the arguments to `_mlir_ciface_newSparseTensor` so that client code doesn't depend on the details of the API. (This makes way for the next differential which significantly alters the API.)
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D137680
Added:
Modified:
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index f41a5798e18d..6ca6cfc88ef8 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -70,16 +70,6 @@ static Value genDimSizeCall(OpBuilder &builder, Location loc,
.getResult(0);
}
-/// Generates a call into the "swiss army knife" method of the sparse runtime
-/// support library for materializing sparse tensors into the computation.
-static Value genNewCall(OpBuilder &builder, Location loc,
- ArrayRef<Value> params) {
- StringRef name = "newSparseTensor";
- Type pTp = getOpaquePointerType(builder);
- return createFuncCall(builder, loc, name, pTp, params, EmitCInterface::On)
- .getResult(0);
-}
-
/// Compute the size from type (for static sizes) or from an already-converted
/// opaque pointer source (for dynamic sizes) at the given dimension.
static Value sizeFromPtrAtDim(OpBuilder &builder, Location loc,
@@ -168,41 +158,132 @@ static Value genBuffer(OpBuilder &builder, Location loc, ValueRange values) {
return buffer;
}
-/// Populates parameters required to call the "swiss army knife" method of the
-/// sparse runtime support library for materializing sparse tensors into the
-/// computation.
-static void newParams(OpBuilder &builder, SmallVector<Value, 8> ¶ms,
- Location loc, ShapedType stp,
- SparseTensorEncodingAttr &enc, Action action,
- ValueRange szs, Value ptr = Value()) {
- ArrayRef<DimLevelType> dlt = enc.getDimLevelType();
- unsigned sz = dlt.size();
+/// This class abstracts over the API of `_mlir_ciface_newSparseTensor`:
+/// the "swiss army knife" method of the sparse runtime support library
+/// for materializing sparse tensors into the computation. This abstraction
+/// reduces the need to make modifications to client code whenever that
+/// API changes.
+class NewCallParams final {
+public:
+ /// Allocates the `ValueRange` for the `func::CallOp` parameters,
+ /// but does not initialize them.
+ NewCallParams(OpBuilder &builder, Location loc)
+ : builder(builder), loc(loc), pTp(getOpaquePointerType(builder)) {}
+
+ /// Initializes all static parameters (i.e., those which indicate
+ /// type-level information such as the encoding and sizes), generating
+ /// MLIR buffers as needed, and returning `this` for method chaining.
+ /// This method does not set the action and pointer arguments, since
+ /// those are handled by `genNewCall` instead.
+ NewCallParams &genBuffers(SparseTensorEncodingAttr enc, ValueRange sizes,
+ ShapedType stp);
+
+ /// (Re)sets the C++ template type parameters, and returns `this`
+ /// for method chaining. This is already done as part of `genBuffers`,
+ /// but is factored out so that it can also be called independently
+ /// whenever subsequent `genNewCall` calls want to reuse the same
+ /// buffers but
diff erent type parameters.
+ //
+ // TODO: This is only ever used by sparse2sparse-viaCOO `ConvertOp`;
+ // is there a better way to handle that than this one-off setter method?
+ NewCallParams &setTemplateTypes(SparseTensorEncodingAttr enc,
+ ShapedType stp) {
+ params[kParamPtrTp] = constantPointerTypeEncoding(builder, loc, enc);
+ params[kParamIndTp] = constantIndexTypeEncoding(builder, loc, enc);
+ params[kParamValTp] =
+ constantPrimaryTypeEncoding(builder, loc, stp.getElementType());
+ return *this;
+ }
+
+ /// Checks whether all the static parameters have been initialized.
+ bool isInitialized() const {
+ for (unsigned i = 0; i < kNumStaticParams; ++i)
+ if (!params[i])
+ return false;
+ return true;
+ }
+
+ /// Gets the dimension-to-level mapping.
+ //
+ // TODO: This is only ever used for passing into `genAddEltCall`;
+ // is there a better way to encapsulate that pattern (both to avoid
+ // this one-off getter, and to avoid potential mixups)?
+ Value getDim2LvlMap() const {
+ assert(isInitialized() && "Must initialize before getDim2LvlMap");
+ return params[kParamDim2Lvl];
+ }
+
+ /// Generates a function call, with the current static parameters
+ /// and the given dynamic arguments.
+ Value genNewCall(Action action, Value ptr = Value()) {
+ assert(isInitialized() && "Must initialize before genNewCall");
+ StringRef name = "newSparseTensor";
+ params[kParamAction] = constantAction(builder, loc, action);
+ params[kParamPtr] = ptr ? ptr : builder.create<LLVM::NullOp>(loc, pTp);
+ return createFuncCall(builder, loc, name, pTp, params, EmitCInterface::On)
+ .getResult(0);
+ }
+
+private:
+ static constexpr unsigned kNumStaticParams = 6;
+ static constexpr unsigned kNumDynamicParams = 2;
+ static constexpr unsigned kNumParams = kNumStaticParams + kNumDynamicParams;
+ static constexpr unsigned kParamLvlTypes = 0;
+ static constexpr unsigned kParamDimSizes = 1;
+ static constexpr unsigned kParamDim2Lvl = 2;
+ static constexpr unsigned kParamPtrTp = 3;
+ static constexpr unsigned kParamIndTp = 4;
+ static constexpr unsigned kParamValTp = 5;
+ static constexpr unsigned kParamAction = 6;
+ static constexpr unsigned kParamPtr = 7;
+
+ OpBuilder &builder;
+ Location loc;
+ Type pTp;
+ Value params[kNumParams];
+};
+
+// TODO: see the note at `_mlir_ciface_newSparseTensor` about how
+// the meaning of the various arguments (e.g., "sizes" vs "shapes")
+// is inconsistent between the
diff erent actions.
+NewCallParams &NewCallParams::genBuffers(SparseTensorEncodingAttr enc,
+ ValueRange dimSizes, ShapedType stp) {
+ const unsigned lvlRank = enc.getDimLevelType().size();
+ const unsigned dimRank = stp.getRank();
// Sparsity annotations.
- SmallVector<Value, 4> attrs;
- for (unsigned i = 0; i < sz; i++)
- attrs.push_back(constantDimLevelTypeEncoding(builder, loc, dlt[i]));
- params.push_back(genBuffer(builder, loc, attrs));
- // Dimension sizes array of the enveloping tensor. Useful for either
+ SmallVector<Value, 4> lvlTypes;
+ for (auto dlt : enc.getDimLevelType())
+ lvlTypes.push_back(constantDimLevelTypeEncoding(builder, loc, dlt));
+ assert(lvlTypes.size() == lvlRank && "Level-rank mismatch");
+ params[kParamLvlTypes] = genBuffer(builder, loc, lvlTypes);
+ // Dimension-sizes array of the enveloping tensor. Useful for either
// verification of external data, or for construction of internal data.
- params.push_back(genBuffer(builder, loc, szs));
- // Dimension order permutation array. This is the "identity" permutation by
- // default, or otherwise the "reverse" permutation of a given ordering, so
- // that indices can be mapped quickly to the right position.
- SmallVector<Value, 4> rev(sz);
- for (unsigned i = 0; i < sz; i++)
- rev[toOrigDim(enc, i)] = constantIndex(builder, loc, i);
- params.push_back(genBuffer(builder, loc, rev));
+ assert(dimSizes.size() == dimRank && "Dimension-rank mismatch");
+ params[kParamDimSizes] = genBuffer(builder, loc, dimSizes);
+ // The dimension-to-level mapping. We must preinitialize `dim2lvl`
+ // so that the true branch below can perform random-access `operator[]`
+ // assignment.
+ SmallVector<Value, 4> dim2lvl(dimRank);
+ auto dimOrder = enc.getDimOrdering();
+ if (dimOrder) {
+ assert(dimOrder.isPermutation());
+ for (unsigned l = 0; l < lvlRank; l++) {
+ // The `d`th source variable occurs in the `l`th result position.
+ uint64_t d = dimOrder.getDimPosition(l);
+ dim2lvl[d] = constantIndex(builder, loc, l);
+ }
+ } else {
+ assert(dimRank == lvlRank && "Rank mismatch");
+ for (unsigned i = 0; i < lvlRank; i++)
+ dim2lvl[i] = constantIndex(builder, loc, i);
+ }
+ params[kParamDim2Lvl] = genBuffer(builder, loc, dim2lvl);
// Secondary and primary types encoding.
- Type elemTp = stp.getElementType();
- params.push_back(constantPointerTypeEncoding(builder, loc, enc));
- params.push_back(constantIndexTypeEncoding(builder, loc, enc));
- params.push_back(constantPrimaryTypeEncoding(builder, loc, elemTp));
- // User action.
- params.push_back(constantAction(builder, loc, action));
- // Payload pointer.
- if (!ptr)
- ptr = builder.create<LLVM::NullOp>(loc, getOpaquePointerType(builder));
- params.push_back(ptr);
+ setTemplateTypes(enc, stp);
+ // Finally, make note that initialization is complete.
+ assert(isInitialized() && "Initialization failed");
+ // And return `this` for method chaining.
+ return *this;
}
/// Generates a call to obtain the values array.
@@ -387,14 +468,12 @@ genSparse2SparseReshape(ReshapeOp op, typename ReshapeOp::Adaptor adaptor,
op->getContext(), encSrc.getDimLevelType(), AffineMap(), AffineMap(),
encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth());
SmallVector<Value, 4> srcSizes;
- SmallVector<Value, 8> params;
sizesFromPtr(rewriter, srcSizes, loc, encSrc, srcTp, adaptor.getSrc());
- newParams(rewriter, params, loc, srcTp, noPerm, Action::kToIterator, srcSizes,
- adaptor.getSrc());
- Value iter = genNewCall(rewriter, loc, params);
+ NewCallParams params(rewriter, loc);
+ Value iter = params.genBuffers(noPerm, srcSizes, srcTp)
+ .genNewCall(Action::kToIterator, adaptor.getSrc());
// Start a new COO for the destination tensor.
SmallVector<Value, 4> dstSizes;
- params.clear();
if (dstTp.hasStaticShape()) {
sizesFromType(rewriter, dstSizes, loc, dstTp);
} else {
@@ -402,9 +481,9 @@ genSparse2SparseReshape(ReshapeOp op, typename ReshapeOp::Adaptor adaptor,
genReshapeDstShape(loc, rewriter, dstSizes, srcSizes, dstShape,
op.getReassociationIndices());
}
- newParams(rewriter, params, loc, dstTp, encDst, Action::kEmptyCOO, dstSizes);
- Value coo = genNewCall(rewriter, loc, params);
- Value dstPerm = params[2];
+ Value coo =
+ params.genBuffers(encDst, dstSizes, dstTp).genNewCall(Action::kEmptyCOO);
+ Value dstPerm = params.getDim2LvlMap();
// Construct a while loop over the iterator.
Value srcIdx = genAlloca(rewriter, loc, srcRank, rewriter.getIndexType());
Value dstIdx = genAlloca(rewriter, loc, dstRank, rewriter.getIndexType());
@@ -426,9 +505,7 @@ genSparse2SparseReshape(ReshapeOp op, typename ReshapeOp::Adaptor adaptor,
rewriter.create<scf::YieldOp>(loc);
// Final call to construct sparse tensor storage and free temporary resources.
rewriter.setInsertionPointAfter(whileOp);
- params[6] = constantAction(rewriter, loc, Action::kFromCOO);
- params[7] = coo;
- Value dst = genNewCall(rewriter, loc, params);
+ Value dst = params.genNewCall(Action::kFromCOO, coo);
genDelCOOCall(rewriter, loc, elemTp, coo);
genDelIteratorCall(rewriter, loc, elemTp, iter);
rewriter.replaceOp(op, dst);
@@ -458,11 +535,10 @@ static void genSparseCOOIterationLoop(
rewriter.getContext(), enc.getDimLevelType(), AffineMap(), AffineMap(),
enc.getPointerBitWidth(), enc.getIndexBitWidth());
SmallVector<Value, 4> sizes;
- SmallVector<Value, 8> params;
sizesFromPtr(rewriter, sizes, loc, noPerm, tensorTp, t);
- newParams(rewriter, params, loc, tensorTp, noPerm, Action::kToIterator, sizes,
- t);
- Value iter = genNewCall(rewriter, loc, params);
+ Value iter = NewCallParams(rewriter, loc)
+ .genBuffers(noPerm, sizes, tensorTp)
+ .genNewCall(Action::kToIterator, t);
// Construct a while loop over the iterator.
Value srcIdx = genAlloca(rewriter, loc, rank, rewriter.getIndexType());
@@ -611,12 +687,12 @@ class SparseTensorNewConverter : public OpConversionPattern<NewOp> {
// Generate the call to construct tensor from ptr. The sizes are
// inferred from the result type of the new operator.
SmallVector<Value, 4> sizes;
- SmallVector<Value, 8> params;
ShapedType stp = resType.cast<ShapedType>();
sizesFromType(rewriter, sizes, loc, stp);
Value ptr = adaptor.getOperands()[0];
- newParams(rewriter, params, loc, stp, enc, Action::kFromFile, sizes, ptr);
- rewriter.replaceOp(op, genNewCall(rewriter, loc, params));
+ rewriter.replaceOp(op, NewCallParams(rewriter, loc)
+ .genBuffers(enc, sizes, stp)
+ .genNewCall(Action::kFromFile, ptr));
return success();
}
};
@@ -650,10 +726,10 @@ class SparseTensorAllocConverter
}
// Generate the call to construct empty tensor. The sizes are
// explicitly defined by the arguments to the alloc operator.
- SmallVector<Value, 8> params;
- ShapedType stp = resType.cast<ShapedType>();
- newParams(rewriter, params, loc, stp, enc, Action::kEmpty, sizes);
- rewriter.replaceOp(op, genNewCall(rewriter, loc, params));
+ rewriter.replaceOp(op,
+ NewCallParams(rewriter, loc)
+ .genBuffers(enc, sizes, resType.cast<ShapedType>())
+ .genNewCall(Action::kEmpty));
return success();
}
};
@@ -690,7 +766,7 @@ class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
return success();
}
SmallVector<Value, 4> sizes;
- SmallVector<Value, 8> params;
+ NewCallParams params(rewriter, loc);
ShapedType stp = srcType.cast<ShapedType>();
sizesFromPtr(rewriter, sizes, loc, encSrc, stp, src);
bool useDirectConversion;
@@ -708,9 +784,8 @@ class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
break;
}
if (useDirectConversion) {
- newParams(rewriter, params, loc, stp, encDst, Action::kSparseToSparse,
- sizes, src);
- rewriter.replaceOp(op, genNewCall(rewriter, loc, params));
+ rewriter.replaceOp(op, params.genBuffers(encDst, sizes, stp)
+ .genNewCall(Action::kSparseToSparse, src));
} else { // use via-COO conversion.
// Set up encoding with right mix of src and dst so that the two
// method calls can share most parameters, while still providing
@@ -719,13 +794,13 @@ class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
op->getContext(), encDst.getDimLevelType(), encDst.getDimOrdering(),
encDst.getHigherOrdering(), encSrc.getPointerBitWidth(),
encSrc.getIndexBitWidth());
- newParams(rewriter, params, loc, stp, enc, Action::kToCOO, sizes, src);
- Value coo = genNewCall(rewriter, loc, params);
- params[3] = constantPointerTypeEncoding(rewriter, loc, encDst);
- params[4] = constantIndexTypeEncoding(rewriter, loc, encDst);
- params[6] = constantAction(rewriter, loc, Action::kFromCOO);
- params[7] = coo;
- Value dst = genNewCall(rewriter, loc, params);
+ // TODO: This is the only place where `kToCOO` (or `kToIterator`)
+ // is called with a non-identity permutation. Is there any clean
+ // way to push the permutation over to the `kFromCOO` side instead?
+ Value coo =
+ params.genBuffers(enc, sizes, stp).genNewCall(Action::kToCOO, src);
+ Value dst = params.setTemplateTypes(encDst, stp)
+ .genNewCall(Action::kFromCOO, coo);
genDelCOOCall(rewriter, loc, stp.getElementType(), coo);
rewriter.replaceOp(op, dst);
}
@@ -743,7 +818,7 @@ class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
RankedTensorType srcTensorTp = srcType.cast<RankedTensorType>();
unsigned rank = dstTensorTp.getRank();
Type elemTp = dstTensorTp.getElementType();
- // Fabricate a no-permutation encoding for newParams().
+ // Fabricate a no-permutation encoding for NewCallParams
// The pointer/index types must be those of `src`.
// The dimLevelTypes aren't actually used by Action::kToIterator.
encDst = SparseTensorEncodingAttr::get(
@@ -751,11 +826,10 @@ class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
SmallVector<DimLevelType>(rank, DimLevelType::Dense), AffineMap(),
AffineMap(), encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth());
SmallVector<Value, 4> sizes;
- SmallVector<Value, 8> params;
sizesFromPtr(rewriter, sizes, loc, encSrc, srcTensorTp, src);
- newParams(rewriter, params, loc, dstTensorTp, encDst, Action::kToIterator,
- sizes, src);
- Value iter = genNewCall(rewriter, loc, params);
+ Value iter = NewCallParams(rewriter, loc)
+ .genBuffers(encDst, sizes, dstTensorTp)
+ .genNewCall(Action::kToIterator, src);
Value ind = genAlloca(rewriter, loc, rank, rewriter.getIndexType());
Value elemPtr = genAllocaScalar(rewriter, loc, elemTp);
Block *insertionBlock = rewriter.getInsertionBlock();
@@ -817,12 +891,12 @@ class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
ShapedType stp = resType.cast<ShapedType>();
unsigned rank = stp.getRank();
SmallVector<Value, 4> sizes;
- SmallVector<Value, 8> params;
sizesFromSrc(rewriter, sizes, loc, src);
- newParams(rewriter, params, loc, stp, encDst, Action::kEmptyCOO, sizes);
- Value coo = genNewCall(rewriter, loc, params);
+ NewCallParams params(rewriter, loc);
+ Value coo =
+ params.genBuffers(encDst, sizes, stp).genNewCall(Action::kEmptyCOO);
Value ind = genAlloca(rewriter, loc, rank, rewriter.getIndexType());
- Value perm = params[2];
+ Value perm = params.getDim2LvlMap();
Type eltType = stp.getElementType();
Value elemPtr = genAllocaScalar(rewriter, loc, eltType);
genDenseTensorOrSparseConstantIterLoop(
@@ -836,9 +910,7 @@ class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
genAddEltCall(builder, loc, eltType, coo, elemPtr, ind, perm);
});
// Final call to construct sparse tensor storage.
- params[6] = constantAction(rewriter, loc, Action::kFromCOO);
- params[7] = coo;
- Value dst = genNewCall(rewriter, loc, params);
+ Value dst = params.genNewCall(Action::kFromCOO, coo);
genDelCOOCall(rewriter, loc, eltType, coo);
rewriter.replaceOp(op, dst);
return success();
@@ -1117,15 +1189,15 @@ class SparseTensorConcatConverter : public OpConversionPattern<ConcatenateOp> {
Value offset = constantIndex(rewriter, loc, 0);
SmallVector<Value, 4> sizes;
- SmallVector<Value, 8> params;
+ NewCallParams params(rewriter, loc);
concatSizesFromInputs(rewriter, sizes, loc, dstTp, op.getInputs(),
concatDim);
if (encDst) {
// Start a new COO for the destination tensor.
- newParams(rewriter, params, loc, dstTp, encDst, Action::kEmptyCOO, sizes);
- dst = genNewCall(rewriter, loc, params);
- dstPerm = params[2];
+ dst =
+ params.genBuffers(encDst, sizes, dstTp).genNewCall(Action::kEmptyCOO);
+ dstPerm = params.getDim2LvlMap();
elemPtr = genAllocaScalar(rewriter, loc, elemTp);
dstIdx = genAlloca(rewriter, loc, rank, rewriter.getIndexType());
} else {
@@ -1188,11 +1260,9 @@ class SparseTensorConcatConverter : public OpConversionPattern<ConcatenateOp> {
offset = rewriter.create<arith::AddIOp>(loc, offset, curDim);
}
if (encDst) {
- params[6] = constantAction(rewriter, loc, Action::kFromCOO);
// In sparse output case, the destination holds the COO.
Value coo = dst;
- params[7] = coo;
- dst = genNewCall(rewriter, loc, params);
+ dst = params.genNewCall(Action::kFromCOO, coo);
// Release resources.
genDelCOOCall(rewriter, loc, elemTp, coo);
rewriter.replaceOp(op, dst);
@@ -1216,27 +1286,25 @@ class SparseTensorOutConverter : public OpConversionPattern<OutOp> {
Value src = adaptor.getOperands()[0];
auto encSrc = getSparseTensorEncoding(srcType);
SmallVector<Value, 4> sizes;
- SmallVector<Value, 8> params;
sizesFromPtr(rewriter, sizes, loc, encSrc, srcType, src);
auto enc = SparseTensorEncodingAttr::get(
op->getContext(), encSrc.getDimLevelType(), AffineMap(), AffineMap(),
encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth());
- newParams(rewriter, params, loc, srcType, enc, Action::kToCOO, sizes, src);
- Value coo = genNewCall(rewriter, loc, params);
+ Value coo = NewCallParams(rewriter, loc)
+ .genBuffers(enc, sizes, srcType)
+ .genNewCall(Action::kToCOO, src);
// Then output the tensor to external file with indices in the externally
// visible lexicographic index order. A sort is required if the source was
// not in that order yet (note that the sort can be dropped altogether if
// external format does not care about the order at all, but here we assume
// it does).
- bool sort =
- encSrc.getDimOrdering() && !encSrc.getDimOrdering().isIdentity();
- params.clear();
- params.push_back(coo);
- params.push_back(adaptor.getOperands()[1]);
- params.push_back(constantI1(rewriter, loc, sort));
+ Value sort = constantI1(rewriter, loc,
+ encSrc.getDimOrdering() &&
+ !encSrc.getDimOrdering().isIdentity());
+ SmallVector<Value, 3> outParams{coo, adaptor.getOperands()[1], sort};
Type eltType = srcType.getElementType();
SmallString<18> name{"outSparseTensor", primaryTypeFunctionSuffix(eltType)};
- createFuncCall(rewriter, loc, name, {}, params, EmitCInterface::Off);
+ createFuncCall(rewriter, loc, name, {}, outParams, EmitCInterface::Off);
genDelCOOCall(rewriter, loc, eltType, coo);
rewriter.eraseOp(op);
return success();
More information about the Mlir-commits
mailing list