[Mlir-commits] [mlir] c81a2c0 - [mlir][sparse] add helper class to implement common rewriter to re/demap sparse tensors. (#70750)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Oct 31 12:35:56 PDT 2023
Author: Peiming Liu
Date: 2023-10-31T12:35:52-07:00
New Revision: c81a2c058e2b8a0bb8426cfaedc90e00d9290eb0
URL: https://github.com/llvm/llvm-project/commit/c81a2c058e2b8a0bb8426cfaedc90e00d9290eb0
DIFF: https://github.com/llvm/llvm-project/commit/c81a2c058e2b8a0bb8426cfaedc90e00d9290eb0.diff
LOG: [mlir][sparse] add helper class to implement common rewriter to re/demap sparse tensors. (#70750)
Added:
Modified:
mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
index 51aacd8cc2438e4..9776361da480920 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
@@ -97,7 +97,7 @@ template <typename T>
inline RankedTensorType getRankedTensorType(T &&t) {
assert(static_cast<bool>(std::forward<T>(t)) &&
"getRankedTensorType got null argument");
- return cast<RankedTensorType>(std::forward<T>(t).getType());
+ return dyn_cast<RankedTensorType>(std::forward<T>(t).getType());
}
/// Convenience method to abbreviate casting `getType()`.
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
index 0761cbee5240733..1fd91d0c02e4d1b 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
@@ -336,11 +336,18 @@ class SparseTensorType {
const AffineMap lvlToDim;
};
-/// Convenience method to abbreviate wrapping `getRankedTensorType`.
+/// Convenience methods to abbreviate wrapping `getRankedTensorType`.
template <typename T>
inline SparseTensorType getSparseTensorType(T t) {
return SparseTensorType(getRankedTensorType(t));
}
+template <typename T>
+inline std::optional<SparseTensorType> tryGetSparseTensorType(T t) {
+ RankedTensorType rtp = getRankedTensorType(t);
+ if (rtp)
+ return SparseTensorType(rtp);
+ return std::nullopt;
+}
} // namespace sparse_tensor
} // namespace mlir
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
index 66fd2e4d94a28bd..5880f2158b8cd05 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
@@ -23,8 +23,44 @@ namespace {
// (2) rewrite linalg.generic ops traits on level crds
// (3) compute topsort, and resolve cyles with sparse_tensor.convert ops
+// CRTP to help implementing a rewriter that demaps all its inputs and remaps
+// all its outputs.
+template <typename SubClass, typename SourceOp>
+struct DemapInsRemapOutsRewriter : public OpRewritePattern<SourceOp> {
+ using OpRewritePattern<SourceOp>::OpRewritePattern;
+ using OpAdaptor = typename SourceOp::Adaptor;
+
+ LogicalResult matchAndRewrite(SourceOp op,
+ PatternRewriter &rewriter) const override {
+ if (!static_cast<const SubClass *>(this)->matchOp(op))
+ return failure();
+
+ Location loc = op.getLoc();
+ // Demaps non-trivial inputs.
+ SmallVector<Value> deMappedIns(op->getOperands());
+ for (Value &in : deMappedIns)
+ if (auto stt = tryGetSparseTensorType(in); stt && !stt->isIdentity())
+ in = rewriter.create<ReinterpretMapOp>(loc, stt->getDemappedType(), in);
+
+ // CRTP call.
+ OpAdaptor adaptor(deMappedIns);
+ ValueRange outs =
+ static_cast<const SubClass *>(this)->rewriteOp(op, adaptor, rewriter);
+ assert(outs.size() == op->getResults().size());
+
+ // Remap outputs.
+ SmallVector<Value> reMappedOuts(outs);
+ for (auto [r, a] : llvm::zip(reMappedOuts, op->getResults()))
+ if (r.getType() != a.getType())
+ r = rewriter.create<ReinterpretMapOp>(loc, a.getType(), r);
+
+ rewriter.replaceOp(op, reMappedOuts);
+ return success();
+ }
+};
+
//===----------------------------------------------------------------------===//
-// Reiterpret Map Rewriters for operations other than linalg.generics
+// Reinterpret Map Rewriters for operations other than linalg.generics
//===----------------------------------------------------------------------===//
struct CrdTranslateRewriter : public OpRewritePattern<CrdTranslateOp> {
@@ -34,6 +70,7 @@ struct CrdTranslateRewriter : public OpRewritePattern<CrdTranslateOp> {
AffineMap map = op.getDirection() == CrdTransDirectionKind::dim2lvl
? op.getEncoder().getDimToLvl()
: op.getEncoder().getLvlToDim();
+
SmallVector<Value> outCrds;
for (AffineExpr result : map.getResults()) {
// TODO: we should probably expand the affine map to IR using our own
@@ -49,24 +86,23 @@ struct CrdTranslateRewriter : public OpRewritePattern<CrdTranslateOp> {
}
};
-struct TensorInsertRewriter : public OpRewritePattern<tensor::InsertOp> {
- using OpRewritePattern::OpRewritePattern;
- LogicalResult matchAndRewrite(tensor::InsertOp op,
- PatternRewriter &rewriter) const override {
+struct TensorInsertRewriter
+ : public DemapInsRemapOutsRewriter<TensorInsertRewriter, tensor::InsertOp> {
+ using DemapInsRemapOutsRewriter::DemapInsRemapOutsRewriter;
- if (!op.getResult().getType().getEncoding())
- return failure();
+ bool matchOp(tensor::InsertOp op) const {
+ return op.getResult().getType().getEncoding() != nullptr;
+ }
+
+ ValueRange rewriteOp(tensor::InsertOp op, OpAdaptor adaptor,
+ PatternRewriter &rewriter) const {
Location loc = op.getLoc();
auto stt = getSparseTensorType(op.getResult());
ValueRange lvlCrd = stt.translateCrds(rewriter, loc, op.getIndices(),
CrdTransDirectionKind::dim2lvl);
-
- Value t = rewriter.create<ReinterpretMapOp>(
- loc, stt.getEncoding().withoutDimToLvl(), op.getDest());
- t = rewriter.create<sparse_tensor::InsertOp>(loc, op.getScalar(), t,
- lvlCrd);
- rewriter.replaceOpWithNewOp<ReinterpretMapOp>(op, op.getType(), t);
- return success();
+ Operation *insertOp = rewriter.create<sparse_tensor::InsertOp>(
+ loc, op.getScalar(), adaptor.getDest(), lvlCrd);
+ return insertOp->getResults();
}
};
More information about the Mlir-commits
mailing list