[Mlir-commits] [mlir] 3986c86 - [mlir][sparse] partially implement codegen for sparse_tensor.compress
Aart Bik
llvmlistbot at llvm.org
Thu Sep 15 10:32:47 PDT 2022
Author: Aart Bik
Date: 2022-09-15T10:32:33-07:00
New Revision: 3986c8698622c447fd814378292555cb5316cc10
URL: https://github.com/llvm/llvm-project/commit/3986c8698622c447fd814378292555cb5316cc10
DIFF: https://github.com/llvm/llvm-project/commit/3986c8698622c447fd814378292555cb5316cc10.diff
LOG: [mlir][sparse] partially implement codegen for sparse_tensor.compress
Reviewed By: Peiming
Differential Revision: https://reviews.llvm.org/D133912
Added:
Modified:
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
mlir/test/Dialect/SparseTensor/codegen.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index d5c6d8a276728..0ad44b56e6878 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -302,6 +302,16 @@ static void createAllocFields(OpBuilder &builder, Location loc, Type type,
assert(fields.size() == lastField);
}
+/// Creates a straightforward counting for-loop.
+static scf::ForOp createFor(OpBuilder &builder, Location loc, Value count) {
+ Type indexType = builder.getIndexType();
+ Value zero = constantZero(builder, loc, indexType);
+ Value one = constantOne(builder, loc, indexType);
+ scf::ForOp forOp = builder.create<scf::ForOp>(loc, zero, count, one);
+ builder.setInsertionPointToStart(forOp.getBody());
+ return forOp;
+}
+
//===----------------------------------------------------------------------===//
// Codegen rules.
//===----------------------------------------------------------------------===//
@@ -518,12 +528,12 @@ class SparseExpandConverter : public OpConversionPattern<ExpandOp> {
auto memTp = MemRefType::get({ShapedType::kDynamicSize}, t);
return rewriter.create<memref::AllocOp>(loc, memTp, ValueRange{*sz});
};
- // Allocate temporary buffers for values, filled-switch, and indices.
+ // Allocate temporary buffers for values/filled-switch and added.
// We do not use stack buffers for this, since the expanded size may
// be rather large (as it envelops a single expanded dense dimension).
Value values = genAlloc(eltType);
Value filled = genAlloc(boolType);
- Value indices = genAlloc(idxType);
+ Value added = genAlloc(idxType);
Value zero = constantZero(rewriter, loc, idxType);
// Reset the values/filled-switch to all-zero/false. Note that this
// introduces an O(N) operation into the computation, but this reset
@@ -538,7 +548,67 @@ class SparseExpandConverter : public OpConversionPattern<ExpandOp> {
ValueRange{filled});
// Replace expansion op with these buffers and initial index.
assert(op.getNumResults() == 4);
- rewriter.replaceOp(op, {values, filled, indices, zero});
+ rewriter.replaceOp(op, {values, filled, added, zero});
+ return success();
+ }
+};
+
+/// Sparse codegen rule for the compress operator.
+class SparseCompressConverter : public OpConversionPattern<CompressOp> {
+public:
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(CompressOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = op->getLoc();
+ ShapedType srcType = op.getTensor().getType().cast<ShapedType>();
+ Type eltType = srcType.getElementType();
+ Value values = adaptor.getValues();
+ Value filled = adaptor.getFilled();
+ Value added = adaptor.getAdded();
+ Value count = adaptor.getCount();
+
+ //
+ // TODO: need to implement "std::sort(added, added + count);" for ordered
+ //
+
+ // While performing the insertions, we also need to reset the elements
+ // of the values/filled-switch by only iterating over the set elements,
+ // to ensure that the runtime complexity remains proportional to the
+ // sparsity of the expanded access pattern.
+ //
+ // Generate
+ // for (i = 0; i < count; i++) {
+ // index = added[i];
+ // value = values[index];
+ //
+ // TODO: insert prev_indices, index, value
+ //
+ // values[index] = 0;
+ // filled[index] = false;
+ // }
+ Value i = createFor(rewriter, loc, count).getInductionVar();
+ Value index = rewriter.create<memref::LoadOp>(loc, added, i);
+ rewriter.create<memref::LoadOp>(loc, values, index);
+ // TODO: insert
+ rewriter.create<memref::StoreOp>(loc, constantZero(rewriter, loc, eltType),
+ values, index);
+ rewriter.create<memref::StoreOp>(loc, constantI1(rewriter, loc, false),
+ filled, index);
+
+ // Deallocate the buffers on exit of the full loop nest.
+ Operation *parent = op;
+ for (; isa<scf::ForOp>(parent->getParentOp()) ||
+ isa<scf::WhileOp>(parent->getParentOp()) ||
+ isa<scf::ParallelOp>(parent->getParentOp()) ||
+ isa<scf::IfOp>(parent->getParentOp());
+ parent = parent->getParentOp())
+ ;
+ rewriter.setInsertionPointAfter(parent);
+ rewriter.create<memref::DeallocOp>(loc, values);
+ rewriter.create<memref::DeallocOp>(loc, filled);
+ rewriter.create<memref::DeallocOp>(loc, added);
+ rewriter.eraseOp(op);
return success();
}
};
@@ -626,7 +696,7 @@ void mlir::populateSparseTensorCodegenPatterns(TypeConverter &typeConverter,
patterns.add<SparseReturnConverter, SparseCallConverter, SparseDimOpConverter,
SparseCastConverter, SparseTensorAllocConverter,
SparseTensorDeallocConverter, SparseTensorLoadConverter,
- SparseExpandConverter, SparseToPointersConverter,
- SparseToIndicesConverter, SparseToValuesConverter>(
- typeConverter, patterns.getContext());
+ SparseExpandConverter, SparseCompressConverter,
+ SparseToPointersConverter, SparseToIndicesConverter,
+ SparseToValuesConverter>(typeConverter, patterns.getContext());
}
diff --git a/mlir/test/Dialect/SparseTensor/codegen.mlir b/mlir/test/Dialect/SparseTensor/codegen.mlir
index 6a8a3ca7a56a5..7e241de598069 100644
--- a/mlir/test/Dialect/SparseTensor/codegen.mlir
+++ b/mlir/test/Dialect/SparseTensor/codegen.mlir
@@ -347,3 +347,40 @@ func.func @sparse_expansion3(%arg0: index, %arg1: index) -> memref<?xindex> {
: tensor<?x?xf64, #CSC> to memref<?xf64>, memref<?xi1>, memref<?xindex>, index
return %added : memref<?xindex>
}
+
+// CHECK-LABEL: func @sparse_compression(
+// CHECK-SAME: %[[A0:.*0]]: memref<2xindex>,
+// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>,
+// CHECK-SAME: %[[A2:.*2]]: memref<?xi32>,
+// CHECK-SAME: %[[A3:.*3]]: memref<?xi64>,
+// CHECK-SAME: %[[A4:.*4]]: memref<?xf64>,
+// CHECK-SAME: %[[A5:.*5]]: memref<?xindex>,
+// CHECK-SAME: %[[A6:.*6]]: memref<?xf64>,
+// CHECK-SAME: %[[A7:.*7]]: memref<?xi1>,
+// CHECK-SAME: %[[A8:.*8]]: memref<?xindex>,
+// CHECK-SAME: %[[A9:.*9]]: index)
+// CHECK-DAG: %[[B0:.*]] = arith.constant false
+// CHECK-DAG: %[[F0:.*]] = arith.constant 0.000000e+00 : f64
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// TODO: sort
+// CHECK-NEXT: scf.for %[[I:.*]] = %[[C0]] to %[[A9]] step %[[C1]] {
+// CHECK-NEXT: %[[INDEX:.*]] = memref.load %[[A8]][%[[I]]] : memref<?xindex>
+// TODO: insert
+// CHECK-DAG: memref.store %[[F0]], %[[A6]][%[[INDEX]]] : memref<?xf64>
+// CHECK-DAG: memref.store %[[B0]], %[[A7]][%[[INDEX]]] : memref<?xi1>
+// CHECK-NEXT: }
+// CHECK-DAG: memref.dealloc %[[A6]] : memref<?xf64>
+// CHECK-DAG: memref.dealloc %[[A7]] : memref<?xi1>
+// CHECK-DAG: memref.dealloc %[[A8]] : memref<?xindex>
+// CHECK: return
+func.func @sparse_compression(%arg0: tensor<8x8xf64, #CSR>,
+ %arg1: memref<?xindex>,
+ %arg2: memref<?xf64>,
+ %arg3: memref<?xi1>,
+ %arg4: memref<?xindex>,
+ %arg5: index) {
+ sparse_tensor.compress %arg0, %arg1, %arg2, %arg3, %arg4, %arg5
+ : tensor<8x8xf64, #CSR>, memref<?xindex>, memref<?xf64>, memref<?xi1>, memref<?xindex>, index
+ return
+}
More information about the Mlir-commits
mailing list