[Mlir-commits] [mlir] 1be0949 - [mlir][sparse] improved tensor type lowering
Aart Bik
llvmlistbot at llvm.org
Thu Sep 1 09:24:33 PDT 2022
Author: Aart Bik
Date: 2022-09-01T09:24:20-07:00
New Revision: 1be09496bfd5fb764b1b2b3e62ca1c16e3180223
URL: https://github.com/llvm/llvm-project/commit/1be09496bfd5fb764b1b2b3e62ca1c16e3180223
DIFF: https://github.com/llvm/llvm-project/commit/1be09496bfd5fb764b1b2b3e62ca1c16e3180223.diff
LOG: [mlir][sparse] improved tensor type lowering
Also includes a first codegen example (although full support need tuple access)
Reviewed By: Peiming
Differential Revision: https://reviews.llvm.org/D133080
Added:
Modified:
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
mlir/test/Dialect/SparseTensor/codegen.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index b905b442f0975..d82ebea74d205 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -33,6 +33,16 @@ namespace {
// Helper methods.
//===----------------------------------------------------------------------===//
+/// Reorders stored dimension to logical dimension.
+static unsigned reorder(const SparseTensorEncodingAttr &enc, unsigned d) {
+ auto order = enc.getDimOrdering();
+ if (order) {
+ assert(order.isPermutation());
+ return order.getDimPosition(d);
+ }
+ return d;
+}
+
/// Maps a sparse tensor type to the appropriate compounded buffers.
static Optional<Type> convertSparseTensorType(Type type) {
auto enc = getSparseTensorEncoding(type);
@@ -47,12 +57,14 @@ static Optional<Type> convertSparseTensorType(Type type) {
Type idxType = idxWidth ? IntegerType::get(context, idxWidth) : indexType;
Type ptrType = ptrWidth ? IntegerType::get(context, ptrWidth) : indexType;
Type eltType = rType.getElementType();
+ ArrayRef<int64_t> shape = rType.getShape();
//
// Sparse tensor storage for rank-dimensional tensor is organized as a
// single compound type with the following fields:
//
// struct {
- // memref<rank x index> dimSize ; size in each dimension
+ // ; if dynamic shape:
+ // memref<rank x index> dimSize ; size in each dimension
// ; per-dimension d:
// ; if dense:
// <nothing>
@@ -61,23 +73,31 @@ static Optional<Type> convertSparseTensorType(Type type) {
// memref<? x ptr> pointers-d ; pointers for sparse dim d
// ; if singleton:
// memref<? x idx> indices-d ; indices for singleton dim d
- // memref<? x eltType> values ; values
+ // memref<? x eltType> values ; values
// };
//
- // TODO: fill in the ? when statically known
- //
- // TODO: emit dimSizes when not needed (e.g. all-dense)
- //
+ int64_t linear = 1;
+ bool allDense = true;
unsigned rank = rType.getShape().size();
SmallVector<Type, 8> fields;
- fields.push_back(MemRefType::get({rank}, indexType));
+ // The dimSizes array.
+ if (!rType.hasStaticShape())
+ fields.push_back(MemRefType::get({rank}, indexType));
+ // Per-dimension storage.
for (unsigned r = 0; r < rank; r++) {
+ // Get the original dimension (ro) for the current stored dimension (r).
+ unsigned ro = reorder(enc, r);
// Dimension level types apply in order to the reordered dimension.
// As a result, the compound type can be constructed directly in the given
// order. Clients of this type know what field is what from the sparse
// tensor type.
switch (enc.getDimLevelType()[r]) {
case SparseTensorEncodingAttr::DimLevelType::Dense:
+ // Linearize the size of consecutive dense dimensions.
+ if (ShapedType::isDynamic(shape[ro]) || ShapedType::isDynamic(linear))
+ linear = ShapedType::kDynamicSize;
+ else
+ linear *= shape[ro];
break;
case SparseTensorEncodingAttr::DimLevelType::Compressed:
case SparseTensorEncodingAttr::DimLevelType::CompressedNu:
@@ -85,16 +105,23 @@ static Optional<Type> convertSparseTensorType(Type type) {
case SparseTensorEncodingAttr::DimLevelType::CompressedNuNo:
fields.push_back(MemRefType::get({ShapedType::kDynamicSize}, idxType));
fields.push_back(MemRefType::get({ShapedType::kDynamicSize}, ptrType));
+ allDense = false;
+ linear = 1;
break;
case SparseTensorEncodingAttr::DimLevelType::Singleton:
case SparseTensorEncodingAttr::DimLevelType::SingletonNu:
case SparseTensorEncodingAttr::DimLevelType::SingletonNo:
case SparseTensorEncodingAttr::DimLevelType::SingletonNuNo:
fields.push_back(MemRefType::get({ShapedType::kDynamicSize}, idxType));
+ allDense = false;
+ linear = 1;
break;
}
}
- fields.push_back(MemRefType::get({ShapedType::kDynamicSize}, eltType));
+ // The values array.
+ int64_t nnz =
+ (rType.hasStaticShape() && allDense) ? linear : ShapedType::kDynamicSize;
+ fields.push_back(MemRefType::get({nnz}, eltType));
// Sparse tensor storage (temporarily) lives in a tuple. This allows a
// simple 1:1 type conversion during codegen. A subsequent pass uses
// a 1:N type conversion to expand the tuple into its fields.
@@ -102,10 +129,10 @@ static Optional<Type> convertSparseTensorType(Type type) {
}
//===----------------------------------------------------------------------===//
-// Conversion rules.
+// Codegen rules.
//===----------------------------------------------------------------------===//
-/// Sparse conversion rule for returns.
+/// Sparse codegen rule for returns.
class SparseReturnConverter : public OpConversionPattern<func::ReturnOp> {
public:
using OpConversionPattern::OpConversionPattern;
@@ -117,6 +144,36 @@ class SparseReturnConverter : public OpConversionPattern<func::ReturnOp> {
}
};
+/// Sparse codegen rule for dimension accesses.
+class SparseDimOpConverter : public OpConversionPattern<tensor::DimOp> {
+public:
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(tensor::DimOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = op->getLoc();
+ Type type = op.getSource().getType();
+ // Only rewrite annotated DimOp with constant index.
+ auto enc = getSparseTensorEncoding(type);
+ if (!enc)
+ return failure();
+ Optional<int64_t> index = op.getConstantIndex();
+ if (!index)
+ return failure();
+ // Access into static shape can query original type directly.
+ // Note that this is typically already done by DimOp's folding.
+ RankedTensorType rType = type.cast<RankedTensorType>();
+ if (rType.hasStaticShape()) {
+ rewriter.replaceOp(
+ op, constantIndex(rewriter, loc, rType.getShape()[*index]));
+ return success();
+ }
+ // Any other query can consult the dimSize array.
+ // TODO: this needs tuple access
+ return failure();
+ }
+};
+
} // namespace
//===----------------------------------------------------------------------===//
@@ -136,5 +193,6 @@ mlir::SparseTensorTypeToBufferConverter::SparseTensorTypeToBufferConverter() {
/// the sparsification of linear algebra operations.
void mlir::populateSparseTensorCodegenPatterns(TypeConverter &typeConverter,
RewritePatternSet &patterns) {
- patterns.add<SparseReturnConverter>(typeConverter, patterns.getContext());
+ patterns.add<SparseReturnConverter, SparseDimOpConverter>(
+ typeConverter, patterns.getContext());
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
index b380c50f98bf4..b30d0d2b927f0 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
@@ -158,8 +158,7 @@ struct SparseTensorCodegenPass
ConversionTarget target(*ctx);
// Everything in the sparse dialect must go!
target.addIllegalDialect<SparseTensorDialect>();
- // All dynamic rules below accept new function, call, return, and various
- // tensor and bufferization operations as legal output of the rewriting.
+ // All dynamic rules below accept new function, call, return.
target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
return converter.isSignatureLegal(op.getFunctionType());
});
@@ -169,6 +168,10 @@ struct SparseTensorCodegenPass
target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) {
return converter.isLegal(op.getOperandTypes());
});
+ // Legal dialects may occur in generated code.
+ target.addLegalDialect<arith::ArithmeticDialect,
+ bufferization::BufferizationDialect,
+ memref::MemRefDialect, scf::SCFDialect>();
// Populate with rules and apply rewriting rules.
populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
converter);
diff --git a/mlir/test/Dialect/SparseTensor/codegen.mlir b/mlir/test/Dialect/SparseTensor/codegen.mlir
index f9a979e38d42b..66626163ff97c 100644
--- a/mlir/test/Dialect/SparseTensor/codegen.mlir
+++ b/mlir/test/Dialect/SparseTensor/codegen.mlir
@@ -6,7 +6,7 @@
pointerBitWidth = 32
}>
-#Dense = #sparse_tensor.encoding<{
+#Dense2D = #sparse_tensor.encoding<{
dimLevelType = [ "dense", "dense" ],
indexBitWidth = 64,
pointerBitWidth = 32
@@ -30,6 +30,13 @@
pointerBitWidth = 32
}>
+#Dense3D = #sparse_tensor.encoding<{
+ dimLevelType = [ "dense", "dense", "dense" ],
+ indexBitWidth = 64,
+ pointerBitWidth = 32,
+ dimOrdering = affine_map<(i,j,k) -> (k, i,j)>
+}>
+
// CHECK-LABEL: func @sparse_nop(
// CHECK-SAME: %[[A:.*]]: tuple<memref<1xindex>, memref<?xi64>, memref<?xi32>, memref<?xf64>>) -> tuple<memref<1xindex>, memref<?xi64>, memref<?xi32>, memref<?xf64>>
// CHECK: return %[[A]] : tuple<memref<1xindex>, memref<?xi64>, memref<?xi32>, memref<?xf64>>
@@ -37,9 +44,9 @@ func.func @sparse_nop(%arg0: tensor<?xf64, #SparseVector>) -> tensor<?xf64, #Spa
return %arg0 : tensor<?xf64, #SparseVector>
}
-// CHECK-LABEL: func @sparse_dense(
+// CHECK-LABEL: func @sparse_dense_2d(
// CHECK-SAME: %[[A:.*]]: tuple<memref<2xindex>, memref<?xf64>>)
-func.func @sparse_dense(%arg0: tensor<?x?xf64, #Dense>) {
+func.func @sparse_dense_2d(%arg0: tensor<?x?xf64, #Dense2D>) {
return
}
@@ -60,3 +67,16 @@ func.func @sparse_csr(%arg0: tensor<?x?xf64, #CSR>) {
func.func @sparse_dcsr(%arg0: tensor<?x?xf64, #DCSR>) {
return
}
+
+//
+// Just a linearized array in the end. Dim op is statically known.
+//
+// CHECK-LABEL: func @sparse_dense_3d(
+// CHECK-SAME: %[[A:.*]]: tuple<memref<6000xf64>>) -> index
+// CHECK: %[[C:.*]] = arith.constant 20 : index
+// CHECK: return %[[C]] : index
+func.func @sparse_dense_3d(%arg0: tensor<10x20x30xf64, #Dense3D>) -> index {
+ %c = arith.constant 1 : index
+ %0 = tensor.dim %arg0, %c : tensor<10x20x30xf64, #Dense3D>
+ return %0 : index
+}
More information about the Mlir-commits
mailing list