[Mlir-commits] [mlir] fbd5821 - Implement the conversion from sparse constant to sparse tensors.

Bixia Zheng llvmlistbot at llvm.org
Mon Sep 27 09:47:34 PDT 2021


Author: Bixia Zheng
Date: 2021-09-27T09:47:29-07:00
New Revision: fbd5821c6f2c516a64602839745ddc6f9566f710

URL: https://github.com/llvm/llvm-project/commit/fbd5821c6f2c516a64602839745ddc6f9566f710
DIFF: https://github.com/llvm/llvm-project/commit/fbd5821c6f2c516a64602839745ddc6f9566f710.diff

LOG: Implement the conversion from sparse constant to sparse tensors.

The sparse constant provides a constant tensor in coordinate format. We first split the sparse constant into a constant tensor for indices and a constant tensor for values. We then generate a loop to fill a sparse tensor in coordinate format using the tensors for the indices and the values. Finally, we convert the sparse tensor in coordinate format to the destination sparse tensor format.

Add tests.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D110373

Added: 
    mlir/test/Integration/Dialect/SparseTensor/CPU/sparse-constant_to_sparse_tensor.mlir

Modified: 
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
    mlir/test/Dialect/SparseTensor/conversion.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index 328bf8e403092..681c8160aaeaa 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -195,16 +195,39 @@ static Value genIsNonzero(ConversionPatternRewriter &rewriter, Location loc,
   llvm_unreachable("Unknown element type");
 }
 
+/// Generates the code to read the value from tensor[ivs], and conditionally
+/// stores the indices ivs to the memory in ind. The generated code looks like
+/// the following and the insertion point after this routine is inside the
+/// if-then branch behind the assignment to ind. This is to ensure that the
+/// addEltX call generated after is inside the if-then branch.
+///    if (tensor[ivs]!=0) {
+///      ind = ivs
+static Value genIndexAndValueForDense(ConversionPatternRewriter &rewriter,
+                                      Operation *op, Type eltType, Value tensor,
+                                      Value ind, ValueRange ivs) {
+  Location loc = op->getLoc();
+  Value val = rewriter.create<tensor::ExtractOp>(loc, tensor, ivs);
+  Value cond = genIsNonzero(rewriter, loc, eltType, val);
+  scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, cond, /*else*/ false);
+  rewriter.setInsertionPointToStart(&ifOp.thenRegion().front());
+  unsigned i = 0;
+  for (auto iv : ivs) {
+    Value idx = rewriter.create<ConstantOp>(loc, rewriter.getIndexAttr(i++));
+    rewriter.create<memref::StoreOp>(loc, iv, ind, idx);
+  }
+  return val;
+}
+
 /// Generates a call that adds one element to a coordinate scheme.
 /// In particular, this generates code like the following:
 ///   val = a[i1,..,ik];
 ///   if val != 0
 ///     t->add(val, [i1,..,ik], [p1,..,pk]);
 static void genAddEltCall(ConversionPatternRewriter &rewriter, Operation *op,
-                          Value ptr, Value tensor, Value ind, Value perm,
-                          ValueRange ivs) {
+                          Type eltType, Value ptr, Value val, Value ind,
+                          Value perm) {
+  Location loc = op->getLoc();
   StringRef name;
-  Type eltType = tensor.getType().cast<ShapedType>().getElementType();
   if (eltType.isF64())
     name = "addEltF64";
   else if (eltType.isF32())
@@ -219,16 +242,6 @@ static void genAddEltCall(ConversionPatternRewriter &rewriter, Operation *op,
     name = "addEltI8";
   else
     llvm_unreachable("Unknown element type");
-  Location loc = op->getLoc();
-  Value val = rewriter.create<tensor::ExtractOp>(loc, tensor, ivs);
-  Value cond = genIsNonzero(rewriter, loc, eltType, val);
-  scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, cond, /*else*/ false);
-  rewriter.setInsertionPointToStart(&ifOp.thenRegion().front());
-  unsigned i = 0;
-  for (auto iv : ivs) {
-    Value idx = rewriter.create<ConstantOp>(loc, rewriter.getIndexAttr(i++));
-    rewriter.create<memref::StoreOp>(loc, iv, ind, idx);
-  }
   SmallVector<Value, 8> params;
   params.push_back(ptr);
   params.push_back(val);
@@ -240,6 +253,41 @@ static void genAddEltCall(ConversionPatternRewriter &rewriter, Operation *op,
       params);
 }
 
+/// If the tensor is a sparse constant, generates and returns the pair of
+/// the constants for the indices and the values.
+static Optional<std::pair<Value, Value>>
+genSplitSparseConstant(ConversionPatternRewriter &rewriter, ConvertOp op,
+                       Value tensor) {
+  if (auto constOp = tensor.getDefiningOp<ConstantOp>()) {
+    if (auto attr = constOp.value().dyn_cast<SparseElementsAttr>()) {
+      Location loc = op->getLoc();
+      DenseElementsAttr indicesAttr = attr.getIndices();
+      Value indices = rewriter.create<ConstantOp>(loc, indicesAttr);
+      DenseElementsAttr valuesAttr = attr.getValues();
+      Value values = rewriter.create<ConstantOp>(loc, valuesAttr);
+      return std::make_pair(indices, values);
+    }
+  }
+  return {};
+}
+
+/// Generates the code to copy the index at indices[ivs] to ind, and return
+/// the value at value[ivs].
+static Value genIndexAndValueForSparse(ConversionPatternRewriter &rewriter,
+                                       Operation *op, Value indices,
+                                       Value values, Value ind, ValueRange ivs,
+                                       unsigned rank) {
+  Location loc = op->getLoc();
+  for (unsigned i = 0; i < rank; i++) {
+    Value idx = rewriter.create<ConstantOp>(loc, rewriter.getIndexAttr(i));
+    Value val = rewriter.create<tensor::ExtractOp>(loc, indices,
+                                                   ValueRange{ivs[0], idx});
+    val = rewriter.create<IndexCastOp>(loc, val, rewriter.getIndexType());
+    rewriter.create<memref::StoreOp>(loc, val, ind, idx);
+  }
+  return rewriter.create<tensor::ExtractOp>(loc, values, ivs[0]);
+}
+
 //===----------------------------------------------------------------------===//
 // Conversion rules.
 //===----------------------------------------------------------------------===//
@@ -330,15 +378,26 @@ class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
       // TODO: sparse => dense
       return failure();
     }
-    // This is a dense => sparse conversion, which is handled as follows:
+    // This is a dense => sparse conversion or a sparse constant in COO =>
+    // sparse conversion, which is handled as follows:
     //   t = newSparseCOO()
+    //   ...code to fill the COO tensor t...
+    //   s = newSparseTensor(t)
+    //
+    // To fill the COO tensor from a dense tensor:
     //   for i1 in dim1
     //    ..
     //     for ik in dimk
     //       val = a[i1,..,ik]
     //       if val != 0
     //         t->add(val, [i1,..,ik], [p1,..,pk])
-    //   s = newSparseTensor(t)
+    //
+    // To fill the COO tensor from a sparse constant in COO format:
+    //   for i in range(NNZ)
+    //     val = values[i]
+    //     [i1,..,ik] = indices[i]
+    //     t->add(val, [i1,..,ik], [p1,..,pk])
+    //
     // Note that the dense tensor traversal code is actually implemented
     // using MLIR IR to avoid having to expose too much low-level
     // memref traversal details to the runtime support library.
@@ -351,7 +410,6 @@ class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
         MemRefType::get({ShapedType::kDynamicSize}, rewriter.getIndexType());
     Value perm;
     Value ptr = genNewCall(rewriter, op, encDst, 2, perm);
-    Value tensor = adaptor.getOperands()[0];
     Value arg = rewriter.create<ConstantOp>(
         loc, rewriter.getIndexAttr(shape.getRank()));
     Value ind = rewriter.create<memref::AllocaOp>(loc, memTp, ValueRange{arg});
@@ -360,16 +418,38 @@ class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
     SmallVector<Value> st;
     Value zero = rewriter.create<ConstantOp>(loc, rewriter.getIndexAttr(0));
     Value one = rewriter.create<ConstantOp>(loc, rewriter.getIndexAttr(1));
-    for (unsigned i = 0, rank = shape.getRank(); i < rank; i++) {
+    Value tensor = adaptor.getOperands()[0];
+    auto indicesValues = genSplitSparseConstant(rewriter, op, tensor);
+    bool isCOOConstant = indicesValues.hasValue();
+    Value indices;
+    Value values;
+    if (isCOOConstant) {
+      indices = indicesValues->first;
+      values = indicesValues->second;
       lo.push_back(zero);
-      hi.push_back(linalg::createOrFoldDimOp(rewriter, loc, tensor, i));
+      hi.push_back(linalg::createOrFoldDimOp(rewriter, loc, values, 0));
       st.push_back(one);
+    } else {
+      for (unsigned i = 0, rank = shape.getRank(); i < rank; i++) {
+        lo.push_back(zero);
+        hi.push_back(linalg::createOrFoldDimOp(rewriter, loc, tensor, i));
+        st.push_back(one);
+      }
     }
+    Type eltType = shape.getElementType();
+    unsigned rank = shape.getRank();
     scf::buildLoopNest(rewriter, op.getLoc(), lo, hi, st, {},
                        [&](OpBuilder &builder, Location loc, ValueRange ivs,
                            ValueRange args) -> scf::ValueVector {
-                         genAddEltCall(rewriter, op, ptr, tensor, ind, perm,
-                                       ivs);
+                         Value val;
+                         if (isCOOConstant)
+                           val = genIndexAndValueForSparse(
+                               rewriter, op, indices, values, ind, ivs, rank);
+                         else
+                           val = genIndexAndValueForDense(rewriter, op, eltType,
+                                                          tensor, ind, ivs);
+                         genAddEltCall(rewriter, op, eltType, ptr, val, ind,
+                                       perm);
                          return {};
                        });
     rewriter.replaceOp(op, genNewCall(rewriter, op, encDst, 1, perm, ptr));

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
index f452a25912d77..7597fc3dbdc28 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
@@ -114,8 +114,8 @@ struct SparseTensorConversionPass
     });
     // The following operations and dialects may be introduced by the
     // rewriting rules, and are therefore marked as legal.
-    target.addLegalOp<ConstantOp, tensor::CastOp, tensor::ExtractOp, CmpFOp,
-                      CmpIOp>();
+    target.addLegalOp<ConstantOp, IndexCastOp, tensor::CastOp,
+                      tensor::ExtractOp, CmpFOp, CmpIOp>();
     target.addLegalDialect<scf::SCFDialect, LLVM::LLVMDialect,
                            memref::MemRefDialect>();
     // Populate with rules and apply rewriting rules.

diff  --git a/mlir/test/Dialect/SparseTensor/conversion.mlir b/mlir/test/Dialect/SparseTensor/conversion.mlir
index deca4f12c0f3f..f2831f88b88ae 100644
--- a/mlir/test/Dialect/SparseTensor/conversion.mlir
+++ b/mlir/test/Dialect/SparseTensor/conversion.mlir
@@ -182,6 +182,45 @@ func @sparse_convert_2d(%arg0: tensor<2x4xf64>) -> tensor<2x4xf64, #SparseMatrix
   return %0 : tensor<2x4xf64, #SparseMatrix>
 }
 
+#CSR = #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ] }>
+
+// CHECK-LABEL:   func @entry() -> !llvm.ptr<i8> {
+// CHECK:           %[[C1:.*]] = constant 1 : i32
+// CHECK:           %[[Offset:.*]] = constant dense<[0, 1]> : tensor<2xi64>
+// CHECK:           %[[Dims:.*]] = constant dense<[8, 7]> : tensor<2xi64>
+// CHECK:           %[[Base:.*]] = constant dense<[0, 1]> : tensor<2xi8>
+// CHECK:           %[[I2:.*]] = constant 2 : index
+// CHECK:           %[[SparseV:.*]] = constant dense<[1.000000e+00, 5.000000e+00]> : tensor<2xf32>
+// CHECK:           %[[SparseI:.*]] = constant dense<{{\[\[}}0, 0], [1, 6]]> : tensor<2x2xi64>
+// CHECK:           %[[I1:.*]] = constant 1 : index
+// CHECK:           %[[I0:.*]] = constant 0 : index
+// CHECK:           %[[C2:.*]] = constant 2 : i32
+// CHECK:           %[[BaseD:.*]] = tensor.cast %[[Base]] : tensor<2xi8> to tensor<?xi8>
+// CHECK:           %[[DimsD:.*]] = tensor.cast %[[Dims]] : tensor<2xi64> to tensor<?xi64>
+// CHECK:           %[[OffsetD:.*]] = tensor.cast %[[Offset]] : tensor<2xi64> to tensor<?xi64>
+// CHECK:           %[[TCOO:.*]] = call @newSparseTensor(%[[BaseD]], %[[DimsD]], %[[OffsetD]], %{{.*}}, %{{.*}}, %{{.*}}, %[[C2]], %{{.}})
+// CHECK:           %[[Index:.*]] = memref.alloca() : memref<2xindex>
+// CHECK:           %[[IndexD:.*]] = memref.cast %[[Index]] : memref<2xindex> to memref<?xindex>
+// CHECK:           scf.for %[[IV:.*]] = %[[I0]] to %[[I2]] step %[[I1]] {
+// CHECK:             %[[VAL0:.*]] = tensor.extract %[[SparseI]]{{\[}}%[[IV]], %[[I0]]] : tensor<2x2xi64>
+// CHECK:             %[[VAL1:.*]] = index_cast %[[VAL0]] : i64 to index
+// CHECK:             memref.store %[[VAL1]], %[[Index]]{{\[}}%[[I0]]] : memref<2xindex>
+// CHECK:             %[[VAL2:.*]] = tensor.extract %[[SparseI]]{{\[}}%[[IV]], %[[I1]]] : tensor<2x2xi64>
+// CHECK:             %[[VAL3:.*]] = index_cast %[[VAL2]] : i64 to index
+// CHECK:             memref.store %[[VAL3]], %[[Index]]{{\[}}%[[I1]]] : memref<2xindex>
+// CHECK:             %[[VAL4:.*]] = tensor.extract %[[SparseV]]{{\[}}%[[IV]]] : tensor<2xf32>
+// CHECK:             call @addEltF32(%[[TCOO]], %[[VAL4]], %[[IndexD]], %[[OffsetD]])
+// CHECK:           }
+// CHECK:           %[[T:.*]] = call @newSparseTensor(%[[BaseD]], %[[DimsD]], %[[OffsetD]], %{{.*}}, %{{.*}}, %[[C1]], %{{.*}})
+// CHECK:           return %[[T]] : !llvm.ptr<i8>
+func @entry() -> tensor<8x7xf32, #CSR>{
+  // Initialize a tensor.
+  %0 = constant sparse<[[0, 0], [1, 6]], [1.0, 5.0]> : tensor<8x7xf32>
+  // Convert the tensor to a sparse tensor.
+  %1 = sparse_tensor.convert %0 : tensor<8x7xf32> to tensor<8x7xf32, #CSR>
+  return %1 : tensor<8x7xf32, #CSR>
+}
+
 // CHECK-LABEL: func @sparse_convert_3d(
 //  CHECK-SAME: %[[A:.*]]: tensor<?x?x?xf64>) -> !llvm.ptr<i8>
 //   CHECK-DAG: %[[C0:.*]] = constant 0 : index

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse-constant_to_sparse_tensor.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse-constant_to_sparse_tensor.mlir
new file mode 100644
index 0000000000000..9154b402635e0
--- /dev/null
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse-constant_to_sparse_tensor.mlir
@@ -0,0 +1,51 @@
+// RUN: mlir-opt %s \
+// RUN:   --sparsification --sparse-tensor-conversion \
+// RUN:   --convert-vector-to-scf --convert-scf-to-std \
+// RUN:   --func-bufferize --tensor-constant-bufferize --tensor-bufferize \
+// RUN:   --std-bufferize --finalizing-bufferize  \
+// RUN:   --convert-vector-to-llvm --convert-memref-to-llvm --convert-std-to-llvm --reconcile-unrealized-casts | \
+// RUN: mlir-cpu-runner \
+// RUN:  -e entry -entry-point-result=void  \
+// RUN:  -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
+// RUN: FileCheck %s
+
+#Tensor1  = #sparse_tensor.encoding<{
+  dimLevelType = [ "compressed", "compressed"]
+}>
+
+//
+// Integration tests for conversions from sparse constants to sparse tensors.
+//
+module {
+  func @entry() {
+    %c0 = constant 0 : index
+    %c1 = constant 1 : index
+    %c2 = constant 2 : index
+    %d0 = constant 0.0 : f64
+
+    // A tensor in COO format.
+    %ti = constant sparse<[[0, 0], [0, 7], [1, 2], [4, 2], [5, 3], [6, 4], [6, 6], [9, 7]],
+                          [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]> : tensor<10x8xf64>
+
+    // Convert the tensor in COO format to a sparse tensor with annotation #Tensor1.
+    %ts = sparse_tensor.convert %ti : tensor<10x8xf64> to tensor<10x8xf64, #Tensor1>
+
+    // CHECK: ( 0, 1, 4, 5, 6, 9 )
+    %i0 = sparse_tensor.indices %ts, %c0 : tensor<10x8xf64, #Tensor1> to memref<?xindex>
+    %i0r = vector.transfer_read %i0[%c0], %c0: memref<?xindex>, vector<6xindex>
+    vector.print %i0r : vector<6xindex>
+
+    // CHECK: ( 0, 7, 2, 2, 3, 4, 6, 7 )
+    %i1 = sparse_tensor.indices %ts, %c1 : tensor<10x8xf64, #Tensor1> to memref<?xindex>
+    %i1r = vector.transfer_read %i1[%c0], %c0: memref<?xindex>, vector<8xindex>
+    vector.print %i1r : vector<8xindex>
+
+    // CHECK: ( 1, 2, 3, 4, 5, 6, 7, 8 )
+    %v = sparse_tensor.values %ts : tensor<10x8xf64, #Tensor1> to memref<?xf64>
+    %vr = vector.transfer_read %v[%c0], %d0: memref<?xf64>, vector<8xf64>
+    vector.print %vr : vector<8xf64>
+
+    return
+  }
+}
+


        


More information about the Mlir-commits mailing list