[Mlir-commits] [mlir] [mlir][sparse] end-to-end matmul between Dense and BSR tensors (PR #71448)

Aart Bik llvmlistbot at llvm.org
Tue Nov 7 09:37:22 PST 2023


================
@@ -73,128 +353,57 @@ static bool hasNonIdentityOperandsOrResults(Operation *op) {
          llvm::any_of(op->getResults(), hasNonIdentityMap);
 }
 
-// Generates a clone of the given linalg generic operation, but with
-// remapped arguments, index maps, and iteration types.
-//
-// TODO: As decribed below, this is proof-of-concept code which makes a lot
-//       of simplifying assumptions for now.
-//
-static linalg::GenericOp genGenericLinalg(PatternRewriter &rewriter,
-                                          linalg::GenericOp linalgOp,
-                                          SparseTensorType stt, Value out) {
-  unsigned dimRank = stt.getDimRank();
-  unsigned lvlRank = stt.getLvlRank();
-  SmallVector<Value> inputOps = linalgOp.getInputs();
-  SmallVector<Value> outputOps = {out};
-  SmallVector<AffineMap> indexMaps;
-  SmallVector<utils::IteratorType> iterTypes;
-  // Translate the index maps, except output map, which is lvl-identity.
-  auto maps = linalgOp.getIndexingMapsArray();
-  for (unsigned i = 0, n = maps.size() - 1; i < n; i++)
-    indexMaps.push_back(translateMap(rewriter, stt, maps[i]));
-  indexMaps.push_back(
-      AffineMap::getMultiDimIdentityMap(lvlRank, rewriter.getContext()));
-  // Add additional "parallel" iteration types at the top.
-  for (unsigned i = 0, diff = lvlRank = dimRank; i < diff; i++)
-    iterTypes.push_back(utils::IteratorType::parallel);
-  for (auto &i : linalgOp.getIteratorTypesArray())
-    iterTypes.push_back(i);
-  // Generate the new linalg generic operation and clone body.
-  auto newOp = rewriter.create<linalg::GenericOp>(
-      linalgOp.getLoc(), out.getType(), inputOps, outputOps, indexMaps,
-      iterTypes);
-  rewriter.cloneRegionBefore(linalgOp.getRegion(), newOp.getRegion(),
-                             newOp.getRegion().begin());
-  return newOp;
-}
-
 namespace {
 
 //===----------------------------------------------------------------------===//
 // Rewriting rules for linalg generic ops.
 //===----------------------------------------------------------------------===//
 
 /// Sparse rewriting rule for the generic `linalg` operation.
-struct GenericOpReinterpretMap : public OpRewritePattern<linalg::GenericOp> {
+struct GenericOpReinterpretMap
+    : public DemapInsRewriter<GenericOpReinterpretMap, linalg::GenericOp> {
 public:
-  GenericOpReinterpretMap(MLIRContext *context)
-      : OpRewritePattern<linalg::GenericOp>(context) {}
+  using DemapInsRewriter::DemapInsRewriter;
+  LogicalResult rewriteOp(linalg::GenericOp linalgOp, OpAdaptor adaptor,
+                          PatternRewriter &rewriter) const {
+    // Only rewrite single output operations with pure (sparse) tensor
+    // semantics.
+    if (linalgOp.getNumDpsInits() != 1 || !linalgOp.hasTensorSemantics() ||
+        !hasAnySparseOperandOrResult(linalgOp) ||
+        !hasNonIdentityOperandsOrResults(linalgOp))
+      return failure();
 
-  LogicalResult matchAndRewrite(linalg::GenericOp linalgOp,
-                                PatternRewriter &rewriter) const override {
-    // Only rewrite single output operations with pure tensor semantics.
-    if (linalgOp.getNumDpsInits() != 1 || !linalgOp.hasTensorSemantics())
+    // Try translating the index map.
+    if (failed(translateMap(linalgOp, rewriter)))
       return failure();
-    // Scan all operands, inspect sparse tensors.
-    //
-    // TODO: generalize this proof-of-concept algorithm, since the current
-    //       implementation accepts only simple indexing maps, and one
-    //       non-permutation sparse tensor, which must have an identity
-    //       indexing map and be the output.
-    //
-    OpOperand *tx = nullptr;
-    for (OpOperand &t : linalgOp->getOpOperands()) {
-      // Ensure every index map is "simple".
-      const auto map = linalgOp.getMatchingIndexingMap(&t);
-      for (unsigned i = 0, n = map.getNumResults(); i < n; i++)
-        if (map.getResult(i).getKind() != AffineExprKind::DimId)
-          return failure();
-      // Inspect sparse operands.
-      auto stt = tryGetSparseTensorType(t.get());
-      if (stt && stt->hasEncoding()) {
-        if (stt->isPermutation())
-          continue;
-        assert(stt->getDimRank() < stt->getLvlRank()); // only allowed non-perm
-        if (tx)
-          return failure(); // more than one non-perm
-        if (!map.isIdentity())
-          return failure(); // no ID indexing map on the non-perm
-        tx = &t;
-      }
-    }
-    // Found a non-permutation, rewrite when this is the output.
-    if (tx && tx == linalgOp.getDpsInitOperand(0)) {
-      auto stt = getSparseTensorType(tx->get());
-      auto demap = genDemap(rewriter, stt.getEncoding(), tx->get());
-      auto newOp = genGenericLinalg(rewriter, linalgOp, stt, demap);
-      auto remap = genRemap(rewriter, stt.getEncoding(), newOp.getResult(0));
-      rewriter.replaceOp(linalgOp, remap);
-      return success();
+
+    // Must only have one result
----------------
aartbik wrote:

A "must" comment is never very useful.Is it assumed, or it is enforced.
I would stick with descriptive comment, like

// On success, update the linalg op fields, including the single result, in place.

https://github.com/llvm/llvm-project/pull/71448


More information about the Mlir-commits mailing list