[Mlir-commits] [mlir] a41672e - [mlir][sparse] implement lowering rules for sparse_tensor.pack operation
Peiming Liu
llvmlistbot at llvm.org
Fri Feb 3 15:51:42 PST 2023
Author: Peiming Liu
Date: 2023-02-03T23:51:36Z
New Revision: a41672e16a25a8180a9b9edcb3f505ab2d5cace8
URL: https://github.com/llvm/llvm-project/commit/a41672e16a25a8180a9b9edcb3f505ab2d5cace8
DIFF: https://github.com/llvm/llvm-project/commit/a41672e16a25a8180a9b9edcb3f505ab2d5cace8.diff
LOG: [mlir][sparse] implement lowering rules for sparse_tensor.pack operation
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D143230
Added:
mlir/test/Dialect/SparseTensor/sparse_pack.mlir
mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir
Modified:
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index 1fba7d8d74f73..9f3388d61f3f2 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -1021,6 +1021,98 @@ class SparseNumberOfEntriesConverter
}
};
+struct SparsePackOpConverter : public OpConversionPattern<PackOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(PackOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+
+ auto rtp = op.getResult().getType().cast<RankedTensorType>();
+ assert(isUniqueCOOType(rtp));
+
+ SmallVector<Value> fields;
+ Location loc = op.getLoc();
+
+ foreachFieldAndTypeInSparseTensor(
+ rtp,
+ [&rewriter, &fields, &op, rtp,
+ loc](Type fType, unsigned fIdx, SparseTensorFieldKind fKind,
+ unsigned /*dim*/, DimLevelType /*dlt*/) -> bool {
+ assert(fields.size() == fIdx);
+ auto enc = getSparseTensorEncoding(rtp);
+ Value field;
+ switch (fKind) {
+ case SparseTensorFieldKind::StorageSpec:
+ field = SparseTensorSpecifier::getInitValue(rewriter, loc, rtp);
+ break;
+ case SparseTensorFieldKind::PtrMemRef: {
+ // TACO-style COO starts with a PtrBuffer
+ // By creating a constant value for it, we avoid the complexity of
+ // memory management.
+ auto tensorType = RankedTensorType::get({2}, enc.getPointerType());
+ auto memrefType = MemRefType::get(tensorType.getShape(),
+ tensorType.getElementType());
+ auto cstPtr = rewriter.create<arith::ConstantOp>(
+ loc, tensorType,
+ DenseElementsAttr::get(
+ tensorType,
+ {APInt(64, 0),
+ APInt(64, op.getData().getType().getShape()[0])}));
+ field = rewriter.create<bufferization::ToMemrefOp>(loc, memrefType,
+ cstPtr);
+ break;
+ }
+ case SparseTensorFieldKind::IdxMemRef: {
+ auto tensorType = op.getIndices().getType();
+ auto memrefType = MemRefType::get(tensorType.getShape(),
+ tensorType.getElementType());
+ auto idxMemRef = rewriter.create<bufferization::ToMemrefOp>(
+ op->getLoc(), memrefType, op.getIndices());
+ ReassociationIndices reassociation;
+ for (int i = 0, e = tensorType.getRank(); i < e; i++)
+ reassociation.push_back(i);
+
+ // Flattened the indices buffer to rank 1.
+ field = rewriter.create<memref::CollapseShapeOp>(
+ loc, idxMemRef, ArrayRef<ReassociationIndices>(reassociation));
+ break;
+ }
+ case SparseTensorFieldKind::ValMemRef: {
+ auto tensorType = op.getData().getType();
+ auto memrefType = MemRefType::get(tensorType.getShape(),
+ tensorType.getElementType());
+ field = rewriter.create<bufferization::ToMemrefOp>(
+ op->getLoc(), memrefType, op.getData());
+ break;
+ }
+ }
+
+ assert(field);
+ if (fType != field.getType())
+ field = rewriter.create<memref::CastOp>(loc, fType, field);
+ fields.push_back(field);
+ // Returns true to continue the iteration.
+ return true;
+ });
+
+ MutSparseTensorDescriptor desc(rtp, fields);
+ auto noe = linalg::createOrFoldDimOp(rewriter, loc, op.getData(), 0);
+ for (unsigned i = 0, e = rtp.getRank(); i < e; i++) {
+ int dim = rtp.getShape()[i];
+ assert(!ShapedType::isDynamic(dim));
+ desc.setDimSize(rewriter, loc, i, constantIndex(rewriter, loc, dim));
+ if (i == 0)
+ desc.setPtrMemSize(rewriter, loc, i, constantIndex(rewriter, loc, 2));
+
+ desc.setIdxMemSize(rewriter, loc, i, noe);
+ }
+ desc.setValMemSize(rewriter, loc, noe);
+
+ rewriter.replaceOp(op, genTuple(rewriter, loc, desc));
+ return success();
+ }
+};
+
} // namespace
//===----------------------------------------------------------------------===//
@@ -1032,14 +1124,15 @@ class SparseNumberOfEntriesConverter
void mlir::populateSparseTensorCodegenPatterns(
TypeConverter &typeConverter, RewritePatternSet &patterns,
bool enableBufferInitialization) {
- patterns.add<SparseReturnConverter, SparseCallConverter, SparseDimOpConverter,
- SparseCastConverter, SparseTensorDeallocConverter,
- SparseTensorLoadConverter, SparseExpandConverter,
- SparseCompressConverter, SparseInsertConverter,
- SparseToPointersConverter, SparseToIndicesConverter,
- SparseToIndicesBufferConverter, SparseToValuesConverter,
- SparseConvertConverter, SparseNumberOfEntriesConverter>(
- typeConverter, patterns.getContext());
+ patterns.add<SparsePackOpConverter, SparseReturnConverter,
+ SparseCallConverter, SparseDimOpConverter, SparseCastConverter,
+ SparseTensorDeallocConverter, SparseTensorLoadConverter,
+ SparseExpandConverter, SparseCompressConverter,
+ SparseInsertConverter, SparseToPointersConverter,
+ SparseToIndicesConverter, SparseToIndicesBufferConverter,
+ SparseToValuesConverter, SparseConvertConverter,
+ SparseNumberOfEntriesConverter>(typeConverter,
+ patterns.getContext());
patterns.add<SparseTensorAllocConverter>(typeConverter, patterns.getContext(),
enableBufferInitialization);
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h
index 996485642c961..193c227171368 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h
@@ -396,6 +396,19 @@ class MutSparseTensorDescriptor
fields.back() = md;
}
+ void setValMemSize(OpBuilder &builder, Location loc, Value v) {
+ setSpecifierField(builder, loc, StorageSpecifierKind::ValMemSize,
+ std::nullopt, v);
+ }
+
+ void setIdxMemSize(OpBuilder &builder, Location loc, unsigned dim, Value v) {
+ setSpecifierField(builder, loc, StorageSpecifierKind::IdxMemSize, dim, v);
+ }
+
+ void setPtrMemSize(OpBuilder &builder, Location loc, unsigned dim, Value v) {
+ setSpecifierField(builder, loc, StorageSpecifierKind::PtrMemSize, dim, v);
+ }
+
void setDimSize(OpBuilder &builder, Location loc, unsigned dim, Value v) {
setSpecifierField(builder, loc, StorageSpecifierKind::DimSize, dim, v);
}
diff --git a/mlir/test/Dialect/SparseTensor/sparse_pack.mlir b/mlir/test/Dialect/SparseTensor/sparse_pack.mlir
new file mode 100644
index 0000000000000..6dbf258c7e7bd
--- /dev/null
+++ b/mlir/test/Dialect/SparseTensor/sparse_pack.mlir
@@ -0,0 +1,39 @@
+// RUN: mlir-opt %s --canonicalize --post-sparsification-rewrite="enable-runtime-library=false" --sparse-tensor-codegen -cse | FileCheck %s
+
+#COO = #sparse_tensor.encoding<{
+ dimLevelType = ["compressed-nu", "singleton"],
+ indexBitWidth=32
+}>
+
+// CHECK-LABEL: func.func @sparse_pack(
+// CHECK-SAME: %[[VAL_0:.*]]: tensor<6xf64>,
+// CHECK-SAME: %[[VAL_1:.*]]: tensor<6x2xi32>) -> (memref<?xindex>, memref<?xi32>, memref<?xf64>,
+// CHECK: %[[VAL_2:.*]] = arith.constant dense<[0, 6]> : tensor<2xindex>
+// CHECK: %[[VAL_3:.*]] = bufferization.to_memref %[[VAL_2]] : memref<2xindex>
+// CHECK: %[[VAL_4:.*]] = memref.cast %[[VAL_3]] : memref<2xindex> to memref<?xindex>
+// CHECK: %[[VAL_5:.*]] = bufferization.to_memref %[[VAL_1]] : memref<6x2xi32>
+// CHECK: %[[VAL_6:.*]] = memref.collapse_shape %[[VAL_5]] {{\[\[}}0, 1]] : memref<6x2xi32> into memref<12xi32>
+// CHECK: %[[VAL_7:.*]] = memref.cast %[[VAL_6]] : memref<12xi32> to memref<?xi32>
+// CHECK: %[[VAL_8:.*]] = bufferization.to_memref %[[VAL_0]] : memref<6xf64>
+// CHECK: %[[VAL_9:.*]] = memref.cast %[[VAL_8]] : memref<6xf64> to memref<?xf64>
+// CHECK: %[[VAL_10:.*]] = sparse_tensor.storage_specifier.init :
+// CHECK: %[[VAL_11:.*]] = arith.constant 6 : index
+// CHECK: %[[VAL_12:.*]] = arith.constant 100 : index
+// CHECK: %[[VAL_13:.*]] = arith.index_cast %[[VAL_12]] : index to i32
+// CHECK: %[[VAL_14:.*]] = sparse_tensor.storage_specifier.set %[[VAL_10]] dim_sz at 0 with %[[VAL_13]] : i32,
+// CHECK: %[[VAL_15:.*]] = arith.constant 2 : index
+// CHECK: %[[VAL_16:.*]] = arith.index_cast %[[VAL_15]] : index to i32
+// CHECK: %[[VAL_17:.*]] = sparse_tensor.storage_specifier.set %[[VAL_14]] ptr_mem_sz at 0 with %[[VAL_16]] : i32,
+// CHECK: %[[VAL_18:.*]] = arith.index_cast %[[VAL_11]] : index to i32
+// CHECK: %[[VAL_19:.*]] = sparse_tensor.storage_specifier.set %[[VAL_17]] idx_mem_sz at 0 with %[[VAL_18]] : i32,
+// CHECK: %[[VAL_20:.*]] = sparse_tensor.storage_specifier.set %[[VAL_19]] dim_sz at 1 with %[[VAL_13]] : i32,
+// CHECK: %[[VAL_21:.*]] = sparse_tensor.storage_specifier.set %[[VAL_20]] idx_mem_sz at 1 with %[[VAL_18]] : i32,
+// CHECK: %[[VAL_22:.*]] = sparse_tensor.storage_specifier.set %[[VAL_21]] val_mem_sz with %[[VAL_18]] : i32,
+// CHECK: return %[[VAL_4]], %[[VAL_7]], %[[VAL_9]], %[[VAL_22]] : memref<?xindex>, memref<?xi32>, memref<?xf64>,
+// CHECK: }
+func.func @sparse_pack(%data: tensor<6xf64>, %index: tensor<6x2xi32>)
+ -> tensor<100x100xf64, #COO> {
+ %0 = sparse_tensor.pack %data, %index : tensor<6xf64>, tensor<6x2xi32>
+ to tensor<100x100xf64, #COO>
+ return %0 : tensor<100x100xf64, #COO>
+}
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir
new file mode 100644
index 0000000000000..a7e0f210dfd11
--- /dev/null
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir
@@ -0,0 +1,56 @@
+// DEFINE: %{option} = enable-runtime-library=false
+// DEFINE: %{command} = mlir-opt %s --sparse-compiler=%{option} | \
+// DEFINE: mlir-cpu-runner \
+// DEFINE: -e entry -entry-point-result=void \
+// DEFINE: -shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext | \
+// DEFINE: FileCheck %s
+//
+// RUN: %{command}
+//
+
+// TODO: Pack only support CodeGen Path
+
+#SortedCOO = #sparse_tensor.encoding<{
+ dimLevelType = [ "compressed-nu", "singleton" ]
+}>
+
+module {
+ //
+ // Main driver.
+ //
+ func.func @entry() {
+ //
+ // Initialize a 3-dim dense tensor.
+ //
+ %data = arith.constant dense<
+ [ 1.0, 2.0, 3.0]
+ > : tensor<3xf64>
+
+ %index = arith.constant dense<
+ [[ 1, 2],
+ [ 5, 6],
+ [ 7, 8]]
+ > : tensor<3x2xindex>
+
+ %s4 = sparse_tensor.pack %data, %index : tensor<3xf64>, tensor<3x2xindex>
+ to tensor<10x10xf64, #SortedCOO>
+ // CHECK:1
+ // CHECK-NEXT:2
+ // CHECK-NEXT:1
+ //
+ // CHECK-NEXT:5
+ // CHECK-NEXT:6
+ // CHECK-NEXT:2
+ //
+ // CHECK-NEXT:7
+ // CHECK-NEXT:8
+ // CHECK-NEXT:3
+ sparse_tensor.foreach in %s4 : tensor<10x10xf64, #SortedCOO> do {
+ ^bb0(%1: index, %2: index, %v: f64) :
+ vector.print %1: index
+ vector.print %2: index
+ vector.print %v: f64
+ }
+ return
+ }
+}
More information about the Mlir-commits
mailing list