[Mlir-commits] [mlir] [mlir][sparse] end-to-end matmul between Dense and BSR tensors (PR #71448)
Yinying Li
llvmlistbot at llvm.org
Tue Nov 7 09:02:01 PST 2023
================
@@ -16,27 +16,307 @@
#include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/AffineExprVisitor.h"
#include "mlir/IR/AffineMap.h"
using namespace mlir;
using namespace mlir::sparse_tensor;
+namespace {
+
+//===----------------------------------------------------------------------===//
+// File Local Helper classes.
+//===----------------------------------------------------------------------===//
+
+// CRTP to help implementing a rewriter that demaps all its inputs.
+template <typename SubClass, typename SourceOp>
+struct DemapInsRewriter : public OpRewritePattern<SourceOp> {
+ using OpRewritePattern<SourceOp>::OpRewritePattern;
+ using OpAdaptor = typename SourceOp::Adaptor;
+
+ LogicalResult matchAndRewrite(SourceOp op,
+ PatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ // Demaps non-trivial inputs.
+ SmallVector<Value> deMappedIns(op->getOperands());
+ for (Value &in : deMappedIns)
+ if (auto stt = tryGetSparseTensorType(in); stt && !stt->isIdentity())
+ in = rewriter.create<ReinterpretMapOp>(loc, stt->getDemappedType(), in);
+
+ // CRTP call.
+ OpAdaptor adaptor(deMappedIns, op);
+ return static_cast<const SubClass *>(this)->rewriteOp(op, adaptor,
+ rewriter);
+ }
+};
+
+// Flattens an affine expression into a list of AffineDimExprs.
+struct AffineDimCollector : public AffineExprVisitor<AffineDimCollector> {
+ explicit AffineDimCollector(unsigned dimNum) : dims(dimNum){};
+ void visitDimExpr(AffineDimExpr expr) { dims.set(expr.getPosition()); }
+ BitVector dims;
+};
+
+// Flattens an affine expression into a list of AffineDimExprs.
+struct AffineExprAdmissibleVisitor
+ : public AffineExprVisitor<AffineExprAdmissibleVisitor> {
+ explicit AffineExprAdmissibleVisitor(bool isOutput)
+ : admissible(true), isOutput(isOutput){};
+
+ // We only allow AffineDimExpr on output.
+ void visitAddExpr(AffineBinaryOpExpr expr) {
+ if (isOutput)
+ admissible = false;
+ }
+ void visitMulExpr(AffineBinaryOpExpr expr) {
+ if (isOutput)
+ admissible = false;
+ }
+
+ // For input, mod, floor div and ceil div are not supported.
----------------
yinying-lisa-li wrote:
nit: Mod, floor div and ceil div are not supported for input.
It's a bit hard to know when the for phrase will end in the original sentence.
https://github.com/llvm/llvm-project/pull/71448
More information about the Mlir-commits
mailing list