[Mlir-commits] [mlir] [mlir][sparse] end-to-end matmul between Dense and BSR tensors (PR #71448)

Yinying Li llvmlistbot at llvm.org
Tue Nov 7 09:02:04 PST 2023


================
@@ -73,128 +353,57 @@ static bool hasNonIdentityOperandsOrResults(Operation *op) {
          llvm::any_of(op->getResults(), hasNonIdentityMap);
 }
 
-// Generates a clone of the given linalg generic operation, but with
-// remapped arguments, index maps, and iteration types.
-//
-// TODO: As decribed below, this is proof-of-concept code which makes a lot
-//       of simplifying assumptions for now.
-//
-static linalg::GenericOp genGenericLinalg(PatternRewriter &rewriter,
-                                          linalg::GenericOp linalgOp,
-                                          SparseTensorType stt, Value out) {
-  unsigned dimRank = stt.getDimRank();
-  unsigned lvlRank = stt.getLvlRank();
-  SmallVector<Value> inputOps = linalgOp.getInputs();
-  SmallVector<Value> outputOps = {out};
-  SmallVector<AffineMap> indexMaps;
-  SmallVector<utils::IteratorType> iterTypes;
-  // Translate the index maps, except output map, which is lvl-identity.
-  auto maps = linalgOp.getIndexingMapsArray();
-  for (unsigned i = 0, n = maps.size() - 1; i < n; i++)
-    indexMaps.push_back(translateMap(rewriter, stt, maps[i]));
-  indexMaps.push_back(
-      AffineMap::getMultiDimIdentityMap(lvlRank, rewriter.getContext()));
-  // Add additional "parallel" iteration types at the top.
-  for (unsigned i = 0, diff = lvlRank = dimRank; i < diff; i++)
-    iterTypes.push_back(utils::IteratorType::parallel);
-  for (auto &i : linalgOp.getIteratorTypesArray())
-    iterTypes.push_back(i);
-  // Generate the new linalg generic operation and clone body.
-  auto newOp = rewriter.create<linalg::GenericOp>(
-      linalgOp.getLoc(), out.getType(), inputOps, outputOps, indexMaps,
-      iterTypes);
-  rewriter.cloneRegionBefore(linalgOp.getRegion(), newOp.getRegion(),
-                             newOp.getRegion().begin());
-  return newOp;
-}
-
 namespace {
 
 //===----------------------------------------------------------------------===//
 // Rewriting rules for linalg generic ops.
 //===----------------------------------------------------------------------===//
 
 /// Sparse rewriting rule for the generic `linalg` operation.
-struct GenericOpReinterpretMap : public OpRewritePattern<linalg::GenericOp> {
+struct GenericOpReinterpretMap
+    : public DemapInsRewriter<GenericOpReinterpretMap, linalg::GenericOp> {
 public:
-  GenericOpReinterpretMap(MLIRContext *context)
-      : OpRewritePattern<linalg::GenericOp>(context) {}
+  using DemapInsRewriter::DemapInsRewriter;
+  LogicalResult rewriteOp(linalg::GenericOp linalgOp, OpAdaptor adaptor,
+                          PatternRewriter &rewriter) const {
+    // Only rewrite single output operations with pure (sparse) tensor
----------------
yinying-lisa-li wrote:

rewrites

https://github.com/llvm/llvm-project/pull/71448


More information about the Mlir-commits mailing list