[Mlir-commits] [mlir] 3986c86 - [mlir][sparse] partially implement codegen for sparse_tensor.compress

Aart Bik llvmlistbot at llvm.org
Thu Sep 15 10:32:47 PDT 2022


Author: Aart Bik
Date: 2022-09-15T10:32:33-07:00
New Revision: 3986c8698622c447fd814378292555cb5316cc10

URL: https://github.com/llvm/llvm-project/commit/3986c8698622c447fd814378292555cb5316cc10
DIFF: https://github.com/llvm/llvm-project/commit/3986c8698622c447fd814378292555cb5316cc10.diff

LOG: [mlir][sparse] partially implement codegen for sparse_tensor.compress

Reviewed By: Peiming

Differential Revision: https://reviews.llvm.org/D133912

Added: 
    

Modified: 
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
    mlir/test/Dialect/SparseTensor/codegen.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index d5c6d8a276728..0ad44b56e6878 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -302,6 +302,16 @@ static void createAllocFields(OpBuilder &builder, Location loc, Type type,
   assert(fields.size() == lastField);
 }
 
+/// Creates a straightforward counting for-loop.
+static scf::ForOp createFor(OpBuilder &builder, Location loc, Value count) {
+  Type indexType = builder.getIndexType();
+  Value zero = constantZero(builder, loc, indexType);
+  Value one = constantOne(builder, loc, indexType);
+  scf::ForOp forOp = builder.create<scf::ForOp>(loc, zero, count, one);
+  builder.setInsertionPointToStart(forOp.getBody());
+  return forOp;
+}
+
 //===----------------------------------------------------------------------===//
 // Codegen rules.
 //===----------------------------------------------------------------------===//
@@ -518,12 +528,12 @@ class SparseExpandConverter : public OpConversionPattern<ExpandOp> {
       auto memTp = MemRefType::get({ShapedType::kDynamicSize}, t);
       return rewriter.create<memref::AllocOp>(loc, memTp, ValueRange{*sz});
     };
-    // Allocate temporary buffers for values, filled-switch, and indices.
+    // Allocate temporary buffers for values/filled-switch and added.
     // We do not use stack buffers for this, since the expanded size may
     // be rather large (as it envelops a single expanded dense dimension).
     Value values = genAlloc(eltType);
     Value filled = genAlloc(boolType);
-    Value indices = genAlloc(idxType);
+    Value added = genAlloc(idxType);
     Value zero = constantZero(rewriter, loc, idxType);
     // Reset the values/filled-switch to all-zero/false. Note that this
     // introduces an O(N) operation into the computation, but this reset
@@ -538,7 +548,67 @@ class SparseExpandConverter : public OpConversionPattern<ExpandOp> {
         ValueRange{filled});
     // Replace expansion op with these buffers and initial index.
     assert(op.getNumResults() == 4);
-    rewriter.replaceOp(op, {values, filled, indices, zero});
+    rewriter.replaceOp(op, {values, filled, added, zero});
+    return success();
+  }
+};
+
+/// Sparse codegen rule for the compress operator.
+class SparseCompressConverter : public OpConversionPattern<CompressOp> {
+public:
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(CompressOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = op->getLoc();
+    ShapedType srcType = op.getTensor().getType().cast<ShapedType>();
+    Type eltType = srcType.getElementType();
+    Value values = adaptor.getValues();
+    Value filled = adaptor.getFilled();
+    Value added = adaptor.getAdded();
+    Value count = adaptor.getCount();
+
+    //
+    // TODO: need to implement "std::sort(added, added + count);" for ordered
+    //
+
+    // While performing the insertions, we also need to reset the elements
+    // of the values/filled-switch by only iterating over the set elements,
+    // to ensure that the runtime complexity remains proportional to the
+    // sparsity of the expanded access pattern.
+    //
+    // Generate
+    //    for (i = 0; i < count; i++) {
+    //      index = added[i];
+    //      value = values[index];
+    //
+    //      TODO: insert prev_indices, index, value
+    //
+    //      values[index] = 0;
+    //      filled[index] = false;
+    //    }
+    Value i = createFor(rewriter, loc, count).getInductionVar();
+    Value index = rewriter.create<memref::LoadOp>(loc, added, i);
+    rewriter.create<memref::LoadOp>(loc, values, index);
+    // TODO: insert
+    rewriter.create<memref::StoreOp>(loc, constantZero(rewriter, loc, eltType),
+                                     values, index);
+    rewriter.create<memref::StoreOp>(loc, constantI1(rewriter, loc, false),
+                                     filled, index);
+
+    // Deallocate the buffers on exit of the full loop nest.
+    Operation *parent = op;
+    for (; isa<scf::ForOp>(parent->getParentOp()) ||
+           isa<scf::WhileOp>(parent->getParentOp()) ||
+           isa<scf::ParallelOp>(parent->getParentOp()) ||
+           isa<scf::IfOp>(parent->getParentOp());
+         parent = parent->getParentOp())
+      ;
+    rewriter.setInsertionPointAfter(parent);
+    rewriter.create<memref::DeallocOp>(loc, values);
+    rewriter.create<memref::DeallocOp>(loc, filled);
+    rewriter.create<memref::DeallocOp>(loc, added);
+    rewriter.eraseOp(op);
     return success();
   }
 };
@@ -626,7 +696,7 @@ void mlir::populateSparseTensorCodegenPatterns(TypeConverter &typeConverter,
   patterns.add<SparseReturnConverter, SparseCallConverter, SparseDimOpConverter,
                SparseCastConverter, SparseTensorAllocConverter,
                SparseTensorDeallocConverter, SparseTensorLoadConverter,
-               SparseExpandConverter, SparseToPointersConverter,
-               SparseToIndicesConverter, SparseToValuesConverter>(
-      typeConverter, patterns.getContext());
+               SparseExpandConverter, SparseCompressConverter,
+               SparseToPointersConverter, SparseToIndicesConverter,
+               SparseToValuesConverter>(typeConverter, patterns.getContext());
 }

diff  --git a/mlir/test/Dialect/SparseTensor/codegen.mlir b/mlir/test/Dialect/SparseTensor/codegen.mlir
index 6a8a3ca7a56a5..7e241de598069 100644
--- a/mlir/test/Dialect/SparseTensor/codegen.mlir
+++ b/mlir/test/Dialect/SparseTensor/codegen.mlir
@@ -347,3 +347,40 @@ func.func @sparse_expansion3(%arg0: index, %arg1: index) -> memref<?xindex> {
     : tensor<?x?xf64, #CSC> to memref<?xf64>, memref<?xi1>, memref<?xindex>, index
   return %added : memref<?xindex>
 }
+
+// CHECK-LABEL: func @sparse_compression(
+//  CHECK-SAME: %[[A0:.*0]]: memref<2xindex>,
+//  CHECK-SAME: %[[A1:.*1]]: memref<3xindex>,
+//  CHECK-SAME: %[[A2:.*2]]: memref<?xi32>,
+//  CHECK-SAME: %[[A3:.*3]]: memref<?xi64>,
+//  CHECK-SAME: %[[A4:.*4]]: memref<?xf64>,
+//  CHECK-SAME: %[[A5:.*5]]: memref<?xindex>,
+//  CHECK-SAME: %[[A6:.*6]]: memref<?xf64>,
+//  CHECK-SAME: %[[A7:.*7]]: memref<?xi1>,
+//  CHECK-SAME: %[[A8:.*8]]: memref<?xindex>,
+//  CHECK-SAME: %[[A9:.*9]]: index)
+//   CHECK-DAG: %[[B0:.*]] = arith.constant false
+//   CHECK-DAG: %[[F0:.*]] = arith.constant 0.000000e+00 : f64
+//   CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+//   CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+//        TODO: sort
+//  CHECK-NEXT: scf.for %[[I:.*]] = %[[C0]] to %[[A9]] step %[[C1]] {
+//  CHECK-NEXT:   %[[INDEX:.*]] = memref.load %[[A8]][%[[I]]] : memref<?xindex>
+//        TODO:   insert
+//   CHECK-DAG:   memref.store %[[F0]], %[[A6]][%[[INDEX]]] : memref<?xf64>
+//   CHECK-DAG:   memref.store %[[B0]], %[[A7]][%[[INDEX]]] : memref<?xi1>
+//  CHECK-NEXT: }
+//   CHECK-DAG: memref.dealloc %[[A6]] : memref<?xf64>
+//   CHECK-DAG: memref.dealloc %[[A7]] : memref<?xi1>
+//   CHECK-DAG: memref.dealloc %[[A8]] : memref<?xindex>
+//       CHECK: return
+func.func @sparse_compression(%arg0: tensor<8x8xf64, #CSR>,
+                              %arg1: memref<?xindex>,
+                              %arg2: memref<?xf64>,
+                              %arg3: memref<?xi1>,
+                              %arg4: memref<?xindex>,
+                              %arg5: index) {
+  sparse_tensor.compress %arg0, %arg1, %arg2, %arg3, %arg4, %arg5
+    : tensor<8x8xf64, #CSR>, memref<?xindex>, memref<?xf64>, memref<?xi1>, memref<?xindex>, index
+  return
+}


        


More information about the Mlir-commits mailing list