[Mlir-commits] [mlir] [mlir][sparse] Implement rewriters to reinterpret maps on foreach (PR #70868)
Yinying Li
llvmlistbot at llvm.org
Wed Nov 1 11:30:20 PDT 2023
================
@@ -166,61 +178,125 @@ struct DemapInsRemapOutsRewriter : public OpRewritePattern<SourceOp> {
// CRTP call.
OpAdaptor adaptor(deMappedIns);
- ValueRange outs =
- static_cast<const SubClass *>(this)->rewriteOp(op, adaptor, rewriter);
- assert(outs.size() == op->getResults().size());
-
- // Remap outputs.
- SmallVector<Value> reMappedOuts(outs);
- for (auto [r, a] : llvm::zip(reMappedOuts, op->getResults()))
- if (r.getType() != a.getType())
- r = rewriter.create<ReinterpretMapOp>(loc, a.getType(), r);
-
- rewriter.replaceOp(op, reMappedOuts);
- return success();
+ return static_cast<const SubClass *>(this)->rewriteOp(op, adaptor,
+ rewriter);
}
};
-struct CrdTranslateRewriter : public OpRewritePattern<CrdTranslateOp> {
- using OpRewritePattern::OpRewritePattern;
- LogicalResult matchAndRewrite(CrdTranslateOp op,
- PatternRewriter &rewriter) const override {
- AffineMap map = op.getDirection() == CrdTransDirectionKind::dim2lvl
- ? op.getEncoder().getDimToLvl()
- : op.getEncoder().getLvlToDim();
-
- SmallVector<Value> outCrds;
- for (AffineExpr result : map.getResults()) {
- // TODO: we should probably expand the affine map to IR using our own
- // rules, since affine.apply assume signed value, while the cooridinates
- // we provided must always be signless.
- Value trans = rewriter.create<affine::AffineApplyOp>(
- op.getLoc(), AffineMap::get(map.getNumDims(), 0, result),
- op.getInCrds());
- outCrds.push_back(trans);
- }
- rewriter.replaceOp(op, outCrds);
- return success();
- }
-};
+//===----------------------------------------------------------------------===//
+// Reinterpret Map Rewriters for operations other than linalg.generics
+//===----------------------------------------------------------------------===//
-struct TensorInsertRewriter
- : public DemapInsRemapOutsRewriter<TensorInsertRewriter, tensor::InsertOp> {
+struct TensorInsertDemapper
+ : public DemapInsRemapOutsRewriter<TensorInsertDemapper, tensor::InsertOp> {
using DemapInsRemapOutsRewriter::DemapInsRemapOutsRewriter;
+ LogicalResult rewriteOp(tensor::InsertOp op, OpAdaptor adaptor,
+ PatternRewriter &rewriter) const {
+ if (!hasAnySparseResult(op))
+ return failure();
- bool matchOp(tensor::InsertOp op) const {
- return op.getResult().getType().getEncoding() != nullptr;
- }
-
- ValueRange rewriteOp(tensor::InsertOp op, OpAdaptor adaptor,
- PatternRewriter &rewriter) const {
Location loc = op.getLoc();
auto stt = getSparseTensorType(op.getResult());
ValueRange lvlCrd = stt.translateCrds(rewriter, loc, op.getIndices(),
CrdTransDirectionKind::dim2lvl);
Operation *insertOp = rewriter.create<sparse_tensor::InsertOp>(
loc, op.getScalar(), adaptor.getDest(), lvlCrd);
- return insertOp->getResults();
+
+ SmallVector<Value> outs(insertOp->getResults());
+ remapValueRange(rewriter, op->getResultTypes(), outs);
+ rewriter.replaceOp(op, outs);
+ return success();
+ }
+};
+
+struct ForeachOpDemapper
+ : public DemapInsRemapOutsRewriter<ForeachOpDemapper, ForeachOp> {
+ using DemapInsRemapOutsRewriter::DemapInsRemapOutsRewriter;
+ LogicalResult rewriteOp(ForeachOp op, OpAdaptor adaptor,
+ PatternRewriter &rewriter) const {
+ // Only handles operations with sparse input/output.
----------------
yinying-lisa-li wrote:
For consistency, maybe handle?
https://github.com/llvm/llvm-project/pull/70868
More information about the Mlir-commits
mailing list