[Mlir-commits] [mlir] 4d06861 - [mlir][sparse] add "sort" to the compress op codegen
Aart Bik
llvmlistbot at llvm.org
Wed Sep 28 10:41:58 PDT 2022
Author: Aart Bik
Date: 2022-09-28T10:41:40-07:00
New Revision: 4d06861950978b223f6ebdacee071e8203d0911b
URL: https://github.com/llvm/llvm-project/commit/4d06861950978b223f6ebdacee071e8203d0911b
DIFF: https://github.com/llvm/llvm-project/commit/4d06861950978b223f6ebdacee071e8203d0911b.diff
LOG: [mlir][sparse] add "sort" to the compress op codegen
This revision also adds convenience methods to test the
dim level type/property (with the codegen being first client)
Reviewed By: bixia
Differential Revision: https://reviews.llvm.org/D134776
Added:
Modified:
mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
mlir/test/Dialect/SparseTensor/codegen.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
index ec40a6c628883..575e5c5bcc8f2 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
@@ -27,9 +27,43 @@
namespace mlir {
namespace sparse_tensor {
+
/// Convenience method to get a sparse encoding attribute from a type.
/// Returns null-attribute for any type without an encoding.
SparseTensorEncodingAttr getSparseTensorEncoding(Type type);
+
+//
+// Dimension level types.
+//
+
+bool isDenseDim(SparseTensorEncodingAttr::DimLevelType dltp);
+bool isCompressedDim(SparseTensorEncodingAttr::DimLevelType dltp);
+bool isSingletonDim(SparseTensorEncodingAttr::DimLevelType dltp);
+
+/// Convenience method to test for dense dimension (0 <= d < rank).
+bool isDenseDim(RankedTensorType type, uint64_t d);
+
+/// Convenience method to test for compressed dimension (0 <= d < rank).
+bool isCompressedDim(RankedTensorType type, uint64_t d);
+
+/// Convenience method to test for singleton dimension (0 <= d < rank).
+bool isSingletonDim(RankedTensorType type, uint64_t d);
+
+//
+// Dimension level properties.
+//
+
+bool isOrderedDim(SparseTensorEncodingAttr::DimLevelType dltp);
+bool isUniqueDim(SparseTensorEncodingAttr::DimLevelType dltp);
+
+/// Convenience method to test for ordered property in the
+/// given dimension (0 <= d < rank).
+bool isOrderedDim(RankedTensorType type, uint64_t d);
+
+/// Convenience method to test for unique property in the
+/// given dimension (0 <= d < rank).
+bool isUniqueDim(RankedTensorType type, uint64_t d);
+
} // namespace sparse_tensor
} // namespace mlir
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 9b625483d4daa..3997f4f2cf9a9 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -216,6 +216,10 @@ LogicalResult SparseTensorEncodingAttr::verifyEncoding(
return success();
}
+//===----------------------------------------------------------------------===//
+// Convenience Methods.
+//===----------------------------------------------------------------------===//
+
SparseTensorEncodingAttr
mlir::sparse_tensor::getSparseTensorEncoding(Type type) {
if (auto ttp = type.dyn_cast<RankedTensorType>())
@@ -223,6 +227,98 @@ mlir::sparse_tensor::getSparseTensorEncoding(Type type) {
return nullptr;
}
+bool mlir::sparse_tensor::isDenseDim(
+ SparseTensorEncodingAttr::DimLevelType dltp) {
+ return dltp == SparseTensorEncodingAttr::DimLevelType::Dense;
+}
+
+bool mlir::sparse_tensor::isCompressedDim(
+ SparseTensorEncodingAttr::DimLevelType dltp) {
+ switch (dltp) {
+ case SparseTensorEncodingAttr::DimLevelType::Compressed:
+ case SparseTensorEncodingAttr::DimLevelType::CompressedNu:
+ case SparseTensorEncodingAttr::DimLevelType::CompressedNo:
+ case SparseTensorEncodingAttr::DimLevelType::CompressedNuNo:
+ return true;
+ default:
+ return false;
+ }
+}
+
+bool mlir::sparse_tensor::isSingletonDim(
+ SparseTensorEncodingAttr::DimLevelType dltp) {
+ switch (dltp) {
+ case SparseTensorEncodingAttr::DimLevelType::Singleton:
+ case SparseTensorEncodingAttr::DimLevelType::SingletonNu:
+ case SparseTensorEncodingAttr::DimLevelType::SingletonNo:
+ case SparseTensorEncodingAttr::DimLevelType::SingletonNuNo:
+ return true;
+ default:
+ return false;
+ }
+}
+
+bool mlir::sparse_tensor::isDenseDim(RankedTensorType type, uint64_t d) {
+ assert(d < static_cast<uint64_t>(type.getRank()));
+ if (auto enc = getSparseTensorEncoding(type))
+ return isDenseDim(enc.getDimLevelType()[d]);
+ return true; // unannotated tensor is dense
+}
+
+bool mlir::sparse_tensor::isCompressedDim(RankedTensorType type, uint64_t d) {
+ assert(d < static_cast<uint64_t>(type.getRank()));
+ if (auto enc = getSparseTensorEncoding(type))
+ return isCompressedDim(enc.getDimLevelType()[d]);
+ return false; // unannotated tensor is dense
+}
+
+bool mlir::sparse_tensor::isSingletonDim(RankedTensorType type, uint64_t d) {
+ assert(d < static_cast<uint64_t>(type.getRank()));
+ if (auto enc = getSparseTensorEncoding(type))
+ return isSingletonDim(enc.getDimLevelType()[d]);
+ return false; // unannotated tensor is dense
+}
+
+bool mlir::sparse_tensor::isOrderedDim(
+ SparseTensorEncodingAttr::DimLevelType dltp) {
+ switch (dltp) {
+ case SparseTensorEncodingAttr::DimLevelType::CompressedNo:
+ case SparseTensorEncodingAttr::DimLevelType::CompressedNuNo:
+ case SparseTensorEncodingAttr::DimLevelType::SingletonNo:
+ case SparseTensorEncodingAttr::DimLevelType::SingletonNuNo:
+ return false;
+ default:
+ return true;
+ }
+}
+
+bool mlir::sparse_tensor::isUniqueDim(
+ SparseTensorEncodingAttr::DimLevelType dltp) {
+ switch (dltp) {
+ case SparseTensorEncodingAttr::DimLevelType::CompressedNu:
+ case SparseTensorEncodingAttr::DimLevelType::CompressedNuNo:
+ case SparseTensorEncodingAttr::DimLevelType::SingletonNu:
+ case SparseTensorEncodingAttr::DimLevelType::SingletonNuNo:
+ return false;
+ default:
+ return true;
+ }
+}
+
+bool mlir::sparse_tensor::isOrderedDim(RankedTensorType type, uint64_t d) {
+ assert(d < static_cast<uint64_t>(type.getRank()));
+ if (auto enc = getSparseTensorEncoding(type))
+ return isOrderedDim(enc.getDimLevelType()[d]);
+ return true; // unannotated tensor is dense (and thus ordered)
+}
+
+bool mlir::sparse_tensor::isUniqueDim(RankedTensorType type, uint64_t d) {
+ assert(d < static_cast<uint64_t>(type.getRank()));
+ if (auto enc = getSparseTensorEncoding(type))
+ return isUniqueDim(enc.getDimLevelType()[d]);
+ return true; // unannotated tensor is dense (and thus unique)
+}
+
//===----------------------------------------------------------------------===//
// TensorDialect Operations.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index 7dcc81c3c3ee4..0dc3e9ed3f3dd 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -103,37 +103,28 @@ static Optional<Value> sizeFromTensorAtDim(OpBuilder &rewriter, Location loc,
/// Returns field index of sparse tensor type for pointers/indices, when set.
static unsigned getFieldIndex(Type type, unsigned ptrDim, unsigned idxDim) {
- auto enc = getSparseTensorEncoding(type);
- assert(enc);
+ assert(getSparseTensorEncoding(type));
RankedTensorType rType = type.cast<RankedTensorType>();
unsigned field = 2; // start past sizes
unsigned ptr = 0;
unsigned idx = 0;
for (unsigned r = 0, rank = rType.getShape().size(); r < rank; r++) {
- switch (enc.getDimLevelType()[r]) {
- case SparseTensorEncodingAttr::DimLevelType::Dense:
- break; // no fields
- case SparseTensorEncodingAttr::DimLevelType::Compressed:
- case SparseTensorEncodingAttr::DimLevelType::CompressedNu:
- case SparseTensorEncodingAttr::DimLevelType::CompressedNo:
- case SparseTensorEncodingAttr::DimLevelType::CompressedNuNo:
+ if (isCompressedDim(rType, r)) {
if (ptr++ == ptrDim)
return field;
field++;
if (idx++ == idxDim)
return field;
field++;
- break;
- case SparseTensorEncodingAttr::DimLevelType::Singleton:
- case SparseTensorEncodingAttr::DimLevelType::SingletonNu:
- case SparseTensorEncodingAttr::DimLevelType::SingletonNo:
- case SparseTensorEncodingAttr::DimLevelType::SingletonNuNo:
+ } else if (isSingletonDim(rType, r)) {
if (idx++ == idxDim)
return field;
field++;
- break;
+ } else {
+ assert(isDenseDim(rType, r)); // no fields
}
}
+ assert(ptrDim == -1u && idxDim == -1u);
return field + 1; // return values field index
}
@@ -176,7 +167,7 @@ convertSparseTensorType(Type type, SmallVectorImpl<Type> &fields) {
// The dimSizes array.
fields.push_back(MemRefType::get({rank}, indexType));
// The memSizes array.
- unsigned lastField = getFieldIndex(type, -1, -1);
+ unsigned lastField = getFieldIndex(type, -1u, -1u);
fields.push_back(MemRefType::get({lastField - 2}, indexType));
// Per-dimension storage.
for (unsigned r = 0; r < rank; r++) {
@@ -184,22 +175,13 @@ convertSparseTensorType(Type type, SmallVectorImpl<Type> &fields) {
// As a result, the compound type can be constructed directly in the given
// order. Clients of this type know what field is what from the sparse
// tensor type.
- switch (enc.getDimLevelType()[r]) {
- case SparseTensorEncodingAttr::DimLevelType::Dense:
- break; // no fields
- case SparseTensorEncodingAttr::DimLevelType::Compressed:
- case SparseTensorEncodingAttr::DimLevelType::CompressedNu:
- case SparseTensorEncodingAttr::DimLevelType::CompressedNo:
- case SparseTensorEncodingAttr::DimLevelType::CompressedNuNo:
+ if (isCompressedDim(rType, r)) {
fields.push_back(MemRefType::get({ShapedType::kDynamicSize}, ptrType));
fields.push_back(MemRefType::get({ShapedType::kDynamicSize}, idxType));
- break;
- case SparseTensorEncodingAttr::DimLevelType::Singleton:
- case SparseTensorEncodingAttr::DimLevelType::SingletonNu:
- case SparseTensorEncodingAttr::DimLevelType::SingletonNo:
- case SparseTensorEncodingAttr::DimLevelType::SingletonNuNo:
+ } else if (isSingletonDim(rType, r)) {
fields.push_back(MemRefType::get({ShapedType::kDynamicSize}, idxType));
- break;
+ } else {
+ assert(isDenseDim(rType, r)); // no fields
}
}
// The values array.
@@ -254,7 +236,7 @@ static void createAllocFields(OpBuilder &builder, Location loc, Type type,
builder.create<memref::AllocOp>(loc, MemRefType::get({rank}, indexType));
fields.push_back(dimSizes);
// The sizes array.
- unsigned lastField = getFieldIndex(type, -1, -1);
+ unsigned lastField = getFieldIndex(type, -1u, -1u);
Value memSizes = builder.create<memref::AllocOp>(
loc, MemRefType::get({lastField - 2}, indexType));
fields.push_back(memSizes);
@@ -265,25 +247,16 @@ static void createAllocFields(OpBuilder &builder, Location loc, Type type,
builder.create<memref::StoreOp>(loc, sizes[ro], dimSizes,
constantIndex(builder, loc, r));
linear = builder.create<arith::MulIOp>(loc, linear, sizes[ro]);
- // Allocate fiels.
- switch (enc.getDimLevelType()[r]) {
- case SparseTensorEncodingAttr::DimLevelType::Dense:
- break; // no fields
- case SparseTensorEncodingAttr::DimLevelType::Compressed:
- case SparseTensorEncodingAttr::DimLevelType::CompressedNu:
- case SparseTensorEncodingAttr::DimLevelType::CompressedNo:
- case SparseTensorEncodingAttr::DimLevelType::CompressedNuNo:
+ // Allocate fields.
+ if (isCompressedDim(rType, r)) {
fields.push_back(createAllocation(builder, loc, ptrType, heuristic));
fields.push_back(createAllocation(builder, loc, idxType, heuristic));
allDense = false;
- break;
- case SparseTensorEncodingAttr::DimLevelType::Singleton:
- case SparseTensorEncodingAttr::DimLevelType::SingletonNu:
- case SparseTensorEncodingAttr::DimLevelType::SingletonNo:
- case SparseTensorEncodingAttr::DimLevelType::SingletonNuNo:
+ } else if (isSingletonDim(rType, r)) {
fields.push_back(createAllocation(builder, loc, idxType, heuristic));
allDense = false;
- break;
+ } else {
+ assert(isDenseDim(rType, r)); // no fields
}
}
// The values array. For all-dense, the full length is required.
@@ -507,7 +480,8 @@ class SparseExpandConverter : public OpConversionPattern<ExpandOp> {
matchAndRewrite(ExpandOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
- ShapedType srcType = op.getTensor().getType().cast<ShapedType>();
+ RankedTensorType srcType =
+ op.getTensor().getType().cast<RankedTensorType>();
Type eltType = srcType.getElementType();
Type boolType = rewriter.getIntegerType(1);
Type idxType = rewriter.getIndexType();
@@ -561,17 +535,18 @@ class SparseCompressConverter : public OpConversionPattern<CompressOp> {
matchAndRewrite(CompressOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
- ShapedType srcType = op.getTensor().getType().cast<ShapedType>();
- Type eltType = srcType.getElementType();
+ RankedTensorType dstType =
+ op.getTensor().getType().cast<RankedTensorType>();
+ Type eltType = dstType.getElementType();
Value values = adaptor.getValues();
Value filled = adaptor.getFilled();
Value added = adaptor.getAdded();
Value count = adaptor.getCount();
-
- //
- // TODO: need to implement "std::sort(added, added + count);" for ordered
- //
-
+ // If the innermost dimension is ordered, we need to sort the indices
+ // in the "added" array prior to applying the compression.
+ unsigned rank = dstType.getShape().size();
+ if (isOrderedDim(dstType, rank - 1))
+ rewriter.create<SortOp>(loc, count, ValueRange{added}, ValueRange{});
// While performing the insertions, we also need to reset the elements
// of the values/filled-switch by only iterating over the set elements,
// to ensure that the runtime complexity remains proportional to the
@@ -699,7 +674,7 @@ class SparseToPointersConverter
static unsigned getIndexForOp(UnrealizedConversionCastOp /*tuple*/,
ToPointersOp op) {
uint64_t dim = op.getDimension().getZExtValue();
- return getFieldIndex(op.getTensor().getType(), /*ptrDim=*/dim, -1);
+ return getFieldIndex(op.getTensor().getType(), /*ptrDim=*/dim, -1u);
}
};
@@ -712,7 +687,7 @@ class SparseToIndicesConverter
static unsigned getIndexForOp(UnrealizedConversionCastOp /*tuple*/,
ToIndicesOp op) {
uint64_t dim = op.getDimension().getZExtValue();
- return getFieldIndex(op.getTensor().getType(), -1, /*idxDim=*/dim);
+ return getFieldIndex(op.getTensor().getType(), -1u, /*idxDim=*/dim);
}
};
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
index 6f48e134d4770..a008b38834fb9 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
@@ -156,8 +156,9 @@ struct SparseTensorCodegenPass
RewritePatternSet patterns(ctx);
SparseTensorTypeToBufferConverter converter;
ConversionTarget target(*ctx);
- // Everything in the sparse dialect must go!
+ // Most ops in the sparse dialect must go!
target.addIllegalDialect<SparseTensorDialect>();
+ target.addLegalOp<SortOp>();
// All dynamic rules below accept new function, call, return, and various
// tensor and bufferization operations as legal output of the rewriting
// provided that all sparse tensor types have been fully rewritten.
diff --git a/mlir/test/Dialect/SparseTensor/codegen.mlir b/mlir/test/Dialect/SparseTensor/codegen.mlir
index baef485b82384..1263d3efc3c95 100644
--- a/mlir/test/Dialect/SparseTensor/codegen.mlir
+++ b/mlir/test/Dialect/SparseTensor/codegen.mlir
@@ -24,6 +24,10 @@
pointerBitWidth = 32
}>
+#UCSR = #sparse_tensor.encoding<{
+ dimLevelType = [ "dense", "compressed-no" ]
+}>
+
#CSC = #sparse_tensor.encoding<{
dimLevelType = [ "dense", "compressed" ],
dimOrdering = affine_map<(i, j) -> (j, i)>
@@ -363,7 +367,7 @@ func.func @sparse_expansion3(%arg0: index, %arg1: index) -> memref<?xindex> {
// CHECK-DAG: %[[F0:.*]] = arith.constant 0.000000e+00 : f64
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
-// TODO: sort
+// CHECK: sparse_tensor.sort %[[A8]], %[[A7]] : memref<?xindex>
// CHECK-NEXT: scf.for %[[I:.*]] = %[[C0]] to %[[A8]] step %[[C1]] {
// CHECK-NEXT: %[[INDEX:.*]] = memref.load %[[A7]][%[[I]]] : memref<?xindex>
// TODO: insert
@@ -385,6 +389,43 @@ func.func @sparse_compression(%tensor: tensor<8x8xf64, #CSR>,
return
}
+// CHECK-LABEL: func @sparse_compression_unordered(
+// CHECK-SAME: %[[A0:.*0]]: memref<2xindex>,
+// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>,
+// CHECK-SAME: %[[A2:.*2]]: memref<?xindex>,
+// CHECK-SAME: %[[A3:.*3]]: memref<?xindex>,
+// CHECK-SAME: %[[A4:.*4]]: memref<?xf64>,
+// CHECK-SAME: %[[A5:.*5]]: memref<?xf64>,
+// CHECK-SAME: %[[A6:.*6]]: memref<?xi1>,
+// CHECK-SAME: %[[A7:.*7]]: memref<?xindex>,
+// CHECK-SAME: %[[A8:.*8]]: index,
+// CHECK-SAME: %[[A9:.*9]]: index)
+// CHECK-DAG: %[[B0:.*]] = arith.constant false
+// CHECK-DAG: %[[F0:.*]] = arith.constant 0.000000e+00 : f64
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-NOT: sparse_tensor.sort
+// CHECK-NEXT: scf.for %[[I:.*]] = %[[C0]] to %[[A8]] step %[[C1]] {
+// CHECK-NEXT: %[[INDEX:.*]] = memref.load %[[A7]][%[[I]]] : memref<?xindex>
+// TODO: insert
+// CHECK-DAG: memref.store %[[F0]], %[[A5]][%[[INDEX]]] : memref<?xf64>
+// CHECK-DAG: memref.store %[[B0]], %[[A6]][%[[INDEX]]] : memref<?xi1>
+// CHECK-NEXT: }
+// CHECK-DAG: memref.dealloc %[[A5]] : memref<?xf64>
+// CHECK-DAG: memref.dealloc %[[A6]] : memref<?xi1>
+// CHECK-DAG: memref.dealloc %[[A7]] : memref<?xindex>
+// CHECK: return
+func.func @sparse_compression_unordered(%tensor: tensor<8x8xf64, #UCSR>,
+ %values: memref<?xf64>,
+ %filled: memref<?xi1>,
+ %added: memref<?xindex>,
+ %count: index,
+ %i: index) {
+ sparse_tensor.compress %values, %filled, %added, %count into %tensor[%i]
+ : memref<?xf64>, memref<?xi1>, memref<?xindex>, tensor<8x8xf64, #UCSR>
+ return
+}
+
// CHECK-LABEL: func @sparse_push_back(
// CHECK-SAME: %[[A:.*]]: memref<?xindex>,
// CHECK-SAME: %[[B:.*]]: memref<?xf64>,
More information about the Mlir-commits
mailing list