[Mlir-commits] [mlir] [mlir][sparse] end-to-end matmul between Dense and BSR tensors (PR #71448)
Yinying Li
llvmlistbot at llvm.org
Tue Nov 7 09:02:04 PST 2023
================
@@ -73,128 +353,57 @@ static bool hasNonIdentityOperandsOrResults(Operation *op) {
llvm::any_of(op->getResults(), hasNonIdentityMap);
}
-// Generates a clone of the given linalg generic operation, but with
-// remapped arguments, index maps, and iteration types.
-//
-// TODO: As decribed below, this is proof-of-concept code which makes a lot
-// of simplifying assumptions for now.
-//
-static linalg::GenericOp genGenericLinalg(PatternRewriter &rewriter,
- linalg::GenericOp linalgOp,
- SparseTensorType stt, Value out) {
- unsigned dimRank = stt.getDimRank();
- unsigned lvlRank = stt.getLvlRank();
- SmallVector<Value> inputOps = linalgOp.getInputs();
- SmallVector<Value> outputOps = {out};
- SmallVector<AffineMap> indexMaps;
- SmallVector<utils::IteratorType> iterTypes;
- // Translate the index maps, except output map, which is lvl-identity.
- auto maps = linalgOp.getIndexingMapsArray();
- for (unsigned i = 0, n = maps.size() - 1; i < n; i++)
- indexMaps.push_back(translateMap(rewriter, stt, maps[i]));
- indexMaps.push_back(
- AffineMap::getMultiDimIdentityMap(lvlRank, rewriter.getContext()));
- // Add additional "parallel" iteration types at the top.
- for (unsigned i = 0, diff = lvlRank = dimRank; i < diff; i++)
- iterTypes.push_back(utils::IteratorType::parallel);
- for (auto &i : linalgOp.getIteratorTypesArray())
- iterTypes.push_back(i);
- // Generate the new linalg generic operation and clone body.
- auto newOp = rewriter.create<linalg::GenericOp>(
- linalgOp.getLoc(), out.getType(), inputOps, outputOps, indexMaps,
- iterTypes);
- rewriter.cloneRegionBefore(linalgOp.getRegion(), newOp.getRegion(),
- newOp.getRegion().begin());
- return newOp;
-}
-
namespace {
//===----------------------------------------------------------------------===//
// Rewriting rules for linalg generic ops.
//===----------------------------------------------------------------------===//
/// Sparse rewriting rule for the generic `linalg` operation.
-struct GenericOpReinterpretMap : public OpRewritePattern<linalg::GenericOp> {
+struct GenericOpReinterpretMap
+ : public DemapInsRewriter<GenericOpReinterpretMap, linalg::GenericOp> {
public:
- GenericOpReinterpretMap(MLIRContext *context)
- : OpRewritePattern<linalg::GenericOp>(context) {}
+ using DemapInsRewriter::DemapInsRewriter;
+ LogicalResult rewriteOp(linalg::GenericOp linalgOp, OpAdaptor adaptor,
+ PatternRewriter &rewriter) const {
+ // Only rewrite single output operations with pure (sparse) tensor
----------------
yinying-lisa-li wrote:
rewrites
https://github.com/llvm/llvm-project/pull/71448
More information about the Mlir-commits
mailing list