[Mlir-commits] [mlir] 123c4b0 - [mlir][SCF][bufferize] Support different iter_arg/init_arg types (scf.for)

Matthias Springer llvmlistbot at llvm.org
Tue Aug 30 07:35:44 PDT 2022


Author: Matthias Springer
Date: 2022-08-30T16:35:32+02:00
New Revision: 123c4b02517865b11af1079d206bc838edad79a6

URL: https://github.com/llvm/llvm-project/commit/123c4b02517865b11af1079d206bc838edad79a6
DIFF: https://github.com/llvm/llvm-project/commit/123c4b02517865b11af1079d206bc838edad79a6.diff

LOG: [mlir][SCF][bufferize] Support different iter_arg/init_arg types (scf.for)

Even though iter_arg and init_arg of an scf.for loop may have the same tensor type, their bufferized memref types are not necessarily equal. It is sometimes necessary to insert a cast in case of differing layout maps.

Differential Revision: https://reviews.llvm.org/D132860

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
    mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
    mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
    mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
    mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
    mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
    mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
    mlir/test/Dialect/SCF/one-shot-bufferize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index f82bf26c19ef1..f22fe002ec86e 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -495,6 +495,19 @@ FailureOr<Value> getBuffer(RewriterBase &rewriter, Value value,
 FailureOr<BaseMemRefType> getBufferType(Value value,
                                         const BufferizationOptions &options);
 
+/// Return the buffer type for a given Value (tensor) after bufferization
+/// without bufferizing any IR. If at any point during the type computation, the
+/// type of a value in `fixedTypes` in required, the mapped type is used.
+///
+/// Note: It should be sufficient to call `getBuffer()->getType()` in most
+/// cases. However, when a buffer type should be predicted without modifying any
+/// IR, this function can be used.
+///
+/// This function is a wrapper around BufferizableOpInterface::getBufferType.
+FailureOr<BaseMemRefType>
+getBufferType(Value value, const BufferizationOptions &options,
+              const DenseMap<Value, BaseMemRefType> &fixedTypes);
+
 /// Replace an op with replacement values. The op is deleted. Tensor OpResults
 /// must be replaced with memref values.
 void replaceOpWithBufferizedValues(RewriterBase &rewriter, Operation *op,
@@ -551,7 +564,8 @@ namespace detail {
 /// BufferizableOpInterface::getBufferType. Should not be called from other
 /// places.
 FailureOr<BaseMemRefType>
-defaultGetBufferType(Value value, const BufferizationOptions &options);
+defaultGetBufferType(Value value, const BufferizationOptions &options,
+                     const DenseMap<Value, BaseMemRefType> &fixedTypes);
 } // namespace detail
 
 } // namespace bufferization

diff  --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
index ce88e01bff0a9..c28e3f6745539 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
@@ -350,12 +350,14 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
         /*retType=*/"FailureOr<BaseMemRefType>",
         /*methodName=*/"getBufferType",
         /*args=*/(ins "Value":$value,
-                      "const BufferizationOptions &":$options),
+                      "const BufferizationOptions &":$options,
+                      "const DenseMap<Value, BaseMemRefType>":$fixedTypes),
         /*methodBody=*/"",
         /*defaultImplementation=*/[{
           assert(getOwnerOfValue(value) == $_op.getOperation() &&
                  "expected that value belongs to this op");
-          return bufferization::detail::defaultGetBufferType(value, options);
+          return bufferization::detail::defaultGetBufferType(
+              value, options, fixedTypes);
         }]
       >,
   ];

diff  --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
index 22d5ef27bdbe8..07e1f53ab6a79 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
@@ -92,7 +92,8 @@ def Bufferization_AllocTensorOp : Bufferization_Op<"alloc_tensor",
         OpOperand &opOperand, const AnalysisState &state);
 
     FailureOr<BaseMemRefType> getBufferType(
-        Value value, const BufferizationOptions &options);
+        Value value, const BufferizationOptions &options,
+        const DenseMap<Value, BaseMemRefType> &fixedTypes);
 
     RankedTensorType getType() {
       return getResult().getType().cast<RankedTensorType>();
@@ -323,7 +324,8 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
     }
 
     FailureOr<BaseMemRefType> getBufferType(
-        Value value, const BufferizationOptions &options) {
+        Value value, const BufferizationOptions &options,
+        const DenseMap<Value, BaseMemRefType> &fixedTypes) {
       return getMemref().getType().cast<BaseMemRefType>();
     }
   }];

diff  --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 3d34ee4d3e2ee..265c417b6ce0c 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -568,7 +568,8 @@ FailureOr<Value> bufferization::getBuffer(RewriterBase &rewriter, Value value,
 }
 
 FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
-    Value value, const BufferizationOptions &options) {
+    Value value, const BufferizationOptions &options,
+    const DenseMap<Value, BaseMemRefType> &fixedTypes) {
   assert(value.getType().isa<TensorType>() && "expected tensor type");
 
   // No further analysis is possible for a block argument.
@@ -587,7 +588,7 @@ FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
     // If the OpResult has an equivalent OpOperand, both OpResult and
     // OpOperand bufferize to the exact same buffer type.
     Value equivalentOperand = aliasingOperands.front()->get();
-    return getBufferType(equivalentOperand, options);
+    return getBufferType(equivalentOperand, options, fixedTypes);
   }
 
   // If we do not know the memory space and there is no default memory space,
@@ -602,11 +603,26 @@ FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
 /// Return the buffer type for a given Value (tensor) after bufferization.
 FailureOr<BaseMemRefType>
 bufferization::getBufferType(Value value, const BufferizationOptions &options) {
+  DenseMap<Value, BaseMemRefType> fixedTypes;
+  return getBufferType(value, options, fixedTypes);
+}
+
+/// Return the buffer type for a given Value (tensor) after bufferization.
+FailureOr<BaseMemRefType> bufferization::getBufferType(
+    Value value, const BufferizationOptions &options,
+    const DenseMap<Value, BaseMemRefType> &fixedTypes) {
   assert(value.getType().isa<TensorType>() && "unexpected non-tensor type");
+
+  // If the `value` is in `fixedTypes`, return the mapped type.
+  const auto &it = fixedTypes.find(value);
+  if (it != fixedTypes.end())
+    return it->second;
+
+  // Try querying BufferizableOpInterface.
   Operation *op = getOwnerOfValue(value);
   auto bufferizableOp = options.dynCastBufferizableOp(op);
   if (bufferizableOp)
-    return bufferizableOp.getBufferType(value, options);
+    return bufferizableOp.getBufferType(value, options, fixedTypes);
 
   // Op is not bufferizable.
   if (!options.defaultMemorySpace.has_value())

diff  --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index ee9591b2a01b8..cf29d721002bd 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -170,7 +170,7 @@ LogicalResult AllocTensorOp::bufferize(RewriterBase &rewriter,
   }
 
   // Create memory allocation.
-  auto allocType = getBufferType(getResult(), options);
+  auto allocType = bufferization::getBufferType(getResult(), options);
   if (failed(allocType))
     return failure();
   SmallVector<Value> dynamicDims = getDynamicSizes();
@@ -233,8 +233,9 @@ AllocTensorOp::getAliasingOpResult(OpOperand &opOperand,
   return {};
 }
 
-FailureOr<BaseMemRefType>
-AllocTensorOp::getBufferType(Value value, const BufferizationOptions &options) {
+FailureOr<BaseMemRefType> AllocTensorOp::getBufferType(
+    Value value, const BufferizationOptions &options,
+    const DenseMap<Value, BaseMemRefType> &fixedTypes) {
   assert(value == getResult() && "invalid value");
 
   // Compute memory space of this allocation.
@@ -242,7 +243,8 @@ AllocTensorOp::getBufferType(Value value, const BufferizationOptions &options) {
   if (getMemorySpace().has_value()) {
     memorySpace = *getMemorySpace();
   } else if (getCopy()) {
-    auto copyBufferType = bufferization::getBufferType(getCopy(), options);
+    auto copyBufferType =
+        bufferization::getBufferType(getCopy(), options, fixedTypes);
     if (failed(copyBufferType))
       return failure();
     memorySpace = copyBufferType->getMemorySpaceAsInt();

diff  --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index 0b5939e60c1bd..b92b131616fbc 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -472,15 +472,76 @@ struct ForOpInterface
   }
 
   FailureOr<BaseMemRefType>
-  getBufferType(Operation *op, Value value,
-                const BufferizationOptions &options) const {
+  getBufferType(Operation *op, Value value, const BufferizationOptions &options,
+                const DenseMap<Value, BaseMemRefType> &fixedTypes) const {
     auto forOp = cast<scf::ForOp>(op);
-    // TODO: Only block arguments supported at the moment.
-    if (value.isa<OpResult>())
+    assert(getOwnerOfValue(value) == op && "invalid value");
+    assert(value.getType().isa<TensorType>() && "expected tensor type");
+
+    // Get result/argument number.
+    unsigned resultNum;
+    if (auto bbArg = value.dyn_cast<BlockArgument>()) {
+      resultNum =
+          forOp.getResultForOpOperand(forOp.getOpOperandForRegionIterArg(bbArg))
+              .getResultNumber();
+    } else {
+      resultNum = value.cast<OpResult>().getResultNumber();
+    }
+
+    // Determine the buffer type of the init_arg.
+    Value initArg = forOp.getInitArgs()[resultNum];
+    auto initArgBufferType =
+        bufferization::getBufferType(initArg, options, fixedTypes);
+    if (failed(initArgBufferType))
       return failure();
-    auto bbArg = value.cast<BlockArgument>();
-    return bufferization::getBufferType(
-        forOp.getOpOperandForRegionIterArg(bbArg).get(), options);
+
+    // Fix the iter_arg type, so that recursive lookups return the buffer type
+    // of the init_arg. This is to avoid infinite loops when calculating the
+    // buffer type of the yielded value.
+    //
+    // Note: For more precise layout map computation, a fixpoint iteration could
+    // be done (i.e., re-computing the yielded buffer type until the bufferized
+    // iter_arg type no longer changes). This current implementation immediately
+    // switches to a fully dynamic layout map when a mismatch between bufferized
+    // init_arg type and bufferized yield value type is detected.
+    DenseMap<Value, BaseMemRefType> newFixedTypes(fixedTypes);
+    newFixedTypes[forOp.getRegionIterArgs()[resultNum]] = *initArgBufferType;
+
+    // Compute the buffer type of the yielded value.
+    auto yieldOp =
+        cast<scf::YieldOp>(forOp.getLoopBody().front().getTerminator());
+    Value yieldedValue = yieldOp.getOperand(resultNum);
+    BaseMemRefType yieldedValueBufferType;
+    if (yieldedValue.getType().isa<BaseMemRefType>()) {
+      // scf.yield was already bufferized.
+      yieldedValueBufferType = yieldedValue.getType().cast<BaseMemRefType>();
+    } else {
+      auto maybeBufferType =
+          bufferization::getBufferType(yieldedValue, options, newFixedTypes);
+      if (failed(maybeBufferType))
+        return failure();
+      yieldedValueBufferType = *maybeBufferType;
+    }
+
+    // If yielded type and init_arg type are the same, use that type directly.
+    if (*initArgBufferType == yieldedValueBufferType)
+      return yieldedValueBufferType;
+
+    // If there is a mismatch between the yielded buffer type and the iter_arg
+    // buffer type, the buffer type must be promoted to a fully dynamic layout
+    // map.
+    auto yieldedRanked = yieldedValueBufferType.cast<MemRefType>();
+#ifndef NDEBUG
+    auto iterRanked = initArgBufferType->cast<MemRefType>();
+    assert(llvm::equal(yieldedRanked.getShape(), iterRanked.getShape()) &&
+           "expected same shape");
+    assert(yieldedRanked.getMemorySpaceAsInt() ==
+               iterRanked.getMemorySpaceAsInt() &&
+           "expected same memory space");
+#endif // NDEBUG
+    return getMemRefTypeWithFullyDynamicLayout(
+        forOp.getRegionIterArgs()[resultNum].getType().cast<RankedTensorType>(),
+        yieldedRanked.getMemorySpaceAsInt());
   }
 
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
@@ -499,13 +560,22 @@ struct ForOpInterface
       return failure();
     SmallVector<Value> initArgs = *maybeInitArgs;
 
+    // Cast init_args if necessary.
+    SmallVector<Value> castedInitArgs;
+    for (const auto &it : llvm::enumerate(initArgs)) {
+      Value initArg = it.value();
+      auto targetType =
+          bufferization::getBufferType(forOp->getResult(it.index()), options);
+      if (failed(targetType))
+        return failure();
+      castedInitArgs.push_back(castBuffer(rewriter, initArg, *targetType));
+    }
+
     // Construct a new scf.for op with memref instead of tensor values.
     auto newForOp = rewriter.create<scf::ForOp>(
         forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
-        forOp.getStep(), initArgs);
+        forOp.getStep(), castedInitArgs);
     newForOp->setAttrs(forOp->getAttrs());
-    ValueRange initArgsRange(initArgs);
-    TypeRange initArgsTypes(initArgsRange);
     Block *loopBody = &newForOp.getLoopBody().front();
 
     // Set up new iter_args. The loop body uses tensors, so wrap the (memref)
@@ -904,10 +974,8 @@ struct YieldOpInterface
           return failure();
         Value buffer = *maybeBuffer;
         if (auto forOp = dyn_cast<scf::ForOp>(yieldOp->getParentOp())) {
-          FailureOr<BaseMemRefType> resultType =
-              cast<BufferizableOpInterface>(forOp.getOperation())
-                  .getBufferType(forOp.getRegionIterArgs()[it.index()],
-                                 options);
+          FailureOr<BaseMemRefType> resultType = bufferization::getBufferType(
+              forOp.getRegionIterArgs()[it.index()], options);
           if (failed(resultType))
             return failure();
           buffer = castBuffer(rewriter, buffer, *resultType);

diff  --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index 2e23a167cd0b2..35010f520e022 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -293,7 +293,7 @@ struct ExtractSliceOpInterface
 
     // Take a subview of the source buffer.
     auto resultMemrefType =
-        getBufferType(op, extractSliceOp.getResult(), options);
+        bufferization::getBufferType(extractSliceOp.getResult(), options);
     if (failed(resultMemrefType))
       return failure();
     Value subView = rewriter.create<memref::SubViewOp>(
@@ -305,12 +305,12 @@ struct ExtractSliceOpInterface
   }
 
   FailureOr<BaseMemRefType>
-  getBufferType(Operation *op, Value value,
-                const BufferizationOptions &options) const {
+  getBufferType(Operation *op, Value value, const BufferizationOptions &options,
+                const DenseMap<Value, BaseMemRefType> &fixedTypes) const {
     auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
     assert(value == extractSliceOp.getResult() && "invalid value");
-    auto srcMemrefType =
-        bufferization::getBufferType(extractSliceOp.getSource(), options);
+    auto srcMemrefType = bufferization::getBufferType(
+        extractSliceOp.getSource(), options, fixedTypes);
     if (failed(srcMemrefType))
       return failure();
     SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets();

diff  --git a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
index 71ee9922f8f52..c8bf6125e252e 100644
--- a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
+++ b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
@@ -742,3 +742,25 @@ func.func @scf_for_yield_alias_of_non_equivalent(%sz: index) -> tensor<?xf32> {
   }
   return %r : tensor<?xf32>
 }
+
+// -----
+
+// We just check that this example bufferizes to valid IR.
+
+// CHECK-LABEL: func @scf_for_buffer_type_mismatch
+func.func @scf_for_buffer_type_mismatch(%sz: index, %sz2: index) -> f32 {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c10 = arith.constant 10 : index
+  %0 = bufferization.alloc_tensor(%sz) : tensor<?xf32>
+  %e2 = tensor.extract_slice %0[1][%sz2][1] : tensor<?xf32> to tensor<?xf32>
+  // init_arg and iter_arg have 
diff erent buffer types. This must be resolved
+  // with casts.
+  %r = scf.for %iv = %c0 to %c10 step %c1 iter_args(%t = %e2) -> tensor<?xf32> {
+    %s = "test.dummy"() : () -> (index)
+    %e = tensor.extract_slice %t[1][%s][1] : tensor<?xf32> to tensor<?xf32>
+    scf.yield %e : tensor<?xf32>
+  }
+  %x = tensor.extract %r[%c1] : tensor<?xf32>
+  return %x : f32
+}


        


More information about the Mlir-commits mailing list