[Mlir-commits] [mlir] 8d615a2 - [mlir][sparse] fix crash on sparse_tensor.foreach operation on tensors with complex<T> elements.
Peiming Liu
llvmlistbot at llvm.org
Thu Nov 17 11:36:21 PST 2022
Author: Peiming Liu
Date: 2022-11-17T19:36:15Z
New Revision: 8d615a23ef781c91da06e45337229c19f2c777d7
URL: https://github.com/llvm/llvm-project/commit/8d615a23ef781c91da06e45337229c19f2c777d7
DIFF: https://github.com/llvm/llvm-project/commit/8d615a23ef781c91da06e45337229c19f2c777d7.diff
LOG: [mlir][sparse] fix crash on sparse_tensor.foreach operation on tensors with complex<T> elements.
Reviewed By: aartbik, bixia
Differential Revision: https://reviews.llvm.org/D138223
Added:
Modified:
mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_complex_ops.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
index 0db86dd2e8c16..fa36d083e285e 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
@@ -1024,3 +1024,36 @@ Operation *mlir::sparse_tensor::getTop(Operation *op) {
;
return op;
}
+
+void sparse_tensor::foreachInSparseConstant(
+ Location loc, RewriterBase &rewriter, SparseElementsAttr attr,
+ function_ref<void(ArrayRef<Value>, Value)> callback) {
+ int64_t rank = attr.getType().getRank();
+ // Foreach on constant.
+ DenseElementsAttr indicesAttr = attr.getIndices();
+ DenseElementsAttr valuesAttr = attr.getValues();
+
+ SmallVector<Value> coords;
+ for (int i = 0, e = valuesAttr.size(); i < e; i++) {
+ coords.clear();
+ for (int j = 0; j < rank; j++) {
+ auto coordAttr = indicesAttr.getValues<IntegerAttr>()[i * rank + j];
+ auto coord =
+ rewriter.create<arith::ConstantIndexOp>(loc, coordAttr.getInt());
+ // Remaps coordinates.
+ coords.push_back(coord);
+ }
+ Value val;
+ if (attr.getElementType().isa<ComplexType>()) {
+ auto valAttr = valuesAttr.getValues<ArrayAttr>()[i];
+ val = rewriter.create<complex::ConstantOp>(loc, attr.getElementType(),
+ valAttr);
+ } else {
+ auto valAttr = valuesAttr.getValues<TypedAttr>()[i];
+ // Remaps value.
+ val = rewriter.create<arith::ConstantOp>(loc, valAttr);
+ }
+ assert(val);
+ callback(coords, val);
+ }
+}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
index ecd4135be8e67..373817eb13a47 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
@@ -183,6 +183,26 @@ void sizesFromSrc(OpBuilder &builder, SmallVectorImpl<Value> &sizes,
/// Scans to top of generated loop.
Operation *getTop(Operation *op);
+/// Iterate over a sparse constant, generates constantOp for value and indices.
+/// E.g.,
+/// sparse<[ [0], [28], [31] ],
+/// [ (-5.13, 2.0), (3.0, 4.0), (5.0, 6.0) ] >
+/// =>
+/// %c1 = arith.constant 0
+/// %v1 = complex.constant (5.13, 2.0)
+/// callback({%c1}, %v1)
+///
+/// %c2 = arith.constant 28
+/// %v2 = complex.constant (3.0, 4.0)
+/// callback({%c2}, %v2)
+///
+/// %c3 = arith.constant 31
+/// %v3 = complex.constant (5.0, 6.0)
+/// callback({%c3}, %v3)
+void foreachInSparseConstant(
+ Location loc, RewriterBase &rewriter, SparseElementsAttr attr,
+ function_ref<void(ArrayRef<Value>, Value)> callback);
+
//===----------------------------------------------------------------------===//
// Inlined constant generators.
//
@@ -197,9 +217,9 @@ Operation *getTop(Operation *op);
//===----------------------------------------------------------------------===//
/// Generates a 0-valued constant of the given type. In addition to
-/// the scalar types (`ComplexType`, ``FloatType`, `IndexType`, `IntegerType`),
-/// this also works for `RankedTensorType` and `VectorType` (for which it
-/// generates a constant `DenseElementsAttr` of zeros).
+/// the scalar types (`ComplexType`, ``FloatType`, `IndexType`,
+/// `IntegerType`), this also works for `RankedTensorType` and `VectorType`
+/// (for which it generates a constant `DenseElementsAttr` of zeros).
inline Value constantZero(OpBuilder &builder, Location loc, Type tp) {
if (auto ctp = tp.dyn_cast<ComplexType>()) {
auto zeroe = builder.getZeroAttr(ctp.getElementType());
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 5fa2b4ebbfd2c..8a3cc0ae5fea3 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -170,6 +170,35 @@ static void getDynamicSizes(RankedTensorType tp,
}
}
+static LogicalResult genForeachOnSparseConstant(ForeachOp op,
+ RewriterBase &rewriter,
+ SparseElementsAttr attr) {
+ auto loc = op.getLoc();
+ SmallVector<Value> reduc = op.getInitArgs();
+
+ // Foreach on constant.
+ foreachInSparseConstant(
+ loc, rewriter, attr,
+ [&reduc, &rewriter, op](ArrayRef<Value> coords, Value v) mutable {
+ SmallVector<Value> args;
+ args.append(coords.begin(), coords.end());
+ args.push_back(v);
+ args.append(reduc);
+ // Clones the foreach op to get a copy of the loop body.
+ auto cloned = cast<ForeachOp>(rewriter.clone(*op.getOperation()));
+ assert(args.size() == cloned.getBody()->getNumArguments());
+ Operation *yield = cloned.getBody()->getTerminator();
+ rewriter.mergeBlockBefore(cloned.getBody(), op, args);
+ // clean up
+ rewriter.eraseOp(cloned);
+ reduc = yield->getOperands();
+ rewriter.eraseOp(yield);
+ });
+
+ rewriter.replaceOp(op, reduc);
+ return success();
+}
+
//===---------------------------------------------------------------------===//
// The actual sparse tensor rewriting rules.
//===---------------------------------------------------------------------===//
@@ -752,36 +781,7 @@ struct ForeachRewriter : public OpRewritePattern<ForeachOp> {
// rule.
if (auto constOp = input.getDefiningOp<arith::ConstantOp>()) {
if (auto attr = constOp.getValue().dyn_cast<SparseElementsAttr>()) {
- // Foreach on constant.
- DenseElementsAttr indicesAttr = attr.getIndices();
- DenseElementsAttr valuesAttr = attr.getValues();
-
- SmallVector<Value> args;
- for (int i = 0, e = valuesAttr.size(); i < e; i++) {
- auto valAttr = valuesAttr.getValues<TypedAttr>()[i];
- for (int j = 0; j < rank; j++) {
- auto coordAttr = indicesAttr.getValues<IntegerAttr>()[i * rank + j];
- auto coord = rewriter.create<arith::ConstantIndexOp>(
- loc, coordAttr.getInt());
- // Remaps coordinates.
- args.push_back(coord);
- }
- // Remaps value.
- auto val = rewriter.create<arith::ConstantOp>(loc, valAttr);
- args.push_back(val);
- // Remaps iteration args.
- args.append(reduc);
- auto cloned = cast<ForeachOp>(rewriter.clone(*op.getOperation()));
- Operation *yield = cloned.getBody()->getTerminator();
- rewriter.mergeBlockBefore(cloned.getBody(), op, args);
- // clean up
- args.clear();
- rewriter.eraseOp(cloned);
- reduc = yield->getOperands();
- rewriter.eraseOp(yield);
- }
- rewriter.replaceOp(op, reduc);
- return success();
+ return genForeachOnSparseConstant(op, rewriter, attr);
}
}
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_complex_ops.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_complex_ops.mlir
index aab48131605ed..97ecf509ea552 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_complex_ops.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_complex_ops.mlir
@@ -1,8 +1,15 @@
-// RUN: mlir-opt %s --sparse-compiler | \
-// RUN: mlir-cpu-runner \
-// RUN: -e entry -entry-point-result=void \
-// RUN: -shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext | \
-// RUN: FileCheck %s
+// DEFINE: %{option} = enable-runtime-library=true
+// DEFINE: %{command} = mlir-opt %s --sparse-compiler=%{option} | \
+// DEFINE: mlir-cpu-runner \
+// DEFINE: -e entry -entry-point-result=void \
+// DEFINE: -shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext | \
+// DEFINE: FileCheck %s
+//
+// RUN: %{command}
+//
+// Do the same run, but now with direct IR generation.
+// REDEFINE: %{option} = enable-runtime-library=false
+// RUN: %{command}
#SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}>
More information about the Mlir-commits
mailing list