[Mlir-commits] [mlir] [mlir][sparse] end-to-end matmul between Dense and BSR tensors (PR #71448)
Aart Bik
llvmlistbot at llvm.org
Tue Nov 7 09:37:22 PST 2023
================
@@ -73,128 +353,57 @@ static bool hasNonIdentityOperandsOrResults(Operation *op) {
llvm::any_of(op->getResults(), hasNonIdentityMap);
}
-// Generates a clone of the given linalg generic operation, but with
-// remapped arguments, index maps, and iteration types.
-//
-// TODO: As decribed below, this is proof-of-concept code which makes a lot
-// of simplifying assumptions for now.
-//
-static linalg::GenericOp genGenericLinalg(PatternRewriter &rewriter,
- linalg::GenericOp linalgOp,
- SparseTensorType stt, Value out) {
- unsigned dimRank = stt.getDimRank();
- unsigned lvlRank = stt.getLvlRank();
- SmallVector<Value> inputOps = linalgOp.getInputs();
- SmallVector<Value> outputOps = {out};
- SmallVector<AffineMap> indexMaps;
- SmallVector<utils::IteratorType> iterTypes;
- // Translate the index maps, except output map, which is lvl-identity.
- auto maps = linalgOp.getIndexingMapsArray();
- for (unsigned i = 0, n = maps.size() - 1; i < n; i++)
- indexMaps.push_back(translateMap(rewriter, stt, maps[i]));
- indexMaps.push_back(
- AffineMap::getMultiDimIdentityMap(lvlRank, rewriter.getContext()));
- // Add additional "parallel" iteration types at the top.
- for (unsigned i = 0, diff = lvlRank = dimRank; i < diff; i++)
- iterTypes.push_back(utils::IteratorType::parallel);
- for (auto &i : linalgOp.getIteratorTypesArray())
- iterTypes.push_back(i);
- // Generate the new linalg generic operation and clone body.
- auto newOp = rewriter.create<linalg::GenericOp>(
- linalgOp.getLoc(), out.getType(), inputOps, outputOps, indexMaps,
- iterTypes);
- rewriter.cloneRegionBefore(linalgOp.getRegion(), newOp.getRegion(),
- newOp.getRegion().begin());
- return newOp;
-}
-
namespace {
//===----------------------------------------------------------------------===//
// Rewriting rules for linalg generic ops.
//===----------------------------------------------------------------------===//
/// Sparse rewriting rule for the generic `linalg` operation.
-struct GenericOpReinterpretMap : public OpRewritePattern<linalg::GenericOp> {
+struct GenericOpReinterpretMap
+ : public DemapInsRewriter<GenericOpReinterpretMap, linalg::GenericOp> {
public:
- GenericOpReinterpretMap(MLIRContext *context)
- : OpRewritePattern<linalg::GenericOp>(context) {}
+ using DemapInsRewriter::DemapInsRewriter;
+ LogicalResult rewriteOp(linalg::GenericOp linalgOp, OpAdaptor adaptor,
+ PatternRewriter &rewriter) const {
+ // Only rewrite single output operations with pure (sparse) tensor
+ // semantics.
+ if (linalgOp.getNumDpsInits() != 1 || !linalgOp.hasTensorSemantics() ||
+ !hasAnySparseOperandOrResult(linalgOp) ||
+ !hasNonIdentityOperandsOrResults(linalgOp))
+ return failure();
- LogicalResult matchAndRewrite(linalg::GenericOp linalgOp,
- PatternRewriter &rewriter) const override {
- // Only rewrite single output operations with pure tensor semantics.
- if (linalgOp.getNumDpsInits() != 1 || !linalgOp.hasTensorSemantics())
+ // Try translating the index map.
+ if (failed(translateMap(linalgOp, rewriter)))
return failure();
- // Scan all operands, inspect sparse tensors.
- //
- // TODO: generalize this proof-of-concept algorithm, since the current
- // implementation accepts only simple indexing maps, and one
- // non-permutation sparse tensor, which must have an identity
- // indexing map and be the output.
- //
- OpOperand *tx = nullptr;
- for (OpOperand &t : linalgOp->getOpOperands()) {
- // Ensure every index map is "simple".
- const auto map = linalgOp.getMatchingIndexingMap(&t);
- for (unsigned i = 0, n = map.getNumResults(); i < n; i++)
- if (map.getResult(i).getKind() != AffineExprKind::DimId)
- return failure();
- // Inspect sparse operands.
- auto stt = tryGetSparseTensorType(t.get());
- if (stt && stt->hasEncoding()) {
- if (stt->isPermutation())
- continue;
- assert(stt->getDimRank() < stt->getLvlRank()); // only allowed non-perm
- if (tx)
- return failure(); // more than one non-perm
- if (!map.isIdentity())
- return failure(); // no ID indexing map on the non-perm
- tx = &t;
- }
- }
- // Found a non-permutation, rewrite when this is the output.
- if (tx && tx == linalgOp.getDpsInitOperand(0)) {
- auto stt = getSparseTensorType(tx->get());
- auto demap = genDemap(rewriter, stt.getEncoding(), tx->get());
- auto newOp = genGenericLinalg(rewriter, linalgOp, stt, demap);
- auto remap = genRemap(rewriter, stt.getEncoding(), newOp.getResult(0));
- rewriter.replaceOp(linalgOp, remap);
- return success();
+
+ // Must only have one result
----------------
aartbik wrote:
A "must" comment is never very useful.Is it assumed, or it is enforced.
I would stick with descriptive comment, like
// On success, update the linalg op fields, including the single result, in place.
https://github.com/llvm/llvm-project/pull/71448
More information about the Mlir-commits
mailing list