[Mlir-commits] [mlir] 99ad207 - [mlir][linalg][bufferize] Fix buffer equivalence around scf.if ops

Matthias Springer llvmlistbot at llvm.org
Wed Nov 10 01:33:15 PST 2021


Author: Matthias Springer
Date: 2021-11-10T18:33:08+09:00
New Revision: 99ad2079d452f587be050b3867e0ed4856335fb2

URL: https://github.com/llvm/llvm-project/commit/99ad2079d452f587be050b3867e0ed4856335fb2
DIFF: https://github.com/llvm/llvm-project/commit/99ad2079d452f587be050b3867e0ed4856335fb2.diff

LOG: [mlir][linalg][bufferize] Fix buffer equivalence around scf.if ops

Also extend the comments for aliasInfo and equivalenceInfo.

Differential Revision: https://reviews.llvm.org/D113340

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
    mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir
    mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h
index 8976f69757268..0e29821dec5bd 100644
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h
@@ -109,12 +109,19 @@ class BufferizationAliasInfo {
   /// Set of tensors that are known to bufferize to writable memory.
   llvm::DenseSet<Value> bufferizeToWritableMemory;
 
-  /// Auxiliary structure to store all the values a given value aliases with.
-  /// These are the conservative cases that can further decompose into
-  /// "equivalent" buffer relationships.
+  /// Auxiliary structure to store all the values a given value may alias with.
+  /// Alias information is "may be" conservative: In the presence of branches, a
+  /// value may alias with one of multiple other values. The concrete aliasing
+  /// value may not even be known at compile time. All such values are
+  /// considered to be aliases.
   llvm::EquivalenceClasses<Value, ValueComparator> aliasInfo;
 
-  /// Auxiliary structure to store all the equivalent buffer classes.
+  /// Auxiliary structure to store all the equivalent buffer classes. Equivalent
+  /// buffer information is "must be" conservative: Only if two values are
+  /// guaranteed to be equivalent at runtime, they said to be equivalent. It is
+  /// possible that, in the presence of branches, it cannot be determined
+  /// statically if two values are equivalent. In that case, the values are
+  /// considered to be not equivalent.
   llvm::EquivalenceClasses<Value, ValueComparator> equivalentInfo;
 };
 

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
index 8dda1b2f40b35..27db0e9afcf52 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
@@ -545,8 +545,6 @@ BufferizationAliasInfo::BufferizationAliasInfo(Operation *rootOp) {
                                ifOp.elseYield().results(), ifOp.results())) {
         aliasInfo.unionSets(std::get<0>(it), std::get<1>(it));
         aliasInfo.unionSets(std::get<0>(it), std::get<2>(it));
-        equivalentInfo.unionSets(std::get<0>(it), std::get<1>(it));
-        equivalentInfo.unionSets(std::get<0>(it), std::get<2>(it));
       }
     }
   });
@@ -1344,6 +1342,9 @@ static Value getResultBuffer(OpBuilder &b, OpResult result,
   assert(operandBuffer && "operand buffer not found");
   // Make sure that all OpOperands are the same buffer. If this is not the case,
   // we would have to materialize a memref value.
+  // TODO: Should be looking for checking for "equivalent buffers" instead of
+  // operator== here, but equivalent buffers for scf.if yield values are not
+  // set up yet.
   if (!llvm::all_of(aliasingOperands, [&](OpOperand *o) {
         return lookup(bvm, o->get()) == operandBuffer;
       })) {

diff  --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir
index 0584ebde985cb..0706657d14b57 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir
@@ -35,6 +35,22 @@ func @swappy(%cond1 : i1, %cond2 : i1, %t1 : tensor<f32>, %t2 : tensor<f32>)
 
 // -----
 
+func @scf_if_not_equivalent(
+    %cond: i1, %t1: tensor<?xf32> {linalg.inplaceable = true},
+    %idx: index) -> tensor<?xf32> {
+  // expected-error @+1 {{result buffer is ambiguous}}
+  %r = scf.if %cond -> (tensor<?xf32>) {
+    scf.yield %t1 : tensor<?xf32>
+  } else {
+    // This buffer aliases, but is not equivalent.
+    %t2 = tensor.extract_slice %t1 [%idx] [%idx] [1] : tensor<?xf32> to tensor<?xf32>
+    scf.yield %t2 : tensor<?xf32>
+  }
+  return %r : tensor<?xf32>
+}
+
+// -----
+
 // expected-error @-3 {{expected callgraph to be free of circular dependencies}}
 
 func @foo() {

diff  --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
index bf46c1120bfaa..184e0f36e271d 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
@@ -887,4 +887,3 @@ func @scf_if_inplace(%cond: i1,
   }
   return %r : tensor<?xf32>
 }
-


        


More information about the Mlir-commits mailing list