[Mlir-commits] [mlir] ea4be70 - [mlir][sparse] Fix problems in creating complex zero for initialization.
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Dec 8 07:49:33 PST 2022
Author: bixia1
Date: 2022-12-08T07:49:27-08:00
New Revision: ea4be70cea8509520db8638bb17bcd7b5d8d60ac
URL: https://github.com/llvm/llvm-project/commit/ea4be70cea8509520db8638bb17bcd7b5d8d60ac
DIFF: https://github.com/llvm/llvm-project/commit/ea4be70cea8509520db8638bb17bcd7b5d8d60ac.diff
LOG: [mlir][sparse] Fix problems in creating complex zero for initialization.
Reviewed By: aartbik, wrengr
Differential Revision: https://reviews.llvm.org/D139591
Added:
Modified:
mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_complex32.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
index 0592009844c19..ded1e653fb5ff 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
@@ -823,8 +823,7 @@ struct PushBackRewriter : OpRewritePattern<PushBackOp> {
rewriter.create<memref::ReallocOp>(loc, bufferType, buffer, capacity);
if (enableBufferInitialization) {
Value fillSize = rewriter.create<arith::SubIOp>(loc, capacity, newSize);
- Value fillValue = rewriter.create<arith::ConstantOp>(
- loc, value.getType(), rewriter.getZeroAttr(value.getType()));
+ Value fillValue = constantZero(rewriter, loc, value.getType());
Value subBuffer = rewriter.create<memref::SubViewOp>(
loc, newBuffer, /*offset=*/ValueRange{newSize},
/*size=*/ValueRange{fillSize},
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index e059bd36dc02c..4c190bc1e92ed 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -235,8 +235,7 @@ static Value createAllocation(OpBuilder &builder, Location loc,
Value buffer = builder.create<memref::AllocOp>(loc, memRefType, sz);
Type elemType = memRefType.getElementType();
if (enableInit) {
- Value fillValue = builder.create<arith::ConstantOp>(
- loc, elemType, builder.getZeroAttr(elemType));
+ Value fillValue = constantZero(builder, loc, elemType);
builder.create<linalg::FillOp>(loc, fillValue, buffer);
}
return buffer;
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
index 06f57eb6edc50..6845aa0d81877 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
@@ -216,9 +216,9 @@ struct SparseTensorCodegenPass
// The following operations and dialects may be introduced by the
// codegen rules, and are therefore marked as legal.
target.addLegalOp<linalg::FillOp>();
- target.addLegalDialect<arith::ArithDialect,
- bufferization::BufferizationDialect,
- memref::MemRefDialect, scf::SCFDialect>();
+ target.addLegalDialect<
+ arith::ArithDialect, bufferization::BufferizationDialect,
+ complex::ComplexDialect, memref::MemRefDialect, scf::SCFDialect>();
target.addLegalOp<UnrealizedConversionCastOp>();
// Populate with rules and apply rewriting rules.
populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_complex32.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_complex32.mlir
index 80947a27f1e94..6ee26b923514b 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_complex32.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_complex32.mlir
@@ -8,7 +8,7 @@
// RUN: %{command}
//
// Do the same run, but now with direct IR generation.
-// REDEFINE: %{option} = enable-runtime-library=false
+// REDEFINE: %{option} = "enable-runtime-library=false enable-buffer-initialization=true"
// RUN: %{command}
#SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}>
More information about the Mlir-commits
mailing list