[Mlir-commits] [mlir] [mlir][sparse] Implement rewriters to reinterpret maps on foreach (PR #70868)

Yinying Li llvmlistbot at llvm.org
Wed Nov 1 11:30:20 PDT 2023


================
@@ -166,61 +178,125 @@ struct DemapInsRemapOutsRewriter : public OpRewritePattern<SourceOp> {
 
     // CRTP call.
     OpAdaptor adaptor(deMappedIns);
-    ValueRange outs =
-        static_cast<const SubClass *>(this)->rewriteOp(op, adaptor, rewriter);
-    assert(outs.size() == op->getResults().size());
-
-    // Remap  outputs.
-    SmallVector<Value> reMappedOuts(outs);
-    for (auto [r, a] : llvm::zip(reMappedOuts, op->getResults()))
-      if (r.getType() != a.getType())
-        r = rewriter.create<ReinterpretMapOp>(loc, a.getType(), r);
-
-    rewriter.replaceOp(op, reMappedOuts);
-    return success();
+    return static_cast<const SubClass *>(this)->rewriteOp(op, adaptor,
+                                                          rewriter);
   }
 };
 
-struct CrdTranslateRewriter : public OpRewritePattern<CrdTranslateOp> {
-  using OpRewritePattern::OpRewritePattern;
-  LogicalResult matchAndRewrite(CrdTranslateOp op,
-                                PatternRewriter &rewriter) const override {
-    AffineMap map = op.getDirection() == CrdTransDirectionKind::dim2lvl
-                        ? op.getEncoder().getDimToLvl()
-                        : op.getEncoder().getLvlToDim();
-
-    SmallVector<Value> outCrds;
-    for (AffineExpr result : map.getResults()) {
-      // TODO: we should probably expand the affine map to IR using our own
-      // rules, since affine.apply assume signed value, while the cooridinates
-      // we provided must always be signless.
-      Value trans = rewriter.create<affine::AffineApplyOp>(
-          op.getLoc(), AffineMap::get(map.getNumDims(), 0, result),
-          op.getInCrds());
-      outCrds.push_back(trans);
-    }
-    rewriter.replaceOp(op, outCrds);
-    return success();
-  }
-};
+//===----------------------------------------------------------------------===//
+// Reinterpret Map Rewriters for operations other than linalg.generics
+//===----------------------------------------------------------------------===//
 
-struct TensorInsertRewriter
-    : public DemapInsRemapOutsRewriter<TensorInsertRewriter, tensor::InsertOp> {
+struct TensorInsertDemapper
+    : public DemapInsRemapOutsRewriter<TensorInsertDemapper, tensor::InsertOp> {
   using DemapInsRemapOutsRewriter::DemapInsRemapOutsRewriter;
+  LogicalResult rewriteOp(tensor::InsertOp op, OpAdaptor adaptor,
+                          PatternRewriter &rewriter) const {
+    if (!hasAnySparseResult(op))
+      return failure();
 
-  bool matchOp(tensor::InsertOp op) const {
-    return op.getResult().getType().getEncoding() != nullptr;
-  }
-
-  ValueRange rewriteOp(tensor::InsertOp op, OpAdaptor adaptor,
-                       PatternRewriter &rewriter) const {
     Location loc = op.getLoc();
     auto stt = getSparseTensorType(op.getResult());
     ValueRange lvlCrd = stt.translateCrds(rewriter, loc, op.getIndices(),
                                           CrdTransDirectionKind::dim2lvl);
     Operation *insertOp = rewriter.create<sparse_tensor::InsertOp>(
         loc, op.getScalar(), adaptor.getDest(), lvlCrd);
-    return insertOp->getResults();
+
+    SmallVector<Value> outs(insertOp->getResults());
+    remapValueRange(rewriter, op->getResultTypes(), outs);
+    rewriter.replaceOp(op, outs);
+    return success();
+  }
+};
+
+struct ForeachOpDemapper
+    : public DemapInsRemapOutsRewriter<ForeachOpDemapper, ForeachOp> {
+  using DemapInsRemapOutsRewriter::DemapInsRemapOutsRewriter;
+  LogicalResult rewriteOp(ForeachOp op, OpAdaptor adaptor,
+                          PatternRewriter &rewriter) const {
+    // Only handles operations with sparse input/output.
----------------
yinying-lisa-li wrote:

For consistency, maybe handle?

https://github.com/llvm/llvm-project/pull/70868


More information about the Mlir-commits mailing list