[Mlir-commits] [mlir] [mlir][sparse] end-to-end matmul between Dense and BSR tensors (PR #71448)

Yinying Li llvmlistbot at llvm.org
Tue Nov 7 09:02:05 PST 2023


================
@@ -16,27 +16,307 @@
 #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/AffineExprVisitor.h"
 #include "mlir/IR/AffineMap.h"
 
 using namespace mlir;
 using namespace mlir::sparse_tensor;
 
+namespace {
+
+//===----------------------------------------------------------------------===//
+// File Local Helper classes.
+//===----------------------------------------------------------------------===//
+
+// CRTP to help implementing a rewriter that demaps all its inputs.
+template <typename SubClass, typename SourceOp>
+struct DemapInsRewriter : public OpRewritePattern<SourceOp> {
+  using OpRewritePattern<SourceOp>::OpRewritePattern;
+  using OpAdaptor = typename SourceOp::Adaptor;
+
+  LogicalResult matchAndRewrite(SourceOp op,
+                                PatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    // Demaps non-trivial inputs.
+    SmallVector<Value> deMappedIns(op->getOperands());
+    for (Value &in : deMappedIns)
+      if (auto stt = tryGetSparseTensorType(in); stt && !stt->isIdentity())
+        in = rewriter.create<ReinterpretMapOp>(loc, stt->getDemappedType(), in);
+
+    // CRTP call.
+    OpAdaptor adaptor(deMappedIns, op);
+    return static_cast<const SubClass *>(this)->rewriteOp(op, adaptor,
+                                                          rewriter);
+  }
+};
+
+// Flattens an affine expression into a list of AffineDimExprs.
+struct AffineDimCollector : public AffineExprVisitor<AffineDimCollector> {
+  explicit AffineDimCollector(unsigned dimNum) : dims(dimNum){};
+  void visitDimExpr(AffineDimExpr expr) { dims.set(expr.getPosition()); }
+  BitVector dims;
+};
+
+// Flattens an affine expression into a list of AffineDimExprs.
+struct AffineExprAdmissibleVisitor
+    : public AffineExprVisitor<AffineExprAdmissibleVisitor> {
+  explicit AffineExprAdmissibleVisitor(bool isOutput)
+      : admissible(true), isOutput(isOutput){};
+
+  // We only allow AffineDimExpr on output.
+  void visitAddExpr(AffineBinaryOpExpr expr) {
+    if (isOutput)
+      admissible = false;
+  }
+  void visitMulExpr(AffineBinaryOpExpr expr) {
+    if (isOutput)
+      admissible = false;
+  }
+
+  // For input, mod, floor div and ceil div are not supported.
+  void visitModExpr(AffineBinaryOpExpr expr) { admissible = false; }
+  void visitFloorDivExpr(AffineBinaryOpExpr expr) { admissible = false; }
+  void visitCeilDivExpr(AffineBinaryOpExpr expr) { admissible = false; }
+  operator bool() { return admissible; }
+
+private:
+  bool admissible;
+  bool isOutput;
+};
+
+// The first BitVector stores levels where inadmissible exprs are used.
+// The second BitVector stores the AffineDimExp that are used by the
+// inadmissible expressions.
+using InadmissInfo = std::pair<BitVector, BitVector>;
+
+} // namespace
+
 //===----------------------------------------------------------------------===//
 // File Local Helper methods.
 //===----------------------------------------------------------------------===//
 
-// Translates a "simple" map according to an identity lvl-map.
-static AffineMap translateMap(OpBuilder &builder, SparseTensorType stt,
-                              AffineMap map) {
-  unsigned lvlRank = stt.getLvlRank();
-  AffineMap lvl2dim = stt.getLvlToDim();
-  assert(lvl2dim.getNumInputs() == lvlRank);
-  SmallVector<AffineExpr> exps;
-  for (unsigned i = 0, n = map.getNumResults(); i < n; i++) {
-    unsigned pos = map.getResult(i).cast<AffineDimExpr>().getPosition();
-    exps.push_back(lvl2dim.getResult(pos));
+static InadmissInfo collectInadmissInfo(AffineMap map, bool isOutput) {
+  auto ret = std::make_pair(BitVector(map.getNumResults()),
+                            BitVector(map.getNumDims()));
+  AffineDimCollector collector(map.getNumDims());
+  for (unsigned lvl = 0, e = map.getNumResults(); lvl < e; lvl++) {
+    AffineExprAdmissibleVisitor admissible(isOutput);
+    admissible.walkPostOrder(map.getResult(lvl));
+    if (!admissible) {
+      // Record the inadmissible level.
+      ret.first.set(lvl);
+      // Records the AffineDimExpr that is used in the inadmissible expr.
+      collector.walkPostOrder(map.getResult(lvl));
+    }
+  }
+  ret.second = collector.dims;
+  return ret;
+}
+
+// Build the AffineMap to replace the idx in idxMap to lvl such that all tht
+// inadmissible affine expressions can be eliminated.
+// For example, we can rewrite
+// idxMap = (d0, d1) -> (d0 floordiv 2, d1 floordiv 3, d0 mod 2, d1 mod 3)
+// to
+// idxMap = (l0, l1, l2, l3) -> (l0, l1, l2, l3)
+// by composing inverse(idxMap), that is
+// inverse(idxMap) . idxMap = (l0, l1, l2, l3) -> (l0 * 2 + l2, l1 * 3 + l3)
+//                         -> ((l0 * 2 + l2) floordiv 2,
+//                             (l1 * 3 + l3) floordiv 3,
+//                             (l0 * 2 + l2) mod 2,
+//                             (l1 * 3 + l3) mod 3) = (l0, l1, l2, l3)
+//
+// This function builds the inverse(idxMap) that replace every dimensions used
+// in `info` to levels, and updates the iterator type array `itTps` for the  new
+// index variable introduced.
+//
+// Note that the returned affine map does not retain the order of the input
+// affine map. Instead, it always used the first `info.inAdlvls.count()` for the
+// replaced levels, and remaining ones for unused dimensions.
+// For example, to handle
+// idxMap = (d0, d1) -> (d0, d1 floordiv 4, d2 mod 4)
+// which is a typical map for block_2to4. The function returns:
+// inverse(idxMap) = (l0, l1, d0) -> (d0, l0 * 4 + l1)
+// in which, (l0, l1) together replaces `d1`, yet they appears
+// before `d0` in the resulting affine map.
+// The the index (loop) order can later be canonicalized by a topo sort.
+static AffineMap
+genReplaceDimToLvlMap(const InadmissInfo &info, AffineMap idxMap,
+                      SmallVector<utils::IteratorType> &itTps) {
+  MLIRContext *ctx = idxMap.getContext();
+  auto [inAdLvls, usedDims] = info;
+  // Note that idxMap is not equal dim2Lvl map, it is computed by
+  // composing idx2Dim(dim2Lvl), they are only equal when idx2Dim is an
+  // ID map.
+  // TODO: we might fail here, in those case we should really return
+  // failure instead of assertion error.
+  auto lvl2Idx = inferLvlToDim(idxMap, ctx);
+
+  assert(lvl2Idx.getNumResults() <= idxMap.getNumDims());
+  if (lvl2Idx.getNumResults() != idxMap.getNumDims()) {
+    // This could happen when some dimensions are projected.
+    // E.g., idx2Lvl = (*i*, j, k) -> (j, k)
+    //   ==> lvl2Idx = (j, k) -> (j, k)
+    // In this case, we append the unused dimesion at the end.
+    //   ==> lvl2Idx = (j, k, *i*) -> (*i*, j, k)
+    SmallVector<AffineExpr> results;
+    AffineDimCollector usedInLvl(idxMap.getNumDims());
+    for (auto e : idxMap.getResults())
+      usedInLvl.walkPostOrder(e);
+
+    unsigned curUsedDimID = 0;
+    unsigned curUnusedDimID = lvl2Idx.getNumDims();
+
+    BitVector unused = usedInLvl.dims.flip();
+    for (unsigned i = 0; i < idxMap.getNumDims(); i++) {
+      if (unused.test(i))
+        results.push_back(getAffineDimExpr(curUnusedDimID++, ctx));
+      else
+        results.push_back(lvl2Idx.getResult(curUsedDimID++));
+    }
+    lvl2Idx =
+        AffineMap::get(lvl2Idx.getNumDims() + unused.count(), 0, results, ctx);
   }
-  return AffineMap::get(lvlRank, 0, exps, builder.getContext());
+  assert(lvl2Idx.getNumResults() == idxMap.getNumDims());
+
+  // We do not need to replace the DimExpr that is not used in Inadmissible
+  // level expressions. We use the first inAdLvl.count() dim to represent the
+  // replaced level, the remainings are used for unchanged ones.
+  unsigned curRepID = 0;
+  unsigned curOriID = inAdLvls.count();
+  // Since we changed the ordered of the AffineMap's dimention, we need to
+  // update the dimension here.
+  SmallVector<AffineExpr> results;
+  SmallVector<AffineExpr> dimRep(idxMap.getNumResults(), AffineExpr());
+  SmallVector<utils::IteratorType> transItTps;
+
+  for (unsigned l : inAdLvls.set_bits()) {
+    dimRep[l] = getAffineDimExpr(curRepID++, ctx);
+    AffineExpr lvlExp = idxMap.getResult(l);
+    AffineDimCollector collector(idxMap.getNumDims());
+    collector.walkPostOrder(lvlExp);
+    assert(collector.dims.count() == 1);
+    // Inherit the iterator type from the used idx.
+    transItTps.push_back(itTps[collector.dims.find_first()]);
+  }
+
+  for (unsigned d = 0, e = idxMap.getNumDims(); d < e; d++) {
+    if (usedDims.test(d))
+      results.push_back(lvl2Idx.getResult(d).replaceDims(dimRep));
+    else {
+      results.push_back(getAffineDimExpr(curOriID++, ctx));
+      transItTps.push_back(itTps[d]);
+    }
+  }
+  unsigned numDim = idxMap.getNumDims() - usedDims.count() + inAdLvls.count();
+  // Update iterator type.
+  itTps.assign(transItTps.begin(), transItTps.end());
+  return AffineMap::get(numDim, 0, results, ctx);
+}
+
+// Translates a the index map in the linalg::GenericOp from idx->dim map to
+// idx->lvl map. Returns failure if the index map can not be translated to an
+// admissible form.
+// The funciton also update the GenericOp's index map and iterator type array
+// *in-place*.
+static LogicalResult translateMap(linalg::GenericOp op,
+                                  PatternRewriter &rewriter) {
+  // idxMap is a idx2dim map before reinterpretation.
+  MLIRContext *ctx = op.getContext();
+  SmallVector<AffineMap> idxMapArray = op.getIndexingMapsArray();
+  SmallVector<utils::IteratorType> itTps = op.getIteratorTypesArray();
+  for (unsigned i = 0, e = idxMapArray.size(); i < e; i++) {
+    Value tensor = op->getOpOperand(i).get();
+    auto stt = tryGetSparseTensorType(tensor);
+    if (stt && !stt->isIdentity()) {
+      AffineMap dim2Lvl = stt->getDimToLvl();
+      // By composing the idx2dim(dim2lvl), we got a idx2lvl Map
+      idxMapArray[i] = dim2Lvl.compose(idxMapArray[i]);
+    }
+  }
+
+  // A naive way to handle common constant expressions that arise during dim2lvl
+  // translation.
+  auto populateCstMapping = [ctx](DenseMap<AffineExpr, AffineExpr> &cstMapping,
+                                  unsigned pos, int64_t lvlSz) {
+    if (!ShapedType::isDynamic(lvlSz)) {
+      auto c0 = getAffineConstantExpr(0, ctx);
+      auto lvlExp = getAffineDimExpr(pos, ctx);
+      auto szExp = getAffineConstantExpr(lvlSz, ctx);
+
+      // lvl floordiv lvlSz = 0
+      auto divExp =
+          getAffineBinaryOpExpr(AffineExprKind::FloorDiv, lvlExp, szExp);
+      cstMapping.try_emplace(divExp, c0);
+
+      // lvl mod lvlSz = lvl
+      auto modExp = getAffineBinaryOpExpr(AffineExprKind::Mod, lvlExp, szExp);
+      cstMapping.try_emplace(modExp, lvlExp);
+    }
+  };
+
+  unsigned boundedNum = 0;
+  // A fixed-point algorithm.
+  bool changed = true;
+  while (changed) {
+    changed = false;
+    for (OpOperand &operand : op->getOpOperands()) {
+      auto stt = tryGetSparseTensorType(operand.get());
+      // Skip on dense operands.
+      if (!stt || !stt->getEncoding())
+        continue;
+
+      unsigned tid = operand.getOperandNumber();
+      bool isOutput = &operand == op.getDpsInitOperand(0);
+      AffineMap idxMap = idxMapArray[tid];
+      InadmissInfo inAdInfo = collectInadmissInfo(idxMap, isOutput);
+      auto [inAdLvls, dimExprs] = inAdInfo;
+      for (unsigned d : dimExprs.set_bits()) {
+        // The first `boundedNum` used in the AffineMap is introduced to
+        // resolve previous inadmissible expressions. We can not replace them
+        // to bring back the inadmissible expressions.
+        if (d < boundedNum)
+          return failure();
+      }
+
+      if (inAdLvls.count() != 0) {
+        // Naive constant progagation, should be sufficient to handle block
+        // sparsity in our cases.
+        SmallVector<int64_t> lvlShape = stt->getLvlShape();
+        DenseMap<AffineExpr, AffineExpr> cstMapping;
+        unsigned position = 0;
+        for (unsigned lvl : inAdLvls.set_bits()) {
+          int64_t lvlSz = lvlShape[lvl];
+          populateCstMapping(cstMapping, position, lvlSz);
+          position++;
+        }
+
+        AffineMap lvl2Idx = genReplaceDimToLvlMap(inAdInfo, idxMap, itTps);
+        // Compose the lvl2Idx Map to all AffineIdxMap to eliminate
----------------
yinying-lisa-li wrote:

Composes

https://github.com/llvm/llvm-project/pull/71448


More information about the Mlir-commits mailing list