[Mlir-commits] [mlir] [mlir][sparse] end-to-end matmul between Dense and BSR tensors (PR #71448)

Yinying Li llvmlistbot at llvm.org
Tue Nov 7 09:02:04 PST 2023


================
@@ -16,27 +16,307 @@
 #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/AffineExprVisitor.h"
 #include "mlir/IR/AffineMap.h"
 
 using namespace mlir;
 using namespace mlir::sparse_tensor;
 
+namespace {
+
+//===----------------------------------------------------------------------===//
+// File Local Helper classes.
+//===----------------------------------------------------------------------===//
+
+// CRTP to help implementing a rewriter that demaps all its inputs.
+template <typename SubClass, typename SourceOp>
+struct DemapInsRewriter : public OpRewritePattern<SourceOp> {
+  using OpRewritePattern<SourceOp>::OpRewritePattern;
+  using OpAdaptor = typename SourceOp::Adaptor;
+
+  LogicalResult matchAndRewrite(SourceOp op,
+                                PatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    // Demaps non-trivial inputs.
+    SmallVector<Value> deMappedIns(op->getOperands());
+    for (Value &in : deMappedIns)
+      if (auto stt = tryGetSparseTensorType(in); stt && !stt->isIdentity())
+        in = rewriter.create<ReinterpretMapOp>(loc, stt->getDemappedType(), in);
+
+    // CRTP call.
+    OpAdaptor adaptor(deMappedIns, op);
+    return static_cast<const SubClass *>(this)->rewriteOp(op, adaptor,
+                                                          rewriter);
+  }
+};
+
+// Flattens an affine expression into a list of AffineDimExprs.
+struct AffineDimCollector : public AffineExprVisitor<AffineDimCollector> {
+  explicit AffineDimCollector(unsigned dimNum) : dims(dimNum){};
+  void visitDimExpr(AffineDimExpr expr) { dims.set(expr.getPosition()); }
+  BitVector dims;
+};
+
+// Flattens an affine expression into a list of AffineDimExprs.
+struct AffineExprAdmissibleVisitor
+    : public AffineExprVisitor<AffineExprAdmissibleVisitor> {
+  explicit AffineExprAdmissibleVisitor(bool isOutput)
+      : admissible(true), isOutput(isOutput){};
+
+  // We only allow AffineDimExpr on output.
+  void visitAddExpr(AffineBinaryOpExpr expr) {
+    if (isOutput)
+      admissible = false;
+  }
+  void visitMulExpr(AffineBinaryOpExpr expr) {
+    if (isOutput)
+      admissible = false;
+  }
+
+  // For input, mod, floor div and ceil div are not supported.
+  void visitModExpr(AffineBinaryOpExpr expr) { admissible = false; }
+  void visitFloorDivExpr(AffineBinaryOpExpr expr) { admissible = false; }
+  void visitCeilDivExpr(AffineBinaryOpExpr expr) { admissible = false; }
+  operator bool() { return admissible; }
+
+private:
+  bool admissible;
+  bool isOutput;
+};
+
+// The first BitVector stores levels where inadmissible exprs are used.
+// The second BitVector stores the AffineDimExp that are used by the
+// inadmissible expressions.
+using InadmissInfo = std::pair<BitVector, BitVector>;
+
+} // namespace
+
 //===----------------------------------------------------------------------===//
 // File Local Helper methods.
 //===----------------------------------------------------------------------===//
 
-// Translates a "simple" map according to an identity lvl-map.
-static AffineMap translateMap(OpBuilder &builder, SparseTensorType stt,
-                              AffineMap map) {
-  unsigned lvlRank = stt.getLvlRank();
-  AffineMap lvl2dim = stt.getLvlToDim();
-  assert(lvl2dim.getNumInputs() == lvlRank);
-  SmallVector<AffineExpr> exps;
-  for (unsigned i = 0, n = map.getNumResults(); i < n; i++) {
-    unsigned pos = map.getResult(i).cast<AffineDimExpr>().getPosition();
-    exps.push_back(lvl2dim.getResult(pos));
+static InadmissInfo collectInadmissInfo(AffineMap map, bool isOutput) {
+  auto ret = std::make_pair(BitVector(map.getNumResults()),
+                            BitVector(map.getNumDims()));
+  AffineDimCollector collector(map.getNumDims());
+  for (unsigned lvl = 0, e = map.getNumResults(); lvl < e; lvl++) {
+    AffineExprAdmissibleVisitor admissible(isOutput);
+    admissible.walkPostOrder(map.getResult(lvl));
+    if (!admissible) {
+      // Record the inadmissible level.
----------------
yinying-lisa-li wrote:

nit: Records

https://github.com/llvm/llvm-project/pull/71448


More information about the Mlir-commits mailing list