[Mlir-commits] [mlir] ea4be70 - [mlir][sparse] Fix problems in creating complex zero for initialization.

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Dec 8 07:49:33 PST 2022


Author: bixia1
Date: 2022-12-08T07:49:27-08:00
New Revision: ea4be70cea8509520db8638bb17bcd7b5d8d60ac

URL: https://github.com/llvm/llvm-project/commit/ea4be70cea8509520db8638bb17bcd7b5d8d60ac
DIFF: https://github.com/llvm/llvm-project/commit/ea4be70cea8509520db8638bb17bcd7b5d8d60ac.diff

LOG: [mlir][sparse] Fix problems in creating complex zero for initialization.

Reviewed By: aartbik, wrengr

Differential Revision: https://reviews.llvm.org/D139591

Added: 
    

Modified: 
    mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
    mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_complex32.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
index 0592009844c19..ded1e653fb5ff 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
@@ -823,8 +823,7 @@ struct PushBackRewriter : OpRewritePattern<PushBackOp> {
           rewriter.create<memref::ReallocOp>(loc, bufferType, buffer, capacity);
       if (enableBufferInitialization) {
         Value fillSize = rewriter.create<arith::SubIOp>(loc, capacity, newSize);
-        Value fillValue = rewriter.create<arith::ConstantOp>(
-            loc, value.getType(), rewriter.getZeroAttr(value.getType()));
+        Value fillValue = constantZero(rewriter, loc, value.getType());
         Value subBuffer = rewriter.create<memref::SubViewOp>(
             loc, newBuffer, /*offset=*/ValueRange{newSize},
             /*size=*/ValueRange{fillSize},

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index e059bd36dc02c..4c190bc1e92ed 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -235,8 +235,7 @@ static Value createAllocation(OpBuilder &builder, Location loc,
   Value buffer = builder.create<memref::AllocOp>(loc, memRefType, sz);
   Type elemType = memRefType.getElementType();
   if (enableInit) {
-    Value fillValue = builder.create<arith::ConstantOp>(
-        loc, elemType, builder.getZeroAttr(elemType));
+    Value fillValue = constantZero(builder, loc, elemType);
     builder.create<linalg::FillOp>(loc, fillValue, buffer);
   }
   return buffer;

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
index 06f57eb6edc50..6845aa0d81877 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
@@ -216,9 +216,9 @@ struct SparseTensorCodegenPass
     // The following operations and dialects may be introduced by the
     // codegen rules, and are therefore marked as legal.
     target.addLegalOp<linalg::FillOp>();
-    target.addLegalDialect<arith::ArithDialect,
-                           bufferization::BufferizationDialect,
-                           memref::MemRefDialect, scf::SCFDialect>();
+    target.addLegalDialect<
+        arith::ArithDialect, bufferization::BufferizationDialect,
+        complex::ComplexDialect, memref::MemRefDialect, scf::SCFDialect>();
     target.addLegalOp<UnrealizedConversionCastOp>();
     // Populate with rules and apply rewriting rules.
     populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_complex32.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_complex32.mlir
index 80947a27f1e94..6ee26b923514b 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_complex32.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_complex32.mlir
@@ -8,7 +8,7 @@
 // RUN: %{command}
 //
 // Do the same run, but now with direct IR generation.
-// REDEFINE: %{option} = enable-runtime-library=false
+// REDEFINE: %{option} = "enable-runtime-library=false enable-buffer-initialization=true"
 // RUN: %{command}
 
 #SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}>


        


More information about the Mlir-commits mailing list