[Mlir-commits] [mlir] [mlir][sparse] end-to-end matmul between Dense and BSR tensors (PR #71448)
Yinying Li
llvmlistbot at llvm.org
Tue Nov 7 09:02:04 PST 2023
================
@@ -16,27 +16,307 @@
#include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/AffineExprVisitor.h"
#include "mlir/IR/AffineMap.h"
using namespace mlir;
using namespace mlir::sparse_tensor;
+namespace {
+
+//===----------------------------------------------------------------------===//
+// File Local Helper classes.
+//===----------------------------------------------------------------------===//
+
+// CRTP to help implementing a rewriter that demaps all its inputs.
+template <typename SubClass, typename SourceOp>
+struct DemapInsRewriter : public OpRewritePattern<SourceOp> {
+ using OpRewritePattern<SourceOp>::OpRewritePattern;
+ using OpAdaptor = typename SourceOp::Adaptor;
+
+ LogicalResult matchAndRewrite(SourceOp op,
+ PatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ // Demaps non-trivial inputs.
+ SmallVector<Value> deMappedIns(op->getOperands());
+ for (Value &in : deMappedIns)
+ if (auto stt = tryGetSparseTensorType(in); stt && !stt->isIdentity())
+ in = rewriter.create<ReinterpretMapOp>(loc, stt->getDemappedType(), in);
+
+ // CRTP call.
+ OpAdaptor adaptor(deMappedIns, op);
+ return static_cast<const SubClass *>(this)->rewriteOp(op, adaptor,
+ rewriter);
+ }
+};
+
+// Flattens an affine expression into a list of AffineDimExprs.
+struct AffineDimCollector : public AffineExprVisitor<AffineDimCollector> {
+ explicit AffineDimCollector(unsigned dimNum) : dims(dimNum){};
+ void visitDimExpr(AffineDimExpr expr) { dims.set(expr.getPosition()); }
+ BitVector dims;
+};
+
+// Flattens an affine expression into a list of AffineDimExprs.
+struct AffineExprAdmissibleVisitor
+ : public AffineExprVisitor<AffineExprAdmissibleVisitor> {
+ explicit AffineExprAdmissibleVisitor(bool isOutput)
+ : admissible(true), isOutput(isOutput){};
+
+ // We only allow AffineDimExpr on output.
+ void visitAddExpr(AffineBinaryOpExpr expr) {
+ if (isOutput)
+ admissible = false;
+ }
+ void visitMulExpr(AffineBinaryOpExpr expr) {
+ if (isOutput)
+ admissible = false;
+ }
+
+ // For input, mod, floor div and ceil div are not supported.
+ void visitModExpr(AffineBinaryOpExpr expr) { admissible = false; }
+ void visitFloorDivExpr(AffineBinaryOpExpr expr) { admissible = false; }
+ void visitCeilDivExpr(AffineBinaryOpExpr expr) { admissible = false; }
+ operator bool() { return admissible; }
+
+private:
+ bool admissible;
+ bool isOutput;
+};
+
+// The first BitVector stores levels where inadmissible exprs are used.
+// The second BitVector stores the AffineDimExp that are used by the
+// inadmissible expressions.
+using InadmissInfo = std::pair<BitVector, BitVector>;
+
+} // namespace
+
//===----------------------------------------------------------------------===//
// File Local Helper methods.
//===----------------------------------------------------------------------===//
-// Translates a "simple" map according to an identity lvl-map.
-static AffineMap translateMap(OpBuilder &builder, SparseTensorType stt,
- AffineMap map) {
- unsigned lvlRank = stt.getLvlRank();
- AffineMap lvl2dim = stt.getLvlToDim();
- assert(lvl2dim.getNumInputs() == lvlRank);
- SmallVector<AffineExpr> exps;
- for (unsigned i = 0, n = map.getNumResults(); i < n; i++) {
- unsigned pos = map.getResult(i).cast<AffineDimExpr>().getPosition();
- exps.push_back(lvl2dim.getResult(pos));
+static InadmissInfo collectInadmissInfo(AffineMap map, bool isOutput) {
+ auto ret = std::make_pair(BitVector(map.getNumResults()),
+ BitVector(map.getNumDims()));
+ AffineDimCollector collector(map.getNumDims());
+ for (unsigned lvl = 0, e = map.getNumResults(); lvl < e; lvl++) {
+ AffineExprAdmissibleVisitor admissible(isOutput);
+ admissible.walkPostOrder(map.getResult(lvl));
+ if (!admissible) {
+ // Record the inadmissible level.
+ ret.first.set(lvl);
+ // Records the AffineDimExpr that is used in the inadmissible expr.
+ collector.walkPostOrder(map.getResult(lvl));
+ }
+ }
+ ret.second = collector.dims;
+ return ret;
+}
+
+// Build the AffineMap to replace the idx in idxMap to lvl such that all tht
+// inadmissible affine expressions can be eliminated.
+// For example, we can rewrite
+// idxMap = (d0, d1) -> (d0 floordiv 2, d1 floordiv 3, d0 mod 2, d1 mod 3)
+// to
+// idxMap = (l0, l1, l2, l3) -> (l0, l1, l2, l3)
+// by composing inverse(idxMap), that is
+// inverse(idxMap) . idxMap = (l0, l1, l2, l3) -> (l0 * 2 + l2, l1 * 3 + l3)
+// -> ((l0 * 2 + l2) floordiv 2,
+// (l1 * 3 + l3) floordiv 3,
+// (l0 * 2 + l2) mod 2,
+// (l1 * 3 + l3) mod 3) = (l0, l1, l2, l3)
+//
+// This function builds the inverse(idxMap) that replace every dimensions used
+// in `info` to levels, and updates the iterator type array `itTps` for the new
+// index variable introduced.
+//
+// Note that the returned affine map does not retain the order of the input
+// affine map. Instead, it always used the first `info.inAdlvls.count()` for the
+// replaced levels, and remaining ones for unused dimensions.
+// For example, to handle
+// idxMap = (d0, d1) -> (d0, d1 floordiv 4, d2 mod 4)
+// which is a typical map for block_2to4. The function returns:
+// inverse(idxMap) = (l0, l1, d0) -> (d0, l0 * 4 + l1)
+// in which, (l0, l1) together replaces `d1`, yet they appears
+// before `d0` in the resulting affine map.
+// The the index (loop) order can later be canonicalized by a topo sort.
+static AffineMap
+genReplaceDimToLvlMap(const InadmissInfo &info, AffineMap idxMap,
+ SmallVector<utils::IteratorType> &itTps) {
+ MLIRContext *ctx = idxMap.getContext();
+ auto [inAdLvls, usedDims] = info;
+ // Note that idxMap is not equal dim2Lvl map, it is computed by
+ // composing idx2Dim(dim2Lvl), they are only equal when idx2Dim is an
----------------
yinying-lisa-li wrote:
. They
https://github.com/llvm/llvm-project/pull/71448
More information about the Mlir-commits
mailing list