[Mlir-commits] [mlir] [mlir][sparse] end-to-end matmul between Dense and BSR tensors (PR #71448)

Aart Bik llvmlistbot at llvm.org
Tue Nov 7 09:37:21 PST 2023


================
@@ -16,27 +16,307 @@
 #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/AffineExprVisitor.h"
 #include "mlir/IR/AffineMap.h"
 
 using namespace mlir;
 using namespace mlir::sparse_tensor;
 
+namespace {
+
+//===----------------------------------------------------------------------===//
+// File Local Helper classes.
+//===----------------------------------------------------------------------===//
+
+// CRTP to help implementing a rewriter that demaps all its inputs.
+template <typename SubClass, typename SourceOp>
+struct DemapInsRewriter : public OpRewritePattern<SourceOp> {
+  using OpRewritePattern<SourceOp>::OpRewritePattern;
+  using OpAdaptor = typename SourceOp::Adaptor;
+
+  LogicalResult matchAndRewrite(SourceOp op,
+                                PatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    // Demaps non-trivial inputs.
+    SmallVector<Value> deMappedIns(op->getOperands());
+    for (Value &in : deMappedIns)
+      if (auto stt = tryGetSparseTensorType(in); stt && !stt->isIdentity())
+        in = rewriter.create<ReinterpretMapOp>(loc, stt->getDemappedType(), in);
+
+    // CRTP call.
+    OpAdaptor adaptor(deMappedIns, op);
+    return static_cast<const SubClass *>(this)->rewriteOp(op, adaptor,
+                                                          rewriter);
+  }
+};
+
+// Flattens an affine expression into a list of AffineDimExprs.
+struct AffineDimCollector : public AffineExprVisitor<AffineDimCollector> {
+  explicit AffineDimCollector(unsigned dimNum) : dims(dimNum){};
+  void visitDimExpr(AffineDimExpr expr) { dims.set(expr.getPosition()); }
+  BitVector dims;
+};
+
+// Flattens an affine expression into a list of AffineDimExprs.
+struct AffineExprAdmissibleVisitor
+    : public AffineExprVisitor<AffineExprAdmissibleVisitor> {
+  explicit AffineExprAdmissibleVisitor(bool isOutput)
+      : admissible(true), isOutput(isOutput){};
+
+  // We only allow AffineDimExpr on output.
+  void visitAddExpr(AffineBinaryOpExpr expr) {
+    if (isOutput)
+      admissible = false;
+  }
+  void visitMulExpr(AffineBinaryOpExpr expr) {
+    if (isOutput)
+      admissible = false;
+  }
+
+  // For input, mod, floor div and ceil div are not supported.
+  void visitModExpr(AffineBinaryOpExpr expr) { admissible = false; }
+  void visitFloorDivExpr(AffineBinaryOpExpr expr) { admissible = false; }
+  void visitCeilDivExpr(AffineBinaryOpExpr expr) { admissible = false; }
+  operator bool() { return admissible; }
+
+private:
+  bool admissible;
+  bool isOutput;
+};
+
+// The first BitVector stores levels where inadmissible exprs are used.
+// The second BitVector stores the AffineDimExp that are used by the
+// inadmissible expressions.
+using InadmissInfo = std::pair<BitVector, BitVector>;
+
+} // namespace
+
 //===----------------------------------------------------------------------===//
 // File Local Helper methods.
 //===----------------------------------------------------------------------===//
 
-// Translates a "simple" map according to an identity lvl-map.
-static AffineMap translateMap(OpBuilder &builder, SparseTensorType stt,
-                              AffineMap map) {
-  unsigned lvlRank = stt.getLvlRank();
-  AffineMap lvl2dim = stt.getLvlToDim();
-  assert(lvl2dim.getNumInputs() == lvlRank);
-  SmallVector<AffineExpr> exps;
-  for (unsigned i = 0, n = map.getNumResults(); i < n; i++) {
-    unsigned pos = map.getResult(i).cast<AffineDimExpr>().getPosition();
-    exps.push_back(lvl2dim.getResult(pos));
+static InadmissInfo collectInadmissInfo(AffineMap map, bool isOutput) {
+  auto ret = std::make_pair(BitVector(map.getNumResults()),
+                            BitVector(map.getNumDims()));
+  AffineDimCollector collector(map.getNumDims());
+  for (unsigned lvl = 0, e = map.getNumResults(); lvl < e; lvl++) {
+    AffineExprAdmissibleVisitor admissible(isOutput);
+    admissible.walkPostOrder(map.getResult(lvl));
+    if (!admissible) {
+      // Record the inadmissible level.
+      ret.first.set(lvl);
+      // Records the AffineDimExpr that is used in the inadmissible expr.
+      collector.walkPostOrder(map.getResult(lvl));
+    }
+  }
+  ret.second = collector.dims;
+  return ret;
+}
+
+// Build the AffineMap to replace the idx in idxMap to lvl such that all tht
+// inadmissible affine expressions can be eliminated.
+// For example, we can rewrite
+// idxMap = (d0, d1) -> (d0 floordiv 2, d1 floordiv 3, d0 mod 2, d1 mod 3)
+// to
+// idxMap = (l0, l1, l2, l3) -> (l0, l1, l2, l3)
+// by composing inverse(idxMap), that is
+// inverse(idxMap) . idxMap = (l0, l1, l2, l3) -> (l0 * 2 + l2, l1 * 3 + l3)
+//                         -> ((l0 * 2 + l2) floordiv 2,
+//                             (l1 * 3 + l3) floordiv 3,
+//                             (l0 * 2 + l2) mod 2,
+//                             (l1 * 3 + l3) mod 3) = (l0, l1, l2, l3)
+//
+// This function builds the inverse(idxMap) that replace every dimensions used
+// in `info` to levels, and updates the iterator type array `itTps` for the  new
+// index variable introduced.
+//
+// Note that the returned affine map does not retain the order of the input
+// affine map. Instead, it always used the first `info.inAdlvls.count()` for the
+// replaced levels, and remaining ones for unused dimensions.
+// For example, to handle
+// idxMap = (d0, d1) -> (d0, d1 floordiv 4, d2 mod 4)
+// which is a typical map for block_2to4. The function returns:
+// inverse(idxMap) = (l0, l1, d0) -> (d0, l0 * 4 + l1)
+// in which, (l0, l1) together replaces `d1`, yet they appears
+// before `d0` in the resulting affine map.
+// The the index (loop) order can later be canonicalized by a topo sort.
+static AffineMap
+genReplaceDimToLvlMap(const InadmissInfo &info, AffineMap idxMap,
+                      SmallVector<utils::IteratorType> &itTps) {
+  MLIRContext *ctx = idxMap.getContext();
+  auto [inAdLvls, usedDims] = info;
+  // Note that idxMap is not equal dim2Lvl map, it is computed by
+  // composing idx2Dim(dim2Lvl), they are only equal when idx2Dim is an
+  // ID map.
+  // TODO: we might fail here, in those case we should really return
+  // failure instead of assertion error.
+  auto lvl2Idx = inferLvlToDim(idxMap, ctx);
+
+  assert(lvl2Idx.getNumResults() <= idxMap.getNumDims());
+  if (lvl2Idx.getNumResults() != idxMap.getNumDims()) {
+    // This could happen when some dimensions are projected.
+    // E.g., idx2Lvl = (*i*, j, k) -> (j, k)
+    //   ==> lvl2Idx = (j, k) -> (j, k)
+    // In this case, we append the unused dimesion at the end.
+    //   ==> lvl2Idx = (j, k, *i*) -> (*i*, j, k)
+    SmallVector<AffineExpr> results;
+    AffineDimCollector usedInLvl(idxMap.getNumDims());
+    for (auto e : idxMap.getResults())
+      usedInLvl.walkPostOrder(e);
+
+    unsigned curUsedDimID = 0;
+    unsigned curUnusedDimID = lvl2Idx.getNumDims();
+
+    BitVector unused = usedInLvl.dims.flip();
+    for (unsigned i = 0; i < idxMap.getNumDims(); i++) {
+      if (unused.test(i))
+        results.push_back(getAffineDimExpr(curUnusedDimID++, ctx));
+      else
+        results.push_back(lvl2Idx.getResult(curUsedDimID++));
+    }
+    lvl2Idx =
+        AffineMap::get(lvl2Idx.getNumDims() + unused.count(), 0, results, ctx);
   }
-  return AffineMap::get(lvlRank, 0, exps, builder.getContext());
+  assert(lvl2Idx.getNumResults() == idxMap.getNumDims());
+
+  // We do not need to replace the DimExpr that is not used in Inadmissible
+  // level expressions. We use the first inAdLvl.count() dim to represent the
+  // replaced level, the remainings are used for unchanged ones.
+  unsigned curRepID = 0;
+  unsigned curOriID = inAdLvls.count();
+  // Since we changed the ordered of the AffineMap's dimention, we need to
+  // update the dimension here.
+  SmallVector<AffineExpr> results;
+  SmallVector<AffineExpr> dimRep(idxMap.getNumResults(), AffineExpr());
+  SmallVector<utils::IteratorType> transItTps;
+
+  for (unsigned l : inAdLvls.set_bits()) {
+    dimRep[l] = getAffineDimExpr(curRepID++, ctx);
+    AffineExpr lvlExp = idxMap.getResult(l);
+    AffineDimCollector collector(idxMap.getNumDims());
+    collector.walkPostOrder(lvlExp);
+    assert(collector.dims.count() == 1);
+    // Inherit the iterator type from the used idx.
+    transItTps.push_back(itTps[collector.dims.find_first()]);
+  }
+
+  for (unsigned d = 0, e = idxMap.getNumDims(); d < e; d++) {
+    if (usedDims.test(d))
+      results.push_back(lvl2Idx.getResult(d).replaceDims(dimRep));
+    else {
+      results.push_back(getAffineDimExpr(curOriID++, ctx));
+      transItTps.push_back(itTps[d]);
+    }
+  }
+  unsigned numDim = idxMap.getNumDims() - usedDims.count() + inAdLvls.count();
+  // Update iterator type.
+  itTps.assign(transItTps.begin(), transItTps.end());
+  return AffineMap::get(numDim, 0, results, ctx);
+}
+
+// Translates a the index map in the linalg::GenericOp from idx->dim map to
+// idx->lvl map. Returns failure if the index map can not be translated to an
+// admissible form.
+// The funciton also update the GenericOp's index map and iterator type array
+// *in-place*.
+static LogicalResult translateMap(linalg::GenericOp op,
+                                  PatternRewriter &rewriter) {
+  // idxMap is a idx2dim map before reinterpretation.
+  MLIRContext *ctx = op.getContext();
+  SmallVector<AffineMap> idxMapArray = op.getIndexingMapsArray();
+  SmallVector<utils::IteratorType> itTps = op.getIteratorTypesArray();
+  for (unsigned i = 0, e = idxMapArray.size(); i < e; i++) {
+    Value tensor = op->getOpOperand(i).get();
+    auto stt = tryGetSparseTensorType(tensor);
+    if (stt && !stt->isIdentity()) {
+      AffineMap dim2Lvl = stt->getDimToLvl();
+      // By composing the idx2dim(dim2lvl), we got a idx2lvl Map
+      idxMapArray[i] = dim2Lvl.compose(idxMapArray[i]);
+    }
+  }
+
+  // A naive way to handle common constant expressions that arise during dim2lvl
+  // translation.
+  auto populateCstMapping = [ctx](DenseMap<AffineExpr, AffineExpr> &cstMapping,
+                                  unsigned pos, int64_t lvlSz) {
+    if (!ShapedType::isDynamic(lvlSz)) {
+      auto c0 = getAffineConstantExpr(0, ctx);
+      auto lvlExp = getAffineDimExpr(pos, ctx);
+      auto szExp = getAffineConstantExpr(lvlSz, ctx);
+
+      // lvl floordiv lvlSz = 0
+      auto divExp =
+          getAffineBinaryOpExpr(AffineExprKind::FloorDiv, lvlExp, szExp);
+      cstMapping.try_emplace(divExp, c0);
+
+      // lvl mod lvlSz = lvl
+      auto modExp = getAffineBinaryOpExpr(AffineExprKind::Mod, lvlExp, szExp);
+      cstMapping.try_emplace(modExp, lvlExp);
+    }
+  };
+
+  unsigned boundedNum = 0;
+  // A fixed-point algorithm.
+  bool changed = true;
+  while (changed) {
+    changed = false;
+    for (OpOperand &operand : op->getOpOperands()) {
+      auto stt = tryGetSparseTensorType(operand.get());
+      // Skip on dense operands.
+      if (!stt || !stt->getEncoding())
+        continue;
+
+      unsigned tid = operand.getOperandNumber();
+      bool isOutput = &operand == op.getDpsInitOperand(0);
+      AffineMap idxMap = idxMapArray[tid];
+      InadmissInfo inAdInfo = collectInadmissInfo(idxMap, isOutput);
+      auto [inAdLvls, dimExprs] = inAdInfo;
+      for (unsigned d : dimExprs.set_bits()) {
+        // The first `boundedNum` used in the AffineMap is introduced to
+        // resolve previous inadmissible expressions. We can not replace them
+        // to bring back the inadmissible expressions.
+        if (d < boundedNum)
+          return failure();
+      }
+
+      if (inAdLvls.count() != 0) {
+        // Naive constant progagation, should be sufficient to handle block
+        // sparsity in our cases.
+        SmallVector<int64_t> lvlShape = stt->getLvlShape();
+        DenseMap<AffineExpr, AffineExpr> cstMapping;
+        unsigned position = 0;
+        for (unsigned lvl : inAdLvls.set_bits()) {
+          int64_t lvlSz = lvlShape[lvl];
+          populateCstMapping(cstMapping, position, lvlSz);
+          position++;
+        }
+
+        AffineMap lvl2Idx = genReplaceDimToLvlMap(inAdInfo, idxMap, itTps);
+        // Compose the lvl2Idx Map to all AffineIdxMap to eliminate
+        // inadmissible expressions.
+        for (unsigned tid = 0, e = idxMapArray.size(); tid < e; tid++) {
+          AffineMap transMap = idxMapArray[tid].compose(lvl2Idx);
+          idxMapArray[tid] = transMap.replace(
+              cstMapping, /*numResultDims=*/transMap.getNumDims(),
+              /*numResultSyms=*/0);
+        }
+        changed = true;
+        boundedNum += inAdLvls.count();
+      }
+    }
+  };
+
+  SmallVector<Attribute> iterAttr =
+      llvm::map_to_vector(itTps, [ctx](auto itTp) -> Attribute {
+        return linalg::IteratorTypeAttr::get(ctx, itTp);
+      });
+
+  rewriter.startRootUpdate(op);
----------------
aartbik wrote:

I actually like the pattern that you started

if (!translate)
   return failure()

update lingalg op in place

but this code here breaks the flow. Do you see chance to make the "translate" function purer, and move ALL the replacement on success to the rewriting rule below?

https://github.com/llvm/llvm-project/pull/71448


More information about the Mlir-commits mailing list