[Mlir-commits] [mlir] 0f3e4d1 - [mlir][sparse] lower number of entries op to actual code
Aart Bik
llvmlistbot at llvm.org
Fri Oct 21 10:48:50 PDT 2022
Author: Aart Bik
Date: 2022-10-21T10:48:37-07:00
New Revision: 0f3e4d1afaa1dc330c374b729269f2ff8422e8dd
URL: https://github.com/llvm/llvm-project/commit/0f3e4d1afaa1dc330c374b729269f2ff8422e8dd
DIFF: https://github.com/llvm/llvm-project/commit/0f3e4d1afaa1dc330c374b729269f2ff8422e8dd.diff
LOG: [mlir][sparse] lower number of entries op to actual code
works both along runtime path and pure codegen path
Reviewed By: Peiming
Differential Revision: https://reviews.llvm.org/D136389
Added:
Modified:
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
mlir/test/Dialect/SparseTensor/codegen.mlir
mlir/test/Dialect/SparseTensor/conversion.mlir
mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_dot.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index 1beb1271103b4..bf2f77d95e665 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -277,6 +277,12 @@ static scf::ForOp createFor(OpBuilder &builder, Location loc, Value count,
return forOp;
}
+/// Translates field index to memSizes index.
+static unsigned getMemSizesIndex(unsigned field) {
+ assert(2 <= field);
+ return field - 2;
+}
+
/// Creates a pushback op for given field and updates the fields array
/// accordingly.
static void createPushback(OpBuilder &builder, Location loc,
@@ -286,9 +292,9 @@ static void createPushback(OpBuilder &builder, Location loc,
Type etp = fields[field].getType().cast<ShapedType>().getElementType();
if (value.getType() != etp)
value = builder.create<arith::IndexCastOp>(loc, etp, value);
- fields[field] =
- builder.create<PushBackOp>(loc, fields[field].getType(), fields[1],
- fields[field], value, APInt(64, field - 2));
+ fields[field] = builder.create<PushBackOp>(
+ loc, fields[field].getType(), fields[1], fields[field], value,
+ APInt(64, getMemSizesIndex(field)));
}
/// Generates insertion code.
@@ -739,6 +745,25 @@ class SparseConvertConverter : public OpConversionPattern<ConvertOp> {
}
};
+/// Sparse codegen rule for number of entries operator.
+class SparseNumberOfEntriesConverter
+ : public OpConversionPattern<NumberOfEntriesOp> {
+public:
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(NumberOfEntriesOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ // Query memSizes for the actually stored values size.
+ auto tuple = getTuple(adaptor.getTensor());
+ auto fields = tuple.getInputs();
+ unsigned lastField = fields.size() - 1;
+ Value field =
+ constantIndex(rewriter, op.getLoc(), getMemSizesIndex(lastField));
+ rewriter.replaceOpWithNewOp<memref::LoadOp>(op, fields[1], field);
+ return success();
+ }
+};
+
} // namespace
//===----------------------------------------------------------------------===//
@@ -775,5 +800,6 @@ void mlir::populateSparseTensorCodegenPatterns(TypeConverter &typeConverter,
SparseExpandConverter, SparseCompressConverter,
SparseInsertConverter, SparseToPointersConverter,
SparseToIndicesConverter, SparseToValuesConverter,
- SparseConvertConverter>(typeConverter, patterns.getContext());
+ SparseConvertConverter, SparseNumberOfEntriesConverter>(
+ typeConverter, patterns.getContext());
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index 40112078572bb..c7c81767a4041 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -205,6 +205,15 @@ static void newParams(OpBuilder &builder, SmallVector<Value, 8> ¶ms,
params.push_back(ptr);
}
+/// Generates a call to obtain the values array.
+static Value genValuesCall(OpBuilder &builder, Location loc, ShapedType tp,
+ ValueRange ptr) {
+ SmallString<15> name{"sparseValues",
+ primaryTypeFunctionSuffix(tp.getElementType())};
+ return createFuncCall(builder, loc, name, tp, ptr, EmitCInterface::On)
+ .getResult(0);
+}
+
/// Generates a call to release/delete a `SparseTensorCOO`.
static void genDelCOOCall(OpBuilder &builder, Location loc, Type elemTp,
Value coo) {
@@ -903,11 +912,28 @@ class SparseTensorToValuesConverter : public OpConversionPattern<ToValuesOp> {
LogicalResult
matchAndRewrite(ToValuesOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- Type resType = op.getType();
- Type eltType = resType.cast<ShapedType>().getElementType();
- SmallString<15> name{"sparseValues", primaryTypeFunctionSuffix(eltType)};
- replaceOpWithFuncCall(rewriter, op, name, resType, adaptor.getOperands(),
- EmitCInterface::On);
+ auto resType = op.getType().cast<ShapedType>();
+ rewriter.replaceOp(op, genValuesCall(rewriter, op.getLoc(), resType,
+ adaptor.getOperands()));
+ return success();
+ }
+};
+
+/// Sparse conversion rule for number of entries operator.
+class SparseNumberOfEntriesConverter
+ : public OpConversionPattern<NumberOfEntriesOp> {
+public:
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(NumberOfEntriesOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ // Query values array size for the actually stored values size.
+ Type eltType = op.getTensor().getType().cast<ShapedType>().getElementType();
+ auto resTp = MemRefType::get({ShapedType::kDynamicSize}, eltType);
+ Value values = genValuesCall(rewriter, loc, resTp, adaptor.getOperands());
+ rewriter.replaceOpWithNewOp<memref::DimOp>(op, values,
+ constantIndex(rewriter, loc, 0));
return success();
}
};
@@ -1250,9 +1276,10 @@ void mlir::populateSparseTensorConversionPatterns(
SparseTensorConcatConverter, SparseTensorAllocConverter,
SparseTensorDeallocConverter, SparseTensorToPointersConverter,
SparseTensorToIndicesConverter, SparseTensorToValuesConverter,
- SparseTensorLoadConverter, SparseTensorInsertConverter,
- SparseTensorExpandConverter, SparseTensorCompressConverter,
- SparseTensorOutConverter>(typeConverter, patterns.getContext());
+ SparseNumberOfEntriesConverter, SparseTensorLoadConverter,
+ SparseTensorInsertConverter, SparseTensorExpandConverter,
+ SparseTensorCompressConverter, SparseTensorOutConverter>(
+ typeConverter, patterns.getContext());
patterns.add<SparseTensorConvertConverter>(typeConverter,
patterns.getContext(), options);
diff --git a/mlir/test/Dialect/SparseTensor/codegen.mlir b/mlir/test/Dialect/SparseTensor/codegen.mlir
index 6b5c6c4ce3808..71f736d7263de 100644
--- a/mlir/test/Dialect/SparseTensor/codegen.mlir
+++ b/mlir/test/Dialect/SparseTensor/codegen.mlir
@@ -239,6 +239,20 @@ func.func @sparse_values_dcsr(%arg0: tensor<?x?xf64, #DCSR>) -> memref<?xf64> {
return %0 : memref<?xf64>
}
+// CHECK-LABEL: func @sparse_noe(
+// CHECK-SAME: %[[A0:.*0]]: memref<1xindex>,
+// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>,
+// CHECK-SAME: %[[A2:.*2]]: memref<?xi32>,
+// CHECK-SAME: %[[A3:.*3]]: memref<?xi64>,
+// CHECK-SAME: %[[A4:.*4]]: memref<?xf64>)
+// CHECK: %[[C2:.*]] = arith.constant 2 : index
+// CHECK: %[[NOE:.*]] = memref.load %[[A1]][%[[C2]]] : memref<3xindex>
+// CHECK: return %[[NOE]] : index
+func.func @sparse_noe(%arg0: tensor<128xf64, #SparseVector>) -> index {
+ %0 = sparse_tensor.number_of_entries %arg0 : tensor<128xf64, #SparseVector>
+ return %0 : index
+}
+
// CHECK-LABEL: func @sparse_dealloc_csr(
// CHECK-SAME: %[[A0:.*0]]: memref<2xindex>,
// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>,
diff --git a/mlir/test/Dialect/SparseTensor/conversion.mlir b/mlir/test/Dialect/SparseTensor/conversion.mlir
index 44fcd4219ec08..33b7d133fe849 100644
--- a/mlir/test/Dialect/SparseTensor/conversion.mlir
+++ b/mlir/test/Dialect/SparseTensor/conversion.mlir
@@ -268,6 +268,17 @@ func.func @sparse_valuesi8(%arg0: tensor<128xi8, #SparseVector>) -> memref<?xi8>
return %0 : memref<?xi8>
}
+// CHECK-LABEL: func @sparse_noe(
+// CHECK-SAME: %[[A:.*]]: !llvm.ptr<i8>)
+// CHECK-DAG: %[[C:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[T:.*]] = call @sparseValuesF64(%[[A]]) : (!llvm.ptr<i8>) -> memref<?xf64>
+// CHECK: %[[NOE:.*]] = memref.dim %[[T]], %[[C]] : memref<?xf64>
+// CHECK: return %[[NOE]] : index
+func.func @sparse_noe(%arg0: tensor<128xf64, #SparseVector>) -> index {
+ %0 = sparse_tensor.number_of_entries %arg0 : tensor<128xf64, #SparseVector>
+ return %0 : index
+}
+
// CHECK-LABEL: func @sparse_reconstruct(
// CHECK-SAME: %[[A:.*]]: !llvm.ptr<i8>
// CHECK: return %[[A]] : !llvm.ptr<i8>
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_dot.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_dot.mlir
index 6fc68ed700fdb..b3a5bbdb8f54a 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_dot.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_dot.mlir
@@ -46,6 +46,16 @@ module {
%1 = tensor.extract %0[] : tensor<f32>
vector.print %1 : f32
+ // Print number of entries in the sparse vectors.
+ //
+ // CHECK: 5
+ // CHECK: 3
+ //
+ %noe1 = sparse_tensor.number_of_entries %s1 : tensor<1024xf32, #SparseVector>
+ %noe2 = sparse_tensor.number_of_entries %s2 : tensor<1024xf32, #SparseVector>
+ vector.print %noe1 : index
+ vector.print %noe2 : index
+
// Release the resources.
bufferization.dealloc_tensor %s1 : tensor<1024xf32, #SparseVector>
bufferization.dealloc_tensor %s2 : tensor<1024xf32, #SparseVector>
More information about the Mlir-commits
mailing list