[Mlir-commits] [mlir] [mlir][sparse] simplify ConvertOp rewriting rules (PR #68350)
Peiming Liu
llvmlistbot at llvm.org
Fri Oct 6 14:57:28 PDT 2023
https://github.com/PeimingLiu updated https://github.com/llvm/llvm-project/pull/68350
>From f9b022e3be60c7d34a340aae0a986cdf02520521 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Wed, 4 Oct 2023 22:47:15 +0000
Subject: [PATCH 1/9] implement direct convert rewriter
---
.../SparseTensor/IR/SparseTensorOps.td | 13 ++
.../SparseTensor/IR/SparseTensorDialect.cpp | 92 +++++++++++-
.../Transforms/SparseTensorRewriting.cpp | 135 +++++++++++++++++-
.../SparsificationAndBufferizationPass.cpp | 1 +
.../SparseTensor/convert_sparse2sparse.mlir | 2 +
.../CPU/sparse_foreach_slices.mlir | 59 ++++----
.../SparseTensor/CPU/sparse_matmul_slice.mlir | 28 ++--
7 files changed, 279 insertions(+), 51 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 7ea5ca23f122a8a..680540235536880 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -195,9 +195,22 @@ def SparseTensor_ConvertOp : SparseTensor_Op<"convert",
```
}];
+
+
+ let extraClassDeclaration = [{
+ // Whether the convert can be done by a single step (either a sort or a foreach),
+ // or it would require a tmp buffer (sort, then foreach).
+ bool directConvertable();
+
+ // Whether the convert is actually a sort coo
+ // TODO: The method will be removed when sort_coo operation is introduced.
+ bool isSortCOOConvert();
+ }];
+
let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)";
let hasFolder = 1;
let hasVerifier = 1;
+ let hasCanonicalizer = 1;
}
def SparseTensor_ToPositionsOp : SparseTensor_Op<"positions", [Pure]>,
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 96ed5f13b9d9ecb..0fe1ed165b041c9 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -1066,6 +1066,91 @@ OpFoldResult ConvertOp::fold(FoldAdaptor adaptor) {
return {};
}
+bool ConvertOp::directConvertable() {
+ if (isSortCOOConvert())
+ return true;
+
+ SparseTensorType srcStt = getSparseTensorType(getSource());
+ SparseTensorType dstStt = getSparseTensorType(getDest());
+
+ // We can always directly convert to unordered sparse tensor or dense tensor
+ // since dense tensor support random access.
+ if (dstStt.isAllDense() || !dstStt.isAllOrdered())
+ return true;
+
+ if (srcStt.isAllOrdered() && dstStt.isAllOrdered() &&
+ srcStt.hasSameDimToLvl(dstStt)) {
+ return true;
+ }
+
+ // Source and dest tensors are ordered in different ways. We only do direct
+ // dense to sparse conversion when the dense input is defined by a sparse
+ // constant. Note that we can theoritically always directly convert from dense
+ // inputs by rotating dense loops but it leads to bad cache locality and hurt
+ // performance.
+ if (auto constOp = getSource().getDefiningOp<arith::ConstantOp>())
+ if (isa<SparseElementsAttr>(constOp.getValue()))
+ return true;
+
+ return false;
+}
+
+bool ConvertOp::isSortCOOConvert() {
+ // TODO: we should instead use a different sort_coo operation to handle
+ // the conversion between COOs (but with different ordering).
+ return isUniqueCOOType(getSource().getType()) &&
+ isUniqueCOOType(getDest().getType()) &&
+ getSparseTensorType(getDest()).isAllOrdered();
+}
+
+struct StageUnorderedConvert : public OpRewritePattern<ConvertOp> {
+ using OpRewritePattern<ConvertOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ConvertOp op,
+ PatternRewriter &rewriter) const override {
+ if (op.directConvertable())
+ return failure();
+
+ Location loc = op.getLoc();
+ SparseTensorType srcStt = getSparseTensorType(op.getSource());
+ SparseTensorType dstStt = getSparseTensorType(op.getDest());
+
+ // Just to make sure that convert to dense tensor is always direct.
+ assert(!dstStt.isAllDense());
+
+ // source -> coo
+ // The tmp COO must be unordered, otherwise it is a direct conversion.
+ assert(!(srcStt.hasSameDimToLvl(dstStt) && srcStt.isAllOrdered()));
+ Type srcCOOTp = getCOOFromTypeWithOrdering(
+ srcStt.getRankedTensorType(), dstStt.getDimToLvl(), /*ordered=*/false);
+ Value srcCOO = rewriter.create<ConvertOp>(loc, srcCOOTp, op.getSource());
+
+ // -> sort
+ Type dstCOOTp = getCOOFromTypeWithOrdering(
+ dstStt.getRankedTensorType(), dstStt.getDimToLvl(), /*ordered=*/true);
+ // TODO: this should be a sort_coo operation.
+ Value dstCOO = rewriter.create<ConvertOp>(loc, dstCOOTp, srcCOO);
+
+ // -> dest.
+ if (dstCOO.getType() == op.getType()) {
+ rewriter.replaceOp(op, dstCOO);
+ } else {
+ // Need an extra conversion if the target type is not COO.
+ rewriter.replaceOpWithNewOp<ConvertOp>(op, op.getDest().getType(),
+ dstCOO);
+ }
+ // TODO: deallocate extra COOs, we should probably delegate it to buffer
+ // deallocation pass.
+
+ return success();
+ }
+};
+
+void ConvertOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.add<StageUnorderedConvert>(context);
+}
+
LogicalResult ToPositionsOp::verify() {
auto e = getSparseTensorEncoding(getTensor().getType());
if (failed(lvlIsInBounds(getLevel(), getTensor())))
@@ -1262,9 +1347,10 @@ LogicalResult ConcatenateOp::verify() {
// If all dimension are statically known, the sum of all the input
// dimensions should be equal to the output dimension.
if (sumSz != dstSh)
- return emitError(
- "The concatenation dimension of the output tensor should be the "
- "sum of all the concatenation dimensions of the input tensors.");
+ return emitError("The concatenation dimension of the output tensor "
+ "should be the "
+ "sum of all the concatenation dimensions of the "
+ "input tensors.");
}
} else {
DynSize prev = dstSh;
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index b0bd22b156cc292..a095931625a2070 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -147,8 +147,7 @@ static RankedTensorType getBufferType(const SparseTensorType &stt,
/// Collects the dynamic dimension sizes for `tp` with the assumption that
/// `sizes` are the dimension sizes for the type. Stores the dynamic dimension
/// sizes to dynSizes.
-static void getDynamicSizes(RankedTensorType tp,
- const SmallVectorImpl<Value> &sizes,
+static void getDynamicSizes(RankedTensorType tp, ValueRange sizes,
SmallVectorImpl<Value> &dynSizes) {
for (const auto &d : enumerate(tp.getShape())) {
if (d.value() == ShapedType::kDynamic)
@@ -971,7 +970,10 @@ struct ConcatenateRewriter : public OpRewritePattern<ConcatenateOp> {
dst = rewriter.create<LoadOp>(loc, dst, true);
if (needTmpCOO) {
Value tmpCoo = dst;
- dst = rewriter.create<ConvertOp>(loc, dstRTT, tmpCoo).getResult();
+ Type dstCooTp = getCOOType(dstRTT, true);
+ // TODO: this should be a sort_coo operation.
+ dst = rewriter.create<ConvertOp>(loc, dstCooTp, tmpCoo).getResult();
+ dst = rewriter.create<ConvertOp>(loc, dstRTT, dst).getResult();
rewriter.create<DeallocTensorOp>(loc, tmpCoo);
}
rewriter.replaceOp(op, dst);
@@ -980,11 +982,129 @@ struct ConcatenateRewriter : public OpRewritePattern<ConcatenateOp> {
}
};
+struct TensorLike {
+ TensorLike(OpBuilder &builder, Location loc, RankedTensorType rtt,
+ ValueRange sizes)
+ : isSparse(rtt.getEncoding() != nullptr) {
+ SmallVector<Value> dynSzs;
+ getDynamicSizes(rtt, sizes, dynSzs);
+
+ if (isSparse)
+ val = builder.create<AllocTensorOp>(loc, rtt, dynSzs);
+ else
+ val = allocDenseTensor(builder, loc, rtt, sizes);
+ };
+
+ void insertOrStore(OpBuilder &builder, Location loc, Value v,
+ ValueRange crds) {
+ if (isSparse)
+ val = builder.create<InsertOp>(loc, v, val, crds);
+ else
+ builder.create<memref::StoreOp>(loc, v, val, crds);
+ }
+
+ Value getIterSSA() const { return val; }
+
+ Value finalize(OpBuilder &builder, Location loc, RankedTensorType rtp) const {
+ if (isSparse)
+ return builder.create<LoadOp>(loc, val, true);
+ return builder.create<bufferization::ToTensorOp>(loc, rtp, val);
+ }
+
+ void updateSSA(Value v) {
+ // Dense memref is a non-SSA value.
+ if (isSparse)
+ val = v;
+ }
+
+private:
+ bool isSparse;
+ Value val; // either a memref (for dense tensor) or a sparse tensor.
+};
+
+struct DirectConvertRewriter : public OpRewritePattern<ConvertOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(ConvertOp op,
+ PatternRewriter &rewriter) const override {
+ if (!op.directConvertable())
+ return op.emitError("ConvertOp not in conanical form.");
+
+ if (op.isSortCOOConvert())
+ return failure();
+
+ Location loc = op.getLoc();
+ Value src = op.getSource();
+
+ SparseTensorType srcStt = getSparseTensorType(op.getSource());
+ SparseTensorType dstStt = getSparseTensorType(op.getDest());
+
+ // We traverse the source tensor in the same level order as specified
+ // by the destinate tensor if the destinate tensor should be sorted.
+ AffineMap foreachOrder = dstStt.isAllOrdered()
+ ? dstStt.getExpandedDimToLvl()
+ : srcStt.getExpandedDimToLvl();
+
+ bool spSrc = srcStt.hasEncoding();
+ SmallVector<Value> sizes;
+ sizesFromSrc(rewriter, sizes, loc, src);
+ ValueRange vs;
+ TensorLike dstBuf(rewriter, loc, dstStt.getRankedTensorType(), sizes);
+ auto foreachOp = rewriter.create<ForeachOp>(
+ loc, src, dstBuf.getIterSSA(), AffineMapAttr::get(foreachOrder),
+ [&](OpBuilder &builder, Location loc, ValueRange dcvs, Value v,
+ ValueRange reduc) {
+ // Enters the loop, update the SSA value for insertion chain.
+ dstBuf.updateSSA(reduc.front());
+ const Dimension dimRank = dstStt.getDimRank();
+ const Level lvlRank = dstStt.getLvlRank();
+ SmallVector<Value> lcvs(lvlRank);
+ for (Dimension d = 0; d < dimRank; d++) {
+ // FIXME: `toStoredDim` is deprecated
+ lcvs[toStoredDim(dstStt.getEncoding(), d)] = dcvs[d];
+ }
+
+ if (!spSrc) {
+ Value cond = genIsNonzero(builder, loc, v);
+ auto ifOp = builder.create<scf::IfOp>(loc, reduc.getTypes(), cond,
+ /*else*/ true);
+ builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
+ builder.create<scf::YieldOp>(loc, dstBuf.getIterSSA());
+
+ builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
+ dstBuf.insertOrStore(builder, loc, v, lcvs);
+ builder.create<scf::YieldOp>(loc, dstBuf.getIterSSA());
+
+ // Exits the ifOp, update the sparse tensor SSA value.
+ builder.setInsertionPointAfter(ifOp);
+ dstBuf.updateSSA(ifOp.getResult(0));
+ } else {
+ dstBuf.insertOrStore(builder, loc, v, lcvs);
+ }
+ builder.create<sparse_tensor::YieldOp>(loc, dstBuf.getIterSSA());
+ });
+
+ rewriter.setInsertionPointAfter(foreachOp);
+
+ // Exits the for loop, links the SSA chain.
+ dstBuf.updateSSA(foreachOp.getResult(0));
+
+ Value ret = dstBuf.finalize(rewriter, loc, dstStt.getRankedTensorType());
+ rewriter.replaceOp(op, ret);
+ return success();
+ }
+};
+
/// Sparse rewriting rule for the convert operator.
-struct ConvertRewriter : public OpRewritePattern<ConvertOp> {
+struct SortConvertRewriter : public OpRewritePattern<ConvertOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(ConvertOp op,
PatternRewriter &rewriter) const override {
+ if (!op.directConvertable())
+ return op.emitError("ConvertOp not in conanical form.");
+
+ if (!op.isSortCOOConvert())
+ return failure();
+
auto encDst = getSparseTensorEncoding(op.getType());
auto encSrc = getSparseTensorEncoding(op.getSource().getType());
if (encDst && encSrc && !encSrc.isSlice() &&
@@ -1048,8 +1168,6 @@ struct ConvertRewriter : public OpRewritePattern<ConvertOp> {
// We don't need a temporary COO tensor if the destination has an identity
// ordering. Otherwise, we use the destination ordering for the temporary
// COO tensor.
- // TODO: enhance foreachOp to take ordering to remove the need of a
- // temporary COO tensor here.
const RankedTensorType bufferTp =
getBufferType(dstTp, !dstTp.isIdentity() && !fromSparseConst);
// Only imposes foreach order on dense constant (which will be statically
@@ -1482,10 +1600,13 @@ void mlir::populatePostSparsificationRewriting(RewritePatternSet &patterns,
if (enableForeach)
patterns.add<ForeachRewriter>(patterns.getContext());
+ if (enableConvert)
+ patterns.add<DirectConvertRewriter>(patterns.getContext());
+
// TODO: If RT not enabled, rewrite concatenate ops, etc here.
if (!enableRT) {
patterns.add<NewRewriter, OutRewriter>(patterns.getContext());
if (enableConvert)
- patterns.add<ConvertRewriter>(patterns.getContext());
+ patterns.add<SortConvertRewriter>(patterns.getContext());
}
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
index 9b5567814a75f32..a41c240b1ff2b3b 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
@@ -141,6 +141,7 @@ class SparsificationAndBufferizationPass
{
OpPassManager pm("builtin.module");
pm.addPass(createSparsificationPass(sparsificationOptions));
+ pm.addPass(createCanonicalizerPass());
pm.addPass(createPostSparsificationRewritePass(enableRuntimeLibrary));
if (vectorLength > 0) {
pm.addPass(mlir::createLoopInvariantCodeMotionPass());
diff --git a/mlir/test/Dialect/SparseTensor/convert_sparse2sparse.mlir b/mlir/test/Dialect/SparseTensor/convert_sparse2sparse.mlir
index c373fd23bbef492..cf7b1bc11986efa 100644
--- a/mlir/test/Dialect/SparseTensor/convert_sparse2sparse.mlir
+++ b/mlir/test/Dialect/SparseTensor/convert_sparse2sparse.mlir
@@ -1,4 +1,6 @@
// First use with `kViaCOO` for sparse2sparse conversion (the old way).
+// RUN: mlir-opt %s --canonicalize --cse | FileCheck %s -check-prefix=CHECK-CANON
+//
// RUN: mlir-opt %s --sparse-tensor-conversion="s2s-strategy=1" \
// RUN: --canonicalize --cse | FileCheck %s -check-prefix=CHECK-COO
//
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_foreach_slices.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_foreach_slices.mlir
index e0dd31b2ca8671c..88447b9cad125d9 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_foreach_slices.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_foreach_slices.mlir
@@ -171,41 +171,44 @@ module {
// The same slice, but with dynamic encoding.
// TODO: Investigates why reusing the same %tmp above would cause bufferization
// errors.
- %tmp1 = sparse_tensor.convert %sa : tensor<8x8xf64> to tensor<8x8xf64, #CSR>
- %a_dyn = tensor.extract_slice %tmp1[%c1, %c1][%c4, %c4][%c1, %c2] : tensor<8x8xf64, #CSR> to
- tensor<?x?xf64, #CSR_SLICE_DYN>
+ //
+ // FIXME: The canonicalizer for tensor.extract_slice does not work with sparse tensors.
+ //
+ // %tmp1 = sparse_tensor.convert %sa : tensor<8x8xf64> to tensor<8x8xf64, #CSR>
+ // %a_dyn = tensor.extract_slice %tmp1[%c1, %c1][%c4, %c4][%c1, %c2] : tensor<8x8xf64, #CSR> to
+ // tensor<?x?xf64, #CSR_SLICE_DYN>
+ // %tmp1_coo = sparse_tensor.convert %sa : tensor<8x8xf64> to tensor<8x8xf64, #COO>
+ // %a_dyn_coo = tensor.extract_slice %tmp1_coo[%c1, %c1][%c4, %c4][%c1, %c2] : tensor<8x8xf64, #COO> to
+ // tensor<?x?xf64, #COO_SLICE_DYN>
- %tmp1_coo = sparse_tensor.convert %sa : tensor<8x8xf64> to tensor<8x8xf64, #COO>
- %a_dyn_coo = tensor.extract_slice %tmp1_coo[%c1, %c1][%c4, %c4][%c1, %c2] : tensor<8x8xf64, #COO> to
- tensor<?x?xf64, #COO_SLICE_DYN>
//
- // CHECK-NEXT: 1
- // CHECK-NEXT: 0
- // CHECK-NEXT: 2.3
- // CHECK-NEXT: 2
- // CHECK-NEXT: 3
- // CHECK-NEXT: 1
- // CHECK-NEXT: 3
- // CHECK-NEXT: 2
- // CHECK-NEXT: 2.1
+ // C_HECK-NEXT: 1
+ // C_HECK-NEXT: 0
+ // C_HECK-NEXT: 2.3
+ // C_HECK-NEXT: 2
+ // C_HECK-NEXT: 3
+ // C_HECK-NEXT: 1
+ // C_HECK-NEXT: 3
+ // C_HECK-NEXT: 2
+ // C_HECK-NEXT: 2.1
//
- call @foreach_print_slice_dyn(%a_dyn) : (tensor<?x?xf64, #CSR_SLICE_DYN>) -> ()
- // CHECK-NEXT: 1
- // CHECK-NEXT: 0
- // CHECK-NEXT: 2.3
- // CHECK-NEXT: 2
- // CHECK-NEXT: 3
- // CHECK-NEXT: 1
- // CHECK-NEXT: 3
- // CHECK-NEXT: 2
- // CHECK-NEXT: 2.1
+ // call @foreach_print_slice_dyn(%a_dyn) : (tensor<?x?xf64, #CSR_SLICE_DYN>) -> ()
+ // C_HECK-NEXT: 1
+ // C_HECK-NEXT: 0
+ // C_HECK-NEXT: 2.3
+ // C_HECK-NEXT: 2
+ // C_HECK-NEXT: 3
+ // C_HECK-NEXT: 1
+ // C_HECK-NEXT: 3
+ // C_HECK-NEXT: 2
+ // C_HECK-NEXT: 2.1
//
- call @foreach_print_slice_coo_dyn(%a_dyn_coo) : (tensor<?x?xf64, #COO_SLICE_DYN>) -> ()
+ // call @foreach_print_slice_coo_dyn(%a_dyn_coo) : (tensor<?x?xf64, #COO_SLICE_DYN>) -> ()
bufferization.dealloc_tensor %tmp : tensor<8x8xf64, #CSR>
- bufferization.dealloc_tensor %tmp1 : tensor<8x8xf64, #CSR>
+ //bufferization.dealloc_tensor %tmp1 : tensor<8x8xf64, #CSR>
bufferization.dealloc_tensor %tmp_coo : tensor<8x8xf64, #COO>
- bufferization.dealloc_tensor %tmp1_coo : tensor<8x8xf64, #COO>
+ //bufferization.dealloc_tensor %tmp1_coo : tensor<8x8xf64, #COO>
bufferization.dealloc_tensor %b : tensor<4x4xf64, #CSR>
return
}
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_matmul_slice.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_matmul_slice.mlir
index 21934fd72f018e9..6794a1bde0c50f2 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_matmul_slice.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_matmul_slice.mlir
@@ -231,21 +231,23 @@ module {
%c4u_coo = tensor.cast %c4_coo : tensor<4x4xf64> to tensor<*xf64>
call @printMemrefF64(%c4u_coo) : (tensor<*xf64>) -> ()
+ // FIXME: The canonicalizer for tensor.extract_slice does not work with sparse tensors.
+ //
// slice x slice (same as above, but with dynamic stride information)
//
- // CHECK: [2.3, 0, 0, 0],
- // CHECK-NEXT: [6.9, 0, 0, 0],
- // CHECK-NEXT: [0, 0, 0, 0],
- // CHECK-NEXT: [12.6, 0, 0, 0]]
+ // C_HECK: [2.3, 0, 0, 0],
+ // C_HECK-NEXT: [6.9, 0, 0, 0],
+ // C_HECK-NEXT: [0, 0, 0, 0],
+ // C_HECK-NEXT: [12.6, 0, 0, 0]]
//
- %s1_dyn = tensor.extract_slice %tmp[%c_0, %c_1][4, 4][%c_2, %c_1] : tensor<8x8xf64, #DCSR> to tensor<4x4xf64, #DCSR_SLICE_dyn>
- %s2_dyn = tensor.extract_slice %b1[%c_0, %c_0][4, 4][%c_2, %c_1] : tensor<8x4xf64, #CSR> to tensor<4x4xf64, #CSR_SLICE_dyn>
- %dyn_4 = call @matmul_dyn(%s2_dyn, %s1_dyn)
- : (tensor<4x4xf64, #CSR_SLICE_dyn>,
- tensor<4x4xf64, #DCSR_SLICE_dyn>) -> tensor<4x4xf64, #CSR>
- %c4_dyn = sparse_tensor.convert %dyn_4 : tensor<4x4xf64, #CSR> to tensor<4x4xf64>
- %c4u_dyn = tensor.cast %c4_dyn : tensor<4x4xf64> to tensor<*xf64>
- call @printMemrefF64(%c4u_dyn) : (tensor<*xf64>) -> ()
+ // %s1_dyn = tensor.extract_slice %tmp[%c_0, %c_1][4, 4][%c_2, %c_1] : tensor<8x8xf64, #DCSR> to tensor<4x4xf64, #DCSR_SLICE_dyn>
+ // %s2_dyn = tensor.extract_slice %b1[%c_0, %c_0][4, 4][%c_2, %c_1] : tensor<8x4xf64, #CSR> to tensor<4x4xf64, #CSR_SLICE_dyn>
+ // %dyn_4 = call @matmul_dyn(%s2_dyn, %s1_dyn)
+ // : (tensor<4x4xf64, #CSR_SLICE_dyn>,
+ // tensor<4x4xf64, #DCSR_SLICE_dyn>) -> tensor<4x4xf64, #CSR>
+ // %c4_dyn = sparse_tensor.convert %dyn_4 : tensor<4x4xf64, #CSR> to tensor<4x4xf64>
+ // %c4u_dyn = tensor.cast %c4_dyn : tensor<4x4xf64> to tensor<*xf64>
+ // call @printMemrefF64(%c4u_dyn) : (tensor<*xf64>) -> ()
// sparse slices should generate the same result as dense slices
//
@@ -274,7 +276,7 @@ module {
bufferization.dealloc_tensor %4 : tensor<4x4xf64, #CSR>
bufferization.dealloc_tensor %3 : tensor<4x4xf64, #CSR>
bufferization.dealloc_tensor %2 : tensor<4x4xf64, #DCSR>
- bufferization.dealloc_tensor %dyn_4 : tensor<4x4xf64, #CSR>
+ // bufferization.dealloc_tensor %dyn_4 : tensor<4x4xf64, #CSR>
return
}
>From dbb1ebbabd07c48a035dc718d09a03db7ce6157f Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Wed, 4 Oct 2023 23:29:57 +0000
Subject: [PATCH 2/9] implement direct convert rewriter (cont.)
---
.../Transforms/SparseTensorRewriting.cpp | 20 +++++++++----------
.../SparseTensor/convert_sparse2sparse.mlir | 12 +++++------
2 files changed, 14 insertions(+), 18 deletions(-)
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index a095931625a2070..bcaad7af7e14e24 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -1038,11 +1038,10 @@ struct DirectConvertRewriter : public OpRewritePattern<ConvertOp> {
SparseTensorType srcStt = getSparseTensorType(op.getSource());
SparseTensorType dstStt = getSparseTensorType(op.getDest());
- // We traverse the source tensor in the same level order as specified
- // by the destinate tensor if the destinate tensor should be sorted.
- AffineMap foreachOrder = dstStt.isAllOrdered()
- ? dstStt.getExpandedDimToLvl()
- : srcStt.getExpandedDimToLvl();
+ const AffineMapAttr foreachOrder =
+ (!dstStt.isIdentity() && !srcStt.hasEncoding())
+ ? AffineMapAttr::get(dstStt.getExpandedDimToLvl())
+ : nullptr;
bool spSrc = srcStt.hasEncoding();
SmallVector<Value> sizes;
@@ -1050,7 +1049,7 @@ struct DirectConvertRewriter : public OpRewritePattern<ConvertOp> {
ValueRange vs;
TensorLike dstBuf(rewriter, loc, dstStt.getRankedTensorType(), sizes);
auto foreachOp = rewriter.create<ForeachOp>(
- loc, src, dstBuf.getIterSSA(), AffineMapAttr::get(foreachOrder),
+ loc, src, dstBuf.getIterSSA(), foreachOrder,
[&](OpBuilder &builder, Location loc, ValueRange dcvs, Value v,
ValueRange reduc) {
// Enters the loop, update the SSA value for insertion chain.
@@ -1600,13 +1599,12 @@ void mlir::populatePostSparsificationRewriting(RewritePatternSet &patterns,
if (enableForeach)
patterns.add<ForeachRewriter>(patterns.getContext());
- if (enableConvert)
- patterns.add<DirectConvertRewriter>(patterns.getContext());
-
// TODO: If RT not enabled, rewrite concatenate ops, etc here.
if (!enableRT) {
patterns.add<NewRewriter, OutRewriter>(patterns.getContext());
- if (enableConvert)
- patterns.add<SortConvertRewriter>(patterns.getContext());
+ if (enableConvert) {
+ patterns.add<DirectConvertRewriter>(patterns.getContext());
+ // patterns.add<SortConvertRewriter>(patterns.getContext());
+ }
}
}
diff --git a/mlir/test/Dialect/SparseTensor/convert_sparse2sparse.mlir b/mlir/test/Dialect/SparseTensor/convert_sparse2sparse.mlir
index cf7b1bc11986efa..3bda9b336c68004 100644
--- a/mlir/test/Dialect/SparseTensor/convert_sparse2sparse.mlir
+++ b/mlir/test/Dialect/SparseTensor/convert_sparse2sparse.mlir
@@ -1,6 +1,4 @@
// First use with `kViaCOO` for sparse2sparse conversion (the old way).
-// RUN: mlir-opt %s --canonicalize --cse | FileCheck %s -check-prefix=CHECK-CANON
-//
// RUN: mlir-opt %s --sparse-tensor-conversion="s2s-strategy=1" \
// RUN: --canonicalize --cse | FileCheck %s -check-prefix=CHECK-COO
//
@@ -115,13 +113,13 @@ func.func @sparse_convert(%arg0: tensor<?xf32, #SparseVector64>) -> tensor<?xf32
}
#SparseSingleton64 = #sparse_tensor.encoding<{
- map = (d0) -> (d0 : singleton),
+ map = (d0) -> (d0 : compressed),
posWidth = 64,
crdWidth = 64
}>
#SparseSingleton32 = #sparse_tensor.encoding<{
- map = (d0) -> (d0 : singleton),
+ map = (d0) -> (d0 : compressed),
posWidth = 32,
crdWidth = 32
}>
@@ -190,9 +188,9 @@ func.func @sparse_convert_singleton(%arg0: tensor<?xf32, #SparseSingleton64>) ->
// CHECK-RWT: %[[VAL_28:.*]] = sparse_tensor.load %[[VAL_29:.*]] hasInserts
// CHECK-RWT: %[[VAL_30:.*]] = sparse_tensor.convert %[[VAL_28]]
// CHECK-RWT: return %[[VAL_30]]
-func.func @sparse_convert_permuted(%arg0: tensor<?x?x?xf32, #SortedCOO3D>) -> tensor<?x?x?xf32, #TsssPermuted> {
- %0 = sparse_tensor.convert %arg0 : tensor<?x?x?xf32, #SortedCOO3D> to tensor<?x?x?xf32, #TsssPermuted>
- return %0 : tensor<?x?x?xf32, #TsssPermuted>
+func.func @sparse_convert_permuted(%arg0: tensor<2x3x4xf32, #SortedCOO3D>) -> tensor<2x3x4xf32, #TsssPermuted> {
+ %0 = sparse_tensor.convert %arg0 : tensor<2x3x4xf32, #SortedCOO3D> to tensor<2x3x4xf32, #TsssPermuted>
+ return %0 : tensor<2x3x4xf32, #TsssPermuted>
}
// CHECK-RWT-LABEL: func.func @sparse_convert_slice(
>From 56627733eecbee7c580eb2a7a29e65d9760f3b40 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Thu, 5 Oct 2023 20:10:14 +0000
Subject: [PATCH 3/9] pass all integrate test
---
.../SparseTensor/IR/SparseTensorDialect.cpp | 7 +-
.../Transforms/SparseTensorCodegen.cpp | 58 +++
.../Transforms/SparseTensorRewriting.cpp | 351 +++---------------
3 files changed, 107 insertions(+), 309 deletions(-)
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 0fe1ed165b041c9..425e7b0009714da 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -1068,7 +1068,7 @@ OpFoldResult ConvertOp::fold(FoldAdaptor adaptor) {
bool ConvertOp::directConvertable() {
if (isSortCOOConvert())
- return true;
+ return false;
SparseTensorType srcStt = getSparseTensorType(getSource());
SparseTensorType dstStt = getSparseTensorType(getDest());
@@ -1100,6 +1100,7 @@ bool ConvertOp::isSortCOOConvert() {
// the conversion between COOs (but with different ordering).
return isUniqueCOOType(getSource().getType()) &&
isUniqueCOOType(getDest().getType()) &&
+ !getSparseTensorType(getSource()).isAllOrdered() &&
getSparseTensorType(getDest()).isAllOrdered();
}
@@ -1108,7 +1109,7 @@ struct StageUnorderedConvert : public OpRewritePattern<ConvertOp> {
LogicalResult matchAndRewrite(ConvertOp op,
PatternRewriter &rewriter) const override {
- if (op.directConvertable())
+ if (op.directConvertable() || op.isSortCOOConvert())
return failure();
Location loc = op.getLoc();
@@ -1122,7 +1123,7 @@ struct StageUnorderedConvert : public OpRewritePattern<ConvertOp> {
// The tmp COO must be unordered, otherwise it is a direct conversion.
assert(!(srcStt.hasSameDimToLvl(dstStt) && srcStt.isAllOrdered()));
Type srcCOOTp = getCOOFromTypeWithOrdering(
- srcStt.getRankedTensorType(), dstStt.getDimToLvl(), /*ordered=*/false);
+ dstStt.getRankedTensorType(), dstStt.getDimToLvl(), /*ordered=*/false);
Value srcCOO = rewriter.create<ConvertOp>(loc, srcCOOTp, op.getSource());
// -> sort
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index 2c03f0a6020e6a8..037962662bac77f 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -679,6 +679,60 @@ class SparseDimOpConverter : public OpConversionPattern<tensor::DimOp> {
}
};
+#ifndef NDEBUG
+LLVM_ATTRIBUTE_UNUSED static void dumpIndexMemRef(OpBuilder &builder,
+ Location loc, Value memref) {
+ memref = builder.create<memref::CastOp>(
+ loc, UnrankedMemRefType::get(builder.getIndexType(), 0), memref);
+ createFuncCall(builder, loc, "printMemrefInd", TypeRange{},
+ ValueRange{memref}, EmitCInterface::On);
+}
+#endif
+
+// TODO: use a new SortCOO operation here instead of reusing convert op.
+struct SparseSortCOOConverter : public OpConversionPattern<ConvertOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(ConvertOp op, ConvertOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ // Direct conversion should have already been lowered.
+ if (!op.isSortCOOConvert())
+ return failure();
+
+ Location loc = op.getLoc();
+ MLIRContext *ctx = op.getContext();
+
+ SparseTensorType srcStt = getSparseTensorType(op.getSource());
+ SparseTensorType dstStt = getSparseTensorType(op.getDest());
+
+ // TODO: This should be verification rules for sort_coo operation.
+ assert(dstStt.isAllOrdered() && !srcStt.isAllOrdered() &&
+ isUniqueCOOType(srcStt.getRankedTensorType()) &&
+ isUniqueCOOType(dstStt.getRankedTensorType()));
+
+ assert(dstStt.hasSameDimToLvl(srcStt));
+
+ // We don't need a mutable descriptor here as we perform sorting in-place.
+ auto nnz = genValMemSize(rewriter, op.getLoc(), adaptor.getSource());
+ auto desc = getDescriptorFromTensorTuple(adaptor.getSource());
+ auto crd = desc.getAOSMemRef();
+ auto val = desc.getValMemRef();
+
+ // Otherwise we need another data shuffle and a non-identity map.
+ assert(dstStt.hasSameDimToLvl(srcStt));
+ auto id = AffineMap::getMultiDimIdentityMap(srcStt.getLvlRank(), ctx);
+
+ rewriter.create<SortOp>(loc, nnz, crd, ValueRange{val}, id,
+ rewriter.getIndexAttr(0),
+ SparseTensorSortKind::HybridQuickSort);
+
+ // Since we do in-place sorting, the destinate tensor will have the same set
+ // of memrefs as the source tensor.
+ rewriter.replaceOp(op, adaptor.getSource());
+ return success();
+ }
+};
+
template <typename Op, StorageSpecifierKind kind>
class SparseSliceGetterOpConverter : public OpConversionPattern<Op> {
public:
@@ -1101,6 +1155,9 @@ class SparseConvertConverter : public OpConversionPattern<ConvertOp> {
LogicalResult
matchAndRewrite(ConvertOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
+ if (op.isSortCOOConvert())
+ return failure();
+
SparseTensorEncodingAttr encDst = getSparseTensorEncoding(op.getType());
SparseTensorEncodingAttr encSrc =
getSparseTensorEncoding(op.getSource().getType());
@@ -1554,6 +1611,7 @@ void mlir::populateSparseTensorCodegenPatterns(
SparseCastConverter, SparseExtractSliceConverter,
SparseTensorLoadConverter, SparseExpandConverter,
SparseCompressConverter, SparseInsertConverter,
+ SparseSortCOOConverter,
SparseSliceGetterOpConverter<ToSliceOffsetOp,
StorageSpecifierKind::DimOffset>,
SparseSliceGetterOpConverter<ToSliceStrideOp,
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index bcaad7af7e14e24..592852f87ba1e04 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -883,8 +883,7 @@ struct ConcatenateRewriter : public OpRewritePattern<ConcatenateOp> {
}
needTmpCOO = !allDense && !allOrdered;
- const RankedTensorType tp =
- getBufferType(dstTp.withoutDimToLvl(), needTmpCOO);
+ const RankedTensorType tp = getBufferType(dstTp, needTmpCOO);
encDst = needTmpCOO ? getSparseTensorEncoding(tp) : encDst;
SmallVector<Value> dynSizes;
getDynamicSizes(dstTp, sizes, dynSizes);
@@ -1003,7 +1002,10 @@ struct TensorLike {
builder.create<memref::StoreOp>(loc, v, val, crds);
}
- Value getIterSSA() const { return val; }
+ Value getSSA() const {
+ // We don't need to maintain the SSA chain for a memref value.
+ return isSparse ? val : nullptr;
+ }
Value finalize(OpBuilder &builder, Location loc, RankedTensorType rtp) const {
if (isSparse)
@@ -1013,8 +1015,8 @@ struct TensorLike {
void updateSSA(Value v) {
// Dense memref is a non-SSA value.
- if (isSparse)
- val = v;
+ assert(isSparse);
+ val = v;
}
private:
@@ -1026,34 +1028,54 @@ struct DirectConvertRewriter : public OpRewritePattern<ConvertOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(ConvertOp op,
PatternRewriter &rewriter) const override {
- if (!op.directConvertable())
+ if (!op.directConvertable() && !op.isSortCOOConvert())
return op.emitError("ConvertOp not in conanical form.");
if (op.isSortCOOConvert())
return failure();
+ // TODO: Maybe we want a different operation for this too.
+ auto encDst = getSparseTensorEncoding(op.getType());
+ auto encSrc = getSparseTensorEncoding(op.getSource().getType());
+ if (encDst && encSrc && !encSrc.isSlice() &&
+ encSrc.withoutBitWidths() == encDst.withoutBitWidths()) {
+ // Trivial tensor conversion and simple element type conversion is handled
+ // in codegen.
+ return failure();
+ }
+
Location loc = op.getLoc();
Value src = op.getSource();
SparseTensorType srcStt = getSparseTensorType(op.getSource());
SparseTensorType dstStt = getSparseTensorType(op.getDest());
+ bool fromSparseConst = false;
+ if (auto constOp = op.getSource().getDefiningOp<arith::ConstantOp>())
+ if (dyn_cast<SparseElementsAttr>(constOp.getValue()))
+ fromSparseConst = true;
+
const AffineMapAttr foreachOrder =
- (!dstStt.isIdentity() && !srcStt.hasEncoding())
+ (!dstStt.isIdentity() && fromSparseConst)
? AffineMapAttr::get(dstStt.getExpandedDimToLvl())
: nullptr;
- bool spSrc = srcStt.hasEncoding();
+ bool skipZeroCheck = srcStt.hasEncoding() || fromSparseConst;
+
SmallVector<Value> sizes;
sizesFromSrc(rewriter, sizes, loc, src);
ValueRange vs;
TensorLike dstBuf(rewriter, loc, dstStt.getRankedTensorType(), sizes);
+
+ Value iterArg = dstBuf.getSSA();
auto foreachOp = rewriter.create<ForeachOp>(
- loc, src, dstBuf.getIterSSA(), foreachOrder,
+ loc, src, iterArg ? ValueRange{iterArg} : ValueRange{}, foreachOrder,
[&](OpBuilder &builder, Location loc, ValueRange dcvs, Value v,
ValueRange reduc) {
// Enters the loop, update the SSA value for insertion chain.
- dstBuf.updateSSA(reduc.front());
+ if (!reduc.empty())
+ dstBuf.updateSSA(reduc.front());
+
const Dimension dimRank = dstStt.getDimRank();
const Level lvlRank = dstStt.getLvlRank();
SmallVector<Value> lcvs(lvlRank);
@@ -1062,16 +1084,17 @@ struct DirectConvertRewriter : public OpRewritePattern<ConvertOp> {
lcvs[toStoredDim(dstStt.getEncoding(), d)] = dcvs[d];
}
- if (!spSrc) {
+ if (!skipZeroCheck) {
+ assert(!reduc.empty());
Value cond = genIsNonzero(builder, loc, v);
auto ifOp = builder.create<scf::IfOp>(loc, reduc.getTypes(), cond,
/*else*/ true);
builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
- builder.create<scf::YieldOp>(loc, dstBuf.getIterSSA());
+ builder.create<scf::YieldOp>(loc, dstBuf.getSSA());
builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
dstBuf.insertOrStore(builder, loc, v, lcvs);
- builder.create<scf::YieldOp>(loc, dstBuf.getIterSSA());
+ builder.create<scf::YieldOp>(loc, dstBuf.getSSA());
// Exits the ifOp, update the sparse tensor SSA value.
builder.setInsertionPointAfter(ifOp);
@@ -1079,13 +1102,17 @@ struct DirectConvertRewriter : public OpRewritePattern<ConvertOp> {
} else {
dstBuf.insertOrStore(builder, loc, v, lcvs);
}
- builder.create<sparse_tensor::YieldOp>(loc, dstBuf.getIterSSA());
+ if (reduc.empty())
+ builder.create<sparse_tensor::YieldOp>(loc);
+ else
+ builder.create<sparse_tensor::YieldOp>(loc, dstBuf.getSSA());
});
rewriter.setInsertionPointAfter(foreachOp);
// Exits the for loop, links the SSA chain.
- dstBuf.updateSSA(foreachOp.getResult(0));
+ if (!foreachOp.getResults().empty())
+ dstBuf.updateSSA(foreachOp.getResult(0));
Value ret = dstBuf.finalize(rewriter, loc, dstStt.getRankedTensorType());
rewriter.replaceOp(op, ret);
@@ -1093,293 +1120,6 @@ struct DirectConvertRewriter : public OpRewritePattern<ConvertOp> {
}
};
-/// Sparse rewriting rule for the convert operator.
-struct SortConvertRewriter : public OpRewritePattern<ConvertOp> {
- using OpRewritePattern::OpRewritePattern;
- LogicalResult matchAndRewrite(ConvertOp op,
- PatternRewriter &rewriter) const override {
- if (!op.directConvertable())
- return op.emitError("ConvertOp not in conanical form.");
-
- if (!op.isSortCOOConvert())
- return failure();
-
- auto encDst = getSparseTensorEncoding(op.getType());
- auto encSrc = getSparseTensorEncoding(op.getSource().getType());
- if (encDst && encSrc && !encSrc.isSlice() &&
- encSrc.withoutBitWidths() == encDst.withoutBitWidths()) {
- // Trivial tensor conversion and simple element type conversion is handled
- // in codegen.
- return failure();
- }
- // TODO: Add a cast before generating InsertOp.
- assert(op.getSource().getType().getElementType() ==
- op.getDest().getType().getElementType());
- if (encSrc && encDst)
- return sparse2SparseRewrite(op, rewriter);
- if (encSrc && !encDst)
- return sparse2DenseRewrite(op, rewriter);
- if (!encSrc && encDst)
- return dense2SparseRewrite(op, rewriter);
-
- // Dense-to-dense convert is a nop and handled by canonicalization.
- return failure();
- }
-
-private:
- // Handles sparse constant to sparse tensor or dense tensor to sparse tensor
- // conversion as follows:
- // t = new sparse COO tensor
- // fill t using src
- // dst = convert t
- //
- // To fill the COO tensor from a dense tensor:
- // for i1 in dim1
- // ..
- // for ik in dimk
- // val = a[i1,..,ik]
- // if val != 0
- // t->add(val, [i1,..,ik], [p1,..,pk])
- //
- // To fill the COO tensor from a sparse constant in COO format:
- // for i in range(NNZ)
- // val = values[i]
- // [i1,..,ik] = coordinates[i]
- // t->add(val, [i1,..,ik], [p1,..,pk])
- LogicalResult dense2SparseRewrite(ConvertOp op,
- PatternRewriter &rewriter) const {
- Location loc = op.getLoc();
- Value src = op.getSource();
- const auto dstTp = getSparseTensorType(op);
- SmallVector<Value> sizes;
- sizesFromSrc(rewriter, sizes, loc, src);
- SmallVector<Value> dynSizes;
- getDynamicSizes(dstTp, sizes, dynSizes);
-
- bool fromSparseConst = false;
- if (auto constOp = op.getSource().getDefiningOp<arith::ConstantOp>()) {
- if (dyn_cast<SparseElementsAttr>(constOp.getValue())) {
- fromSparseConst = true;
- }
- }
-
- const auto encDst = dstTp.getEncoding();
- // We don't need a temporary COO tensor if the destination has an identity
- // ordering. Otherwise, we use the destination ordering for the temporary
- // COO tensor.
- const RankedTensorType bufferTp =
- getBufferType(dstTp, !dstTp.isIdentity() && !fromSparseConst);
- // Only imposes foreach order on dense constant (which will be statically
- // sorted by the sparse compiler), otherwise the rotated loop sequence
- // results to bad cache locality.
- const AffineMapAttr foreachOrder =
- (!dstTp.isIdentity() && fromSparseConst)
- ? AffineMapAttr::get(dstTp.getExpandedDimToLvl())
- : nullptr;
- // TODO: This assertion is to match the behavior from before we merged
- // dimOrdering and higherOrdering into dimToLvl. Although the above
- // can construct `foreachOrder` for non-permutations, it's not clear
- // that the `foreachOp` below actually supports non-permutations.
- assert(!foreachOrder || dstTp.isPermutation());
-
- auto buffer =
- rewriter.create<AllocTensorOp>(loc, bufferTp, dynSizes).getResult();
- auto foreachOp = rewriter.create<ForeachOp>(
- loc, src, buffer, foreachOrder,
- [&](OpBuilder &builder, Location loc, ValueRange dcvs, Value v,
- ValueRange reduc) {
- Value input = reduc.front();
- const Dimension dimRank = dstTp.getDimRank();
- const Level lvlRank = dstTp.getLvlRank();
- SmallVector<Value> lcvs(lvlRank);
- for (Dimension d = 0; d < dimRank; d++)
- // FIXME: `toStoredDim` is deprecated
- lcvs[toStoredDim(encDst, d)] = dcvs[d];
- if (fromSparseConst) {
- input = builder.create<InsertOp>(loc, v, input, lcvs);
- } else {
- Value cond = genIsNonzero(builder, loc, v);
- auto ifOp = builder.create<scf::IfOp>(
- loc, TypeRange(input.getType()), cond, /*else*/ true);
- builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
- Value insert = builder.create<InsertOp>(loc, v, input, lcvs);
- builder.create<scf::YieldOp>(loc, insert);
- builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
- builder.create<scf::YieldOp>(loc, input);
- builder.setInsertionPointAfter(ifOp);
- input = ifOp.getResult(0);
- }
- builder.create<sparse_tensor::YieldOp>(loc, input);
- });
- rewriter.setInsertionPointAfter(op);
- src = rewriter.create<LoadOp>(loc, foreachOp.getResult(0), true);
- if (bufferTp != dstTp) {
- rewriter.replaceOpWithNewOp<ConvertOp>(op, dstTp.getRankedTensorType(),
- src);
- rewriter.create<DeallocTensorOp>(loc, src);
- } else {
- rewriter.replaceOp(op, src);
- }
-
- return success();
- }
-
- // Handles sparse tensor to dense tensor conversion as follows:
- // dst = new dense tensor;
- // foreach elemment in src
- // dst[element.coords] = element.value
- LogicalResult sparse2DenseRewrite(ConvertOp op,
- PatternRewriter &rewriter) const {
- Location loc = op->getLoc();
- RankedTensorType dstTp = getRankedTensorType(op);
- Value src = op.getSource();
- RankedTensorType srcTp = getRankedTensorType(src);
-
- SmallVector<Value> sizes;
- sizesForTensor(rewriter, sizes, loc, srcTp, src);
-
- Value dst = allocDenseTensor(rewriter, loc, dstTp, sizes);
-
- rewriter.create<ForeachOp>(loc, src, std::nullopt,
- [&](OpBuilder &builder, Location loc,
- ValueRange args, Value v, ValueRange reduc) {
- builder.create<memref::StoreOp>(loc, v, dst,
- args);
- builder.create<sparse_tensor::YieldOp>(loc);
- });
-
- rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, dstTp, dst);
- return success();
- }
-
- // Handles sparse tensor to sparse tensor conversion as follows:
- // if src is not COO
- // construct a COO to represent the src
- // sort the src COO
- // foreach elemment in the sorted src COO
- // insert element to dst
- LogicalResult sparse2SparseRewrite(ConvertOp op,
- PatternRewriter &rewriter) const {
- const Location loc = op->getLoc();
- // These two variables cannot be `const` because they're conditionally
- // changed below. Ideally we'd use `SparseTensorType` for `srcRTT`;
- // however that class's copy-ctor is implicitly deleted.
- Value src = op.getSource();
- auto srcRTT = getRankedTensorType(src);
- const auto dstTp = getSparseTensorType(op);
- const auto encDst = dstTp.getEncoding();
- const Level dstLvlRank = dstTp.getLvlRank();
- const Dimension dimRank = dstTp.getDimRank();
- // This assertion should be guaranteed by validity of the op,
- // but just for paranoia's sake.
- assert(static_cast<Dimension>(srcRTT.getRank()) == dimRank);
-
- SmallVector<Value> srcSizes;
- sizesForTensor(rewriter, srcSizes, loc, srcRTT, src);
- Value tmpCoo = Value();
- Value nnz = rewriter.create<NumberOfEntriesOp>(loc, src);
- // We need a tmp COO buffer if and only if
- // 1. the src tensor is not a COO and
- // 2. the src tensor is not ordered in the same way as the target
- // tensor (e.g., src tensor is not ordered or src tensor haves a different
- // dimToLvl).
- if (const SparseTensorType srcTp(srcRTT);
- !(srcTp.isAllOrdered() && srcTp.hasSameDimToLvl(dstTp))) {
- // Construct a COO tensor from the src tensor.
- // TODO: there may be cases for which more efficiently without
- // going through an intermediate COO, such as cases that only change
- // the overhead types.
- SmallVector<Value> dynSrcSizes;
- getDynamicSizes(srcRTT, srcSizes, dynSrcSizes);
- srcRTT = getCOOType(srcTp.withDimToLvl(dstTp), /*ordered=*/false);
- // Ensure that mutating `srcRTT` didn't invalidate `dimRank`.
- assert(static_cast<Dimension>(srcRTT.getRank()) == dimRank);
- tmpCoo = rewriter
- .create<AllocTensorOp>(loc, srcRTT, dynSrcSizes, Value(),
- /*sizeHint=*/nnz, Attribute())
- .getResult();
- auto foreachOp = rewriter.create<ForeachOp>(
- loc, src, tmpCoo,
- [&](OpBuilder &builder, Location loc, ValueRange dcvs, Value v,
- ValueRange reduc) {
- SmallVector<Value> dstLcvs(dstLvlRank);
- for (Dimension d = 0; d < dimRank; d++) {
- // FIXME: `toStoredDim` is deprecated
- Level l = toStoredDim(encDst, d);
- dstLcvs[l] = dcvs[d];
- }
- auto t = builder.create<InsertOp>(loc, v, reduc.front(), dstLcvs);
- builder.create<sparse_tensor::YieldOp>(loc, t);
- });
- src = rewriter.create<LoadOp>(loc, foreachOp.getResult(0), true);
- }
-
- // Now that the conditional is done, we can use `SparseTensorType`.
- const SparseTensorType srcTp(srcRTT);
-
- // Only need to sort if the srcTp is not already sorted (we faithfully take
- // the guarantee from the sparse tensor encoding).
- if (!srcTp.isAllOrdered()) {
- // Retrieve the values-array.
- Value y = genToValues(rewriter, loc, src);
- const auto encSrc = srcTp.getEncoding();
- // Builds the dstLvl -> srcLvl permutation maps.
- SmallVector<AffineExpr> es(dstLvlRank);
- const Level srcLvlRank = srcTp.getLvlRank();
- for (Level srcLvl = 0; srcLvl < srcLvlRank; srcLvl++) {
- // FIXME: `toOrigDim` is deprecated
- Dimension dim = toOrigDim(encSrc, srcLvl);
- // FIXME: `toStoredDim` is deprecated
- Level dstLvl = toStoredDim(encDst, dim);
- es[dstLvl] = rewriter.getAffineDimExpr(srcLvl);
- }
- auto xPerm = AffineMap::get(dstLvlRank, 0, es, rewriter.getContext());
- assert(xPerm.isPermutation()); // must be a permutation.
-
- Value xs = genToCoordinatesBuffer(rewriter, loc, src);
- rewriter.create<SortOp>(loc, nnz, xs, ValueRange{y}, xPerm,
- rewriter.getIndexAttr(0),
- SparseTensorSortKind::HybridQuickSort);
- }
-
- // For each element in the COO tensor, insert the element to the dst tensor.
- SmallVector<Value> dynDstSizes;
- getDynamicSizes(dstTp, srcSizes, dynDstSizes);
- Value dst = rewriter
- .create<AllocTensorOp>(loc, dstTp.getRankedTensorType(),
- dynDstSizes, Value(),
- /*sizeHint=*/nnz, Attribute())
- .getResult();
- SmallVector<Value> dstLcvs(dstLvlRank);
- auto foreachOp = rewriter.create<ForeachOp>(
- loc, src, dst,
- [&](OpBuilder &builder, Location loc, ValueRange dcvs, Value v,
- ValueRange reduc) {
- for (Dimension d = 0; d < dimRank; d++) {
- // FIXME: `toStoredDim` is deprecated
- Level l = toStoredDim(encDst, d);
- dstLcvs[l] = dcvs[d];
- }
- auto t = builder.create<InsertOp>(loc, v, reduc.front(), dstLcvs);
- builder.create<sparse_tensor::YieldOp>(loc, t);
- });
-
- // Release the temporary COO if it is created. Note that tmpCoo is
- // invalidated due to foreach and updated to src.
- if (tmpCoo)
- rewriter.create<DeallocTensorOp>(loc, src);
-
- // Directly replace op with dst results in bufferization error message
- // "sparse tensor allocation should not escape function".
- // As such, we insert a trivial tensor convert which will be removed by
- // codegen.
- rewriter.setInsertionPointAfter(op);
- auto t = rewriter.create<LoadOp>(loc, foreachOp.getResult(0), true);
- rewriter.replaceOpWithNewOp<ConvertOp>(op, dstTp.getRankedTensorType(), t);
- return success();
- }
-};
-
/// Sparse rewriting rule for the foreach operator.
struct ForeachRewriter : public OpRewritePattern<ForeachOp> {
public:
@@ -1599,12 +1339,11 @@ void mlir::populatePostSparsificationRewriting(RewritePatternSet &patterns,
if (enableForeach)
patterns.add<ForeachRewriter>(patterns.getContext());
- // TODO: If RT not enabled, rewrite concatenate ops, etc here.
if (!enableRT) {
patterns.add<NewRewriter, OutRewriter>(patterns.getContext());
- if (enableConvert) {
+ // TODO: Move this to a common path for both lib/codegen when libgen support
+ // lowering sort_coo.
+ if (enableConvert)
patterns.add<DirectConvertRewriter>(patterns.getContext());
- // patterns.add<SortConvertRewriter>(patterns.getContext());
- }
}
}
>From 3295796391cb672777b2fc32971fdd6cc110191d Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Thu, 5 Oct 2023 20:25:08 +0000
Subject: [PATCH 4/9] temporially disable a few tests
---
mlir/test/Dialect/SparseTensor/codegen_sparse_dealloc.mlir | 6 ++++--
mlir/test/Dialect/SparseTensor/convert_sparse2sparse.mlir | 5 +++--
mlir/test/Dialect/SparseTensor/sparse_vector_mv.mlir | 2 ++
3 files changed, 9 insertions(+), 4 deletions(-)
diff --git a/mlir/test/Dialect/SparseTensor/codegen_sparse_dealloc.mlir b/mlir/test/Dialect/SparseTensor/codegen_sparse_dealloc.mlir
index 59e568dd5de6461..e3799e519d3fd5d 100644
--- a/mlir/test/Dialect/SparseTensor/codegen_sparse_dealloc.mlir
+++ b/mlir/test/Dialect/SparseTensor/codegen_sparse_dealloc.mlir
@@ -1,8 +1,10 @@
-// RUN: mlir-opt %s --post-sparsification-rewrite="enable-runtime-library=false" \
+// UNSUPPORTED: target={{.*}}
+//
+// RUN: mlir-opt %s --canonicalize --post-sparsification-rewrite="enable-runtime-library=false" \
// RUN: --sparse-tensor-codegen=create-sparse-deallocs=false \
// RUN: --canonicalize --cse | FileCheck %s -check-prefix=CHECK-NO-DEALLOC
-// RUN: mlir-opt %s --post-sparsification-rewrite="enable-runtime-library=false" \
+// RUN: mlir-opt %s --canonicalize --post-sparsification-rewrite="enable-runtime-library=false" \
// RUN: --sparse-tensor-codegen=create-sparse-deallocs=true \
// RUN: --canonicalize --cse | FileCheck %s -check-prefix=CHECK-DEALLOC
diff --git a/mlir/test/Dialect/SparseTensor/convert_sparse2sparse.mlir b/mlir/test/Dialect/SparseTensor/convert_sparse2sparse.mlir
index 3bda9b336c68004..53c5e4d905ce1db 100644
--- a/mlir/test/Dialect/SparseTensor/convert_sparse2sparse.mlir
+++ b/mlir/test/Dialect/SparseTensor/convert_sparse2sparse.mlir
@@ -6,8 +6,9 @@
// RUN: mlir-opt %s --sparse-tensor-conversion="s2s-strategy=0" \
// RUN: --canonicalize --cse | FileCheck %s -check-prefixes=CHECK-AUTO,CHECK
-// RUN: mlir-opt %s --post-sparsification-rewrite="enable-runtime-library=false enable-foreach=false" \
-// RUN: --canonicalize --cse | FileCheck %s --check-prefix=CHECK-RWT
+// TODO: re-enable after sort_coo is implemented.
+// R_UN: mlir-opt %s --post-sparsification-rewrite="enable-runtime-library=false enable-foreach=false" \
+// R_UN: --canonicalize --cse | FileCheck %s --check-prefix=CHECK-RWT
#SparseVector64 = #sparse_tensor.encoding<{
map = (d0) -> (d0 : compressed),
diff --git a/mlir/test/Dialect/SparseTensor/sparse_vector_mv.mlir b/mlir/test/Dialect/SparseTensor/sparse_vector_mv.mlir
index 0170efeb33f561b..414266679049e70 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_vector_mv.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_vector_mv.mlir
@@ -1,3 +1,5 @@
+// UNSUPPORTED: target={{.*}}
+//
// RUN: mlir-opt %s -sparse-compiler="vl=8" | FileCheck %s
#Dense = #sparse_tensor.encoding<{
>From 8f6270eaa526d7cc128fb5450855fb3f97cbb5b1 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Thu, 5 Oct 2023 20:34:17 +0000
Subject: [PATCH 5/9] revert unintended change
---
mlir/test/Dialect/SparseTensor/convert_sparse2sparse.mlir | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/mlir/test/Dialect/SparseTensor/convert_sparse2sparse.mlir b/mlir/test/Dialect/SparseTensor/convert_sparse2sparse.mlir
index 53c5e4d905ce1db..e7d3f14391540c4 100644
--- a/mlir/test/Dialect/SparseTensor/convert_sparse2sparse.mlir
+++ b/mlir/test/Dialect/SparseTensor/convert_sparse2sparse.mlir
@@ -189,9 +189,9 @@ func.func @sparse_convert_singleton(%arg0: tensor<?xf32, #SparseSingleton64>) ->
// CHECK-RWT: %[[VAL_28:.*]] = sparse_tensor.load %[[VAL_29:.*]] hasInserts
// CHECK-RWT: %[[VAL_30:.*]] = sparse_tensor.convert %[[VAL_28]]
// CHECK-RWT: return %[[VAL_30]]
-func.func @sparse_convert_permuted(%arg0: tensor<2x3x4xf32, #SortedCOO3D>) -> tensor<2x3x4xf32, #TsssPermuted> {
- %0 = sparse_tensor.convert %arg0 : tensor<2x3x4xf32, #SortedCOO3D> to tensor<2x3x4xf32, #TsssPermuted>
- return %0 : tensor<2x3x4xf32, #TsssPermuted>
+func.func @sparse_convert_permuted(%arg0: tensor<?x?x?xf32, #SortedCOO3D>) -> tensor<?x?x?xf32, #TsssPermuted> {
+ %0 = sparse_tensor.convert %arg0 : tensor<?x?x?xf32, #SortedCOO3D> to tensor<?x?x?xf32, #TsssPermuted>
+ return %0 : tensor<?x?x?xf32, #TsssPermuted>
}
// CHECK-RWT-LABEL: func.func @sparse_convert_slice(
>From 0c5d0105a28bd5dfc1ebb6f26011ba852eb196b0 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Thu, 5 Oct 2023 20:36:57 +0000
Subject: [PATCH 6/9] revert unintended change
---
mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp | 7 +++----
1 file changed, 3 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 425e7b0009714da..2fd67ab81c168f8 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -1348,10 +1348,9 @@ LogicalResult ConcatenateOp::verify() {
// If all dimension are statically known, the sum of all the input
// dimensions should be equal to the output dimension.
if (sumSz != dstSh)
- return emitError("The concatenation dimension of the output tensor "
- "should be the "
- "sum of all the concatenation dimensions of the "
- "input tensors.");
+ return emitError(
+ "The concatenation dimension of the output tensor should be the "
+ "sum of all the concatenation dimensions of the input tensors.");
}
} else {
DynSize prev = dstSh;
>From 570208957eacafa0b9a5e9a0a4d613358318f408 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Thu, 5 Oct 2023 20:38:03 +0000
Subject: [PATCH 7/9] revert unintended change
---
.../SparseTensor/Transforms/SparseTensorCodegen.cpp | 10 ----------
1 file changed, 10 deletions(-)
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index 037962662bac77f..f349eb054307c03 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -679,16 +679,6 @@ class SparseDimOpConverter : public OpConversionPattern<tensor::DimOp> {
}
};
-#ifndef NDEBUG
-LLVM_ATTRIBUTE_UNUSED static void dumpIndexMemRef(OpBuilder &builder,
- Location loc, Value memref) {
- memref = builder.create<memref::CastOp>(
- loc, UnrankedMemRefType::get(builder.getIndexType(), 0), memref);
- createFuncCall(builder, loc, "printMemrefInd", TypeRange{},
- ValueRange{memref}, EmitCInterface::On);
-}
-#endif
-
// TODO: use a new SortCOO operation here instead of reusing convert op.
struct SparseSortCOOConverter : public OpConversionPattern<ConvertOp> {
using OpConversionPattern::OpConversionPattern;
>From b74de4d995536c6becb3564cd4164579082ce159 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Fri, 6 Oct 2023 21:43:52 +0000
Subject: [PATCH 8/9] run staging in a separate pass
---
.../SparseTensor/IR/SparseTensorOps.td | 1 -
.../SparseTensor/IR/SparseTensorDialect.cpp | 48 -------------
.../SparsificationAndBufferizationPass.cpp | 2 +-
.../Transforms/StageSparseOperations.cpp | 67 ++++++++++++++++++-
.../SparseTensor/sparse_vector_mv.mlir | 2 -
.../CPU/sparse_foreach_slices.mlir | 58 ++++++++--------
.../SparseTensor/CPU/sparse_matmul_slice.mlir | 25 ++++---
7 files changed, 106 insertions(+), 97 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 680540235536880..0d446e7b787c66a 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -210,7 +210,6 @@ def SparseTensor_ConvertOp : SparseTensor_Op<"convert",
let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)";
let hasFolder = 1;
let hasVerifier = 1;
- let hasCanonicalizer = 1;
}
def SparseTensor_ToPositionsOp : SparseTensor_Op<"positions", [Pure]>,
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 2fd67ab81c168f8..f4b8411ed7f3f25 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -1104,54 +1104,6 @@ bool ConvertOp::isSortCOOConvert() {
getSparseTensorType(getDest()).isAllOrdered();
}
-struct StageUnorderedConvert : public OpRewritePattern<ConvertOp> {
- using OpRewritePattern<ConvertOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(ConvertOp op,
- PatternRewriter &rewriter) const override {
- if (op.directConvertable() || op.isSortCOOConvert())
- return failure();
-
- Location loc = op.getLoc();
- SparseTensorType srcStt = getSparseTensorType(op.getSource());
- SparseTensorType dstStt = getSparseTensorType(op.getDest());
-
- // Just to make sure that convert to dense tensor is always direct.
- assert(!dstStt.isAllDense());
-
- // source -> coo
- // The tmp COO must be unordered, otherwise it is a direct conversion.
- assert(!(srcStt.hasSameDimToLvl(dstStt) && srcStt.isAllOrdered()));
- Type srcCOOTp = getCOOFromTypeWithOrdering(
- dstStt.getRankedTensorType(), dstStt.getDimToLvl(), /*ordered=*/false);
- Value srcCOO = rewriter.create<ConvertOp>(loc, srcCOOTp, op.getSource());
-
- // -> sort
- Type dstCOOTp = getCOOFromTypeWithOrdering(
- dstStt.getRankedTensorType(), dstStt.getDimToLvl(), /*ordered=*/true);
- // TODO: this should be a sort_coo operation.
- Value dstCOO = rewriter.create<ConvertOp>(loc, dstCOOTp, srcCOO);
-
- // -> dest.
- if (dstCOO.getType() == op.getType()) {
- rewriter.replaceOp(op, dstCOO);
- } else {
- // Need an extra conversion if the target type is not COO.
- rewriter.replaceOpWithNewOp<ConvertOp>(op, op.getDest().getType(),
- dstCOO);
- }
- // TODO: deallocate extra COOs, we should probably delegate it to buffer
- // deallocation pass.
-
- return success();
- }
-};
-
-void ConvertOp::getCanonicalizationPatterns(RewritePatternSet &results,
- MLIRContext *context) {
- results.add<StageUnorderedConvert>(context);
-}
-
LogicalResult ToPositionsOp::verify() {
auto e = getSparseTensorEncoding(getTensor().getType());
if (failed(lvlIsInBounds(getLevel(), getTensor())))
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
index a41c240b1ff2b3b..1ed393dd44e20a2 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
@@ -141,7 +141,7 @@ class SparsificationAndBufferizationPass
{
OpPassManager pm("builtin.module");
pm.addPass(createSparsificationPass(sparsificationOptions));
- pm.addPass(createCanonicalizerPass());
+ pm.addNestedPass<func::FuncOp>(createStageSparseOperationsPass());
pm.addPass(createPostSparsificationRewritePass(enableRuntimeLibrary));
if (vectorLength > 0) {
pm.addPass(mlir::createLoopInvariantCodeMotionPass());
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/StageSparseOperations.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/StageSparseOperations.cpp
index 4adc4d131198cc7..60ac71de4dd71ca 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/StageSparseOperations.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/StageSparseOperations.cpp
@@ -1,4 +1,67 @@
+//===- StageSparseOperations.cpp - stage sparse ops rewriting rules -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
+#include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
-void mlir::populateStageSparseOperationsPatterns(
- RewritePatternSet & /*patterns*/) {}
+using namespace mlir;
+using namespace mlir::sparse_tensor;
+
+namespace {
+
+struct StageUnorderedConvert : public OpRewritePattern<ConvertOp> {
+ using OpRewritePattern<ConvertOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ConvertOp op,
+ PatternRewriter &rewriter) const override {
+ // TODO: Implement it as an Interface, this can be reused from other
+ // operations too (e.g., concatenate, reshape, etc).
+
+ if (op.directConvertable() || op.isSortCOOConvert())
+ return failure();
+
+ Location loc = op.getLoc();
+ SparseTensorType srcStt = getSparseTensorType(op.getSource());
+ SparseTensorType dstStt = getSparseTensorType(op.getDest());
+
+ // Just to make sure that convert to dense tensor is always direct.
+ assert(!dstStt.isAllDense());
+
+ // source -> coo
+ // The tmp COO must be unordered, otherwise it is a direct conversion.
+ assert(!(srcStt.hasSameDimToLvl(dstStt) && srcStt.isAllOrdered()));
+ Type srcCOOTp = getCOOFromTypeWithOrdering(
+ dstStt.getRankedTensorType(), dstStt.getDimToLvl(), /*ordered=*/false);
+ Value srcCOO = rewriter.create<ConvertOp>(loc, srcCOOTp, op.getSource());
+
+ // -> sort
+ Type dstCOOTp = getCOOFromTypeWithOrdering(
+ dstStt.getRankedTensorType(), dstStt.getDimToLvl(), /*ordered=*/true);
+ // TODO: this should be a sort_coo operation.
+ Value dstCOO = rewriter.create<ConvertOp>(loc, dstCOOTp, srcCOO);
+
+ // -> dest.
+ if (dstCOO.getType() == op.getType()) {
+ rewriter.replaceOp(op, dstCOO);
+ } else {
+ // Need an extra conversion if the target type is not COO.
+ rewriter.replaceOpWithNewOp<ConvertOp>(op, op.getDest().getType(),
+ dstCOO);
+ }
+ // TODO: deallocate extra COOs, we should probably delegate it to buffer
+ // deallocation pass.
+
+ return success();
+ }
+};
+} // namespace
+
+void mlir::populateStageSparseOperationsPatterns(RewritePatternSet &patterns) {
+ patterns.add<StageUnorderedConvert>(patterns.getContext());
+}
diff --git a/mlir/test/Dialect/SparseTensor/sparse_vector_mv.mlir b/mlir/test/Dialect/SparseTensor/sparse_vector_mv.mlir
index 414266679049e70..0170efeb33f561b 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_vector_mv.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_vector_mv.mlir
@@ -1,5 +1,3 @@
-// UNSUPPORTED: target={{.*}}
-//
// RUN: mlir-opt %s -sparse-compiler="vl=8" | FileCheck %s
#Dense = #sparse_tensor.encoding<{
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_foreach_slices.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_foreach_slices.mlir
index 88447b9cad125d9..bda9ebe9c9eb465 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_foreach_slices.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_foreach_slices.mlir
@@ -172,43 +172,41 @@ module {
// TODO: Investigates why reusing the same %tmp above would cause bufferization
// errors.
//
- // FIXME: The canonicalizer for tensor.extract_slice does not work with sparse tensors.
- //
- // %tmp1 = sparse_tensor.convert %sa : tensor<8x8xf64> to tensor<8x8xf64, #CSR>
- // %a_dyn = tensor.extract_slice %tmp1[%c1, %c1][%c4, %c4][%c1, %c2] : tensor<8x8xf64, #CSR> to
- // tensor<?x?xf64, #CSR_SLICE_DYN>
- // %tmp1_coo = sparse_tensor.convert %sa : tensor<8x8xf64> to tensor<8x8xf64, #COO>
- // %a_dyn_coo = tensor.extract_slice %tmp1_coo[%c1, %c1][%c4, %c4][%c1, %c2] : tensor<8x8xf64, #COO> to
- // tensor<?x?xf64, #COO_SLICE_DYN>
+ %tmp1 = sparse_tensor.convert %sa : tensor<8x8xf64> to tensor<8x8xf64, #CSR>
+ %a_dyn = tensor.extract_slice %tmp1[%c1, %c1][%c4, %c4][%c1, %c2] : tensor<8x8xf64, #CSR> to
+ tensor<?x?xf64, #CSR_SLICE_DYN>
+ %tmp1_coo = sparse_tensor.convert %sa : tensor<8x8xf64> to tensor<8x8xf64, #COO>
+ %a_dyn_coo = tensor.extract_slice %tmp1_coo[%c1, %c1][%c4, %c4][%c1, %c2] : tensor<8x8xf64, #COO> to
+ tensor<?x?xf64, #COO_SLICE_DYN>
//
- // C_HECK-NEXT: 1
- // C_HECK-NEXT: 0
- // C_HECK-NEXT: 2.3
- // C_HECK-NEXT: 2
- // C_HECK-NEXT: 3
- // C_HECK-NEXT: 1
- // C_HECK-NEXT: 3
- // C_HECK-NEXT: 2
- // C_HECK-NEXT: 2.1
+ // CHECK-NEXT: 1
+ // CHECK-NEXT: 0
+ // CHECK-NEXT: 2.3
+ // CHECK-NEXT: 2
+ // CHECK-NEXT: 3
+ // CHECK-NEXT: 1
+ // CHECK-NEXT: 3
+ // CHECK-NEXT: 2
+ // CHECK-NEXT: 2.1
//
- // call @foreach_print_slice_dyn(%a_dyn) : (tensor<?x?xf64, #CSR_SLICE_DYN>) -> ()
- // C_HECK-NEXT: 1
- // C_HECK-NEXT: 0
- // C_HECK-NEXT: 2.3
- // C_HECK-NEXT: 2
- // C_HECK-NEXT: 3
- // C_HECK-NEXT: 1
- // C_HECK-NEXT: 3
- // C_HECK-NEXT: 2
- // C_HECK-NEXT: 2.1
+ call @foreach_print_slice_dyn(%a_dyn) : (tensor<?x?xf64, #CSR_SLICE_DYN>) -> ()
+ // CHECK-NEXT: 1
+ // CHECK-NEXT: 0
+ // CHECK-NEXT: 2.3
+ // CHECK-NEXT: 2
+ // CHECK-NEXT: 3
+ // CHECK-NEXT: 1
+ // CHECK-NEXT: 3
+ // CHECK-NEXT: 2
+ // CHECK-NEXT: 2.1
//
- // call @foreach_print_slice_coo_dyn(%a_dyn_coo) : (tensor<?x?xf64, #COO_SLICE_DYN>) -> ()
+ call @foreach_print_slice_coo_dyn(%a_dyn_coo) : (tensor<?x?xf64, #COO_SLICE_DYN>) -> ()
bufferization.dealloc_tensor %tmp : tensor<8x8xf64, #CSR>
- //bufferization.dealloc_tensor %tmp1 : tensor<8x8xf64, #CSR>
+ bufferization.dealloc_tensor %tmp1 : tensor<8x8xf64, #CSR>
bufferization.dealloc_tensor %tmp_coo : tensor<8x8xf64, #COO>
- //bufferization.dealloc_tensor %tmp1_coo : tensor<8x8xf64, #COO>
+ bufferization.dealloc_tensor %tmp1_coo : tensor<8x8xf64, #COO>
bufferization.dealloc_tensor %b : tensor<4x4xf64, #CSR>
return
}
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_matmul_slice.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_matmul_slice.mlir
index 6794a1bde0c50f2..5923a115f4eb93c 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_matmul_slice.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_matmul_slice.mlir
@@ -231,23 +231,22 @@ module {
%c4u_coo = tensor.cast %c4_coo : tensor<4x4xf64> to tensor<*xf64>
call @printMemrefF64(%c4u_coo) : (tensor<*xf64>) -> ()
- // FIXME: The canonicalizer for tensor.extract_slice does not work with sparse tensors.
//
// slice x slice (same as above, but with dynamic stride information)
//
- // C_HECK: [2.3, 0, 0, 0],
- // C_HECK-NEXT: [6.9, 0, 0, 0],
- // C_HECK-NEXT: [0, 0, 0, 0],
- // C_HECK-NEXT: [12.6, 0, 0, 0]]
+ // CHECK: [2.3, 0, 0, 0],
+ // CHECK-NEXT: [6.9, 0, 0, 0],
+ // CHECK-NEXT: [0, 0, 0, 0],
+ // CHECK-NEXT: [12.6, 0, 0, 0]]
//
- // %s1_dyn = tensor.extract_slice %tmp[%c_0, %c_1][4, 4][%c_2, %c_1] : tensor<8x8xf64, #DCSR> to tensor<4x4xf64, #DCSR_SLICE_dyn>
- // %s2_dyn = tensor.extract_slice %b1[%c_0, %c_0][4, 4][%c_2, %c_1] : tensor<8x4xf64, #CSR> to tensor<4x4xf64, #CSR_SLICE_dyn>
- // %dyn_4 = call @matmul_dyn(%s2_dyn, %s1_dyn)
- // : (tensor<4x4xf64, #CSR_SLICE_dyn>,
- // tensor<4x4xf64, #DCSR_SLICE_dyn>) -> tensor<4x4xf64, #CSR>
- // %c4_dyn = sparse_tensor.convert %dyn_4 : tensor<4x4xf64, #CSR> to tensor<4x4xf64>
- // %c4u_dyn = tensor.cast %c4_dyn : tensor<4x4xf64> to tensor<*xf64>
- // call @printMemrefF64(%c4u_dyn) : (tensor<*xf64>) -> ()
+ %s1_dyn = tensor.extract_slice %tmp[%c_0, %c_1][4, 4][%c_2, %c_1] : tensor<8x8xf64, #DCSR> to tensor<4x4xf64, #DCSR_SLICE_dyn>
+ %s2_dyn = tensor.extract_slice %b1[%c_0, %c_0][4, 4][%c_2, %c_1] : tensor<8x4xf64, #CSR> to tensor<4x4xf64, #CSR_SLICE_dyn>
+ %dyn_4 = call @matmul_dyn(%s2_dyn, %s1_dyn)
+ : (tensor<4x4xf64, #CSR_SLICE_dyn>,
+ tensor<4x4xf64, #DCSR_SLICE_dyn>) -> tensor<4x4xf64, #CSR>
+ %c4_dyn = sparse_tensor.convert %dyn_4 : tensor<4x4xf64, #CSR> to tensor<4x4xf64>
+ %c4u_dyn = tensor.cast %c4_dyn : tensor<4x4xf64> to tensor<*xf64>
+ call @printMemrefF64(%c4u_dyn) : (tensor<*xf64>) -> ()
// sparse slices should generate the same result as dense slices
//
>From 7dfb4c653af5029fc7fbeb08fd3fa6e1ca48675d Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Fri, 6 Oct 2023 21:49:12 +0000
Subject: [PATCH 9/9] revert unintended change
---
.../SparseTensor/codegen_sparse_dealloc.mlir | 5 +++--
.../SparseTensor/CPU/sparse_foreach_slices.mlir | 13 ++++++-------
.../SparseTensor/CPU/sparse_matmul_slice.mlir | 3 +--
3 files changed, 10 insertions(+), 11 deletions(-)
diff --git a/mlir/test/Dialect/SparseTensor/codegen_sparse_dealloc.mlir b/mlir/test/Dialect/SparseTensor/codegen_sparse_dealloc.mlir
index e3799e519d3fd5d..24585832518de05 100644
--- a/mlir/test/Dialect/SparseTensor/codegen_sparse_dealloc.mlir
+++ b/mlir/test/Dialect/SparseTensor/codegen_sparse_dealloc.mlir
@@ -1,10 +1,11 @@
// UNSUPPORTED: target={{.*}}
+// temporially disabled (we probably do not need the option anymore by switch to buffer deallcation pass)
//
-// RUN: mlir-opt %s --canonicalize --post-sparsification-rewrite="enable-runtime-library=false" \
+// RUN: mlir-opt %s --post-sparsification-rewrite="enable-runtime-library=false" \
// RUN: --sparse-tensor-codegen=create-sparse-deallocs=false \
// RUN: --canonicalize --cse | FileCheck %s -check-prefix=CHECK-NO-DEALLOC
-// RUN: mlir-opt %s --canonicalize --post-sparsification-rewrite="enable-runtime-library=false" \
+// RUN: mlir-opt %s --post-sparsification-rewrite="enable-runtime-library=false" \
// RUN: --sparse-tensor-codegen=create-sparse-deallocs=true \
// RUN: --canonicalize --cse | FileCheck %s -check-prefix=CHECK-DEALLOC
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_foreach_slices.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_foreach_slices.mlir
index bda9ebe9c9eb465..e0dd31b2ca8671c 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_foreach_slices.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_foreach_slices.mlir
@@ -171,14 +171,13 @@ module {
// The same slice, but with dynamic encoding.
// TODO: Investigates why reusing the same %tmp above would cause bufferization
// errors.
- //
- %tmp1 = sparse_tensor.convert %sa : tensor<8x8xf64> to tensor<8x8xf64, #CSR>
- %a_dyn = tensor.extract_slice %tmp1[%c1, %c1][%c4, %c4][%c1, %c2] : tensor<8x8xf64, #CSR> to
- tensor<?x?xf64, #CSR_SLICE_DYN>
- %tmp1_coo = sparse_tensor.convert %sa : tensor<8x8xf64> to tensor<8x8xf64, #COO>
- %a_dyn_coo = tensor.extract_slice %tmp1_coo[%c1, %c1][%c4, %c4][%c1, %c2] : tensor<8x8xf64, #COO> to
- tensor<?x?xf64, #COO_SLICE_DYN>
+ %tmp1 = sparse_tensor.convert %sa : tensor<8x8xf64> to tensor<8x8xf64, #CSR>
+ %a_dyn = tensor.extract_slice %tmp1[%c1, %c1][%c4, %c4][%c1, %c2] : tensor<8x8xf64, #CSR> to
+ tensor<?x?xf64, #CSR_SLICE_DYN>
+ %tmp1_coo = sparse_tensor.convert %sa : tensor<8x8xf64> to tensor<8x8xf64, #COO>
+ %a_dyn_coo = tensor.extract_slice %tmp1_coo[%c1, %c1][%c4, %c4][%c1, %c2] : tensor<8x8xf64, #COO> to
+ tensor<?x?xf64, #COO_SLICE_DYN>
//
// CHECK-NEXT: 1
// CHECK-NEXT: 0
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_matmul_slice.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_matmul_slice.mlir
index 5923a115f4eb93c..53328c51c859a88 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_matmul_slice.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_matmul_slice.mlir
@@ -230,7 +230,6 @@ module {
%c4_coo = sparse_tensor.convert %o_coo : tensor<4x4xf64, #COO> to tensor<4x4xf64>
%c4u_coo = tensor.cast %c4_coo : tensor<4x4xf64> to tensor<*xf64>
call @printMemrefF64(%c4u_coo) : (tensor<*xf64>) -> ()
-
//
// slice x slice (same as above, but with dynamic stride information)
//
@@ -275,7 +274,7 @@ module {
bufferization.dealloc_tensor %4 : tensor<4x4xf64, #CSR>
bufferization.dealloc_tensor %3 : tensor<4x4xf64, #CSR>
bufferization.dealloc_tensor %2 : tensor<4x4xf64, #DCSR>
- // bufferization.dealloc_tensor %dyn_4 : tensor<4x4xf64, #CSR>
+ bufferization.dealloc_tensor %dyn_4 : tensor<4x4xf64, #CSR>
return
}
More information about the Mlir-commits
mailing list