[Mlir-commits] [mlir] [mlir][sparse] end-to-end matmul between Dense and BSR tensors (PR #71448)

Peiming Liu llvmlistbot at llvm.org
Tue Nov 7 11:16:14 PST 2023


https://github.com/PeimingLiu updated https://github.com/llvm/llvm-project/pull/71448

>From e3ec3c433ef08239acf8ec39a8d8e59cc0ab0a82 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Mon, 6 Nov 2023 22:03:00 +0000
Subject: [PATCH 1/3] [mlir][sparse] end-to-end matmul between Dense and BSR
 tensors

---
 .../Transforms/SparseReinterpretMap.cpp       | 444 +++++++++++++-----
 .../SparsificationAndBufferizationPass.cpp    |   6 +-
 .../SparseTensor/sparse_reinterpret_map.mlir  |   8 +-
 .../SparseTensor/CPU/sparse_block_matmul.mlir | 124 +++++
 4 files changed, 458 insertions(+), 124 deletions(-)
 create mode 100644 mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_block_matmul.mlir

diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
index 307a609fd1b7746..964786e35f72321 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
@@ -16,27 +16,307 @@
 #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/AffineExprVisitor.h"
 #include "mlir/IR/AffineMap.h"
 
 using namespace mlir;
 using namespace mlir::sparse_tensor;
 
+namespace {
+
+//===----------------------------------------------------------------------===//
+// File Local Helper classes.
+//===----------------------------------------------------------------------===//
+
+// CRTP to help implementing a rewriter that demaps all its inputs.
+template <typename SubClass, typename SourceOp>
+struct DemapInsRewriter : public OpRewritePattern<SourceOp> {
+  using OpRewritePattern<SourceOp>::OpRewritePattern;
+  using OpAdaptor = typename SourceOp::Adaptor;
+
+  LogicalResult matchAndRewrite(SourceOp op,
+                                PatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    // Demaps non-trivial inputs.
+    SmallVector<Value> deMappedIns(op->getOperands());
+    for (Value &in : deMappedIns)
+      if (auto stt = tryGetSparseTensorType(in); stt && !stt->isIdentity())
+        in = rewriter.create<ReinterpretMapOp>(loc, stt->getDemappedType(), in);
+
+    // CRTP call.
+    OpAdaptor adaptor(deMappedIns, op);
+    return static_cast<const SubClass *>(this)->rewriteOp(op, adaptor,
+                                                          rewriter);
+  }
+};
+
+// Flattens an affine expression into a list of AffineDimExprs.
+struct AffineDimCollector : public AffineExprVisitor<AffineDimCollector> {
+  explicit AffineDimCollector(unsigned dimNum) : dims(dimNum){};
+  void visitDimExpr(AffineDimExpr expr) { dims.set(expr.getPosition()); }
+  BitVector dims;
+};
+
+// Flattens an affine expression into a list of AffineDimExprs.
+struct AffineExprAdmissibleVisitor
+    : public AffineExprVisitor<AffineExprAdmissibleVisitor> {
+  explicit AffineExprAdmissibleVisitor(bool isOutput)
+      : admissible(true), isOutput(isOutput){};
+
+  // We only allow AffineDimExpr on output.
+  void visitAddExpr(AffineBinaryOpExpr expr) {
+    if (isOutput)
+      admissible = false;
+  }
+  void visitMulExpr(AffineBinaryOpExpr expr) {
+    if (isOutput)
+      admissible = false;
+  }
+
+  // For input, mod, floor div and ceil div are not supported.
+  void visitModExpr(AffineBinaryOpExpr expr) { admissible = false; }
+  void visitFloorDivExpr(AffineBinaryOpExpr expr) { admissible = false; }
+  void visitCeilDivExpr(AffineBinaryOpExpr expr) { admissible = false; }
+  operator bool() { return admissible; }
+
+private:
+  bool admissible;
+  bool isOutput;
+};
+
+// The first BitVector stores levels where inadmissible exprs are used.
+// The second BitVector stores the AffineDimExp that are used by the
+// inadmissible expressions.
+using InadmissInfo = std::pair<BitVector, BitVector>;
+
+} // namespace
+
 //===----------------------------------------------------------------------===//
 // File Local Helper methods.
 //===----------------------------------------------------------------------===//
 
-// Translates a "simple" map according to an identity lvl-map.
-static AffineMap translateMap(OpBuilder &builder, SparseTensorType stt,
-                              AffineMap map) {
-  unsigned lvlRank = stt.getLvlRank();
-  AffineMap lvl2dim = stt.getLvlToDim();
-  assert(lvl2dim.getNumInputs() == lvlRank);
-  SmallVector<AffineExpr> exps;
-  for (unsigned i = 0, n = map.getNumResults(); i < n; i++) {
-    unsigned pos = map.getResult(i).cast<AffineDimExpr>().getPosition();
-    exps.push_back(lvl2dim.getResult(pos));
+static InadmissInfo collectInadmissInfo(AffineMap map, bool isOutput) {
+  auto ret = std::make_pair(BitVector(map.getNumResults()),
+                            BitVector(map.getNumDims()));
+  AffineDimCollector collector(map.getNumDims());
+  for (unsigned lvl = 0, e = map.getNumResults(); lvl < e; lvl++) {
+    AffineExprAdmissibleVisitor admissible(isOutput);
+    admissible.walkPostOrder(map.getResult(lvl));
+    if (!admissible) {
+      // Record the inadmissible level.
+      ret.first.set(lvl);
+      // Records the AffineDimExpr that is used in the inadmissible expr.
+      collector.walkPostOrder(map.getResult(lvl));
+    }
+  }
+  ret.second = collector.dims;
+  return ret;
+}
+
+// Build the AffineMap to replace the idx in idxMap to lvl such that all tht
+// inadmissible affine expressions can be eliminated.
+// For example, we can rewrite
+// idxMap = (d0, d1) -> (d0 floordiv 2, d1 floordiv 3, d0 mod 2, d1 mod 3)
+// to
+// idxMap = (l0, l1, l2, l3) -> (l0, l1, l2, l3)
+// by composing inverse(idxMap), that is
+// inverse(idxMap) . idxMap = (l0, l1, l2, l3) -> (l0 * 2 + l2, l1 * 3 + l3)
+//                         -> ((l0 * 2 + l2) floordiv 2,
+//                             (l1 * 3 + l3) floordiv 3,
+//                             (l0 * 2 + l2) mod 2,
+//                             (l1 * 3 + l3) mod 3) = (l0, l1, l2, l3)
+//
+// This function builds the inverse(idxMap) that replace every dimensions used
+// in `info` to levels, and updates the iterator type array `itTps` for the  new
+// index variable introduced.
+//
+// Note that the returned affine map does not retain the order of the input
+// affine map. Instead, it always used the first `info.inAdlvls.count()` for the
+// replaced levels, and remaining ones for unused dimensions.
+// For example, to handle
+// idxMap = (d0, d1) -> (d0, d1 floordiv 4, d2 mod 4)
+// which is a typical map for block_2to4. The function returns:
+// inverse(idxMap) = (l0, l1, d0) -> (d0, l0 * 4 + l1)
+// in which, (l0, l1) together replaces `d1`, yet they appears
+// before `d0` in the resulting affine map.
+// The the index (loop) order can later be canonicalized by a topo sort.
+static AffineMap
+genReplaceDimToLvlMap(const InadmissInfo &info, AffineMap idxMap,
+                      SmallVector<utils::IteratorType> &itTps) {
+  MLIRContext *ctx = idxMap.getContext();
+  auto [inAdLvls, usedDims] = info;
+  // Note that idxMap is not equal dim2Lvl map, it is computed by
+  // composing idx2Dim(dim2Lvl), they are only equal when idx2Dim is an
+  // ID map.
+  // TODO: we might fail here, in those case we should really return
+  // failure instead of assertion error.
+  auto lvl2Idx = inferLvlToDim(idxMap, ctx);
+
+  assert(lvl2Idx.getNumResults() <= idxMap.getNumDims());
+  if (lvl2Idx.getNumResults() != idxMap.getNumDims()) {
+    // This could happen when some dimensions are projected.
+    // E.g., idx2Lvl = (*i*, j, k) -> (j, k)
+    //   ==> lvl2Idx = (j, k) -> (j, k)
+    // In this case, we append the unused dimesion at the end.
+    //   ==> lvl2Idx = (j, k, *i*) -> (*i*, j, k)
+    SmallVector<AffineExpr> results;
+    AffineDimCollector usedInLvl(idxMap.getNumDims());
+    for (auto e : idxMap.getResults())
+      usedInLvl.walkPostOrder(e);
+
+    unsigned curUsedDimID = 0;
+    unsigned curUnusedDimID = lvl2Idx.getNumDims();
+
+    BitVector unused = usedInLvl.dims.flip();
+    for (unsigned i = 0; i < idxMap.getNumDims(); i++) {
+      if (unused.test(i))
+        results.push_back(getAffineDimExpr(curUnusedDimID++, ctx));
+      else
+        results.push_back(lvl2Idx.getResult(curUsedDimID++));
+    }
+    lvl2Idx =
+        AffineMap::get(lvl2Idx.getNumDims() + unused.count(), 0, results, ctx);
   }
-  return AffineMap::get(lvlRank, 0, exps, builder.getContext());
+  assert(lvl2Idx.getNumResults() == idxMap.getNumDims());
+
+  // We do not need to replace the DimExpr that is not used in Inadmissible
+  // level expressions. We use the first inAdLvl.count() dim to represent the
+  // replaced level, the remainings are used for unchanged ones.
+  unsigned curRepID = 0;
+  unsigned curOriID = inAdLvls.count();
+  // Since we changed the ordered of the AffineMap's dimention, we need to
+  // update the dimension here.
+  SmallVector<AffineExpr> results;
+  SmallVector<AffineExpr> dimRep(idxMap.getNumResults(), AffineExpr());
+  SmallVector<utils::IteratorType> transItTps;
+
+  for (unsigned l : inAdLvls.set_bits()) {
+    dimRep[l] = getAffineDimExpr(curRepID++, ctx);
+    AffineExpr lvlExp = idxMap.getResult(l);
+    AffineDimCollector collector(idxMap.getNumDims());
+    collector.walkPostOrder(lvlExp);
+    assert(collector.dims.count() == 1);
+    // Inherit the iterator type from the used idx.
+    transItTps.push_back(itTps[collector.dims.find_first()]);
+  }
+
+  for (unsigned d = 0, e = idxMap.getNumDims(); d < e; d++) {
+    if (usedDims.test(d))
+      results.push_back(lvl2Idx.getResult(d).replaceDims(dimRep));
+    else {
+      results.push_back(getAffineDimExpr(curOriID++, ctx));
+      transItTps.push_back(itTps[d]);
+    }
+  }
+  unsigned numDim = idxMap.getNumDims() - usedDims.count() + inAdLvls.count();
+  // Update iterator type.
+  itTps.assign(transItTps.begin(), transItTps.end());
+  return AffineMap::get(numDim, 0, results, ctx);
+}
+
+// Translates a the index map in the linalg::GenericOp from idx->dim map to
+// idx->lvl map. Returns failure if the index map can not be translated to an
+// admissible form.
+// The funciton also update the GenericOp's index map and iterator type array
+// *in-place*.
+static LogicalResult translateMap(linalg::GenericOp op,
+                                  PatternRewriter &rewriter) {
+  // idxMap is a idx2dim map before reinterpretation.
+  MLIRContext *ctx = op.getContext();
+  SmallVector<AffineMap> idxMapArray = op.getIndexingMapsArray();
+  SmallVector<utils::IteratorType> itTps = op.getIteratorTypesArray();
+  for (unsigned i = 0, e = idxMapArray.size(); i < e; i++) {
+    Value tensor = op->getOpOperand(i).get();
+    auto stt = tryGetSparseTensorType(tensor);
+    if (stt && !stt->isIdentity()) {
+      AffineMap dim2Lvl = stt->getDimToLvl();
+      // By composing the idx2dim(dim2lvl), we got a idx2lvl Map
+      idxMapArray[i] = dim2Lvl.compose(idxMapArray[i]);
+    }
+  }
+
+  // A naive way to handle common constant expressions that arise during dim2lvl
+  // translation.
+  auto populateCstMapping = [ctx](DenseMap<AffineExpr, AffineExpr> &cstMapping,
+                                  unsigned pos, int64_t lvlSz) {
+    if (!ShapedType::isDynamic(lvlSz)) {
+      auto c0 = getAffineConstantExpr(0, ctx);
+      auto lvlExp = getAffineDimExpr(pos, ctx);
+      auto szExp = getAffineConstantExpr(lvlSz, ctx);
+
+      // lvl floordiv lvlSz = 0
+      auto divExp =
+          getAffineBinaryOpExpr(AffineExprKind::FloorDiv, lvlExp, szExp);
+      cstMapping.try_emplace(divExp, c0);
+
+      // lvl mod lvlSz = lvl
+      auto modExp = getAffineBinaryOpExpr(AffineExprKind::Mod, lvlExp, szExp);
+      cstMapping.try_emplace(modExp, lvlExp);
+    }
+  };
+
+  unsigned boundedNum = 0;
+  // A fixed-point algorithm.
+  bool changed = true;
+  while (changed) {
+    changed = false;
+    for (OpOperand &operand : op->getOpOperands()) {
+      auto stt = tryGetSparseTensorType(operand.get());
+      // Skip on dense operands.
+      if (!stt || !stt->getEncoding())
+        continue;
+
+      unsigned tid = operand.getOperandNumber();
+      bool isOutput = &operand == op.getDpsInitOperand(0);
+      AffineMap idxMap = idxMapArray[tid];
+      InadmissInfo inAdInfo = collectInadmissInfo(idxMap, isOutput);
+      auto [inAdLvls, dimExprs] = inAdInfo;
+      for (unsigned d : dimExprs.set_bits()) {
+        // The first `boundedNum` used in the AffineMap is introduced to
+        // resolve previous inadmissible expressions. We can not replace them
+        // to bring back the inadmissible expressions.
+        if (d < boundedNum)
+          return failure();
+      }
+
+      if (inAdLvls.count() != 0) {
+        // Naive constant progagation, should be sufficient to handle block
+        // sparsity in our cases.
+        SmallVector<int64_t> lvlShape = stt->getLvlShape();
+        DenseMap<AffineExpr, AffineExpr> cstMapping;
+        unsigned position = 0;
+        for (unsigned lvl : inAdLvls.set_bits()) {
+          int64_t lvlSz = lvlShape[lvl];
+          populateCstMapping(cstMapping, position, lvlSz);
+          position++;
+        }
+
+        AffineMap lvl2Idx = genReplaceDimToLvlMap(inAdInfo, idxMap, itTps);
+        // Compose the lvl2Idx Map to all AffineIdxMap to eliminate
+        // inadmissible expressions.
+        for (unsigned tid = 0, e = idxMapArray.size(); tid < e; tid++) {
+          AffineMap transMap = idxMapArray[tid].compose(lvl2Idx);
+          idxMapArray[tid] = transMap.replace(
+              cstMapping, /*numResultDims=*/transMap.getNumDims(),
+              /*numResultSyms=*/0);
+        }
+        changed = true;
+        boundedNum += inAdLvls.count();
+      }
+    }
+  };
+
+  SmallVector<Attribute> iterAttr =
+      llvm::map_to_vector(itTps, [ctx](auto itTp) -> Attribute {
+        return linalg::IteratorTypeAttr::get(ctx, itTp);
+      });
+
+  rewriter.startRootUpdate(op);
+  op.setIndexingMapsAttr(rewriter.getAffineMapArrayAttr(idxMapArray));
+  op.setIteratorTypesAttr(rewriter.getArrayAttr(iterAttr));
+  rewriter.finalizeRootUpdate(op);
+
+  return success();
 }
 
 // Generates a "de"mapping reinterpretation of the map.
@@ -73,41 +353,6 @@ static bool hasNonIdentityOperandsOrResults(Operation *op) {
          llvm::any_of(op->getResults(), hasNonIdentityMap);
 }
 
-// Generates a clone of the given linalg generic operation, but with
-// remapped arguments, index maps, and iteration types.
-//
-// TODO: As decribed below, this is proof-of-concept code which makes a lot
-//       of simplifying assumptions for now.
-//
-static linalg::GenericOp genGenericLinalg(PatternRewriter &rewriter,
-                                          linalg::GenericOp linalgOp,
-                                          SparseTensorType stt, Value out) {
-  unsigned dimRank = stt.getDimRank();
-  unsigned lvlRank = stt.getLvlRank();
-  SmallVector<Value> inputOps = linalgOp.getInputs();
-  SmallVector<Value> outputOps = {out};
-  SmallVector<AffineMap> indexMaps;
-  SmallVector<utils::IteratorType> iterTypes;
-  // Translate the index maps, except output map, which is lvl-identity.
-  auto maps = linalgOp.getIndexingMapsArray();
-  for (unsigned i = 0, n = maps.size() - 1; i < n; i++)
-    indexMaps.push_back(translateMap(rewriter, stt, maps[i]));
-  indexMaps.push_back(
-      AffineMap::getMultiDimIdentityMap(lvlRank, rewriter.getContext()));
-  // Add additional "parallel" iteration types at the top.
-  for (unsigned i = 0, diff = lvlRank = dimRank; i < diff; i++)
-    iterTypes.push_back(utils::IteratorType::parallel);
-  for (auto &i : linalgOp.getIteratorTypesArray())
-    iterTypes.push_back(i);
-  // Generate the new linalg generic operation and clone body.
-  auto newOp = rewriter.create<linalg::GenericOp>(
-      linalgOp.getLoc(), out.getType(), inputOps, outputOps, indexMaps,
-      iterTypes);
-  rewriter.cloneRegionBefore(linalgOp.getRegion(), newOp.getRegion(),
-                             newOp.getRegion().begin());
-  return newOp;
-}
-
 namespace {
 
 //===----------------------------------------------------------------------===//
@@ -115,53 +360,39 @@ namespace {
 //===----------------------------------------------------------------------===//
 
 /// Sparse rewriting rule for the generic `linalg` operation.
-struct GenericOpReinterpretMap : public OpRewritePattern<linalg::GenericOp> {
+struct GenericOpReinterpretMap
+    : public DemapInsRewriter<GenericOpReinterpretMap, linalg::GenericOp> {
 public:
-  GenericOpReinterpretMap(MLIRContext *context)
-      : OpRewritePattern<linalg::GenericOp>(context) {}
+  using DemapInsRewriter::DemapInsRewriter;
+  LogicalResult rewriteOp(linalg::GenericOp linalgOp, OpAdaptor adaptor,
+                          PatternRewriter &rewriter) const {
+    // Only rewrite single output operations with pure (sparse) tensor
+    // semantics.
+    if (linalgOp.getNumDpsInits() != 1 || !linalgOp.hasTensorSemantics() ||
+        !hasAnySparseOperandOrResult(linalgOp) ||
+        !hasNonIdentityOperandsOrResults(linalgOp))
+      return failure();
 
-  LogicalResult matchAndRewrite(linalg::GenericOp linalgOp,
-                                PatternRewriter &rewriter) const override {
-    // Only rewrite single output operations with pure tensor semantics.
-    if (linalgOp.getNumDpsInits() != 1 || !linalgOp.hasTensorSemantics())
+    // Try translating the index map.
+    if (failed(translateMap(linalgOp, rewriter)))
       return failure();
-    // Scan all operands, inspect sparse tensors.
-    //
-    // TODO: generalize this proof-of-concept algorithm, since the current
-    //       implementation accepts only simple indexing maps, and one
-    //       non-permutation sparse tensor, which must have an identity
-    //       indexing map and be the output.
-    //
-    OpOperand *tx = nullptr;
-    for (OpOperand &t : linalgOp->getOpOperands()) {
-      // Ensure every index map is "simple".
-      const auto map = linalgOp.getMatchingIndexingMap(&t);
-      for (unsigned i = 0, n = map.getNumResults(); i < n; i++)
-        if (map.getResult(i).getKind() != AffineExprKind::DimId)
-          return failure();
-      // Inspect sparse operands.
-      auto stt = tryGetSparseTensorType(t.get());
-      if (stt && stt->hasEncoding()) {
-        if (stt->isPermutation())
-          continue;
-        assert(stt->getDimRank() < stt->getLvlRank()); // only allowed non-perm
-        if (tx)
-          return failure(); // more than one non-perm
-        if (!map.isIdentity())
-          return failure(); // no ID indexing map on the non-perm
-        tx = &t;
-      }
-    }
-    // Found a non-permutation, rewrite when this is the output.
-    if (tx && tx == linalgOp.getDpsInitOperand(0)) {
-      auto stt = getSparseTensorType(tx->get());
-      auto demap = genDemap(rewriter, stt.getEncoding(), tx->get());
-      auto newOp = genGenericLinalg(rewriter, linalgOp, stt, demap);
-      auto remap = genRemap(rewriter, stt.getEncoding(), newOp.getResult(0));
-      rewriter.replaceOp(linalgOp, remap);
-      return success();
+
+    // Must only have one result
+    Value res = linalgOp.getResult(0);
+    auto stt = tryGetSparseTensorType(res);
+
+    rewriter.startRootUpdate(linalgOp);
+    linalgOp.getInputsMutable().assign(adaptor.getInputs());
+    linalgOp.getDpsInitsMutable().assign(adaptor.getOutputs());
+    res.setType(adaptor.getOutputs()[0].getType());
+    rewriter.finalizeRootUpdate(linalgOp);
+    rewriter.setInsertionPointAfter(linalgOp);
+
+    if (stt && stt->hasEncoding()) {
+      Value t = genRemap(rewriter, stt->getEncoding(), res);
+      rewriter.replaceAllUsesExcept(res, t, t.getDefiningOp());
     }
-    return failure();
+    return success();
   }
 };
 
@@ -169,32 +400,10 @@ struct GenericOpReinterpretMap : public OpRewritePattern<linalg::GenericOp> {
 // Reinterpret Map Rewriters for operations other than linalg.generics
 //===----------------------------------------------------------------------===//
 
-// CRTP to help implementing a rewriter that demaps all its inputs.
-template <typename SubClass, typename SourceOp>
-struct DemapInsRewriter : public OpRewritePattern<SourceOp> {
-  using OpRewritePattern<SourceOp>::OpRewritePattern;
-  using OpAdaptor = typename SourceOp::Adaptor;
-
-  LogicalResult matchAndRewrite(SourceOp op,
-                                PatternRewriter &rewriter) const override {
-    Location loc = op.getLoc();
-    // Demaps non-trivial inputs.
-    SmallVector<Value> deMappedIns(op->getOperands());
-    for (Value &in : deMappedIns)
-      if (auto stt = tryGetSparseTensorType(in); stt && !stt->isIdentity())
-        in = rewriter.create<ReinterpretMapOp>(loc, stt->getDemappedType(), in);
-
-    // CRTP call.
-    OpAdaptor adaptor(deMappedIns);
-    return static_cast<const SubClass *>(this)->rewriteOp(op, adaptor,
-                                                          rewriter);
-  }
-};
-
-struct TensorAllocDemapper
-    : public OpRewritePattern<bufferization::AllocTensorOp> {
-  using OpRewritePattern::OpRewritePattern;
-  LogicalResult matchAndRewrite(bufferization::AllocTensorOp op,
+template <typename AllocOp>
+struct TensorAllocDemapper : public OpRewritePattern<AllocOp> {
+  using OpRewritePattern<AllocOp>::OpRewritePattern;
+  LogicalResult matchAndRewrite(AllocOp op,
                                 PatternRewriter &rewriter) const override {
     if (!hasNonIdentityOperandsOrResults(op))
       return failure();
@@ -362,7 +571,8 @@ void mlir::populateSparseReinterpretMap(RewritePatternSet &patterns,
   }
   if (scope == ReinterpretMapScope::kAll ||
       scope == ReinterpretMapScope::kExceptGeneric) {
-    patterns.add<TensorAllocDemapper, TensorInsertDemapper, ForeachOpDemapper>(
-        patterns.getContext());
+    patterns.add<TensorAllocDemapper<bufferization::AllocTensorOp>,
+                 TensorAllocDemapper<tensor::EmptyOp>, TensorInsertDemapper,
+                 ForeachOpDemapper>(patterns.getContext());
   }
 }
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
index 4a293f6819d0976..3d25fef09605348 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
@@ -139,13 +139,13 @@ class SparsificationAndBufferizationPass
     // of `bufferization.alloc_tensor` ops.
     {
       OpPassManager pm("builtin.module");
-      pm.addPass(
-          createSparseReinterpretMapPass(ReinterpretMapScope::kGenericOnly));
+      // We need to reinterpret maps on GenericOp, EmptyOp, and AllocTensorOp.
+      // empty BSR
+      pm.addPass(createSparseReinterpretMapPass(ReinterpretMapScope::kAll));
       pm.addPass(createSparsificationPass(sparsificationOptions));
       pm.addNestedPass<func::FuncOp>(createStageSparseOperationsPass());
       pm.addPass(createLowerSparseOpsToForeachPass(enableRuntimeLibrary,
                                                    /*enableConvert=*/true));
-      // Handle dim-to-lvl maps on operations other than linalg.generic.
       pm.addPass(
           createSparseReinterpretMapPass(ReinterpretMapScope::kExceptGeneric));
       pm.addNestedPass<func::FuncOp>(createLowerForeachToSCFPass());
diff --git a/mlir/test/Dialect/SparseTensor/sparse_reinterpret_map.mlir b/mlir/test/Dialect/SparseTensor/sparse_reinterpret_map.mlir
index 972364289ac2e2a..c4931c62c62633e 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_reinterpret_map.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_reinterpret_map.mlir
@@ -19,15 +19,15 @@
     )
 }>
 
-// CHECK: #[[$map0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0 * 2 + d2, d1 * 4 + d3)>
-// CHECK: #[[$map1:.*]] = affine_map<(d0, d1, d2, d3) -> (d1 * 4 + d3, d0 * 2 + d2)>
-// CHECK: #[[$map2:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK-DAG: #[[$map0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0 * 2 + d2, d1 * 4 + d3)>
+// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1, d2, d3) -> (d1 * 4 + d3, d0 * 2 + d2)>
+// CHECK-DAG: #[[$map2:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
 // CHECK-LABEL: func @mul(
 // CHECK-SAME:  %[[A0:.*0]]: tensor<32x32xf32>,
 // CHECK-SAME:  %[[A1:.*1]]: tensor<32x32xf32>,
 // CHECK-SAME:  %[[A2:.*2]]: tensor<32x32xf32, #sparse_tensor.encoding<{{{.*}}}>>)
 // CHECK:       %[[T0:.*]] = sparse_tensor.reinterpret_map %[[A2]]
-// CHECK:       %[[T1:.*]] = linalg.generic {indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
+// CHECK:       %[[T1:.*]] = linalg.generic {doc = {{.*}} indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
 // CHECK:       %[[T2:.*]] = sparse_tensor.reinterpret_map %[[T1]]
 // CHECK:       return %[[T2]] : tensor<32x32xf32, #sparse_tensor.encoding<{{{.*}}}>>
 func.func @mul(%arg0: tensor<32x32xf32>,
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_block_matmul.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_block_matmul.mlir
new file mode 100644
index 000000000000000..e261e0ec80451fb
--- /dev/null
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_block_matmul.mlir
@@ -0,0 +1,124 @@
+//--------------------------------------------------------------------------------------------------
+// WHEN CREATING A NEW TEST, PLEASE JUST COPY & PASTE WITHOUT EDITS.
+//
+// Set-up that's shared across all tests in this directory. In principle, this
+// config could be moved to lit.local.cfg. However, there are downstream users that
+//  do not use these LIT config files. Hence why this is kept inline.
+//
+// DEFINE: %{sparse_compiler_opts} = enable-runtime-library=true
+// DEFINE: %{sparse_compiler_opts_sve} = enable-arm-sve=true %{sparse_compiler_opts}
+// DEFINE: %{compile} = mlir-opt %s --sparse-compiler="%{sparse_compiler_opts}"
+// DEFINE: %{compile_sve} = mlir-opt %s --sparse-compiler="%{sparse_compiler_opts_sve}"
+// DEFINE: %{run_libs} = -shared-libs=%mlir_c_runner_utils,%mlir_runner_utils
+// DEFINE: %{run_opts} = -e entry -entry-point-result=void
+// DEFINE: %{run} = mlir-cpu-runner %{run_opts} %{run_libs}
+// DEFINE: %{run_sve} = %mcr_aarch64_cmd --march=aarch64 --mattr="+sve" %{run_opts} %{run_libs}
+//
+// DEFINE: %{env} =
+//--------------------------------------------------------------------------------------------------
+
+// RUN: %{compile} | %{run} | FileCheck %s
+//
+// Do the same run, but now with direct IR generation.
+// REDEFINE: %{sparse_compiler_opts} = enable-runtime-library=false enable-index-reduction=true
+// RUN: %{compile} | %{run} | FileCheck %s
+//
+// Do the same run, but now with direct IR generation and vectorization.
+// REDEFINE: %{sparse_compiler_opts} = enable-runtime-library=false enable-index-reduction=true vl=2 reassociate-fp-reductions=true enable-index-optimizations=true
+// RUN: %{compile} | %{run} | FileCheck %s
+//
+// Do the same run, but now with direct IR generation and VLA vectorization.
+// RUN: %if mlir_arm_sve_tests %{ %{compile_sve} | %{run_sve} | FileCheck %s %}
+
+#trait_mul = {
+  indexing_maps = [
+    affine_map<(i,j,k) -> (i,j)>,  // A (in)
+    affine_map<(i,j,k) -> (j,k)>,  // B (in, transposed)
+    affine_map<(i,j,k) -> (i,k)>   // X (out)
+  ],
+  iterator_types = ["parallel", "parallel", "reduction"],
+  doc = "X(i,j) *= A(i,j) * B(j,i)"
+}
+
+
+#BSR = #sparse_tensor.encoding<{
+  map = ( i, j ) ->
+  ( i floordiv 2 : dense,
+    j floordiv 2 : compressed,
+    i mod 2      : dense,
+    j mod 2      : dense
+  )
+}>
+
+module {
+
+func.func @mul(%arg0: tensor<4x4xf64>,
+               %arg1: tensor<4x4xf64, #BSR>) -> tensor<4x4xf64> {
+  %out = tensor.empty() : tensor<4x4xf64>
+  %0 = linalg.generic #trait_mul
+    ins(%arg0, %arg1: tensor<4x4xf64>, tensor<4x4xf64, #BSR>)
+    outs(%out: tensor<4x4xf64>) {
+      ^bb(%x: f64, %y : f64, %z : f64):
+        %1 = arith.mulf %x, %y : f64
+        %2 = arith.addf %1, %z : f64
+        linalg.yield %2 : f64
+  } -> tensor<4x4xf64>
+  return %0 : tensor<4x4xf64>
+}
+
+func.func @mul_dense(%arg0: tensor<4x4xf64>,
+                     %arg1: tensor<4x4xf64>) -> tensor<4x4xf64> {
+  %out = tensor.empty() : tensor<4x4xf64>
+  %0 = linalg.generic #trait_mul
+    ins(%arg0, %arg1: tensor<4x4xf64>, tensor<4x4xf64>)
+    outs(%out: tensor<4x4xf64>) {
+      ^bb(%x: f64, %y : f64, %z : f64):
+        %1 = arith.mulf %x, %y : f64
+        %2 = arith.addf %1, %z : f64
+        linalg.yield %2 : f64
+  } -> tensor<4x4xf64>
+  return %0 : tensor<4x4xf64>
+}
+
+
+  //
+  // Output utilities.
+  //
+  func.func @dumpf64(%arg0: tensor<4x4xf64>) {
+    %c0 = arith.constant 0 : index
+    %d0 = arith.constant -1.0 : f64
+    %0 = vector.transfer_read %arg0[%c0, %c0], %d0: tensor<4x4xf64>, vector<4x4xf64>
+    vector.print %0 : vector<4x4xf64>
+    return
+  }
+
+  //
+  // Main driver.
+  //
+  func.func @entry() {
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1 : index
+    %c2 = arith.constant 2 : index
+
+
+    %td = arith.constant dense<[[ 1.0, 2.0, 3.0, 4.0 ],
+                                [ 1.0, 2.0, 3.0, 4.0 ],
+                                [ 5.0, 6.0, 7.0, 8.0 ],
+                                [ 5.0, 6.0, 7.0, 8.0 ]]> : tensor<4x4xf64>
+
+
+    // constant -> BSR (either from SparseElementAttibutes or DenseElementAttribute)
+    %2 = sparse_tensor.convert %td : tensor<4x4xf64> to tensor<4x4xf64, #BSR>
+
+    %d = call @mul_dense(%td, %td)
+         : (tensor<4x4xf64>, tensor<4x4xf64>) -> tensor<4x4xf64>
+    %s = call @mul(%td, %2)
+         : (tensor<4x4xf64>, tensor<4x4xf64, #BSR>) -> tensor<4x4xf64>
+
+    // CHECK-COUNT-2: ( ( 38, 48, 58, 68 ), ( 38, 48, 58, 68 ), ( 86, 112, 138, 164 ), ( 86, 112, 138, 164 ) )
+    call @dumpf64(%d) : (tensor<4x4xf64>) -> ()
+    call @dumpf64(%s) : (tensor<4x4xf64>) -> ()
+
+    return
+  }
+}

>From 0ac69f0c573541beace74983de23c9d17b40cc02 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Mon, 6 Nov 2023 22:17:54 +0000
Subject: [PATCH 2/3] remove outdated comments

---
 .../Dialect/SparseTensor/CPU/sparse_block_matmul.mlir            | 1 -
 1 file changed, 1 deletion(-)

diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_block_matmul.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_block_matmul.mlir
index e261e0ec80451fb..16d13648ce44553 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_block_matmul.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_block_matmul.mlir
@@ -107,7 +107,6 @@ func.func @mul_dense(%arg0: tensor<4x4xf64>,
                                 [ 5.0, 6.0, 7.0, 8.0 ]]> : tensor<4x4xf64>
 
 
-    // constant -> BSR (either from SparseElementAttibutes or DenseElementAttribute)
     %2 = sparse_tensor.convert %td : tensor<4x4xf64> to tensor<4x4xf64, #BSR>
 
     %d = call @mul_dense(%td, %td)

>From bc55fbd9d66545789ff41fdb91eb468f21546824 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Tue, 7 Nov 2023 19:16:00 +0000
Subject: [PATCH 3/3] address comments.

---
 .../Transforms/SparseReinterpretMap.cpp       | 62 ++++++++++++-------
 .../SparsificationAndBufferizationPass.cpp    |  2 -
 .../SparseTensor/CPU/sparse_block_matmul.mlir | 10 +--
 3 files changed, 43 insertions(+), 31 deletions(-)

diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
index 964786e35f72321..9aadcae36f562c6 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
@@ -73,7 +73,7 @@ struct AffineExprAdmissibleVisitor
       admissible = false;
   }
 
-  // For input, mod, floor div and ceil div are not supported.
+  // We disallow mod, floor div and ceil div  on inputs.
   void visitModExpr(AffineBinaryOpExpr expr) { admissible = false; }
   void visitFloorDivExpr(AffineBinaryOpExpr expr) { admissible = false; }
   void visitCeilDivExpr(AffineBinaryOpExpr expr) { admissible = false; }
@@ -105,7 +105,7 @@ static InadmissInfo collectInadmissInfo(AffineMap map, bool isOutput) {
     if (!admissible) {
       // Record the inadmissible level.
       ret.first.set(lvl);
-      // Records the AffineDimExpr that is used in the inadmissible expr.
+      // Record the AffineDimExpr that is used in the inadmissible expr.
       collector.walkPostOrder(map.getResult(lvl));
     }
   }
@@ -113,7 +113,7 @@ static InadmissInfo collectInadmissInfo(AffineMap map, bool isOutput) {
   return ret;
 }
 
-// Build the AffineMap to replace the idx in idxMap to lvl such that all tht
+// Builds the AffineMap to replace the idx in idxMap to lvl such that all tht
 // inadmissible affine expressions can be eliminated.
 // For example, we can rewrite
 // idxMap = (d0, d1) -> (d0 floordiv 2, d1 floordiv 3, d0 mod 2, d1 mod 3)
@@ -137,7 +137,7 @@ static InadmissInfo collectInadmissInfo(AffineMap map, bool isOutput) {
 // idxMap = (d0, d1) -> (d0, d1 floordiv 4, d2 mod 4)
 // which is a typical map for block_2to4. The function returns:
 // inverse(idxMap) = (l0, l1, d0) -> (d0, l0 * 4 + l1)
-// in which, (l0, l1) together replaces `d1`, yet they appears
+// in which, (l0, l1) together replaces `d1`, yet they appear
 // before `d0` in the resulting affine map.
 // The the index (loop) order can later be canonicalized by a topo sort.
 static AffineMap
@@ -182,28 +182,41 @@ genReplaceDimToLvlMap(const InadmissInfo &info, AffineMap idxMap,
   // We do not need to replace the DimExpr that is not used in Inadmissible
   // level expressions. We use the first inAdLvl.count() dim to represent the
   // replaced level, the remainings are used for unchanged ones.
+  // Note that results from the inverse map computed previously does not follow
+  // the convention we used, and we fix the mismatch.
   unsigned curRepID = 0;
   unsigned curOriID = inAdLvls.count();
-  // Since we changed the ordered of the AffineMap's dimention, we need to
-  // update the dimension here.
   SmallVector<AffineExpr> results;
   SmallVector<AffineExpr> dimRep(idxMap.getNumResults(), AffineExpr());
   SmallVector<utils::IteratorType> transItTps;
 
   for (unsigned l : inAdLvls.set_bits()) {
+    // By our convention, the inadmissible level `l` always appears in the
+    // leading part (accumulated by curRepID) of the affine map. Record the
+    // mapping so that we can replace all the uses of `l` to the correct
+    // position after the translation.
     dimRep[l] = getAffineDimExpr(curRepID++, ctx);
+    // A new index variable is introduced for the inadmissible level, inherit
+    // the iterator type. E.g., if l0 = d0 floordiv 2, the
+    // iterator type of l0 equals to the iterator type of d0.
     AffineExpr lvlExp = idxMap.getResult(l);
     AffineDimCollector collector(idxMap.getNumDims());
     collector.walkPostOrder(lvlExp);
+    // We assumes a level can only be derived from one dimension.
     assert(collector.dims.count() == 1);
-    // Inherit the iterator type from the used idx.
     transItTps.push_back(itTps[collector.dims.find_first()]);
   }
 
   for (unsigned d = 0, e = idxMap.getNumDims(); d < e; d++) {
-    if (usedDims.test(d))
+    if (usedDims.test(d)) {
+      // The dimension is used in some of the inadmissible levels, and it need
+      // to be inversed. Get the inversion from the inverse map, and fix the
+      // mismatch captured by the above loop.
       results.push_back(lvl2Idx.getResult(d).replaceDims(dimRep));
-    else {
+    } else {
+      // The dimension is not used in any of the inadmissible levels, and it
+      // does not need to be inversed. Fix the mismatch by mapping it to the
+      // trailing part of the affine map (accumulated by curOriID).
       results.push_back(getAffineDimExpr(curOriID++, ctx));
       transItTps.push_back(itTps[d]);
     }
@@ -217,10 +230,9 @@ genReplaceDimToLvlMap(const InadmissInfo &info, AffineMap idxMap,
 // Translates a the index map in the linalg::GenericOp from idx->dim map to
 // idx->lvl map. Returns failure if the index map can not be translated to an
 // admissible form.
-// The funciton also update the GenericOp's index map and iterator type array
-// *in-place*.
-static LogicalResult translateMap(linalg::GenericOp op,
-                                  PatternRewriter &rewriter) {
+// Returns the translated index map array and the iterator type array.
+static std::optional<std::pair<ArrayAttr, ArrayAttr>>
+translateMap(linalg::GenericOp op, PatternRewriter &rewriter) {
   // idxMap is a idx2dim map before reinterpretation.
   MLIRContext *ctx = op.getContext();
   SmallVector<AffineMap> idxMapArray = op.getIndexingMapsArray();
@@ -276,7 +288,7 @@ static LogicalResult translateMap(linalg::GenericOp op,
         // resolve previous inadmissible expressions. We can not replace them
         // to bring back the inadmissible expressions.
         if (d < boundedNum)
-          return failure();
+          return std::nullopt;
       }
 
       if (inAdLvls.count() != 0) {
@@ -311,12 +323,8 @@ static LogicalResult translateMap(linalg::GenericOp op,
         return linalg::IteratorTypeAttr::get(ctx, itTp);
       });
 
-  rewriter.startRootUpdate(op);
-  op.setIndexingMapsAttr(rewriter.getAffineMapArrayAttr(idxMapArray));
-  op.setIteratorTypesAttr(rewriter.getArrayAttr(iterAttr));
-  rewriter.finalizeRootUpdate(op);
-
-  return success();
+  return std::make_pair(rewriter.getAffineMapArrayAttr(idxMapArray),
+                        rewriter.getArrayAttr(iterAttr));
 }
 
 // Generates a "de"mapping reinterpretation of the map.
@@ -374,20 +382,26 @@ struct GenericOpReinterpretMap
       return failure();
 
     // Try translating the index map.
-    if (failed(translateMap(linalgOp, rewriter)))
-      return failure();
+    auto transMap = translateMap(linalgOp, rewriter);
+    if (!transMap)
+      return rewriter.notifyMatchFailure(
+          linalgOp, "the sparse kernel can not be sparsified.");
 
-    // Must only have one result
+    // On success, replace update the linalg operands and maps in place.
     Value res = linalgOp.getResult(0);
     auto stt = tryGetSparseTensorType(res);
+    auto [idxMap, itTp] = *transMap;
 
     rewriter.startRootUpdate(linalgOp);
+    linalgOp.setIndexingMapsAttr(idxMap);
+    linalgOp.setIteratorTypesAttr(itTp);
+    // Use demapped arguments.
     linalgOp.getInputsMutable().assign(adaptor.getInputs());
     linalgOp.getDpsInitsMutable().assign(adaptor.getOutputs());
     res.setType(adaptor.getOutputs()[0].getType());
     rewriter.finalizeRootUpdate(linalgOp);
-    rewriter.setInsertionPointAfter(linalgOp);
 
+    rewriter.setInsertionPointAfter(linalgOp);
     if (stt && stt->hasEncoding()) {
       Value t = genRemap(rewriter, stt->getEncoding(), res);
       rewriter.replaceAllUsesExcept(res, t, t.getDefiningOp());
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
index 3d25fef09605348..e20b98add19adbf 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
@@ -139,8 +139,6 @@ class SparsificationAndBufferizationPass
     // of `bufferization.alloc_tensor` ops.
     {
       OpPassManager pm("builtin.module");
-      // We need to reinterpret maps on GenericOp, EmptyOp, and AllocTensorOp.
-      // empty BSR
       pm.addPass(createSparseReinterpretMapPass(ReinterpretMapScope::kAll));
       pm.addPass(createSparsificationPass(sparsificationOptions));
       pm.addNestedPass<func::FuncOp>(createStageSparseOperationsPass());
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_block_matmul.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_block_matmul.mlir
index 16d13648ce44553..5eef4e58752fb40 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_block_matmul.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_block_matmul.mlir
@@ -101,10 +101,10 @@ func.func @mul_dense(%arg0: tensor<4x4xf64>,
     %c2 = arith.constant 2 : index
 
 
-    %td = arith.constant dense<[[ 1.0, 2.0, 3.0, 4.0 ],
-                                [ 1.0, 2.0, 3.0, 4.0 ],
-                                [ 5.0, 6.0, 7.0, 8.0 ],
-                                [ 5.0, 6.0, 7.0, 8.0 ]]> : tensor<4x4xf64>
+    %td = arith.constant dense<[[ 1.0,  2.0,  3.0,  4.0],
+                                [ 5.0,  6.0,  7.0,  8.0],
+                                [ 9.0, 10.0, 11.0, 12.0],
+                                [13.0, 14.0, 15.0, 16.0]]> : tensor<4x4xf64>
 
 
     %2 = sparse_tensor.convert %td : tensor<4x4xf64> to tensor<4x4xf64, #BSR>
@@ -114,7 +114,7 @@ func.func @mul_dense(%arg0: tensor<4x4xf64>,
     %s = call @mul(%td, %2)
          : (tensor<4x4xf64>, tensor<4x4xf64, #BSR>) -> tensor<4x4xf64>
 
-    // CHECK-COUNT-2: ( ( 38, 48, 58, 68 ), ( 38, 48, 58, 68 ), ( 86, 112, 138, 164 ), ( 86, 112, 138, 164 ) )
+    // CHECK-COUNT-2: ( ( 90, 100, 110, 120 ), ( 202, 228, 254, 280 ), ( 314, 356, 398, 440 ), ( 426, 484, 542, 600 ) )
     call @dumpf64(%d) : (tensor<4x4xf64>) -> ()
     call @dumpf64(%s) : (tensor<4x4xf64>) -> ()
 



More information about the Mlir-commits mailing list