[Mlir-commits] [mlir] fa087b4 - [mlir][scf][bufferize][NFC] Lookup buffer using helper function

Matthias Springer llvmlistbot at llvm.org
Tue Apr 12 02:15:41 PDT 2022


Author: Matthias Springer
Date: 2022-04-12T18:09:30+09:00
New Revision: fa087b43529cfac223432cc5545923fb9c6544af

URL: https://github.com/llvm/llvm-project/commit/fa087b43529cfac223432cc5545923fb9c6544af
DIFF: https://github.com/llvm/llvm-project/commit/fa087b43529cfac223432cc5545923fb9c6544af.diff

LOG: [mlir][scf][bufferize][NFC] Lookup buffer using helper function

Lookup iter_arg buffers using `lookupBuffer` instead of always creating a new `ToMemrefOp`. Also cast all yielded buffers (if necessary), regardless of whether they are an equivalent buffer or a new allocation.

Note: This should have been part of D123369.

Differential Revision: https://reviews.llvm.org/D123383

Added: 
    

Modified: 
    mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index 7c6bbcd414d71..33f1af28fdf1b 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -314,6 +314,23 @@ struct ForOpInterface
     auto bufferizableOp = cast<BufferizableOpInterface>(op);
     Block *oldLoopBody = &forOp.getLoopBody().front();
 
+    // Helper function for casting MemRef buffers.
+    auto castBuffer = [&](Value buffer, Type type) {
+      assert(type.isa<BaseMemRefType>() && "expected BaseMemRefType");
+      assert(buffer.getType().isa<BaseMemRefType>() &&
+             "expected BaseMemRefType");
+      // If the buffer already has the correct type, no cast is needed.
+      if (buffer.getType() == type)
+        return buffer;
+      // TODO: In case `type` has a layout map that is not the fully dynamic
+      // one, we may not be able to cast the buffer. In that case, the loop
+      // iter_arg's layout map must be changed (see uses of `castBuffer`).
+      assert(memref::CastOp::areCastCompatible(buffer.getType(), type) &&
+             "scf.for op bufferization: cast incompatible");
+      return rewriter.create<memref::CastOp>(buffer.getLoc(), type, buffer)
+          .getResult();
+    };
+
     // Indices of all iter_args that have tensor type. These are the ones that
     // are bufferized.
     DenseSet<int64_t> indices;
@@ -382,9 +399,10 @@ struct ForOpInterface
     rewriter.setInsertionPoint(yieldOp);
     SmallVector<Value> yieldValues =
         convert(yieldOp.getResults(), [&](Value val, int64_t index) {
-          ensureToMemrefOpIsValid(val, initArgs[index].getType());
-          Value yieldedVal = rewriter.create<bufferization::ToMemrefOp>(
-              val.getLoc(), initArgs[index].getType(), val);
+          Type initArgType = initArgs[index].getType();
+          ensureToMemrefOpIsValid(val, initArgType);
+          Value yieldedVal =
+              bufferization::lookupBuffer(rewriter, val, state.getOptions());
 
           if (equivalentYields[index])
             // Yielded value is equivalent to the corresponding iter_arg bbArg.
@@ -392,7 +410,7 @@ struct ForOpInterface
             // else must be resolved with copies and is potentially inefficient.
             // By default, such problematic IR would already have been rejected
             // during `verifyAnalysis`, unless `allow-return-allocs`.
-            return yieldedVal;
+            return castBuffer(yieldedVal, initArgType);
 
           // It is not certain that the yielded value and the iter_arg bbArg
           // have the same buffer. Allocate a new buffer and copy. The yielded
@@ -412,21 +430,9 @@ struct ForOpInterface
           (void)copyStatus;
           assert(succeeded(copyStatus) && "could not create memcpy");
 
-          if (yieldedVal.getType() == yieldedAlloc->getType())
-            return *yieldedAlloc;
-
-          // The iter_arg memref type has a layout map. Cast the new buffer to
-          // the same type.
-          // TODO: In case the iter_arg has a layout map that is not the fully
-          // dynamic one, we cannot cast the new buffer. In that case, the
-          // iter_arg must be changed to the fully dynamic layout map. (And then
-          // the new buffer can be casted.)
-          assert(memref::CastOp::areCastCompatible(yieldedAlloc->getType(),
-                                                   yieldedVal.getType()) &&
-                 "scf.for op bufferization: cast incompatible");
-          Value casted = rewriter.create<memref::CastOp>(
-              val.getLoc(), yieldedVal.getType(), *yieldedAlloc);
-          return casted;
+          // The iter_arg memref type may have a layout map. Cast the new buffer
+          // to the same type if needed.
+          return castBuffer(*yieldedAlloc, initArgType);
         });
     yieldOp.getResultsMutable().assign(yieldValues);
 


        


More information about the Mlir-commits mailing list