[Mlir-commits] [mlir] c81a2c0 - [mlir][sparse] add helper class to implement common rewriter to re/demap sparse tensors. (#70750)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Oct 31 12:35:56 PDT 2023


Author: Peiming Liu
Date: 2023-10-31T12:35:52-07:00
New Revision: c81a2c058e2b8a0bb8426cfaedc90e00d9290eb0

URL: https://github.com/llvm/llvm-project/commit/c81a2c058e2b8a0bb8426cfaedc90e00d9290eb0
DIFF: https://github.com/llvm/llvm-project/commit/c81a2c058e2b8a0bb8426cfaedc90e00d9290eb0.diff

LOG: [mlir][sparse] add helper class to implement common rewriter to re/demap sparse tensors. (#70750)

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
    mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
    mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
index 51aacd8cc2438e4..9776361da480920 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
@@ -97,7 +97,7 @@ template <typename T>
 inline RankedTensorType getRankedTensorType(T &&t) {
   assert(static_cast<bool>(std::forward<T>(t)) &&
          "getRankedTensorType got null argument");
-  return cast<RankedTensorType>(std::forward<T>(t).getType());
+  return dyn_cast<RankedTensorType>(std::forward<T>(t).getType());
 }
 
 /// Convenience method to abbreviate casting `getType()`.

diff  --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
index 0761cbee5240733..1fd91d0c02e4d1b 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
@@ -336,11 +336,18 @@ class SparseTensorType {
   const AffineMap lvlToDim;
 };
 
-/// Convenience method to abbreviate wrapping `getRankedTensorType`.
+/// Convenience methods to abbreviate wrapping `getRankedTensorType`.
 template <typename T>
 inline SparseTensorType getSparseTensorType(T t) {
   return SparseTensorType(getRankedTensorType(t));
 }
+template <typename T>
+inline std::optional<SparseTensorType> tryGetSparseTensorType(T t) {
+  RankedTensorType rtp = getRankedTensorType(t);
+  if (rtp)
+    return SparseTensorType(rtp);
+  return std::nullopt;
+}
 
 } // namespace sparse_tensor
 } // namespace mlir

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
index 66fd2e4d94a28bd..5880f2158b8cd05 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
@@ -23,8 +23,44 @@ namespace {
 //   (2) rewrite linalg.generic ops traits on level crds
 //   (3) compute topsort, and resolve cyles with sparse_tensor.convert ops
 
+// CRTP to help implementing a rewriter that demaps all its inputs and remaps
+// all its outputs.
+template <typename SubClass, typename SourceOp>
+struct DemapInsRemapOutsRewriter : public OpRewritePattern<SourceOp> {
+  using OpRewritePattern<SourceOp>::OpRewritePattern;
+  using OpAdaptor = typename SourceOp::Adaptor;
+
+  LogicalResult matchAndRewrite(SourceOp op,
+                                PatternRewriter &rewriter) const override {
+    if (!static_cast<const SubClass *>(this)->matchOp(op))
+      return failure();
+
+    Location loc = op.getLoc();
+    // Demaps non-trivial inputs.
+    SmallVector<Value> deMappedIns(op->getOperands());
+    for (Value &in : deMappedIns)
+      if (auto stt = tryGetSparseTensorType(in); stt && !stt->isIdentity())
+        in = rewriter.create<ReinterpretMapOp>(loc, stt->getDemappedType(), in);
+
+    // CRTP call.
+    OpAdaptor adaptor(deMappedIns);
+    ValueRange outs =
+        static_cast<const SubClass *>(this)->rewriteOp(op, adaptor, rewriter);
+    assert(outs.size() == op->getResults().size());
+
+    // Remap  outputs.
+    SmallVector<Value> reMappedOuts(outs);
+    for (auto [r, a] : llvm::zip(reMappedOuts, op->getResults()))
+      if (r.getType() != a.getType())
+        r = rewriter.create<ReinterpretMapOp>(loc, a.getType(), r);
+
+    rewriter.replaceOp(op, reMappedOuts);
+    return success();
+  }
+};
+
 //===----------------------------------------------------------------------===//
-// Reiterpret Map Rewriters for operations other than linalg.generics
+// Reinterpret Map Rewriters for operations other than linalg.generics
 //===----------------------------------------------------------------------===//
 
 struct CrdTranslateRewriter : public OpRewritePattern<CrdTranslateOp> {
@@ -34,6 +70,7 @@ struct CrdTranslateRewriter : public OpRewritePattern<CrdTranslateOp> {
     AffineMap map = op.getDirection() == CrdTransDirectionKind::dim2lvl
                         ? op.getEncoder().getDimToLvl()
                         : op.getEncoder().getLvlToDim();
+
     SmallVector<Value> outCrds;
     for (AffineExpr result : map.getResults()) {
       // TODO: we should probably expand the affine map to IR using our own
@@ -49,24 +86,23 @@ struct CrdTranslateRewriter : public OpRewritePattern<CrdTranslateOp> {
   }
 };
 
-struct TensorInsertRewriter : public OpRewritePattern<tensor::InsertOp> {
-  using OpRewritePattern::OpRewritePattern;
-  LogicalResult matchAndRewrite(tensor::InsertOp op,
-                                PatternRewriter &rewriter) const override {
+struct TensorInsertRewriter
+    : public DemapInsRemapOutsRewriter<TensorInsertRewriter, tensor::InsertOp> {
+  using DemapInsRemapOutsRewriter::DemapInsRemapOutsRewriter;
 
-    if (!op.getResult().getType().getEncoding())
-      return failure();
+  bool matchOp(tensor::InsertOp op) const {
+    return op.getResult().getType().getEncoding() != nullptr;
+  }
+
+  ValueRange rewriteOp(tensor::InsertOp op, OpAdaptor adaptor,
+                       PatternRewriter &rewriter) const {
     Location loc = op.getLoc();
     auto stt = getSparseTensorType(op.getResult());
     ValueRange lvlCrd = stt.translateCrds(rewriter, loc, op.getIndices(),
                                           CrdTransDirectionKind::dim2lvl);
-
-    Value t = rewriter.create<ReinterpretMapOp>(
-        loc, stt.getEncoding().withoutDimToLvl(), op.getDest());
-    t = rewriter.create<sparse_tensor::InsertOp>(loc, op.getScalar(), t,
-                                                 lvlCrd);
-    rewriter.replaceOpWithNewOp<ReinterpretMapOp>(op, op.getType(), t);
-    return success();
+    Operation *insertOp = rewriter.create<sparse_tensor::InsertOp>(
+        loc, op.getScalar(), adaptor.getDest(), lvlCrd);
+    return insertOp->getResults();
   }
 };
 


        


More information about the Mlir-commits mailing list