[Mlir-commits] [mlir] 330d48c - [mlir][sparse] Add rewrite rules for sparse-to-sparse reshape operators.
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Oct 6 08:50:39 PDT 2022
Author: bixia1
Date: 2022-10-06T08:50:30-07:00
New Revision: 330d48c4aaf0ccb0b28b9941cea609beb64bc27c
URL: https://github.com/llvm/llvm-project/commit/330d48c4aaf0ccb0b28b9941cea609beb64bc27c
DIFF: https://github.com/llvm/llvm-project/commit/330d48c4aaf0ccb0b28b9941cea609beb64bc27c.diff
LOG: [mlir][sparse] Add rewrite rules for sparse-to-sparse reshape operators.
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D135077
Added:
Modified:
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
mlir/test/Dialect/SparseTensor/sparse_reshape.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index fbdcc83902e8..9dfcaec93b98 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -111,6 +111,20 @@ static bool isZeroYield(GenericOp op) {
return isZeroValue(yieldOp.getOperand(0));
}
+/// Populates given sizes array from type (for static sizes) and from
+/// the tensor (for dynamic sizes).
+static void sizesForTensor(OpBuilder &builder, SmallVector<Value, 4> &sizes,
+ Location loc, ShapedType stp, Value tensor) {
+ for (const auto &d : enumerate(stp.getShape())) {
+ Value dim;
+ if (d.value() == ShapedType::kDynamicSize)
+ dim = builder.create<tensor::DimOp>(loc, tensor, d.index());
+ else
+ dim = constantIndex(builder, loc, d.value());
+ sizes.push_back(dim);
+ }
+}
+
// TODO: The dim level property of the COO type relies on input tensors, the
// shape relies on the output tensor
// Helpers to setup a COO type.
@@ -119,8 +133,11 @@ static RankedTensorType getUnorderedCOOFromType(RankedTensorType src) {
auto rank = src.getRank();
SmallVector<SparseTensorEncodingAttr::DimLevelType, 4> dims;
- // An unordered and non-unique compressed dim at beginning.
- dims.push_back(SparseTensorEncodingAttr::DimLevelType::CompressedNuNo);
+ // An unordered and non-unique compressed dim at beginning unless the tensor
+ // is a 1D tensor.
+ if (rank > 1)
+ dims.push_back(SparseTensorEncodingAttr::DimLevelType::CompressedNuNo);
+
// TODO: it is actually ordered at the level for ordered input.
// Followed by unordered non-unique n-2 singleton levels.
std::fill_n(std::back_inserter(dims), rank - 2,
@@ -281,7 +298,72 @@ struct FuseSparseMultiplyOverAdd : public OpRewritePattern<GenericOp> {
}
};
-/// Sparse rewriting rule for reshape operator.
+/// Sparse rewriting rule for sparse-to-sparse reshape operator.
+template <typename ReshapeOp>
+struct Sparse2SparseReshapeRewriter : public OpRewritePattern<ReshapeOp> {
+public:
+ using OpRewritePattern<ReshapeOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ReshapeOp op,
+ PatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ Value srcTensor = op.getSrc();
+ auto srcTp = srcTensor.getType().template cast<RankedTensorType>();
+ auto dstTp = op.getResult().getType().template cast<RankedTensorType>();
+ SparseTensorEncodingAttr encSrc = getSparseTensorEncoding(srcTp);
+ SparseTensorEncodingAttr encDst = getSparseTensorEncoding(dstTp);
+ if (!encDst || !encSrc) {
+ return failure();
+ }
+
+ // Generate code to represent the static dimension constants or compute
+ // the dynamic dimension values.
+ SmallVector<Value, 4> srcSizes;
+ sizesForTensor(rewriter, srcSizes, loc, srcTp, srcTensor);
+ SmallVector<Value, 4> dstSizes;
+ SmallVector<Value, 4> dstDynSizes;
+ if (dstTp.hasStaticShape()) {
+ for (auto d : dstTp.getShape())
+ dstSizes.push_back(constantIndex(rewriter, loc, d));
+ } else {
+ ArrayRef<int64_t> dstShape = dstTp.getShape();
+ genReshapeDstShape(loc, rewriter, dstSizes, srcSizes, dstShape,
+ op.getReassociationIndices());
+ for (auto &d : llvm::enumerate(dstShape)) {
+ if (d.value() == ShapedType::kDynamicSize)
+ dstDynSizes.push_back(dstSizes[d.index()]);
+ }
+ }
+
+ // Implement the sparse2sparse reshape as follows:
+ // %tmp = bufferization.alloc_tensor : unordered COO
+ // foreach srcCoords %srcTensor
+ // insert translateIndicesArray(srcCoords), %tmp
+ // %t = sparse_tensor.cast %tmp
+ RankedTensorType cooTp = getUnorderedCOOFromType(dstTp);
+ auto cooBuffer =
+ rewriter.create<AllocTensorOp>(loc, cooTp, dstDynSizes).getResult();
+ rewriter.create<ForeachOp>(
+ loc, srcTensor, [&](OpBuilder &builder, Location loc, ValueRange args) {
+ SmallVector<Value, 4> srcIndices;
+ SmallVector<Value, 4> dstIndices;
+ for (int64_t i = 0, e = srcTp.getRank(); i < e; i++) {
+ uint64_t dim = toStoredDim(encSrc, i);
+ srcIndices.push_back(args[dim]);
+ }
+ translateIndicesArray(builder, loc, op.getReassociationIndices(),
+ srcIndices, srcSizes, dstSizes, dstIndices);
+ builder.create<InsertOp>(loc, args.back(), cooBuffer, dstIndices);
+ builder.create<sparse_tensor::YieldOp>(loc);
+ });
+
+ rewriter.replaceOpWithNewOp<ConvertOp>(op, dstTp, cooBuffer);
+ return success();
+ }
+};
+
+/// Sparse rewriting rule for sparse-to-dense and dense-to-sparse reshape
+/// operator.
template <typename ReshapeOp>
struct ReshapeRewriter : public OpRewritePattern<ReshapeOp> {
public:
@@ -437,7 +519,6 @@ struct ForeachRewriter : public OpRewritePattern<ForeachOp> {
//===---------------------------------------------------------------------===//
// Methods that add patterns described in this file to a pattern list.
//===---------------------------------------------------------------------===//
-
void mlir::populateSparseTensorRewriting(RewritePatternSet &patterns,
bool enableRT) {
patterns.add<FoldInvariantYield, FuseSparseMultiplyOverAdd,
@@ -446,5 +527,8 @@ void mlir::populateSparseTensorRewriting(RewritePatternSet &patterns,
patterns.getContext());
// TODO: If RT not enabled, rewrite concatenate ops, etc here.
if (!enableRT)
- patterns.add<ConcatenateRewriter>(patterns.getContext());
+ patterns.add<ConcatenateRewriter,
+ Sparse2SparseReshapeRewriter<tensor::ExpandShapeOp>,
+ Sparse2SparseReshapeRewriter<tensor::CollapseShapeOp>>(
+ patterns.getContext());
}
diff --git a/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir b/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir
index 5848cbb44a01..420d732ce62a 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir
@@ -1,5 +1,6 @@
// RUN: mlir-opt %s | mlir-opt | FileCheck %s --check-prefix=CHECK-ROUND
// RUN: mlir-opt %s --sparse-tensor-conversion --cse --canonicalize | FileCheck %s --check-prefix=CHECK-CONV
+// RUN: mlir-opt %s --sparse-tensor-rewrite=enable-runtime-library=false --cse --canonicalize | FileCheck %s --check-prefix=CHECK-RWT
#SparseVector = #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ] }>
#SparseMatrix = #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>
@@ -37,6 +38,29 @@
// CHECK-CONV: call @delSparseTensorCOOF64
// CHECK-CONV: return %[[N]] : !llvm.ptr<i8>
//
+// rewrite for codegen:
+//
+// CHECK-RWT-LABEL: func.func @sparse_expand(
+// CHECK-RWT-SAME: %[[S:.*]]:
+// CHECK-RWT-DAG: %[[C10:.*]] = arith.constant 10 : index
+// CHECK-RWT-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-RWT-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-RWT: %[[B:.*]] = bufferization.alloc_tensor()
+// CHECK-RWT: %[[P0:.*]] = sparse_tensor.pointers %[[S]] {dimension = 0 : index}
+// CHECK-RWT: %[[I0:.*]] = sparse_tensor.indices %[[S]] {dimension = 0 : index}
+// CHECK-RWT: %[[V:.*]] = sparse_tensor.values %[[S]]
+// CHECK-RWT: %[[S0:.*]] = memref.load %[[P0]]{{\[}}%[[C0]]] : memref<?xindex>
+// CHECK-RWT: %[[E0:.*]] = memref.load %[[P0]]{{\[}}%[[C1]]] : memref<?xindex>
+// CHECK-RWT: scf.for %[[I:.*]] = %[[S0]] to %[[E0]] step %[[C1]] {
+// CHECK-RWT: %[[SI:.*]] = memref.load %[[I0]]{{\[}}%[[I]]] : memref<?xindex>
+// CHECK-RWT: %[[SV:.*]] = memref.load %[[V]]{{\[}}%[[I]]] : memref<?xf64>
+// CHECK-RWT: %[[DI0:.*]] = arith.divui %[[SI]], %[[C10]] : index
+// CHECK-RWT: %[[DI1:.*]] = arith.remui %[[SI]], %[[C10]] : index
+// CHECK-RWT: sparse_tensor.insert %[[SV]] into %[[B]]{{\[}}%[[DI0]], %[[DI1]]]
+// CHECK-RWT: }
+// CHECK-RWT: %[[T:.*]] = sparse_tensor.convert %[[B]]
+// CHECK-RWT: return %[[T]] : tensor<10x10xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>
+//
func.func @sparse_expand(%arg0: tensor<100xf64, #SparseVector>) -> tensor<10x10xf64, #SparseMatrix> {
%0 = tensor.expand_shape %arg0 [[0, 1]] :
tensor<100xf64, #SparseVector> into tensor<10x10xf64, #SparseMatrix>
@@ -76,6 +100,37 @@ func.func @sparse_expand(%arg0: tensor<100xf64, #SparseVector>) -> tensor<10x10x
// CHECK-CONV: call @delSparseTensorCOOF64
// CHECK-CONV: return %[[N]] : !llvm.ptr<i8>
//
+// rewrite for codegen:
+//
+// CHECK-RWT-LABEL: func.func @sparse_collapse(
+// CHECK-RWT-SAME: %[[S:.*]]:
+// CHECK-RWT-DAG: %[[C10:.*]] = arith.constant 10 : index
+// CHECK-RWT-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-RWT-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-RWT: %[[B:.*]] = bufferization.alloc_tensor()
+// CHECK-RWT: %[[P0:.*]] = sparse_tensor.pointers %[[S]] {dimension = 0 : index}
+// CHECK-RWT: %[[I0:.*]] = sparse_tensor.indices %[[S]] {dimension = 0 : index}
+// CHECK-RWT: %[[P1:.*]] = sparse_tensor.pointers %[[S]] {dimension = 1 : index}
+// CHECK-RWT: %[[I1:.*]] = sparse_tensor.indices %[[S]] {dimension = 1 : index}
+// CHECK-RWT: %[[V:.*]] = sparse_tensor.values %[[S]]
+// CHECK-RWT: %[[S0:.*]] = memref.load %[[P0]]{{\[}}%[[C0]]] : memref<?xindex>
+// CHECK-RWT: %[[E0:.*]] = memref.load %[[P0]]{{\[}}%[[C1]]] : memref<?xindex>
+// CHECK-RWT: scf.for %[[I:.*]] = %[[S0]] to %[[E0]] step %[[C1]] {
+// CHECK-RWT: %[[SI0:.*]] = memref.load %[[I0]]{{\[}}%[[I]]] : memref<?xindex>
+// CHECK-RWT: %[[PE1:.*]] = arith.addi %[[I]], %[[C1]] : index
+// CHECK-RWT: %[[S1:.*]] = memref.load %[[P1]]{{\[}}%[[I]]] : memref<?xindex>
+// CHECK-RWT: %[[E1:.*]] = memref.load %[[P1]]{{\[}}%[[PE1]]] : memref<?xindex>
+// CHECK-RWT: scf.for %[[J:.*]] = %[[S1]] to %[[E1]] step %[[C1]] {
+// CHECK-RWT: %[[SI1:.*]] = memref.load %[[I1]]{{\[}}%[[J]]] : memref<?xindex>
+// CHECK-RWT: %[[SV:.*]] = memref.load %[[V]]{{\[}}%[[J]]] : memref<?xf64>
+// CHECK-RWT: %[[T:.*]] = arith.muli %[[SI0]], %[[C10]] : index
+// CHECK-RWT: %[[DI:.*]] = arith.addi %[[T]], %[[SI1]] : index
+// CHECK-RWT: sparse_tensor.insert %[[SV]] into %[[B]]{{\[}}%[[DI]]]
+// CHECK-RWT }
+// CHECK-RWT: }
+// CHECK-RWT: %[[T:.*]] = sparse_tensor.convert %[[B]]
+// CHECK-RWT: return %[[T]] : tensor<100xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ] }>>
+//
func.func @sparse_collapse(%arg0: tensor<10x10xf64, #SparseMatrix>) -> tensor<100xf64, #SparseVector> {
%0 = tensor.collapse_shape %arg0 [[0, 1]] :
tensor<10x10xf64, #SparseMatrix> into tensor<100xf64, #SparseVector>
@@ -120,6 +175,35 @@ func.func @sparse_collapse(%arg0: tensor<10x10xf64, #SparseMatrix>) -> tensor<10
// CHECK-CONV: call @delSparseTensorCOOF64
// CHECK-CONV: return %[[N]] : !llvm.ptr<i8>
//
+// rewrite for codegen:
+//
+// CHECK-RWT-LABEL: func.func @dynamic_sparse_expand(
+// CHECK-RWT-SAME: %[[S:.*]]:
+// CHECK-RWT-DAG: %[[C10:.*]] = arith.constant 10 : index
+// CHECK-RWT-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-RWT-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-RWT: %[[SD:.*]] = tensor.dim %[[S]], %[[C0]]
+// CHECK-RWT: %[[DD0:.*]] = arith.divui %[[SD]], %[[C10]] : index
+// CHECK-RWT: %[[B:.*]] = bufferization.alloc_tensor(%[[DD0]])
+// CHECK-RWT: %[[P0:.*]] = sparse_tensor.pointers %[[S]] {dimension = 0 : index}
+// CHECK-RWT: %[[I0:.*]] = sparse_tensor.indices %[[S]] {dimension = 0 : index}
+// CHECK-RWT: %[[V:.*]] = sparse_tensor.values %[[S]]
+// CHECK-RWT: %[[S0:.*]] = memref.load %[[P0]]{{\[}}%[[C0]]] : memref<?xindex>
+// CHECK-RWT: %[[E0:.*]] = memref.load %[[P0]]{{\[}}%[[C1]]] : memref<?xindex>
+// CHECK-RWT: scf.for %[[I:.*]] = %[[S0]] to %[[E0]] step %[[C1]] {
+// CHECK-RWT: %[[SI:.*]] = memref.load %[[I0]]{{\[}}%[[I]]] : memref<?xindex>
+// CHECK-RWT: %[[SV:.*]] = memref.load %[[V]]{{\[}}%[[I]]] : memref<?xf64>
+// CHECK-RWT: %[[T1:.*]] = arith.muli %[[DD0]], %[[C10]] : index
+// CHECK-RWT: %[[T2:.*]] = arith.divui %[[T1]], %[[DD0]] : index
+// CHECK-RWT: %[[DI0:.*]] = arith.divui %[[SI]], %[[T2]] : index
+// CHECK-RWT: %[[T3:.*]] = arith.remui %[[SI]], %[[T2]] : index
+// CHECK-RWT: %[[T4:.*]] = arith.divui %[[T2]], %[[C10]] : index
+// CHECK-RWT: %[[DI1:.*]] = arith.divui %[[T3]], %[[T4]] : index
+// CHECK-RWT: sparse_tensor.insert %[[SV]] into %[[B]]{{\[}}%[[DI0]], %[[DI1]]]
+// CHECK-RWT: }
+// CHECK-RWT: %[[T:.*]] = sparse_tensor.convert %[[B]]
+// CHECK-RWT: return %[[T]] : tensor<?x10xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>
+//
func.func @dynamic_sparse_expand(%arg0: tensor<?xf64, #SparseVector>) -> tensor<?x10xf64, #SparseMatrix> {
%0 = tensor.expand_shape %arg0 [[0, 1]] :
tensor<?xf64, #SparseVector> into tensor<?x10xf64, #SparseMatrix>
@@ -163,6 +247,42 @@ func.func @dynamic_sparse_expand(%arg0: tensor<?xf64, #SparseVector>) -> tensor<
// CHECK-CONV: call @delSparseTensorCOOF64
// CHECK-CONV: return %[[N]] : !llvm.ptr<i8>
//
+// rewrite for codegen:
+//
+// CHECK-RWT-LABEL: func.func @dynamic_sparse_collapse(
+// CHECK-RWT-SAME: %[[S:.*]]:
+// CHECK-RWT-DAG: %[[C10:.*]] = arith.constant 10 : index
+// CHECK-RWT-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-RWT-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-RWT: %[[SD1:.*]] = tensor.dim %[[S]], %[[C1]]
+// CHECK-RWT: %[[DD0:.*]] = arith.muli %[[SD1]], %[[C10]] : index
+// CHECK-RWT: %[[B:.*]] = bufferization.alloc_tensor(%[[DD0]])
+// CHECK-RWT: %[[P0:.*]] = sparse_tensor.pointers %[[S]] {dimension = 0 : index}
+// CHECK-RWT: %[[I0:.*]] = sparse_tensor.indices %[[S]] {dimension = 0 : index}
+// CHECK-RWT: %[[P1:.*]] = sparse_tensor.pointers %[[S]] {dimension = 1 : index}
+// CHECK-RWT: %[[I1:.*]] = sparse_tensor.indices %[[S]] {dimension = 1 : index}
+// CHECK-RWT: %[[V:.*]] = sparse_tensor.values %[[S]]
+// CHECK-RWT: %[[S0:.*]] = memref.load %[[P0]]{{\[}}%[[C0]]] : memref<?xindex>
+// CHECK-RWT: %[[E0:.*]] = memref.load %[[P0]]{{\[}}%[[C1]]] : memref<?xindex>
+// CHECK-RWT: scf.for %[[I:.*]] = %[[S0]] to %[[E0]] step %[[C1]] {
+// CHECK-RWT: %[[SI0:.*]] = memref.load %[[I0]]{{\[}}%[[I]]] : memref<?xindex>
+// CHECK-RWT: %[[PE1:.*]] = arith.addi %[[I]], %[[C1]] : index
+// CHECK-RWT: %[[S1:.*]] = memref.load %[[P1]]{{\[}}%[[I]]] : memref<?xindex>
+// CHECK-RWT: %[[E1:.*]] = memref.load %[[P1]]{{\[}}%[[PE1]]] : memref<?xindex>
+// CHECK-RWT: scf.for %[[J:.*]] = %[[S1]] to %[[E1]] step %[[C1]] {
+// CHECK-RWT: %[[SI1:.*]] = memref.load %[[I1]]{{\[}}%[[J]]] : memref<?xindex>
+// CHECK-RWT: %[[SV:.*]] = memref.load %[[V]]{{\[}}%[[J]]] : memref<?xf64>
+// CHECK-RWT: %[[T1:.*]] = arith.divui %[[DD0]], %[[C10]] : index
+// CHECK-RWT: %[[T2:.*]] = arith.muli %[[SI0]], %[[T1]] : index
+// CHECK-RWT: %[[T3:.*]] = arith.divui %[[T1]], %[[SD1]] : index
+// CHECK-RWT: %[[T4:.*]] = arith.muli %[[SI1]], %[[T3]] : index
+// CHECK-RWT: %[[DI:.*]] = arith.addi %[[T2]], %[[T4]] : index
+// CHECK-RWT: sparse_tensor.insert %[[SV]] into %[[B]]{{\[}}%[[DI]]]
+// CHECK-RWT }
+// CHECK-RWT: }
+// CHECK-RWT: %[[T:.*]] = sparse_tensor.convert %[[B]]
+// CHECK-RWT: return %[[T]] : tensor<?xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ] }>>
+//
func.func @dynamic_sparse_collapse(%arg0: tensor<10x?xf64, #SparseMatrix>) -> tensor<?xf64, #SparseVector> {
%0 = tensor.collapse_shape %arg0 [[0, 1]] :
tensor<10x?xf64, #SparseMatrix> into tensor<?xf64, #SparseVector>
More information about the Mlir-commits
mailing list