[Mlir-commits] [mlir] [mlir][sparse] Implement rewriters to reinterpret maps on foreach (PR #70868)

Yinying Li llvmlistbot at llvm.org
Wed Nov 1 11:30:20 PDT 2023


================
@@ -166,61 +178,125 @@ struct DemapInsRemapOutsRewriter : public OpRewritePattern<SourceOp> {
 
     // CRTP call.
     OpAdaptor adaptor(deMappedIns);
-    ValueRange outs =
-        static_cast<const SubClass *>(this)->rewriteOp(op, adaptor, rewriter);
-    assert(outs.size() == op->getResults().size());
-
-    // Remap  outputs.
-    SmallVector<Value> reMappedOuts(outs);
-    for (auto [r, a] : llvm::zip(reMappedOuts, op->getResults()))
-      if (r.getType() != a.getType())
-        r = rewriter.create<ReinterpretMapOp>(loc, a.getType(), r);
-
-    rewriter.replaceOp(op, reMappedOuts);
-    return success();
+    return static_cast<const SubClass *>(this)->rewriteOp(op, adaptor,
+                                                          rewriter);
   }
 };
 
-struct CrdTranslateRewriter : public OpRewritePattern<CrdTranslateOp> {
-  using OpRewritePattern::OpRewritePattern;
-  LogicalResult matchAndRewrite(CrdTranslateOp op,
-                                PatternRewriter &rewriter) const override {
-    AffineMap map = op.getDirection() == CrdTransDirectionKind::dim2lvl
-                        ? op.getEncoder().getDimToLvl()
-                        : op.getEncoder().getLvlToDim();
-
-    SmallVector<Value> outCrds;
-    for (AffineExpr result : map.getResults()) {
-      // TODO: we should probably expand the affine map to IR using our own
-      // rules, since affine.apply assume signed value, while the cooridinates
-      // we provided must always be signless.
-      Value trans = rewriter.create<affine::AffineApplyOp>(
-          op.getLoc(), AffineMap::get(map.getNumDims(), 0, result),
-          op.getInCrds());
-      outCrds.push_back(trans);
-    }
-    rewriter.replaceOp(op, outCrds);
-    return success();
-  }
-};
+//===----------------------------------------------------------------------===//
+// Reinterpret Map Rewriters for operations other than linalg.generics
+//===----------------------------------------------------------------------===//
 
-struct TensorInsertRewriter
-    : public DemapInsRemapOutsRewriter<TensorInsertRewriter, tensor::InsertOp> {
+struct TensorInsertDemapper
+    : public DemapInsRemapOutsRewriter<TensorInsertDemapper, tensor::InsertOp> {
   using DemapInsRemapOutsRewriter::DemapInsRemapOutsRewriter;
+  LogicalResult rewriteOp(tensor::InsertOp op, OpAdaptor adaptor,
+                          PatternRewriter &rewriter) const {
+    if (!hasAnySparseResult(op))
+      return failure();
 
-  bool matchOp(tensor::InsertOp op) const {
-    return op.getResult().getType().getEncoding() != nullptr;
-  }
-
-  ValueRange rewriteOp(tensor::InsertOp op, OpAdaptor adaptor,
-                       PatternRewriter &rewriter) const {
     Location loc = op.getLoc();
     auto stt = getSparseTensorType(op.getResult());
     ValueRange lvlCrd = stt.translateCrds(rewriter, loc, op.getIndices(),
                                           CrdTransDirectionKind::dim2lvl);
     Operation *insertOp = rewriter.create<sparse_tensor::InsertOp>(
         loc, op.getScalar(), adaptor.getDest(), lvlCrd);
-    return insertOp->getResults();
+
+    SmallVector<Value> outs(insertOp->getResults());
+    remapValueRange(rewriter, op->getResultTypes(), outs);
+    rewriter.replaceOp(op, outs);
+    return success();
+  }
+};
+
+struct ForeachOpDemapper
+    : public DemapInsRemapOutsRewriter<ForeachOpDemapper, ForeachOp> {
+  using DemapInsRemapOutsRewriter::DemapInsRemapOutsRewriter;
+  LogicalResult rewriteOp(ForeachOp op, OpAdaptor adaptor,
+                          PatternRewriter &rewriter) const {
+    // Only handles operations with sparse input/output.
+    if (!hasNonIdentityOperandsOrResults(op))
+      return failure();
+
+    // TODO: demap constant as well.
+    if (auto constOp = op.getTensor().getDefiningOp<arith::ConstantOp>())
+      if (auto attr = dyn_cast<SparseElementsAttr>(constOp.getValue()))
+        return failure();
+
+    Location loc = op.getLoc();
+    // Cache the type information since we update the foreach op in-place.
+    auto srcStt = getSparseTensorType(op.getTensor());
+    SmallVector<Type> prevRetTps(op.getResultTypes());
+
+    rewriter.startRootUpdate(op);
+    op.getTensorMutable().assign(adaptor.getTensor());
+    op.getInitArgsMutable().assign(adaptor.getInitArgs());
+    // Update results' types.
+    for (auto r : op.getResults())
+      if (auto stt = tryGetSparseTensorType(r); stt && !stt->isIdentity())
+        r.setType(stt->getDemappedType());
+
+    Level lvlRank = getSparseTensorType(adaptor.getTensor()).getLvlRank();
+    // Update the foreach body.
+    SmallVector<Type> blockArgTps(lvlRank, rewriter.getIndexType());
+    blockArgTps.push_back(srcStt.getElementType());
+    blockArgTps.append(adaptor.getInitArgs().getTypes().begin(),
+                       adaptor.getInitArgs().getTypes().end());
+    Block *body = op.getBody();
+    // Block Args: [dimCrd, val, initArgs]
+    unsigned preArgNum = body->getNumArguments();
+    for (Type t : blockArgTps)
+      body->addArgument(t, loc);
+
+    // Block Args: [dimCrd, val, initArgs, lvlCrds, val, DemappedArgs]
+    rewriter.setInsertionPointToStart(body);
+    ValueRange lvlCrds = body->getArguments().slice(preArgNum, lvlRank);
+
+    ValueRange dimCrds = srcStt.translateCrds(rewriter, loc, lvlCrds,
+                                              CrdTransDirectionKind::lvl2dim);
+    rewriter.replaceAllUsesWith(
+        body->getArguments().take_front(srcStt.getDimRank()), dimCrds);
+    body->eraseArguments(0, srcStt.getDimRank());
+    // Block Args: [val, initArgs, lvlCrds, val, DemappedArgs]
+    unsigned numInitArgs = op.getInitArgs().size();
+    rewriter.replaceAllUsesWith(body->getArgument(0),
+                                body->getArgument(lvlRank + numInitArgs + 1));
+    body->eraseArgument(0);
+    // Block Args: [initArgs, lvlCrds, val, DemappedArgs]
+    ValueRange srcArgs = body->getArguments().take_front(numInitArgs);
+    SmallVector<Value> dstArgs(body->getArguments().take_back(numInitArgs));
+    // Remap back before replacement;
----------------
yinying-lisa-li wrote:

nit: .

https://github.com/llvm/llvm-project/pull/70868


More information about the Mlir-commits mailing list