[Mlir-commits] [mlir] dc6427d - [mlir][sparse] implement lowering rules for sparse_tensor::unpack

Peiming Liu llvmlistbot at llvm.org
Fri Feb 10 17:05:51 PST 2023


Author: Peiming Liu
Date: 2023-02-11T01:05:46Z
New Revision: dc6427d687c4640e33c9bdab0c888b6f7d0569be

URL: https://github.com/llvm/llvm-project/commit/dc6427d687c4640e33c9bdab0c888b6f7d0569be
DIFF: https://github.com/llvm/llvm-project/commit/dc6427d687c4640e33c9bdab0c888b6f7d0569be.diff

LOG: [mlir][sparse] implement lowering rules for sparse_tensor::unpack

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D143672

Added: 
    

Modified: 
    mlir/lib/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
    mlir/test/Dialect/SparseTensor/sparse_pack.mlir
    mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.cpp
index 6a5a85fb8271..641b429f300a 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -143,6 +143,37 @@ struct PackOpInterface
     // ones when packing into a COO format.
     return {{op->getOpResult(0), BufferRelation::Equivalent}};
   }
+
+  BufferRelation bufferRelation(Operation *oo, OpResult opResult,
+                                const AnalysisState &state) const {
+    return BufferRelation::Unknown;
+  }
+};
+
+struct UnpackOpInterface
+    : public BufferizableOpInterface::ExternalModel<UnpackOpInterface,
+                                                    sparse_tensor::UnpackOp> {
+  bool bufferizesToAllocation(Operation *op, OpResult opResult) const {
+    // Similar to InsertOp, reallocation is not considered to allocate a new
+    // piece of memory.
+    return false;
+  }
+
+  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
+                              const AnalysisState &state) const {
+    return true;
+  }
+
+  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
+                               const AnalysisState &state) const {
+    return false;
+  }
+
+  AliasingOpResultList getAliasingOpResults(Operation *op, OpOperand &opOperand,
+                                            const AnalysisState &state) const {
+    // Conceptually, UnpackOp equals to a list of toIndices/toValueOp
+    return {};
+  }
 };
 
 struct InsertOpInterface
@@ -285,6 +316,8 @@ void mlir::sparse_tensor::registerBufferizableOpInterfaceExternalModels(
     sparse_tensor::InsertOp::attachInterface<InsertOpInterface>(*ctx);
     sparse_tensor::NumberOfEntriesOp::attachInterface<
         NumberOfEntriesOpInterface>(*ctx);
+    sparse_tensor::PackOp::attachInterface<PackOpInterface>(*ctx);
+    sparse_tensor::UnpackOp::attachInterface<UnpackOpInterface>(*ctx);
     sparse_tensor::ToIndicesBufferOp::attachInterface<
         ToIndicesBufferOpInterface>(*ctx);
     sparse_tensor::ToIndicesOp::attachInterface<ToIndicesOpInterface>(*ctx);

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index b2541c38a30b..17319b7ffa2a 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -1154,6 +1154,53 @@ struct SparsePackOpConverter : public OpConversionPattern<PackOp> {
   }
 };
 
+struct SparseUnpackOpConverter : public OpConversionPattern<UnpackOp> {
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(UnpackOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
+    Location loc = op.getLoc();
+    int64_t rank = op.getTensor().getType().getRank();
+
+    assert(isUniqueCOOType(op.getTensor().getType()) &&
+           desc.getFields().size() == 4);
+
+    Value flatBuf = rank == 1 ? desc.getIdxMemRefOrView(rewriter, loc, 0)
+                              : desc.getAOSMemRef();
+    Value dataBuf = desc.getValMemRef();
+
+    // If frontend requests a static buffer, we reallocate the data/indices
+    // to ensure that we meet their need.
+    TensorType dataTp = op.getData().getType();
+    if (dataTp.hasStaticShape()) {
+      dataBuf = rewriter.create<memref::ReallocOp>(
+          loc, MemRefType::get(dataTp.getShape(), dataTp.getElementType()),
+          dataBuf);
+    }
+
+    TensorType indicesTp = op.getIndices().getType();
+    if (indicesTp.hasStaticShape()) {
+      auto len = indicesTp.getShape()[0] * indicesTp.getShape()[1];
+      flatBuf = rewriter.create<memref::ReallocOp>(
+          loc, MemRefType::get({len}, indicesTp.getElementType()), flatBuf);
+    }
+
+    Value idxBuf = rewriter.create<memref::ExpandShapeOp>(
+        loc, MemRefType::get(indicesTp.getShape(), indicesTp.getElementType()),
+        flatBuf, ArrayRef{ReassociationIndices{0, 1}});
+
+    // Converts MemRefs back to Tensors.
+    Value data = rewriter.create<bufferization::ToTensorOp>(loc, dataBuf);
+    Value indices = rewriter.create<bufferization::ToTensorOp>(loc, idxBuf);
+    Value nnz = toType(rewriter, loc, desc.getValMemSize(rewriter, loc),
+                       op.getNnz().getType());
+
+    rewriter.replaceOp(op, {data, indices, nnz});
+    return success();
+  }
+};
+
 } // namespace
 
 //===----------------------------------------------------------------------===//
@@ -1165,15 +1212,16 @@ struct SparsePackOpConverter : public OpConversionPattern<PackOp> {
 void mlir::populateSparseTensorCodegenPatterns(
     TypeConverter &typeConverter, RewritePatternSet &patterns,
     bool enableBufferInitialization) {
-  patterns.add<SparsePackOpConverter, SparseReturnConverter,
-               SparseCallConverter, SparseDimOpConverter, SparseCastConverter,
-               SparseTensorDeallocConverter, SparseExtractSliceCoverter,
-               SparseTensorLoadConverter, SparseExpandConverter,
-               SparseCompressConverter, SparseInsertConverter,
-               SparseToPointersConverter, SparseToIndicesConverter,
-               SparseToIndicesBufferConverter, SparseToValuesConverter,
-               SparseConvertConverter, SparseNumberOfEntriesConverter>(
-      typeConverter, patterns.getContext());
+  patterns.add<SparsePackOpConverter, SparseUnpackOpConverter,
+               SparseReturnConverter, SparseCallConverter, SparseDimOpConverter,
+               SparseCastConverter, SparseTensorDeallocConverter,
+               SparseExtractSliceCoverter, SparseTensorLoadConverter,
+               SparseExpandConverter, SparseCompressConverter,
+               SparseInsertConverter, SparseToPointersConverter,
+               SparseToIndicesConverter, SparseToIndicesBufferConverter,
+               SparseToValuesConverter, SparseConvertConverter,
+               SparseNumberOfEntriesConverter>(typeConverter,
+                                               patterns.getContext());
   patterns.add<SparseTensorAllocConverter>(typeConverter, patterns.getContext(),
                                            enableBufferInitialization);
 }

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_pack.mlir b/mlir/test/Dialect/SparseTensor/sparse_pack.mlir
index 6dbf258c7e7b..eeb41fe7d00a 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_pack.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_pack.mlir
@@ -37,3 +37,23 @@ func.func @sparse_pack(%data: tensor<6xf64>, %index: tensor<6x2xi32>)
                                        to tensor<100x100xf64, #COO>
   return %0 : tensor<100x100xf64, #COO>
 }
+
+// CHECK-LABEL:   func.func @sparse_unpack(
+// CHECK-SAME:      %[[VAL_0:.*]]: memref<?xindex>,
+// CHECK-SAME:      %[[VAL_1:.*]]: memref<?xi32>,
+// CHECK-SAME:      %[[VAL_2:.*]]: memref<?xf64>,
+// CHECK-SAME:      %[[VAL_3:.*]]: !sparse_tensor.storage_specifier
+// CHECK:           %[[VAL_4:.*]] = memref.realloc %[[VAL_2]] : memref<?xf64> to memref<6xf64>
+// CHECK:           %[[VAL_5:.*]] = memref.realloc %[[VAL_1]] : memref<?xi32> to memref<12xi32>
+// CHECK:           %[[VAL_6:.*]] = memref.expand_shape %[[VAL_5]] {{\[\[}}0, 1]] : memref<12xi32> into memref<6x2xi32>
+// CHECK:           %[[VAL_7:.*]] = bufferization.to_tensor %[[VAL_4]] : memref<6xf64>
+// CHECK:           %[[VAL_8:.*]] = bufferization.to_tensor %[[VAL_6]] : memref<6x2xi32>
+// CHECK:           %[[VAL_9:.*]] = sparse_tensor.storage_specifier.get %[[VAL_3]]  val_mem_sz
+// CHECK:           %[[VAL_10:.*]] = arith.index_cast %[[VAL_9]] : i32 to index
+// CHECK:           return %[[VAL_7]], %[[VAL_8]], %[[VAL_10]] : tensor<6xf64>, tensor<6x2xi32>, index
+// CHECK:         }
+func.func @sparse_unpack(%sp: tensor<100x100xf64, #COO>) -> (tensor<6xf64>, tensor<6x2xi32>, index) {
+  %d, %i, %nnz = sparse_tensor.unpack %sp : tensor<100x100xf64, #COO>
+                                         to tensor<6xf64>, tensor<6x2xi32>, index
+  return %d, %i, %nnz : tensor<6xf64>, tensor<6x2xi32>, index
+}

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir
index 35d9cc6409d6..b20c5330397e 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir
@@ -36,6 +36,9 @@ module {
   // Main driver.
   //
   func.func @entry() {
+    %c0 = arith.constant 0 : index
+    %f0 = arith.constant 0.0 : f64
+    %i0 = arith.constant 0 : i32
     //
     // Initialize a 3-dim dense tensor.
     //
@@ -95,6 +98,23 @@ module {
         vector.print %v: f64
      }
 
+
+    %d, %i, %n = sparse_tensor.unpack %s5 : tensor<10x10xf64, #SortedCOOI32>
+                                         to tensor<3xf64>, tensor<3x2xi32>, i32
+
+
+
+    // CHECK-NEXT: ( 1, 2, 3 )
+    %vd = vector.transfer_read %d[%c0], %f0 : tensor<3xf64>, vector<3xf64>
+    vector.print %vd : vector<3xf64>
+
+    // CHECK-NEXT: ( ( 1, 2 ), ( 5, 6 ), ( 7, 8 ) )
+    %vi = vector.transfer_read %i[%c0, %c0], %i0 : tensor<3x2xi32>, vector<3x2xi32>
+    vector.print %vi : vector<3x2xi32>
+
+    // CHECK-NEXT: 3
+    vector.print %n : i32
+
     return
   }
 }


        


More information about the Mlir-commits mailing list