[Mlir-commits] [mlir] e9fb4dc - [mlir][linalg][bufferize] Remove buffer equivalence from bufferize

Matthias Springer llvmlistbot at llvm.org
Mon Dec 6 00:51:07 PST 2021


Author: Matthias Springer
Date: 2021-12-06T17:48:31+09:00
New Revision: e9fb4dc9e918e23384550df9b66c2fd87cb1ffdd

URL: https://github.com/llvm/llvm-project/commit/e9fb4dc9e918e23384550df9b66c2fd87cb1ffdd
DIFF: https://github.com/llvm/llvm-project/commit/e9fb4dc9e918e23384550df9b66c2fd87cb1ffdd.diff

LOG: [mlir][linalg][bufferize] Remove buffer equivalence from bufferize

Remove all function calls related to buffer equivalence from bufferize implementations.

Add a new PostAnalysisStep for scf.for that ensures that yielded values are equivalent to the corresponding BBArgs. (This was previously checked in `bufferize`.) This will be relaxed in a subsequent commit.

Note: This commit changes two test cases. These were broken by design
and should not have passed. With the new scf.for PostAnalysisStep, this
bug was fixed.

Differential Revision: https://reviews.llvm.org/D114927

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
    mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
    mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h
index 0a4b140a1f96..3ab5cc3525fc 100644
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h
@@ -19,6 +19,13 @@ namespace linalg {
 namespace comprehensive_bufferize {
 namespace scf_ext {
 
+/// Equivalence analysis for scf.for. Raise an error if iter_args are not
+/// equivalent to their corresponding loop yield values.
+struct AssertDestinationPassingStyle : public PostAnalysisStep {
+  LogicalResult run(FuncOp funcOp, BufferizationState &state,
+                    SmallVector<Operation *> &newOps) override;
+};
+
 void registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry);
 
 } // namespace scf_ext

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp
index fa0c96275daf..0298a492360a 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp
@@ -37,7 +37,6 @@ struct ConstantOpInterface
     auto globalMemref = globalCreator.getGlobalFor(constantOp);
     Value memref = b.create<memref::GetGlobalOp>(
         constantOp.getLoc(), globalMemref.type(), globalMemref.getName());
-    state.aliasInfo.insertNewBufferEquivalence(memref, constantOp.getResult());
     state.mapBuffer(constantOp, memref);
 
     return success();

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
index 03ea6bdd63b9..af5362bfe1f5 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
@@ -141,22 +141,7 @@ void BufferizationAliasInfo::setBufferizesToWritableMemory(Value v) {
 
 /// Return `true` if a value was marked as in-place bufferized.
 bool BufferizationAliasInfo::isInPlace(OpResult opResult) const {
-  bool inplace = inplaceBufferized.contains(opResult);
-#ifndef NDEBUG
-  if (inplace) {
-    auto bufferizableOp =
-        dyn_cast<BufferizableOpInterface>(opResult.getDefiningOp());
-    assert(bufferizableOp &&
-           "expected that in-place bufferized op is bufferizable");
-    SmallVector<OpOperand *> operands =
-        bufferizableOp.getAliasingOpOperand(opResult);
-    for (OpOperand *operand : operands)
-      assert(areAliasingBufferizedValues(operand->get(), opResult) &&
-             "expected that in-place bufferized OpResult aliases with "
-             "aliasing OpOperand");
-  }
-#endif // NDEBUG
-  return inplace;
+  return inplaceBufferized.contains(opResult);
 }
 
 /// Set the inPlace bufferization spec to true.
@@ -593,7 +578,6 @@ Value mlir::linalg::comprehensive_bufferize::BufferizationState::
   Value casted = allocated.getValue();
   if (memRefType && memRefType != allocMemRefType) {
     casted = b.create<memref::CastOp>(loc, memRefType, allocated.getValue());
-    aliasInfo.insertNewBufferEquivalence(casted, allocated.getValue());
   }
 
   // 2. Create memory deallocation.

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
index 6a3a25c93582..decfb1d41f8a 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
@@ -253,8 +253,6 @@ struct TiledLoopOpInterface
         return failure();
 
       // Insert mapping and aliasing info.
-      state.aliasInfo.createAliasInfoEntry(resultBuffer);
-      state.aliasInfo.insertNewBufferEquivalence(opResult, resultBuffer);
       state.mapBuffer(opResult, resultBuffer);
 
       // Insert new operand and bbArg.
@@ -263,9 +261,6 @@ struct TiledLoopOpInterface
           body->insertArgument(nextOutputBBArgIndex, resultBuffer.getType());
       BlockArgument oldTensorBBArg = body->getArgument(oldOutputBBArgIndex);
       // Insert mapping and aliasing info.
-      state.aliasInfo.createAliasInfoEntry(newBufferBBArg);
-      state.aliasInfo.insertNewBufferEquivalence(oldTensorBBArg,
-                                                 newBufferBBArg);
       state.mapBuffer(oldTensorBBArg, newBufferBBArg);
 
       // Set operand of `linalg.yield` to the bbArg so it just canonicalizes
@@ -303,9 +298,6 @@ struct TiledLoopOpInterface
       BlockArgument oldTensorBBArg = body->getArgument(oldInputBBArgIndex);
 
       // Insert mapping and aliasing info.
-      state.aliasInfo.createAliasInfoEntry(newBufferBBArg);
-      state.aliasInfo.insertNewBufferEquivalence(oldTensorBBArg,
-                                                 newBufferBBArg);
       state.mapBuffer(oldTensorBBArg, newBufferBBArg);
 
       // Increment indices.

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
index e65b1eb441a3..5dc37968c95b 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
@@ -223,7 +223,6 @@ static LogicalResult bufferizeFuncOpBoundary(FuncOp funcOp,
                                              BufferizationState &state) {
   LLVM_DEBUG(DBGS() << "Begin bufferizeFuncOpBoundary:\n" << funcOp << "\n");
   ModuleBufferizationState &moduleState = getModuleBufferizationState(state);
-  BufferizationAliasInfo &aliasInfo = state.aliasInfo;
 
   // If nothing to do then we are done.
   if (!llvm::any_of(funcOp.getType().getInputs(), isaTensor) &&
@@ -321,15 +320,12 @@ static LogicalResult bufferizeFuncOpBoundary(FuncOp funcOp,
         auto castOp = b.create<memref::CastOp>(
             funcOp.getLoc(), toMemrefOp.memref().getType(), memref);
         toMemrefOp.memref().replaceAllUsesWith(castOp);
-        aliasInfo.insertNewBufferEquivalence(castOp.dest(),
-                                             toMemrefOp.memref());
       }
     }
     // Replace all remaining uses by a to_tensor.
     if (!bbArg.use_empty()) {
       auto toTensorOp =
           b.create<bufferization::ToTensorOp>(funcOp.getLoc(), memref);
-      aliasInfo.insertNewBufferEquivalence(toTensorOp, bbArg);
       bbArg.replaceAllUsesWith(toTensorOp);
     }
     frontBlock.eraseArgument(0);
@@ -562,7 +558,6 @@ struct CallOpInterface
           Value buffer = state.lookupBuffer(callOp->getOperand(idx));
           // Add CallOp operand/result equivalence: this is interprocedural
           // info.
-          state.aliasInfo.insertNewBufferEquivalence(oldRes, buffer);
           state.mapBuffer(oldRes, buffer);
           // Add a ToTensorOp to kill all uses of the CallOp return.
           // Replace all uses of the CallOp results so we can erase the CallOp.
@@ -572,7 +567,6 @@ struct CallOpInterface
               b.create<bufferization::ToTensorOp>(callOp.getLoc(), buffer);
           oldRes.replaceAllUsesWith(toTensorOp);
           // Add new op equivalence info.
-          state.aliasInfo.insertNewBufferEquivalence(toTensorOp, buffer);
           state.mapBuffer(toTensorOp, buffer);
           continue;
         }
@@ -615,7 +609,6 @@ struct CallOpInterface
         Value castBuffer =
             b.create<memref::CastOp>(callOp.getLoc(), memRefType, buffer);
         // Add new op equivalence info.
-        state.aliasInfo.insertNewBufferEquivalence(castBuffer, buffer);
         state.mapBuffer(tensorOperand, castBuffer);
         buffer = castBuffer;
       }
@@ -663,7 +656,6 @@ struct ReturnOpInterface
       Value returnTensor = b.create<bufferization::ToTensorOp>(
           returnOp.getLoc(), v);
       operand.set(returnTensor);
-      state.aliasInfo.insertNewBufferEquivalence(returnTensor, v);
       state.mapBuffer(returnTensor, v);
     }
     return success();
@@ -690,7 +682,6 @@ struct FuncOpInterface
                             : getContiguousOrUnrankedMemRefType(tensorType);
       Value bufferCast = b.create<bufferization::ToMemrefOp>(funcOp.getLoc(),
                                                              memRefType, bbArg);
-      state.aliasInfo.insertNewBufferEquivalence(bufferCast, bbArg);
       state.mapBuffer(bbArg, bufferCast);
     }
 

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
index c41457d7da76..d9416347fab5 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
@@ -147,7 +147,6 @@ struct IfOpInterface
       if (!resultBuffer)
         return failure();
 
-      state.aliasInfo.createAliasInfoEntry(resultBuffer);
       state.mapBuffer(opResult, resultBuffer);
     }
 
@@ -237,8 +236,6 @@ struct ForOpInterface
 
       OpOperand &opOperand = forOp.getOpOperandForResult(opResult);
       BlockArgument bbArg = forOp.getRegionIterArgForOpOperand(opOperand);
-      state.aliasInfo.createAliasInfoEntry(resultBuffer);
-      state.aliasInfo.insertNewBufferEquivalence(bbArg, resultBuffer);
       state.mapBuffer(bbArg, resultBuffer);
       state.mapBuffer(opResult, resultBuffer);
     }
@@ -257,15 +254,6 @@ struct ForOpInterface
       OpOperand &forOperand = forOp.getOpOperandForResult(
           forOp->getResult(operand.getOperandNumber()));
       auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand);
-      if (!state.aliasInfo.areEquivalentBufferizedValues(operand.get(),
-                                                         bbArg)) {
-        // TODO: this could get resolved with copies but it can also turn into
-        // swaps so we need to be careful about order of copies.
-        return yieldOp->emitError()
-               << "Yield operand #" << operand.getOperandNumber()
-               << " does not bufferize to an equivalent buffer to the matching"
-               << " enclosing scf::for operand";
-      }
 
       // Buffers are equivalent so the work is already done and we just yield
       // the bbArg so that it later canonicalizes away.
@@ -275,6 +263,41 @@ struct ForOpInterface
   }
 };
 
+LogicalResult mlir::linalg::comprehensive_bufferize::scf_ext::
+    AssertDestinationPassingStyle::run(FuncOp funcOp, BufferizationState &state,
+                                       SmallVector<Operation *> &newOps) {
+  LogicalResult status = success();
+  funcOp->walk([&](scf::YieldOp yieldOp) {
+    auto forOp = dyn_cast<scf::ForOp>(yieldOp->getParentOp());
+    if (!forOp)
+      return WalkResult::advance();
+
+    for (OpOperand &operand : yieldOp->getOpOperands()) {
+      auto tensorType = operand.get().getType().dyn_cast<TensorType>();
+      if (!tensorType)
+        continue;
+
+      OpOperand &forOperand = forOp.getOpOperandForResult(
+          forOp->getResult(operand.getOperandNumber()));
+      auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand);
+      if (!state.aliasInfo.areEquivalentBufferizedValues(operand.get(),
+                                                         bbArg)) {
+        // TODO: this could get resolved with copies but it can also turn into
+        // swaps so we need to be careful about order of copies.
+        status =
+            yieldOp->emitError()
+            << "Yield operand #" << operand.getOperandNumber()
+            << " does not bufferize to an equivalent buffer to the matching"
+            << " enclosing scf::for operand";
+        return WalkResult::interrupt();
+      }
+    }
+
+    return WalkResult::advance();
+  });
+  return status;
+}
+
 struct YieldOpInterface
     : public BufferizableOpInterface::ExternalModel<YieldOpInterface,
                                                     scf::YieldOp> {

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
index f595de42b7e7..7f1bdb703d18 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
@@ -80,7 +80,6 @@ struct CastOpInterface
         castOp.getResult().getType(), layout, memorySpace);
     Value res =
         b.create<memref::CastOp>(castOp.getLoc(), memRefType, resultBuffer);
-    state.aliasInfo.insertNewBufferEquivalence(res, castOp.getResult());
     state.mapBuffer(castOp.getResult(), res);
     return success();
   }
@@ -233,7 +232,6 @@ struct InsertOpInterface
     b.create<memref::StoreOp>(loc, insertOp.scalar(), destMemref,
                               insertOp.indices());
     state.mapBuffer(insertOp, destMemref);
-    state.aliasInfo.insertNewBufferAlias(insertOp, destMemref);
     return success();
   }
 
@@ -421,8 +419,6 @@ struct InsertSliceOpInterface
       Value subView = b.create<memref::SubViewOp>(
           loc, subviewMemRefType, dstMemref, insertSliceOp.getMixedOffsets(),
           insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides());
-      // Insert new alias.
-      state.aliasInfo.insertNewBufferAlias(subView, dstMemref);
       // Copy tensor.
       Value srcMemref = state.lookupBuffer(insertSliceOp.source());
       state.options.allocationFns->memCpyFn(b, insertSliceOp.getLoc(),

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
index 9da3ed883cfb..f90b5e5e20ec 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
@@ -96,6 +96,7 @@ void LinalgComprehensiveModuleBufferize::runOnOperation() {
   // TODO: Find a way to enable this step automatically when bufferizing tensor
   // dialect ops.
   options.addPostAnalysisStep<tensor_ext::InplaceInsertSliceOpAnalysis>();
+  options.addPostAnalysisStep<scf_ext::AssertDestinationPassingStyle>();
 
   ModuleOp moduleOp = getOperation();
   applyEnablingTransformations(moduleOp);

diff  --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir
index 6e82f65cc905..0e1e7602b2af 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir
@@ -1113,7 +1113,7 @@ func @reading_scf_for(%t1: tensor<?xf32> {linalg.inplaceable = true},
 
     // Read from %t1 via alias %e.
     %v2 = vector.transfer_read %e[%s], %cst : tensor<?xf32>, vector<5xf32>
-    scf.yield %e, %v2 : tensor<?xf32>, vector<5xf32>
+    scf.yield %t2, %v2 : tensor<?xf32>, vector<5xf32>
   }
   // CHECK: __inplace_results_attr__ = ["true", "false"]
 
@@ -1154,14 +1154,10 @@ func @non_reading_scf_for(%t1: tensor<?xf32> {linalg.inplaceable = true},
   // This loop does not read from %t1. It only writes to it.
   // CHECK:      scf.for
   %r, %v3 = scf.for %i = %c0 to %s step %c1 iter_args(%t2 = %t1, %v0 = %v) -> (tensor<?xf32>, vector<5xf32>) {
-    // CHECK:      tensor.extract_slice
-    // CHECK-SAME: __inplace_results_attr__ = ["true"]
-    %e = tensor.extract_slice %t2[%s][%s][1] : tensor<?xf32> to tensor<?xf32>
-
-    // Write to %t1 via alias. (Overwrite %t3.)
+    // Write to %t1 via %t2. (Overwrite %t3.)
     // CHECK:      linalg.generic
     // CHECK-SAME: __inplace_results_attr__ = ["true"]
-    %o2 = linalg.generic #trait outs (%e : tensor<?xf32>) {
+    %o2 = linalg.generic #trait outs (%t2 : tensor<?xf32>) {
         ^bb(%0: f32) :
           linalg.yield %cst : f32
       } -> (tensor<?xf32>)
@@ -1172,8 +1168,8 @@ func @non_reading_scf_for(%t1: tensor<?xf32> {linalg.inplaceable = true},
   }
 
   // Use %t3 in some way without reading it, so that it does not get DCE'd.
-    // CHECK:      linalg.generic
-    // CHECK-SAME: __inplace_results_attr__ = ["true"]
+  // CHECK:      linalg.generic
+  // CHECK-SAME: __inplace_results_attr__ = ["true"]
   %o = linalg.generic #trait outs (%t3 : tensor<?xf32>) {
       ^bb(%0: f32) :
         linalg.yield %cst : f32


        


More information about the Mlir-commits mailing list