[Mlir-commits] [mlir] [mlir][sparse] Implement rewriters to reinterpret maps on foreach (PR #70868)
Yinying Li
llvmlistbot at llvm.org
Wed Nov 1 11:30:20 PDT 2023
================
@@ -166,61 +178,125 @@ struct DemapInsRemapOutsRewriter : public OpRewritePattern<SourceOp> {
// CRTP call.
OpAdaptor adaptor(deMappedIns);
- ValueRange outs =
- static_cast<const SubClass *>(this)->rewriteOp(op, adaptor, rewriter);
- assert(outs.size() == op->getResults().size());
-
- // Remap outputs.
- SmallVector<Value> reMappedOuts(outs);
- for (auto [r, a] : llvm::zip(reMappedOuts, op->getResults()))
- if (r.getType() != a.getType())
- r = rewriter.create<ReinterpretMapOp>(loc, a.getType(), r);
-
- rewriter.replaceOp(op, reMappedOuts);
- return success();
+ return static_cast<const SubClass *>(this)->rewriteOp(op, adaptor,
+ rewriter);
}
};
-struct CrdTranslateRewriter : public OpRewritePattern<CrdTranslateOp> {
- using OpRewritePattern::OpRewritePattern;
- LogicalResult matchAndRewrite(CrdTranslateOp op,
- PatternRewriter &rewriter) const override {
- AffineMap map = op.getDirection() == CrdTransDirectionKind::dim2lvl
- ? op.getEncoder().getDimToLvl()
- : op.getEncoder().getLvlToDim();
-
- SmallVector<Value> outCrds;
- for (AffineExpr result : map.getResults()) {
- // TODO: we should probably expand the affine map to IR using our own
- // rules, since affine.apply assume signed value, while the cooridinates
- // we provided must always be signless.
- Value trans = rewriter.create<affine::AffineApplyOp>(
- op.getLoc(), AffineMap::get(map.getNumDims(), 0, result),
- op.getInCrds());
- outCrds.push_back(trans);
- }
- rewriter.replaceOp(op, outCrds);
- return success();
- }
-};
+//===----------------------------------------------------------------------===//
+// Reinterpret Map Rewriters for operations other than linalg.generics
+//===----------------------------------------------------------------------===//
-struct TensorInsertRewriter
- : public DemapInsRemapOutsRewriter<TensorInsertRewriter, tensor::InsertOp> {
+struct TensorInsertDemapper
+ : public DemapInsRemapOutsRewriter<TensorInsertDemapper, tensor::InsertOp> {
using DemapInsRemapOutsRewriter::DemapInsRemapOutsRewriter;
+ LogicalResult rewriteOp(tensor::InsertOp op, OpAdaptor adaptor,
+ PatternRewriter &rewriter) const {
+ if (!hasAnySparseResult(op))
+ return failure();
- bool matchOp(tensor::InsertOp op) const {
- return op.getResult().getType().getEncoding() != nullptr;
- }
-
- ValueRange rewriteOp(tensor::InsertOp op, OpAdaptor adaptor,
- PatternRewriter &rewriter) const {
Location loc = op.getLoc();
auto stt = getSparseTensorType(op.getResult());
ValueRange lvlCrd = stt.translateCrds(rewriter, loc, op.getIndices(),
CrdTransDirectionKind::dim2lvl);
Operation *insertOp = rewriter.create<sparse_tensor::InsertOp>(
loc, op.getScalar(), adaptor.getDest(), lvlCrd);
- return insertOp->getResults();
+
+ SmallVector<Value> outs(insertOp->getResults());
+ remapValueRange(rewriter, op->getResultTypes(), outs);
+ rewriter.replaceOp(op, outs);
+ return success();
+ }
+};
+
+struct ForeachOpDemapper
+ : public DemapInsRemapOutsRewriter<ForeachOpDemapper, ForeachOp> {
+ using DemapInsRemapOutsRewriter::DemapInsRemapOutsRewriter;
+ LogicalResult rewriteOp(ForeachOp op, OpAdaptor adaptor,
+ PatternRewriter &rewriter) const {
+ // Only handles operations with sparse input/output.
+ if (!hasNonIdentityOperandsOrResults(op))
+ return failure();
+
+ // TODO: demap constant as well.
+ if (auto constOp = op.getTensor().getDefiningOp<arith::ConstantOp>())
+ if (auto attr = dyn_cast<SparseElementsAttr>(constOp.getValue()))
+ return failure();
+
+ Location loc = op.getLoc();
+ // Cache the type information since we update the foreach op in-place.
+ auto srcStt = getSparseTensorType(op.getTensor());
+ SmallVector<Type> prevRetTps(op.getResultTypes());
+
+ rewriter.startRootUpdate(op);
+ op.getTensorMutable().assign(adaptor.getTensor());
+ op.getInitArgsMutable().assign(adaptor.getInitArgs());
+ // Update results' types.
+ for (auto r : op.getResults())
+ if (auto stt = tryGetSparseTensorType(r); stt && !stt->isIdentity())
+ r.setType(stt->getDemappedType());
+
+ Level lvlRank = getSparseTensorType(adaptor.getTensor()).getLvlRank();
+ // Update the foreach body.
+ SmallVector<Type> blockArgTps(lvlRank, rewriter.getIndexType());
+ blockArgTps.push_back(srcStt.getElementType());
+ blockArgTps.append(adaptor.getInitArgs().getTypes().begin(),
+ adaptor.getInitArgs().getTypes().end());
+ Block *body = op.getBody();
+ // Block Args: [dimCrd, val, initArgs]
+ unsigned preArgNum = body->getNumArguments();
+ for (Type t : blockArgTps)
+ body->addArgument(t, loc);
+
+ // Block Args: [dimCrd, val, initArgs, lvlCrds, val, DemappedArgs]
+ rewriter.setInsertionPointToStart(body);
+ ValueRange lvlCrds = body->getArguments().slice(preArgNum, lvlRank);
+
+ ValueRange dimCrds = srcStt.translateCrds(rewriter, loc, lvlCrds,
+ CrdTransDirectionKind::lvl2dim);
+ rewriter.replaceAllUsesWith(
+ body->getArguments().take_front(srcStt.getDimRank()), dimCrds);
+ body->eraseArguments(0, srcStt.getDimRank());
+ // Block Args: [val, initArgs, lvlCrds, val, DemappedArgs]
+ unsigned numInitArgs = op.getInitArgs().size();
+ rewriter.replaceAllUsesWith(body->getArgument(0),
+ body->getArgument(lvlRank + numInitArgs + 1));
+ body->eraseArgument(0);
+ // Block Args: [initArgs, lvlCrds, val, DemappedArgs]
+ ValueRange srcArgs = body->getArguments().take_front(numInitArgs);
+ SmallVector<Value> dstArgs(body->getArguments().take_back(numInitArgs));
+ // Remap back before replacement;
----------------
yinying-lisa-li wrote:
nit: .
https://github.com/llvm/llvm-project/pull/70868
More information about the Mlir-commits
mailing list