[Mlir-commits] [mlir] a41672e - [mlir][sparse] implement lowering rules for sparse_tensor.pack operation

Peiming Liu llvmlistbot at llvm.org
Fri Feb 3 15:51:42 PST 2023


Author: Peiming Liu
Date: 2023-02-03T23:51:36Z
New Revision: a41672e16a25a8180a9b9edcb3f505ab2d5cace8

URL: https://github.com/llvm/llvm-project/commit/a41672e16a25a8180a9b9edcb3f505ab2d5cace8
DIFF: https://github.com/llvm/llvm-project/commit/a41672e16a25a8180a9b9edcb3f505ab2d5cace8.diff

LOG: [mlir][sparse] implement lowering rules for sparse_tensor.pack operation

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D143230

Added: 
    mlir/test/Dialect/SparseTensor/sparse_pack.mlir
    mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir

Modified: 
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index 1fba7d8d74f73..9f3388d61f3f2 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -1021,6 +1021,98 @@ class SparseNumberOfEntriesConverter
   }
 };
 
+struct SparsePackOpConverter : public OpConversionPattern<PackOp> {
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(PackOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+
+    auto rtp = op.getResult().getType().cast<RankedTensorType>();
+    assert(isUniqueCOOType(rtp));
+
+    SmallVector<Value> fields;
+    Location loc = op.getLoc();
+
+    foreachFieldAndTypeInSparseTensor(
+        rtp,
+        [&rewriter, &fields, &op, rtp,
+         loc](Type fType, unsigned fIdx, SparseTensorFieldKind fKind,
+              unsigned /*dim*/, DimLevelType /*dlt*/) -> bool {
+          assert(fields.size() == fIdx);
+          auto enc = getSparseTensorEncoding(rtp);
+          Value field;
+          switch (fKind) {
+          case SparseTensorFieldKind::StorageSpec:
+            field = SparseTensorSpecifier::getInitValue(rewriter, loc, rtp);
+            break;
+          case SparseTensorFieldKind::PtrMemRef: {
+            // TACO-style COO starts with a PtrBuffer
+            // By creating a constant value for it, we avoid the complexity of
+            // memory management.
+            auto tensorType = RankedTensorType::get({2}, enc.getPointerType());
+            auto memrefType = MemRefType::get(tensorType.getShape(),
+                                              tensorType.getElementType());
+            auto cstPtr = rewriter.create<arith::ConstantOp>(
+                loc, tensorType,
+                DenseElementsAttr::get(
+                    tensorType,
+                    {APInt(64, 0),
+                     APInt(64, op.getData().getType().getShape()[0])}));
+            field = rewriter.create<bufferization::ToMemrefOp>(loc, memrefType,
+                                                               cstPtr);
+            break;
+          }
+          case SparseTensorFieldKind::IdxMemRef: {
+            auto tensorType = op.getIndices().getType();
+            auto memrefType = MemRefType::get(tensorType.getShape(),
+                                              tensorType.getElementType());
+            auto idxMemRef = rewriter.create<bufferization::ToMemrefOp>(
+                op->getLoc(), memrefType, op.getIndices());
+            ReassociationIndices reassociation;
+            for (int i = 0, e = tensorType.getRank(); i < e; i++)
+              reassociation.push_back(i);
+
+            // Flattened the indices buffer to rank 1.
+            field = rewriter.create<memref::CollapseShapeOp>(
+                loc, idxMemRef, ArrayRef<ReassociationIndices>(reassociation));
+            break;
+          }
+          case SparseTensorFieldKind::ValMemRef: {
+            auto tensorType = op.getData().getType();
+            auto memrefType = MemRefType::get(tensorType.getShape(),
+                                              tensorType.getElementType());
+            field = rewriter.create<bufferization::ToMemrefOp>(
+                op->getLoc(), memrefType, op.getData());
+            break;
+          }
+          }
+
+          assert(field);
+          if (fType != field.getType())
+            field = rewriter.create<memref::CastOp>(loc, fType, field);
+          fields.push_back(field);
+          // Returns true to continue the iteration.
+          return true;
+        });
+
+    MutSparseTensorDescriptor desc(rtp, fields);
+    auto noe = linalg::createOrFoldDimOp(rewriter, loc, op.getData(), 0);
+    for (unsigned i = 0, e = rtp.getRank(); i < e; i++) {
+      int dim = rtp.getShape()[i];
+      assert(!ShapedType::isDynamic(dim));
+      desc.setDimSize(rewriter, loc, i, constantIndex(rewriter, loc, dim));
+      if (i == 0)
+        desc.setPtrMemSize(rewriter, loc, i, constantIndex(rewriter, loc, 2));
+
+      desc.setIdxMemSize(rewriter, loc, i, noe);
+    }
+    desc.setValMemSize(rewriter, loc, noe);
+
+    rewriter.replaceOp(op, genTuple(rewriter, loc, desc));
+    return success();
+  }
+};
+
 } // namespace
 
 //===----------------------------------------------------------------------===//
@@ -1032,14 +1124,15 @@ class SparseNumberOfEntriesConverter
 void mlir::populateSparseTensorCodegenPatterns(
     TypeConverter &typeConverter, RewritePatternSet &patterns,
     bool enableBufferInitialization) {
-  patterns.add<SparseReturnConverter, SparseCallConverter, SparseDimOpConverter,
-               SparseCastConverter, SparseTensorDeallocConverter,
-               SparseTensorLoadConverter, SparseExpandConverter,
-               SparseCompressConverter, SparseInsertConverter,
-               SparseToPointersConverter, SparseToIndicesConverter,
-               SparseToIndicesBufferConverter, SparseToValuesConverter,
-               SparseConvertConverter, SparseNumberOfEntriesConverter>(
-      typeConverter, patterns.getContext());
+  patterns.add<SparsePackOpConverter, SparseReturnConverter,
+               SparseCallConverter, SparseDimOpConverter, SparseCastConverter,
+               SparseTensorDeallocConverter, SparseTensorLoadConverter,
+               SparseExpandConverter, SparseCompressConverter,
+               SparseInsertConverter, SparseToPointersConverter,
+               SparseToIndicesConverter, SparseToIndicesBufferConverter,
+               SparseToValuesConverter, SparseConvertConverter,
+               SparseNumberOfEntriesConverter>(typeConverter,
+                                               patterns.getContext());
   patterns.add<SparseTensorAllocConverter>(typeConverter, patterns.getContext(),
                                            enableBufferInitialization);
 }

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h
index 996485642c961..193c227171368 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h
@@ -396,6 +396,19 @@ class MutSparseTensorDescriptor
     fields.back() = md;
   }
 
+  void setValMemSize(OpBuilder &builder, Location loc, Value v) {
+    setSpecifierField(builder, loc, StorageSpecifierKind::ValMemSize,
+                      std::nullopt, v);
+  }
+
+  void setIdxMemSize(OpBuilder &builder, Location loc, unsigned dim, Value v) {
+    setSpecifierField(builder, loc, StorageSpecifierKind::IdxMemSize, dim, v);
+  }
+
+  void setPtrMemSize(OpBuilder &builder, Location loc, unsigned dim, Value v) {
+    setSpecifierField(builder, loc, StorageSpecifierKind::PtrMemSize, dim, v);
+  }
+
   void setDimSize(OpBuilder &builder, Location loc, unsigned dim, Value v) {
     setSpecifierField(builder, loc, StorageSpecifierKind::DimSize, dim, v);
   }

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_pack.mlir b/mlir/test/Dialect/SparseTensor/sparse_pack.mlir
new file mode 100644
index 0000000000000..6dbf258c7e7bd
--- /dev/null
+++ b/mlir/test/Dialect/SparseTensor/sparse_pack.mlir
@@ -0,0 +1,39 @@
+// RUN: mlir-opt %s --canonicalize --post-sparsification-rewrite="enable-runtime-library=false" --sparse-tensor-codegen -cse | FileCheck %s
+
+#COO = #sparse_tensor.encoding<{
+  dimLevelType = ["compressed-nu", "singleton"],
+  indexBitWidth=32
+}>
+
+// CHECK-LABEL:   func.func @sparse_pack(
+// CHECK-SAME:      %[[VAL_0:.*]]: tensor<6xf64>,
+// CHECK-SAME:      %[[VAL_1:.*]]: tensor<6x2xi32>) -> (memref<?xindex>, memref<?xi32>, memref<?xf64>,
+// CHECK:           %[[VAL_2:.*]] = arith.constant dense<[0, 6]> : tensor<2xindex>
+// CHECK:           %[[VAL_3:.*]] = bufferization.to_memref %[[VAL_2]] : memref<2xindex>
+// CHECK:           %[[VAL_4:.*]] = memref.cast %[[VAL_3]] : memref<2xindex> to memref<?xindex>
+// CHECK:           %[[VAL_5:.*]] = bufferization.to_memref %[[VAL_1]] : memref<6x2xi32>
+// CHECK:           %[[VAL_6:.*]] = memref.collapse_shape %[[VAL_5]] {{\[\[}}0, 1]] : memref<6x2xi32> into memref<12xi32>
+// CHECK:           %[[VAL_7:.*]] = memref.cast %[[VAL_6]] : memref<12xi32> to memref<?xi32>
+// CHECK:           %[[VAL_8:.*]] = bufferization.to_memref %[[VAL_0]] : memref<6xf64>
+// CHECK:           %[[VAL_9:.*]] = memref.cast %[[VAL_8]] : memref<6xf64> to memref<?xf64>
+// CHECK:           %[[VAL_10:.*]] = sparse_tensor.storage_specifier.init :
+// CHECK:           %[[VAL_11:.*]] = arith.constant 6 : index
+// CHECK:           %[[VAL_12:.*]] = arith.constant 100 : index
+// CHECK:           %[[VAL_13:.*]] = arith.index_cast %[[VAL_12]] : index to i32
+// CHECK:           %[[VAL_14:.*]] = sparse_tensor.storage_specifier.set %[[VAL_10]]  dim_sz at 0 with %[[VAL_13]] : i32,
+// CHECK:           %[[VAL_15:.*]] = arith.constant 2 : index
+// CHECK:           %[[VAL_16:.*]] = arith.index_cast %[[VAL_15]] : index to i32
+// CHECK:           %[[VAL_17:.*]] = sparse_tensor.storage_specifier.set %[[VAL_14]]  ptr_mem_sz at 0 with %[[VAL_16]] : i32,
+// CHECK:           %[[VAL_18:.*]] = arith.index_cast %[[VAL_11]] : index to i32
+// CHECK:           %[[VAL_19:.*]] = sparse_tensor.storage_specifier.set %[[VAL_17]]  idx_mem_sz at 0 with %[[VAL_18]] : i32,
+// CHECK:           %[[VAL_20:.*]] = sparse_tensor.storage_specifier.set %[[VAL_19]]  dim_sz at 1 with %[[VAL_13]] : i32,
+// CHECK:           %[[VAL_21:.*]] = sparse_tensor.storage_specifier.set %[[VAL_20]]  idx_mem_sz at 1 with %[[VAL_18]] : i32,
+// CHECK:           %[[VAL_22:.*]] = sparse_tensor.storage_specifier.set %[[VAL_21]]  val_mem_sz with %[[VAL_18]] : i32,
+// CHECK:           return %[[VAL_4]], %[[VAL_7]], %[[VAL_9]], %[[VAL_22]] : memref<?xindex>, memref<?xi32>, memref<?xf64>,
+// CHECK:         }
+func.func @sparse_pack(%data: tensor<6xf64>, %index: tensor<6x2xi32>)
+                    -> tensor<100x100xf64, #COO> {
+  %0 = sparse_tensor.pack %data, %index : tensor<6xf64>, tensor<6x2xi32>
+                                       to tensor<100x100xf64, #COO>
+  return %0 : tensor<100x100xf64, #COO>
+}

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir
new file mode 100644
index 0000000000000..a7e0f210dfd11
--- /dev/null
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir
@@ -0,0 +1,56 @@
+// DEFINE: %{option} = enable-runtime-library=false
+// DEFINE: %{command} = mlir-opt %s --sparse-compiler=%{option} | \
+// DEFINE: mlir-cpu-runner \
+// DEFINE:  -e entry -entry-point-result=void  \
+// DEFINE:  -shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext | \
+// DEFINE: FileCheck %s
+//
+// RUN: %{command}
+//
+
+// TODO: Pack only support CodeGen Path
+
+#SortedCOO = #sparse_tensor.encoding<{
+  dimLevelType = [ "compressed-nu", "singleton" ]
+}>
+
+module {
+  //
+  // Main driver.
+  //
+  func.func @entry() {
+    //
+    // Initialize a 3-dim dense tensor.
+    //
+    %data = arith.constant dense<
+       [  1.0,  2.0,  3.0]
+    > : tensor<3xf64>
+
+    %index = arith.constant dense<
+       [[  1,  2],
+        [  5,  6],
+        [  7,  8]]
+    > : tensor<3x2xindex>
+
+    %s4 = sparse_tensor.pack %data, %index : tensor<3xf64>, tensor<3x2xindex>
+                                          to tensor<10x10xf64, #SortedCOO>
+    // CHECK:1
+    // CHECK-NEXT:2
+    // CHECK-NEXT:1
+    //
+    // CHECK-NEXT:5
+    // CHECK-NEXT:6
+    // CHECK-NEXT:2
+    //
+    // CHECK-NEXT:7
+    // CHECK-NEXT:8
+    // CHECK-NEXT:3
+    sparse_tensor.foreach in %s4 : tensor<10x10xf64, #SortedCOO> do {
+      ^bb0(%1: index, %2: index, %v: f64) :
+        vector.print %1: index
+        vector.print %2: index
+        vector.print %v: f64
+     }
+    return
+  }
+}


        


More information about the Mlir-commits mailing list