[Mlir-commits] [mlir] 3ae98fd - [mlir][sparse] added codegen for dimop, pointers, indices, values
Aart Bik
llvmlistbot at llvm.org
Thu Sep 1 16:36:24 PDT 2022
Author: Aart Bik
Date: 2022-09-01T16:36:10-07:00
New Revision: 3ae98fd259e56fcc22f9387eca56c178f17eaf89
URL: https://github.com/llvm/llvm-project/commit/3ae98fd259e56fcc22f9387eca56c178f17eaf89
DIFF: https://github.com/llvm/llvm-project/commit/3ae98fd259e56fcc22f9387eca56c178f17eaf89.diff
LOG: [mlir][sparse] added codegen for dimop, pointers, indices, values
Demonstrates how sparse tensor type -> tuple -> getter
will eventually yield actual code on the memrefs directly
Reviewed By: Peiming
Differential Revision: https://reviews.llvm.org/D133143
Added:
Modified:
mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
mlir/test/Dialect/SparseTensor/codegen.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
index d765a10701acb..8f96280efe83a 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
@@ -142,6 +142,10 @@ def SparseTensorCodegen : Pass<"sparse-tensor-codegen", "ModuleOp"> {
}];
let constructor = "mlir::createSparseTensorCodegenPass()";
let dependentDialects = [
+ "arith::ArithmeticDialect",
+ "bufferization::BufferizationDialect",
+ "memref::MemRefDialect",
+ "scf::SCFDialect",
"sparse_tensor::SparseTensorDialect",
];
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index d82ebea74d205..ac710623ed96c 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -33,14 +33,24 @@ namespace {
// Helper methods.
//===----------------------------------------------------------------------===//
-/// Reorders stored dimension to logical dimension.
-static unsigned reorder(const SparseTensorEncodingAttr &enc, unsigned d) {
+/// Reorders stored dimension to original dimension.
+static unsigned toOrig(const SparseTensorEncodingAttr &enc, unsigned i) {
auto order = enc.getDimOrdering();
if (order) {
assert(order.isPermutation());
- return order.getDimPosition(d);
+ return order.getDimPosition(i);
}
- return d;
+ return i;
+}
+
+/// Reorders original dimension to stored dimension.
+static unsigned toStored(const SparseTensorEncodingAttr &enc, unsigned i) {
+ auto order = enc.getDimOrdering();
+ if (order) {
+ assert(order.isPermutation());
+ return order.getPermutedPosition(i);
+ }
+ return i;
}
/// Maps a sparse tensor type to the appropriate compounded buffers.
@@ -63,14 +73,13 @@ static Optional<Type> convertSparseTensorType(Type type) {
// single compound type with the following fields:
//
// struct {
- // ; if dynamic shape:
- // memref<rank x index> dimSize ; size in each dimension
+ // memref<rank x index> dimSizes ; size in each dimension
// ; per-dimension d:
// ; if dense:
// <nothing>
// ; if compresed:
- // memref<? x idx> indices-d ; indices for sparse dim d
// memref<? x ptr> pointers-d ; pointers for sparse dim d
+ // memref<? x idx> indices-d ; indices for sparse dim d
// ; if singleton:
// memref<? x idx> indices-d ; indices for singleton dim d
// memref<? x eltType> values ; values
@@ -81,12 +90,11 @@ static Optional<Type> convertSparseTensorType(Type type) {
unsigned rank = rType.getShape().size();
SmallVector<Type, 8> fields;
// The dimSizes array.
- if (!rType.hasStaticShape())
- fields.push_back(MemRefType::get({rank}, indexType));
+ fields.push_back(MemRefType::get({rank}, indexType));
// Per-dimension storage.
for (unsigned r = 0; r < rank; r++) {
// Get the original dimension (ro) for the current stored dimension (r).
- unsigned ro = reorder(enc, r);
+ unsigned ro = toOrig(enc, r);
// Dimension level types apply in order to the reordered dimension.
// As a result, the compound type can be constructed directly in the given
// order. Clients of this type know what field is what from the sparse
@@ -103,8 +111,8 @@ static Optional<Type> convertSparseTensorType(Type type) {
case SparseTensorEncodingAttr::DimLevelType::CompressedNu:
case SparseTensorEncodingAttr::DimLevelType::CompressedNo:
case SparseTensorEncodingAttr::DimLevelType::CompressedNuNo:
- fields.push_back(MemRefType::get({ShapedType::kDynamicSize}, idxType));
fields.push_back(MemRefType::get({ShapedType::kDynamicSize}, ptrType));
+ fields.push_back(MemRefType::get({ShapedType::kDynamicSize}, idxType));
allDense = false;
linear = 1;
break;
@@ -128,6 +136,63 @@ static Optional<Type> convertSparseTensorType(Type type) {
return TupleType::get(context, fields);
}
+// Returns field index for pointers (d), indices (d) for set field.
+static unsigned getFieldIndex(Type type, unsigned ptrDim, unsigned idxDim) {
+ auto enc = getSparseTensorEncoding(type);
+ assert(enc);
+ RankedTensorType rType = type.cast<RankedTensorType>();
+ unsigned field = 1; // start at DimSizes;
+ unsigned ptr = 0;
+ unsigned idx = 0;
+ for (unsigned r = 0, rank = rType.getShape().size(); r < rank; r++) {
+ switch (enc.getDimLevelType()[r]) {
+ case SparseTensorEncodingAttr::DimLevelType::Dense:
+ break; // no fields
+ case SparseTensorEncodingAttr::DimLevelType::Compressed:
+ case SparseTensorEncodingAttr::DimLevelType::CompressedNu:
+ case SparseTensorEncodingAttr::DimLevelType::CompressedNo:
+ case SparseTensorEncodingAttr::DimLevelType::CompressedNuNo:
+ if (ptr++ == ptrDim)
+ return field;
+ field++;
+ if (idx++ == idxDim)
+ return field;
+ field++;
+ break;
+ case SparseTensorEncodingAttr::DimLevelType::Singleton:
+ case SparseTensorEncodingAttr::DimLevelType::SingletonNu:
+ case SparseTensorEncodingAttr::DimLevelType::SingletonNo:
+ case SparseTensorEncodingAttr::DimLevelType::SingletonNuNo:
+ if (idx++ == idxDim)
+ return field;
+ field++;
+ break;
+ }
+ }
+ llvm_unreachable("failed to find ptr/idx field index");
+ return -1;
+}
+
+/// Returns field type in tuple at given index.
+static Type getFieldType(Value tuple, unsigned field) {
+ return tuple.getType().cast<TupleType>().getType(field);
+}
+
+/// Creates tuple get operation at given index.
+static Value createTupleGet(OpBuilder &builder, Location loc, Value tuple,
+ unsigned field) {
+ Type indexType = builder.getIndexType();
+ return builder.create<StorageGetOp>(loc, getFieldType(tuple, field), tuple,
+ builder.getIntegerAttr(indexType, field));
+}
+
+/// Returns integral constant, if defined.
+static Optional<int64_t> getConstantInt(Value val) {
+ if (auto constantOp = val.getDefiningOp<arith::ConstantOp>())
+ return constantOp.getValue().cast<IntegerAttr>().getInt();
+ return {};
+}
+
//===----------------------------------------------------------------------===//
// Codegen rules.
//===----------------------------------------------------------------------===//
@@ -151,26 +216,82 @@ class SparseDimOpConverter : public OpConversionPattern<tensor::DimOp> {
LogicalResult
matchAndRewrite(tensor::DimOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- Location loc = op->getLoc();
- Type type = op.getSource().getType();
// Only rewrite annotated DimOp with constant index.
- auto enc = getSparseTensorEncoding(type);
+ auto enc = getSparseTensorEncoding(op.getSource().getType());
if (!enc)
return failure();
- Optional<int64_t> index = op.getConstantIndex();
+ Optional<int64_t> index = getConstantInt(adaptor.getIndex());
if (!index)
return failure();
- // Access into static shape can query original type directly.
+ // Access into static dimension can query original type directly.
// Note that this is typically already done by DimOp's folding.
- RankedTensorType rType = type.cast<RankedTensorType>();
- if (rType.hasStaticShape()) {
- rewriter.replaceOp(
- op, constantIndex(rewriter, loc, rType.getShape()[*index]));
+ Location loc = op->getLoc();
+ auto shape = op.getSource().getType().cast<RankedTensorType>().getShape();
+ if (!ShapedType::isDynamic(shape[*index])) {
+ rewriter.replaceOp(op, constantIndex(rewriter, loc, shape[*index]));
return success();
}
- // Any other query can consult the dimSize array.
- // TODO: this needs tuple access
- return failure();
+ // Any other query can consult the dimSizes array at field 0 using,
+ // accounting for the reordering applied to the sparse storage.
+ Value tuple = adaptor.getSource();
+ Value dimSizes = createTupleGet(rewriter, loc, tuple, 0);
+ rewriter.replaceOpWithNewOp<memref::LoadOp>(
+ op, dimSizes, constantIndex(rewriter, loc, toStored(enc, *index)));
+ return success();
+ }
+};
+
+/// Sparse conversion rule for pointer accesses.
+class SparseToPointersConverter : public OpConversionPattern<ToPointersOp> {
+public:
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(ToPointersOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Optional<int64_t> index = getConstantInt(adaptor.getOperands()[1]);
+ if (!index)
+ return failure();
+ // Replace the requested pointer access with corresponding field.
+ Location loc = op->getLoc();
+ Value tuple = adaptor.getTensor();
+ unsigned i = getFieldIndex(op.getTensor().getType(), /*ptrDim=*/*index, -1);
+ rewriter.replaceOp(op, createTupleGet(rewriter, loc, tuple, i));
+ return success();
+ }
+};
+
+/// Sparse conversion rule for index accesses.
+class SparseToIndicesConverter : public OpConversionPattern<ToIndicesOp> {
+public:
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(ToIndicesOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Optional<int64_t> index = getConstantInt(adaptor.getOperands()[1]);
+ if (!index)
+ return failure();
+ // Replace the requested indices access with corresponding field.
+ Location loc = op->getLoc();
+ Value tuple = adaptor.getTensor();
+ unsigned i = getFieldIndex(op.getTensor().getType(), -1, /*idxDim=*/*index);
+ rewriter.replaceOp(op, createTupleGet(rewriter, loc, tuple, i));
+ return success();
+ }
+};
+
+/// Sparse conversion rule for value accesses.
+class SparseToValuesConverter : public OpConversionPattern<ToValuesOp> {
+public:
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(ToValuesOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ // Replace the requested values access with corresponding field.
+ Location loc = op->getLoc();
+ Value tuple = adaptor.getTensor();
+ unsigned i = tuple.getType().cast<TupleType>().size() - 1; // last
+ rewriter.replaceOp(op, createTupleGet(rewriter, loc, tuple, i));
+ return success();
}
};
@@ -193,6 +314,7 @@ mlir::SparseTensorTypeToBufferConverter::SparseTensorTypeToBufferConverter() {
/// the sparsification of linear algebra operations.
void mlir::populateSparseTensorCodegenPatterns(TypeConverter &typeConverter,
RewritePatternSet &patterns) {
- patterns.add<SparseReturnConverter, SparseDimOpConverter>(
- typeConverter, patterns.getContext());
+ patterns.add<SparseReturnConverter, SparseDimOpConverter,
+ SparseToPointersConverter, SparseToIndicesConverter,
+ SparseToValuesConverter>(typeConverter, patterns.getContext());
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
index d5e2b96089d5b..c1a6b7a6e45cc 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
@@ -157,8 +157,9 @@ struct SparseTensorCodegenPass
RewritePatternSet patterns(ctx);
SparseTensorTypeToBufferConverter converter;
ConversionTarget target(*ctx);
- // Everything in the sparse dialect must go!
+ // Almost everything in the sparse dialect must go!
target.addIllegalDialect<SparseTensorDialect>();
+ target.addLegalOp<StorageGetOp, StorageSetOp>();
// All dynamic rules below accept new function, call, return.
target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
return converter.isSignatureLegal(op.getFunctionType());
diff --git a/mlir/test/Dialect/SparseTensor/codegen.mlir b/mlir/test/Dialect/SparseTensor/codegen.mlir
index 66626163ff97c..6c93092a937e0 100644
--- a/mlir/test/Dialect/SparseTensor/codegen.mlir
+++ b/mlir/test/Dialect/SparseTensor/codegen.mlir
@@ -32,14 +32,12 @@
#Dense3D = #sparse_tensor.encoding<{
dimLevelType = [ "dense", "dense", "dense" ],
- indexBitWidth = 64,
- pointerBitWidth = 32,
- dimOrdering = affine_map<(i,j,k) -> (k, i,j)>
+ dimOrdering = affine_map<(i, j, k) -> (k, i, j)>
}>
// CHECK-LABEL: func @sparse_nop(
-// CHECK-SAME: %[[A:.*]]: tuple<memref<1xindex>, memref<?xi64>, memref<?xi32>, memref<?xf64>>) -> tuple<memref<1xindex>, memref<?xi64>, memref<?xi32>, memref<?xf64>>
-// CHECK: return %[[A]] : tuple<memref<1xindex>, memref<?xi64>, memref<?xi32>, memref<?xf64>>
+// CHECK-SAME: %[[A:.*]]: tuple<memref<1xindex>, memref<?xi32>, memref<?xi64>, memref<?xf64>>) -> tuple<memref<1xindex>, memref<?xi32>, memref<?xi64>, memref<?xf64>>
+// CHECK: return %[[A]] : tuple<memref<1xindex>, memref<?xi32>, memref<?xi64>, memref<?xf64>>
func.func @sparse_nop(%arg0: tensor<?xf64, #SparseVector>) -> tensor<?xf64, #SparseVector> {
return %arg0 : tensor<?xf64, #SparseVector>
}
@@ -51,28 +49,29 @@ func.func @sparse_dense_2d(%arg0: tensor<?x?xf64, #Dense2D>) {
}
// CHECK-LABEL: func @sparse_row(
-// CHECK-SAME: %[[A:.*]]: tuple<memref<2xindex>, memref<?xi64>, memref<?xi32>, memref<?xf64>>)
+// CHECK-SAME: %[[A:.*]]: tuple<memref<2xindex>, memref<?xi32>, memref<?xi64>, memref<?xf64>>)
func.func @sparse_row(%arg0: tensor<?x?xf64, #Row>) {
return
}
// CHECK-LABEL: func @sparse_csr(
-// CHECK-SAME: %[[A:.*]]: tuple<memref<2xindex>, memref<?xi64>, memref<?xi32>, memref<?xf64>>)
+// CHECK-SAME: %[[A:.*]]: tuple<memref<2xindex>, memref<?xi32>, memref<?xi64>, memref<?xf64>>)
func.func @sparse_csr(%arg0: tensor<?x?xf64, #CSR>) {
return
}
// CHECK-LABEL: func @sparse_dcsr(
-// CHECK-SAME: %[[A:.*]]: tuple<memref<2xindex>, memref<?xi64>, memref<?xi32>, memref<?xi64>, memref<?xi32>, memref<?xf64>>)
+// CHECK-SAME: %[[A:.*]]: tuple<memref<2xindex>, memref<?xi32>, memref<?xi64>, memref<?xi32>, memref<?xi64>, memref<?xf64>>)
func.func @sparse_dcsr(%arg0: tensor<?x?xf64, #DCSR>) {
return
}
//
-// Just a linearized array in the end. Dim op is statically known.
+// Querying for dimension 1 in the tensor type can immediately
+// fold using the original static dimension sizes.
//
// CHECK-LABEL: func @sparse_dense_3d(
-// CHECK-SAME: %[[A:.*]]: tuple<memref<6000xf64>>) -> index
+// CHECK-SAME: %[[A:.*]]: tuple<memref<3xindex>, memref<6000xf64>>) -> index {
// CHECK: %[[C:.*]] = arith.constant 20 : index
// CHECK: return %[[C]] : index
func.func @sparse_dense_3d(%arg0: tensor<10x20x30xf64, #Dense3D>) -> index {
@@ -80,3 +79,49 @@ func.func @sparse_dense_3d(%arg0: tensor<10x20x30xf64, #Dense3D>) -> index {
%0 = tensor.dim %arg0, %c : tensor<10x20x30xf64, #Dense3D>
return %0 : index
}
+
+//
+// Querying for dimension 1 in the tensor type needs to be permuted
+// into querying for dimension 2 in the stored sparse tensor scheme,
+// since the latter honors the dimOrdering.
+//
+// CHECK-LABEL: func @sparse_dense_3d_dyn(
+// CHECK-SAME: %[[A:.*]]: tuple<memref<3xindex>, memref<?xf64>>) -> index {
+// CHECK: %[[C:.*]] = arith.constant 2 : index
+// CHECK: %[[F:.*]] = sparse_tensor.storage_get %[[A]][0] : tuple<memref<3xindex>, memref<?xf64>> to memref<3xindex>
+// CHECK: %[[L:.*]] = memref.load %[[F]][%[[C]]] : memref<3xindex>
+// CHECK: return %[[L]] : index
+func.func @sparse_dense_3d_dyn(%arg0: tensor<?x?x?xf64, #Dense3D>) -> index {
+ %c = arith.constant 1 : index
+ %0 = tensor.dim %arg0, %c : tensor<?x?x?xf64, #Dense3D>
+ return %0 : index
+}
+
+// CHECK-LABEL: func @sparse_pointers_dcsr(
+// CHECK-SAME: %[[A:.*]]: tuple<memref<2xindex>, memref<?xi32>, memref<?xi64>, memref<?xi32>, memref<?xi64>, memref<?xf64>>)
+// CHECK: %[[F:.*]] = sparse_tensor.storage_get %[[A]][3] : tuple<memref<2xindex>, memref<?xi32>, memref<?xi64>, memref<?xi32>, memref<?xi64>, memref<?xf64>> to memref<?xi32>
+// CHECK: return %[[F]] : memref<?xi32>
+func.func @sparse_pointers_dcsr(%arg0: tensor<?x?xf64, #DCSR>) -> memref<?xi32> {
+ %c = arith.constant 1 : index
+ %0 = sparse_tensor.pointers %arg0, %c : tensor<?x?xf64, #DCSR> to memref<?xi32>
+ return %0 : memref<?xi32>
+}
+
+// CHECK-LABEL: func @sparse_indices_dcsr(
+// CHECK-SAME: %[[A:.*]]: tuple<memref<2xindex>, memref<?xi32>, memref<?xi64>, memref<?xi32>, memref<?xi64>, memref<?xf64>>)
+// CHECK: %[[F:.*]] = sparse_tensor.storage_get %[[A]][4] : tuple<memref<2xindex>, memref<?xi32>, memref<?xi64>, memref<?xi32>, memref<?xi64>, memref<?xf64>> to memref<?xi64>
+// CHECK: return %[[F]] : memref<?xi64>
+func.func @sparse_indices_dcsr(%arg0: tensor<?x?xf64, #DCSR>) -> memref<?xi64> {
+ %c = arith.constant 1 : index
+ %0 = sparse_tensor.indices %arg0, %c : tensor<?x?xf64, #DCSR> to memref<?xi64>
+ return %0 : memref<?xi64>
+}
+
+// CHECK-LABEL: func @sparse_values_dcsr(
+// CHECK-SAME: %[[A:.*]]: tuple<memref<2xindex>, memref<?xi32>, memref<?xi64>, memref<?xi32>, memref<?xi64>, memref<?xf64>>)
+// CHECK: %[[F:.*]] = sparse_tensor.storage_get %[[A]][5] : tuple<memref<2xindex>, memref<?xi32>, memref<?xi64>, memref<?xi32>, memref<?xi64>, memref<?xf64>> to memref<?xf64>
+// CHECK: return %[[F]] : memref<?xf64>
+func.func @sparse_values_dcsr(%arg0: tensor<?x?xf64, #DCSR>) -> memref<?xf64> {
+ %0 = sparse_tensor.values %arg0 : tensor<?x?xf64, #DCSR> to memref<?xf64>
+ return %0 : memref<?xf64>
+}
More information about the Mlir-commits
mailing list