[Mlir-commits] [mlir] 0f3e4d1 - [mlir][sparse] lower number of entries op to actual code

Aart Bik llvmlistbot at llvm.org
Fri Oct 21 10:48:50 PDT 2022


Author: Aart Bik
Date: 2022-10-21T10:48:37-07:00
New Revision: 0f3e4d1afaa1dc330c374b729269f2ff8422e8dd

URL: https://github.com/llvm/llvm-project/commit/0f3e4d1afaa1dc330c374b729269f2ff8422e8dd
DIFF: https://github.com/llvm/llvm-project/commit/0f3e4d1afaa1dc330c374b729269f2ff8422e8dd.diff

LOG: [mlir][sparse] lower number of entries op to actual code

works both along runtime path and pure codegen path

Reviewed By: Peiming

Differential Revision: https://reviews.llvm.org/D136389

Added: 
    

Modified: 
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
    mlir/test/Dialect/SparseTensor/codegen.mlir
    mlir/test/Dialect/SparseTensor/conversion.mlir
    mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_dot.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index 1beb1271103b4..bf2f77d95e665 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -277,6 +277,12 @@ static scf::ForOp createFor(OpBuilder &builder, Location loc, Value count,
   return forOp;
 }
 
+/// Translates field index to memSizes index.
+static unsigned getMemSizesIndex(unsigned field) {
+  assert(2 <= field);
+  return field - 2;
+}
+
 /// Creates a pushback op for given field and updates the fields array
 /// accordingly.
 static void createPushback(OpBuilder &builder, Location loc,
@@ -286,9 +292,9 @@ static void createPushback(OpBuilder &builder, Location loc,
   Type etp = fields[field].getType().cast<ShapedType>().getElementType();
   if (value.getType() != etp)
     value = builder.create<arith::IndexCastOp>(loc, etp, value);
-  fields[field] =
-      builder.create<PushBackOp>(loc, fields[field].getType(), fields[1],
-                                 fields[field], value, APInt(64, field - 2));
+  fields[field] = builder.create<PushBackOp>(
+      loc, fields[field].getType(), fields[1], fields[field], value,
+      APInt(64, getMemSizesIndex(field)));
 }
 
 /// Generates insertion code.
@@ -739,6 +745,25 @@ class SparseConvertConverter : public OpConversionPattern<ConvertOp> {
   }
 };
 
+/// Sparse codegen rule for number of entries operator.
+class SparseNumberOfEntriesConverter
+    : public OpConversionPattern<NumberOfEntriesOp> {
+public:
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(NumberOfEntriesOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    // Query memSizes for the actually stored values size.
+    auto tuple = getTuple(adaptor.getTensor());
+    auto fields = tuple.getInputs();
+    unsigned lastField = fields.size() - 1;
+    Value field =
+        constantIndex(rewriter, op.getLoc(), getMemSizesIndex(lastField));
+    rewriter.replaceOpWithNewOp<memref::LoadOp>(op, fields[1], field);
+    return success();
+  }
+};
+
 } // namespace
 
 //===----------------------------------------------------------------------===//
@@ -775,5 +800,6 @@ void mlir::populateSparseTensorCodegenPatterns(TypeConverter &typeConverter,
                SparseExpandConverter, SparseCompressConverter,
                SparseInsertConverter, SparseToPointersConverter,
                SparseToIndicesConverter, SparseToValuesConverter,
-               SparseConvertConverter>(typeConverter, patterns.getContext());
+               SparseConvertConverter, SparseNumberOfEntriesConverter>(
+      typeConverter, patterns.getContext());
 }

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index 40112078572bb..c7c81767a4041 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -205,6 +205,15 @@ static void newParams(OpBuilder &builder, SmallVector<Value, 8> &params,
   params.push_back(ptr);
 }
 
+/// Generates a call to obtain the values array.
+static Value genValuesCall(OpBuilder &builder, Location loc, ShapedType tp,
+                           ValueRange ptr) {
+  SmallString<15> name{"sparseValues",
+                       primaryTypeFunctionSuffix(tp.getElementType())};
+  return createFuncCall(builder, loc, name, tp, ptr, EmitCInterface::On)
+      .getResult(0);
+}
+
 /// Generates a call to release/delete a `SparseTensorCOO`.
 static void genDelCOOCall(OpBuilder &builder, Location loc, Type elemTp,
                           Value coo) {
@@ -903,11 +912,28 @@ class SparseTensorToValuesConverter : public OpConversionPattern<ToValuesOp> {
   LogicalResult
   matchAndRewrite(ToValuesOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    Type resType = op.getType();
-    Type eltType = resType.cast<ShapedType>().getElementType();
-    SmallString<15> name{"sparseValues", primaryTypeFunctionSuffix(eltType)};
-    replaceOpWithFuncCall(rewriter, op, name, resType, adaptor.getOperands(),
-                          EmitCInterface::On);
+    auto resType = op.getType().cast<ShapedType>();
+    rewriter.replaceOp(op, genValuesCall(rewriter, op.getLoc(), resType,
+                                         adaptor.getOperands()));
+    return success();
+  }
+};
+
+/// Sparse conversion rule for number of entries operator.
+class SparseNumberOfEntriesConverter
+    : public OpConversionPattern<NumberOfEntriesOp> {
+public:
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(NumberOfEntriesOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    // Query values array size for the actually stored values size.
+    Type eltType = op.getTensor().getType().cast<ShapedType>().getElementType();
+    auto resTp = MemRefType::get({ShapedType::kDynamicSize}, eltType);
+    Value values = genValuesCall(rewriter, loc, resTp, adaptor.getOperands());
+    rewriter.replaceOpWithNewOp<memref::DimOp>(op, values,
+                                               constantIndex(rewriter, loc, 0));
     return success();
   }
 };
@@ -1250,9 +1276,10 @@ void mlir::populateSparseTensorConversionPatterns(
                SparseTensorConcatConverter, SparseTensorAllocConverter,
                SparseTensorDeallocConverter, SparseTensorToPointersConverter,
                SparseTensorToIndicesConverter, SparseTensorToValuesConverter,
-               SparseTensorLoadConverter, SparseTensorInsertConverter,
-               SparseTensorExpandConverter, SparseTensorCompressConverter,
-               SparseTensorOutConverter>(typeConverter, patterns.getContext());
+               SparseNumberOfEntriesConverter, SparseTensorLoadConverter,
+               SparseTensorInsertConverter, SparseTensorExpandConverter,
+               SparseTensorCompressConverter, SparseTensorOutConverter>(
+      typeConverter, patterns.getContext());
 
   patterns.add<SparseTensorConvertConverter>(typeConverter,
                                              patterns.getContext(), options);

diff  --git a/mlir/test/Dialect/SparseTensor/codegen.mlir b/mlir/test/Dialect/SparseTensor/codegen.mlir
index 6b5c6c4ce3808..71f736d7263de 100644
--- a/mlir/test/Dialect/SparseTensor/codegen.mlir
+++ b/mlir/test/Dialect/SparseTensor/codegen.mlir
@@ -239,6 +239,20 @@ func.func @sparse_values_dcsr(%arg0: tensor<?x?xf64, #DCSR>) -> memref<?xf64> {
   return %0 : memref<?xf64>
 }
 
+// CHECK-LABEL: func @sparse_noe(
+//  CHECK-SAME: %[[A0:.*0]]: memref<1xindex>,
+//  CHECK-SAME: %[[A1:.*1]]: memref<3xindex>,
+//  CHECK-SAME: %[[A2:.*2]]: memref<?xi32>,
+//  CHECK-SAME: %[[A3:.*3]]: memref<?xi64>,
+//  CHECK-SAME: %[[A4:.*4]]: memref<?xf64>)
+//       CHECK: %[[C2:.*]] = arith.constant 2 : index
+//       CHECK: %[[NOE:.*]] = memref.load %[[A1]][%[[C2]]] : memref<3xindex>
+//       CHECK: return %[[NOE]] : index
+func.func @sparse_noe(%arg0: tensor<128xf64, #SparseVector>) -> index {
+  %0 = sparse_tensor.number_of_entries %arg0 : tensor<128xf64, #SparseVector>
+  return %0 : index
+}
+
 // CHECK-LABEL: func @sparse_dealloc_csr(
 //  CHECK-SAME: %[[A0:.*0]]: memref<2xindex>,
 //  CHECK-SAME: %[[A1:.*1]]: memref<3xindex>,

diff  --git a/mlir/test/Dialect/SparseTensor/conversion.mlir b/mlir/test/Dialect/SparseTensor/conversion.mlir
index 44fcd4219ec08..33b7d133fe849 100644
--- a/mlir/test/Dialect/SparseTensor/conversion.mlir
+++ b/mlir/test/Dialect/SparseTensor/conversion.mlir
@@ -268,6 +268,17 @@ func.func @sparse_valuesi8(%arg0: tensor<128xi8, #SparseVector>) -> memref<?xi8>
   return %0 : memref<?xi8>
 }
 
+// CHECK-LABEL: func @sparse_noe(
+//  CHECK-SAME: %[[A:.*]]: !llvm.ptr<i8>)
+//   CHECK-DAG: %[[C:.*]] = arith.constant 0 : index
+//   CHECK-DAG: %[[T:.*]] = call @sparseValuesF64(%[[A]]) : (!llvm.ptr<i8>) -> memref<?xf64>
+//       CHECK: %[[NOE:.*]] = memref.dim %[[T]], %[[C]] : memref<?xf64>
+//       CHECK: return %[[NOE]] : index
+func.func @sparse_noe(%arg0: tensor<128xf64, #SparseVector>) -> index {
+  %0 = sparse_tensor.number_of_entries %arg0 : tensor<128xf64, #SparseVector>
+  return %0 : index
+}
+
 // CHECK-LABEL: func @sparse_reconstruct(
 //  CHECK-SAME: %[[A:.*]]: !llvm.ptr<i8>
 //       CHECK: return %[[A]] : !llvm.ptr<i8>

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_dot.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_dot.mlir
index 6fc68ed700fdb..b3a5bbdb8f54a 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_dot.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_dot.mlir
@@ -46,6 +46,16 @@ module {
     %1 = tensor.extract %0[] : tensor<f32>
     vector.print %1 : f32
 
+    // Print number of entries in the sparse vectors.
+    //
+    // CHECK: 5
+    // CHECK: 3
+    //
+    %noe1 = sparse_tensor.number_of_entries %s1 : tensor<1024xf32, #SparseVector>
+    %noe2 = sparse_tensor.number_of_entries %s2 : tensor<1024xf32, #SparseVector>
+    vector.print %noe1 : index
+    vector.print %noe2 : index
+
     // Release the resources.
     bufferization.dealloc_tensor %s1 : tensor<1024xf32, #SparseVector>
     bufferization.dealloc_tensor %s2 : tensor<1024xf32, #SparseVector>


        


More information about the Mlir-commits mailing list