[Mlir-commits] [mlir] 6116ca6 - [mlir][sparse] Add sparse rewriting rules for tensor::ReshapeOp

Anlun Xu llvmlistbot at llvm.org
Tue May 16 14:56:55 PDT 2023

Author: Anlun Xu
Date: 2023-05-16T14:56:33-07:00
New Revision: 6116ca67ab0dc8c4ed6756b0e45bc7100efef0ea

URL: https://github.com/llvm/llvm-project/commit/6116ca67ab0dc8c4ed6756b0e45bc7100efef0ea
DIFF: https://github.com/llvm/llvm-project/commit/6116ca67ab0dc8c4ed6756b0e45bc7100efef0ea.diff

LOG: [mlir][sparse] Add sparse rewriting rules for tensor::ReshapeOp

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D149564




diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index ca27794b64c1f..a16ab660e931f 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -385,6 +385,106 @@ struct FuseTensorCast : public OpRewritePattern<tensor::CastOp> {
+/// Sparse rewriting rule for sparse-to-sparse reshape operator.
+struct TensorReshapeRewriter : public OpRewritePattern<tensor::ReshapeOp> {
+  using OpRewritePattern<tensor::ReshapeOp>::OpRewritePattern;
+  LogicalResult matchAndRewrite(tensor::ReshapeOp op,
+                                PatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    Value srcTensor = op.getSource();
+    const auto srcTp = getSparseTensorType(srcTensor);
+    const auto dstTp = getSparseTensorType(op.getResult());
+    if (!srcTp.hasEncoding() || !dstTp.hasEncoding() ||
+        !dstTp.hasStaticDimShape())
+      return failure();
+    SmallVector<Value> srcSizes;
+    sizesForTensor(rewriter, srcSizes, loc, srcTp, srcTensor);
+    SmallVector<Value> dstSizes;
+    for (Dimension d : dstTp.getDimShape())
+      dstSizes.push_back(constantIndex(rewriter, loc, d));
+    Value nnz = rewriter.create<NumberOfEntriesOp>(loc, srcTensor);
+    // Only need an unordered COO buffer if input and output are not sorted
+    // in the same way.
+    Type bufferTp =
+        srcTp.isAllOrdered() && srcTp.isIdentity() && dstTp.isIdentity()
+            ? dstTp.getRankedTensorType()
+            : getUnorderedCOOFromType(dstTp);
+    SmallVector<Value> dynSizes;
+    Value buffer = rewriter
+                       .create<AllocTensorOp>(loc, bufferTp, dynSizes, Value(),
+                                              nnz, Attribute())
+                       .getResult();
+    // Convert src coordinates to dst coordinates by first collapsing it to 1D
+    // and then expand it to the match the rank of the destination tensor.
+    // Implemented as follows:
+    //   foreach srcCoords %srcTensor
+    //     collapsedCoords = reshapeCvs(srcCoords, [1, ..., srcRank])
+    //     expandedCoords = reshapeCvs(collapsedCoords, [1, ..., dstRank])
+    //     insert expandedCoords, %buffer
+    //
+    // followed by an optional
+    //   %t = sparse_tensor.cast %tmp
+    // depending on whether the input/output are sorted in the same way.
+    const auto encSrc = srcTp.getEncoding();
+    ForeachOp foreachOp = rewriter.create<ForeachOp>(
+        loc, srcTensor, buffer,
+        [&](OpBuilder &builder, Location loc, ValueRange srcLcvs, Value v,
+            ValueRange reduc) {
+          const Dimension srcRank = srcTp.getDimRank();
+          SmallVector<Value> srcDcvs;
+          srcDcvs.reserve(srcRank);
+          for (Dimension d = 0; d < srcRank; d++) {
+            // FIXME: `toStoredDim` is deprecated
+            Level lvl = toStoredDim(encSrc, d);
+            srcDcvs.push_back(srcLcvs[lvl]);
+          }
+          Value collapsed_size = constantIndex(builder, loc, 1);
+          for (Dimension d = 0; d < srcRank; d++)
+            collapsed_size =
+                builder.create<arith::MulIOp>(loc, collapsed_size, srcSizes[d]);
+          SmallVector<Value, 1> collapsedSizes = {collapsed_size};
+          ReassociationIndices collapse_indices;
+          for (Dimension i = 0; i < srcRank; i++)
+            collapse_indices.push_back(i);
+          SmallVector<ReassociationIndices, 1> collapse_reassociation = {
+              collapse_indices};
+          SmallVector<Value, 1> collapsedDcvs;
+          reshapeCvs(builder, loc, collapse_reassociation, srcSizes, srcDcvs,
+                     collapsedSizes, collapsedDcvs);
+          ReassociationIndices expand_indices;
+          for (Dimension i = 0; i < dstTp.getDimRank(); i++)
+            expand_indices.push_back(i);
+          SmallVector<ReassociationIndices, 1> expand_reassociation = {
+              expand_indices};
+          SmallVector<Value> dstDcvs;
+          reshapeCvs(builder, loc, expand_reassociation, collapsedSizes,
+                     collapsedDcvs, dstSizes, dstDcvs);
+          auto t = builder.create<InsertOp>(loc, v, reduc.front(), dstDcvs);
+          builder.create<sparse_tensor::YieldOp>(loc, t);
+        });
+    Value t = rewriter.create<LoadOp>(loc, foreachOp.getResult(0), true);
+    if (bufferTp != dstTp) {
+      auto dstRTT = dstTp.getRankedTensorType();
+      Value converted = rewriter.create<ConvertOp>(loc, dstRTT, t).getResult();
+      rewriter.create<DeallocTensorOp>(loc, t);
+      t = converted;
+    }
+    rewriter.replaceOp(op, t);
+    return success();
+  }
 /// Sparse rewriting rule for sparse-to-sparse reshape operator.
 template <typename ReshapeOp>
 struct Sparse2SparseReshapeRewriter : public OpRewritePattern<ReshapeOp> {
@@ -1169,7 +1269,8 @@ void mlir::populatePostSparsificationRewriting(RewritePatternSet &patterns,
                                                bool enableForeach,
                                                bool enableConvert) {
-               ReshapeRewriter<tensor::CollapseShapeOp>>(patterns.getContext());
+               ReshapeRewriter<tensor::CollapseShapeOp>, TensorReshapeRewriter>(
+      patterns.getContext());
   if (enableForeach)
   // TODO: If RT not enabled, rewrite concatenate ops, etc here.

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_tensor_reshape.mlir b/mlir/test/Dialect/SparseTensor/sparse_tensor_reshape.mlir
new file mode 100644
index 0000000000000..369044c38f764
--- /dev/null
+++ b/mlir/test/Dialect/SparseTensor/sparse_tensor_reshape.mlir
@@ -0,0 +1,45 @@
+// RUN: mlir-opt %s --post-sparsification-rewrite="enable-runtime-library=false enable-convert=false" \
+// RUN: --cse --canonicalize  | FileCheck %s
+#SparseMatrix = #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>
+// CHECK:         func.func @sparse_reshape(
+// CHECK-SAME:    %[[S:.*]]:
+// CHECK-DAG:     %[[C25:.*]] = arith.constant 25 : index
+// CHECK-DAG:     %[[C10:.*]] = arith.constant 10 : index
+// CHECK-DAG:     %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:     %[[C1:.*]] = arith.constant 1 : index
+// CHECK:         %[[B:.*]] = bufferization.alloc_tensor()
+// CHECK:         %[[P0:.*]] = sparse_tensor.positions %[[S]] {level = 0 : index}
+// CHECK:         %[[I0:.*]] = sparse_tensor.coordinates %[[S]] {level = 0 : index}
+// CHECK:         %[[P1:.*]] = sparse_tensor.positions %[[S]] {level = 1 : index}
+// CHECK:         %[[I1:.*]] = sparse_tensor.coordinates %[[S]] {level = 1 : index}
+// CHECK:         %[[V:.*]] = sparse_tensor.values %[[S]]
+// CHECK:         %[[S0:.*]] = memref.load %[[P0]]{{\[}}%[[C0]]] : memref<?xindex>
+// CHECK:         %[[E0:.*]] = memref.load %[[P0]]{{\[}}%[[C1]]] : memref<?xindex>
+// CHECK:         %[[RET:.*]] = scf.for %[[I:.*]] = %[[S0]] to %[[E0]] step %[[C1]] iter_args(%[[A0:.*]] = %[[B]])
+// CHECK:           %[[SI0:.*]] = memref.load %[[I0]]{{\[}}%[[I]]] : memref<?xindex>
+// CHECK-DAG:       %[[S1:.*]] = memref.load %[[P1]]{{\[}}%[[I]]] : memref<?xindex>
+// CHECK-DAG:       %[[PE1:.*]] = arith.addi %[[I]], %[[C1]] : index
+// CHECK:           %[[E1:.*]] = memref.load %[[P1]]{{\[}}%[[PE1]]] : memref<?xindex>
+// CHECK:           %[[RET_1:.*]] = scf.for %[[J:.*]] = %[[S1]] to %[[E1]] step %[[C1]] iter_args(%[[A1:.*]] = %[[A0]])
+// CHECK:             %[[SI1:.*]] = memref.load %[[I1]]{{\[}}%[[J]]] : memref<?xindex>
+// CHECK:             %[[SV:.*]] = memref.load %[[V]]{{\[}}%[[J]]] : memref<?xf64>
+// CHECK:             %[[T:.*]] = arith.muli %[[SI0]], %[[C25]] : index
+// CHECK:             %[[DI:.*]] = arith.addi %[[T]], %[[SI1]] : index
+// CHECK:             %[[D:.*]] = arith.divui %[[DI]], %[[C10]] : index
+// CHECK:             %[[R:.*]] = arith.remui %[[DI]], %[[C10]] : index
+// CHECK:             %[[R1:.*]] = sparse_tensor.insert %[[SV]] into %[[A1]]{{\[}}%[[D]], %[[R]]]
+// CHECK:              scf.yield %[[R1]]
+// CHECK:            }
+// CHECK:            scf.yield %[[RET_1]]
+// CHECK:         }
+// CHECK:        %[[NT1:.*]] = sparse_tensor.load %[[RET]] hasInserts
+// CHECK:        return %[[NT1]] : tensor<10x10xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>
+func.func @sparse_reshape(%arg0: tensor<4x25xf64, #SparseMatrix>) -> tensor<10x10xf64, #SparseMatrix> {
+  %shape = arith.constant dense <[ 10, 10 ]> : tensor<2xi32>
+  %0 = tensor.reshape %arg0(%shape) :
+    (tensor<4x25xf64, #SparseMatrix>, tensor<2xi32>) -> tensor<10x10xf64, #SparseMatrix>
+  return %0 : tensor<10x10xf64, #SparseMatrix>
\ No newline at end of file

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_reshape.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_reshape.mlir
new file mode 100644
index 0000000000000..4945294f727e7
--- /dev/null
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_reshape.mlir
@@ -0,0 +1,87 @@
+// DEFINE: %{option} = enable-runtime-library=true
+// DEFINE: %{compile} = mlir-opt %s --sparse-compiler=%{option}
+// DEFINE: %{run} = mlir-cpu-runner \
+// DEFINE:  -e entry -entry-point-result=void  \
+// DEFINE:  -shared-libs=%mlir_c_runner_utils | \
+// DEFINE: FileCheck %s
+// RUN: %{compile} | %{run}
+// Do the same run, but now with direct IR generation.
+// REDEFINE: %{option} = enable-runtime-library=false
+// RUN: %{compile} | %{run}
+// Do the same run, but now with direct IR generation and vectorization.
+// REDEFINE: %{option} = "enable-runtime-library=false vl=2 reassociate-fp-reductions=true enable-index-optimizations=true"
+// RUN: %{compile} | %{run}
+#SparseVector = #sparse_tensor.encoding<{
+  dimLevelType = ["compressed"]
+#SparseMatrix = #sparse_tensor.encoding<{
+  dimLevelType = ["compressed", "compressed"]
+#Sparse3dTensor = #sparse_tensor.encoding<{
+  dimLevelType = ["compressed", "compressed", "compressed"]
+module {
+  func.func @reshape0(%arg0: tensor<3x4xf64, #SparseMatrix>) -> tensor<2x6xf64, #SparseMatrix> {
+    %shape = arith.constant dense <[ 2, 6 ]> : tensor<2xi32>
+    %0 = tensor.reshape %arg0(%shape) : (tensor<3x4xf64, #SparseMatrix>, tensor<2xi32>) -> tensor<2x6xf64, #SparseMatrix>
+    return %0 : tensor<2x6xf64, #SparseMatrix>
+  }
+  func.func @reshape1(%arg0: tensor<3x4xf64, #SparseMatrix>) -> tensor<12xf64, #SparseVector> {
+    %shape = arith.constant dense <[ 12 ]> : tensor<1xi32>
+    %0 = tensor.reshape %arg0(%shape) : (tensor<3x4xf64, #SparseMatrix>, tensor<1xi32>) -> tensor<12xf64, #SparseVector>
+    return %0 : tensor<12xf64, #SparseVector>
+  }
+  func.func @reshape2(%arg0: tensor<3x4xf64, #SparseMatrix>) -> tensor<2x3x2xf64, #Sparse3dTensor> {
+    %shape = arith.constant dense <[ 2, 3, 2 ]> : tensor<3xi32>
+    %0 = tensor.reshape %arg0(%shape) : (tensor<3x4xf64, #SparseMatrix>, tensor<3xi32>) -> tensor<2x3x2xf64, #Sparse3dTensor>
+    return %0 : tensor<2x3x2xf64, #Sparse3dTensor>
+  }
+  func.func @entry() {
+    %m = arith.constant dense <[ [ 1.1,  0.0,  1.3,  0.0 ],
+                                 [ 2.1,  0.0,  2.3,  0.0 ],
+                                 [ 3.1,  0.0,  3.3,  0.0 ]]> : tensor<3x4xf64>
+    %sm = sparse_tensor.convert %m : tensor<3x4xf64> to tensor<3x4xf64, #SparseMatrix>
+    %reshaped0 = call @reshape0(%sm) : (tensor<3x4xf64, #SparseMatrix>) -> tensor<2x6xf64, #SparseMatrix>
+    %reshaped1 = call @reshape1(%sm) : (tensor<3x4xf64, #SparseMatrix>) -> tensor<12xf64, #SparseVector>
+    %reshaped2 = call @reshape2(%sm) : (tensor<3x4xf64, #SparseMatrix>) -> tensor<2x3x2xf64, #Sparse3dTensor>
+    %c0 = arith.constant 0 : index
+    %df = arith.constant -1.0 : f64
+    // CHECK: ( 1.1, 1.3, 2.1, 2.3, 3.1, 3.3
+    %b0 = sparse_tensor.values %reshaped0: tensor<2x6xf64, #SparseMatrix> to memref<?xf64>
+    %v0 = vector.transfer_read %b0[%c0], %df: memref<?xf64>, vector<12xf64>
+    vector.print %v0 : vector<12xf64>
+    // CHECK: ( 1.1, 1.3, 2.1, 2.3, 3.1, 3.3
+    %b1 = sparse_tensor.values %reshaped1: tensor<12xf64, #SparseVector> to memref<?xf64>
+    %v1 = vector.transfer_read %b1[%c0], %df: memref<?xf64>, vector<12xf64>
+    vector.print %v1 : vector<12xf64>
+    // CHECK: ( 1.1, 1.3, 2.1, 2.3, 3.1, 3.3
+    %b2 = sparse_tensor.values %reshaped2: tensor<2x3x2xf64, #Sparse3dTensor> to memref<?xf64>
+    %v2 = vector.transfer_read %b2[%c0], %df: memref<?xf64>, vector<12xf64>
+    vector.print %v2: vector<12xf64>
+    bufferization.dealloc_tensor %sm : tensor<3x4xf64, #SparseMatrix>
+    bufferization.dealloc_tensor %reshaped0 : tensor<2x6xf64, #SparseMatrix>
+    bufferization.dealloc_tensor %reshaped1 : tensor<12xf64, #SparseVector>
+    bufferization.dealloc_tensor %reshaped2 : tensor<2x3x2xf64, #Sparse3dTensor>
+    return
+  }
\ No newline at end of file


More information about the Mlir-commits mailing list