[Mlir-commits] [mlir] 111c919 - [mlir][bufferization] Generalize getBufferType

Matthias Springer llvmlistbot at llvm.org
Tue Aug 30 07:30:16 PDT 2022


Author: Matthias Springer
Date: 2022-08-30T16:26:44+02:00
New Revision: 111c91966582654dec19b5f86334d4783c31a63b

URL: https://github.com/llvm/llvm-project/commit/111c91966582654dec19b5f86334d4783c31a63b
DIFF: https://github.com/llvm/llvm-project/commit/111c91966582654dec19b5f86334d4783c31a63b.diff

LOG: [mlir][bufferization] Generalize getBufferType

This change generalizes getBufferType. This function can be used to predict the buffer type of any tensor value (not just BlockArguments) without changing any IR. It also subsumes getMemorySpace. This is useful for loop bufferization, where the precise buffer type of an iter_arg cannot be known without examining the loop body.

Differential Revision: https://reviews.llvm.org/D132859

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
    mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
    mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
    mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
    mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
    mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
    mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index e39fb696c8917..f82bf26c19ef1 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -484,10 +484,14 @@ bool allocationDoesNotEscape(OpResult opResult);
 FailureOr<Value> getBuffer(RewriterBase &rewriter, Value value,
                            const BufferizationOptions &options);
 
-/// Return the buffer type for a given Value (tensor) after bufferization.
+/// Return the buffer type for a given Value (tensor) after bufferization
+/// without bufferizing any IR.
 ///
-/// Note: Op implementations should preferrably call `getBuffer()->getType()`.
-/// This function should only be used if `getBuffer` cannot be used.
+/// Note: It should be sufficient to call `getBuffer()->getType()` in most
+/// cases. However, when a buffer type should be predicted without modifying any
+/// IR, this function can be used.
+///
+/// This function is a wrapper around BufferizableOpInterface::getBufferType.
 FailureOr<BaseMemRefType> getBufferType(Value value,
                                         const BufferizationOptions &options);
 
@@ -538,6 +542,18 @@ BaseMemRefType getMemRefTypeWithFullyDynamicLayout(TensorType tensorType,
 BaseMemRefType getMemRefTypeWithStaticIdentityLayout(TensorType tensorType,
                                                      unsigned memorySpace = 0);
 
+/// Return the owner of the given value. In case of a BlockArgument that is the
+/// owner of the block. In case of an OpResult that is the defining op.
+Operation *getOwnerOfValue(Value value);
+
+namespace detail {
+/// This is the default implementation of
+/// BufferizableOpInterface::getBufferType. Should not be called from other
+/// places.
+FailureOr<BaseMemRefType>
+defaultGetBufferType(Value value, const BufferizationOptions &options);
+} // namespace detail
+
 } // namespace bufferization
 } // namespace mlir
 

diff  --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
index 9f6d118ae6eb1..ce88e01bff0a9 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
@@ -340,39 +340,22 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
       >,
       InterfaceMethod<
         /*desc=*/[{
-          Return the bufferized type of the given tensor block argument. The
-          block argument is guaranteed to belong to a block of this op.
+          Return the bufferized type of the given tensor value (without
+          bufferizing the IR). The value is either a BlockArgument of a block
+          that belongs to this op or an OpResult of the given op.
+
+          This method is useful when the bufferized type of value must be
+          predicted before modifying any IR.
         }],
         /*retType=*/"FailureOr<BaseMemRefType>",
         /*methodName=*/"getBufferType",
-        /*args=*/(ins "BlockArgument":$bbArg,
+        /*args=*/(ins "Value":$value,
                       "const BufferizationOptions &":$options),
         /*methodBody=*/"",
         /*defaultImplementation=*/[{
-          assert(bbArg.getOwner()->getParentOp() == $_op &&
-                 "bbArg must belong to this op");
-          assert(bbArg.getType().isa<TensorType>() &&
-                 "expected tensor type");
-          return bufferization::getMemRefType(bbArg, options);
-        }]
-      >,
-      InterfaceMethod<
-        /*desc=*/[{
-          Return the memory space of the given tensor OpResult if specified on
-          this op. If not specified, return `failure`.
-
-          This method will never be called with OpResults that do not bufferize
-          to a memory allocation.
-        }],
-        /*retType=*/"FailureOr<unsigned>",
-        /*methodName=*/"getMemorySpace",
-        /*args=*/(ins "OpResult":$opResult),
-        /*methodBody=*/"",
-        /*defaultImplementation=*/[{
-          assert(cast<BufferizableOpInterface>($_op.getOperation())
-                     .bufferizesToAllocation(opResult)
-                 && "expected allocation");
-          return failure();
+          assert(getOwnerOfValue(value) == $_op.getOperation() &&
+                 "expected that value belongs to this op");
+          return bufferization::detail::defaultGetBufferType(value, options);
         }]
       >,
   ];

diff  --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
index 4d7e5462a5f6d..22d5ef27bdbe8 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
@@ -82,12 +82,6 @@ def Bufferization_AllocTensorOp : Bufferization_Op<"alloc_tensor",
 
     bool bufferizesToAllocation(OpResult opResult) { return true; }
 
-    FailureOr<unsigned> getMemorySpace(OpResult opResult) {
-      if (getMemorySpace().has_value())
-        return static_cast<unsigned>(*getMemorySpace());
-      return failure();
-    }
-
     bool bufferizesToMemoryRead(OpOperand &opOperand,
                                 const AnalysisState &state);
 
@@ -97,6 +91,9 @@ def Bufferization_AllocTensorOp : Bufferization_Op<"alloc_tensor",
     SmallVector<OpResult> getAliasingOpResult(
         OpOperand &opOperand, const AnalysisState &state);
 
+    FailureOr<BaseMemRefType> getBufferType(
+        Value value, const BufferizationOptions &options);
+
     RankedTensorType getType() {
       return getResult().getType().cast<RankedTensorType>();
     }
@@ -324,6 +321,11 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
       // It is unknown whether the memref operand is writable or not.
       return false;
     }
+
+    FailureOr<BaseMemRefType> getBufferType(
+        Value value, const BufferizationOptions &options) {
+      return getMemref().getType().cast<BaseMemRefType>();
+    }
   }];
 
   let assemblyFormat = "$memref attr-dict `:` type($memref)";

diff  --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 056f912f86590..3d34ee4d3e2ee 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -38,8 +38,7 @@ namespace bufferization {
 using namespace mlir;
 using namespace bufferization;
 
-/// Return the owner of the given value.
-static Operation *getOwnerOfValue(Value value) {
+Operation *bufferization::getOwnerOfValue(Value value) {
   if (auto opResult = value.dyn_cast<OpResult>())
     return opResult.getDefiningOp();
   return value.cast<BlockArgument>().getOwner()->getParentOp();
@@ -568,47 +567,53 @@ FailureOr<Value> bufferization::getBuffer(RewriterBase &rewriter, Value value,
       .getResult();
 }
 
+FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
+    Value value, const BufferizationOptions &options) {
+  assert(value.getType().isa<TensorType>() && "expected tensor type");
+
+  // No further analysis is possible for a block argument.
+  if (value.isa<BlockArgument>())
+    return bufferization::getMemRefType(value, options);
+
+  // Value is an OpResult.
+  Operation *op = getOwnerOfValue(value);
+  auto opResult = value.cast<OpResult>();
+  auto bufferizableOp = cast<BufferizableOpInterface>(op);
+  AnalysisState state(options);
+  auto aliasingOperands = bufferizableOp.getAliasingOpOperand(opResult, state);
+  if (!aliasingOperands.empty() &&
+      bufferizableOp.bufferRelation(opResult, state) ==
+          BufferRelation::Equivalent) {
+    // If the OpResult has an equivalent OpOperand, both OpResult and
+    // OpOperand bufferize to the exact same buffer type.
+    Value equivalentOperand = aliasingOperands.front()->get();
+    return getBufferType(equivalentOperand, options);
+  }
+
+  // If we do not know the memory space and there is no default memory space,
+  // report a failure.
+  if (!options.defaultMemorySpace.has_value())
+    return op->emitError("could not infer memory space");
+
+  return getMemRefType(value, options, /*layout=*/{},
+                       *options.defaultMemorySpace);
+}
+
 /// Return the buffer type for a given Value (tensor) after bufferization.
 FailureOr<BaseMemRefType>
 bufferization::getBufferType(Value value, const BufferizationOptions &options) {
   assert(value.getType().isa<TensorType>() && "unexpected non-tensor type");
   Operation *op = getOwnerOfValue(value);
+  auto bufferizableOp = options.dynCastBufferizableOp(op);
+  if (bufferizableOp)
+    return bufferizableOp.getBufferType(value, options);
 
-  // ToTensorOp: Take buffer type directly from the op.
-  if (auto toTensorOp = value.getDefiningOp<bufferization::ToTensorOp>())
-    return toTensorOp.getMemref().getType().cast<BaseMemRefType>();
-
-  // If value is a bbArg of a bufferizable op: query op interface.
-  if (auto bbArg = value.dyn_cast<BlockArgument>())
-    if (auto bufferizableOp =
-            options.dynCastBufferizableOp(bbArg.getOwner()->getParentOp()))
-      return bufferizableOp.getBufferType(bbArg, options);
-
-  // Check value is a new buffer allocation with a memory space attribute. In
-  // that case we can at least infer the memory space.
-  Optional<unsigned> memorySpace;
-  if (auto opResult = value.dyn_cast<OpResult>()) {
-    if (auto bufferizableOp =
-            options.dynCastBufferizableOp(opResult.getDefiningOp())) {
-      if (bufferizableOp.bufferizesToAllocation(opResult)) {
-        FailureOr<unsigned> queriedMemorySpace =
-            bufferizableOp.getMemorySpace(opResult);
-        if (!failed(queriedMemorySpace))
-          memorySpace = *queriedMemorySpace;
-      }
-    }
-  }
-
-  // If we still do not know the memory space, use the default memory space (if
-  // any).
-  if (!memorySpace.has_value())
-    memorySpace = options.defaultMemorySpace;
-
-  // If we still do not know the memory space, report a failure.
-  if (!memorySpace.has_value())
+  // Op is not bufferizable.
+  if (!options.defaultMemorySpace.has_value())
     return op->emitError("could not infer memory space");
 
-  return getMemRefType(value, options, /*layout=*/{}, *memorySpace);
+  return getMemRefType(value, options, /*layout=*/{},
+                       *options.defaultMemorySpace);
 }
 
 void bufferization::replaceOpWithBufferizedValues(RewriterBase &rewriter,

diff  --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index 5fc8e6abb0a25..ee9591b2a01b8 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -152,7 +152,6 @@ void mlir::bufferization::populateDynamicDimSizes(
 LogicalResult AllocTensorOp::bufferize(RewriterBase &rewriter,
                                        const BufferizationOptions &options) {
   OpBuilder::InsertionGuard g(rewriter);
-  Operation *op = this->getOperation();
   Location loc = getLoc();
 
   // Nothing to do for dead AllocTensorOps.
@@ -170,30 +169,17 @@ LogicalResult AllocTensorOp::bufferize(RewriterBase &rewriter,
     copyBuffer = *maybeCopyBuffer;
   }
 
-  // Compute memory space of this allocation.
-  unsigned memorySpace;
-  if (getMemorySpace().has_value()) {
-    memorySpace = *getMemorySpace();
-  } else if (getCopy()) {
-    memorySpace =
-        copyBuffer.getType().cast<BaseMemRefType>().getMemorySpaceAsInt();
-  } else if (options.defaultMemorySpace.has_value()) {
-    memorySpace = *options.defaultMemorySpace;
-  } else {
-    return op->emitError("could not infer memory space");
-  }
-
   // Create memory allocation.
-  auto allocType =
-      MemRefType::get(getType().getShape(), getType().getElementType(),
-                      AffineMap(), memorySpace);
+  auto allocType = getBufferType(getResult(), options);
+  if (failed(allocType))
+    return failure();
   SmallVector<Value> dynamicDims = getDynamicSizes();
   if (getCopy()) {
     assert(dynamicDims.empty() && "expected either `copy` or `dynamicDims`");
     populateDynamicDimSizes(rewriter, loc, copyBuffer, dynamicDims);
   }
-  FailureOr<Value> alloc =
-      options.createAlloc(rewriter, loc, allocType, dynamicDims);
+  FailureOr<Value> alloc = options.createAlloc(
+      rewriter, loc, allocType->cast<MemRefType>(), dynamicDims);
   if (failed(alloc))
     return failure();
 
@@ -247,6 +233,28 @@ AllocTensorOp::getAliasingOpResult(OpOperand &opOperand,
   return {};
 }
 
+FailureOr<BaseMemRefType>
+AllocTensorOp::getBufferType(Value value, const BufferizationOptions &options) {
+  assert(value == getResult() && "invalid value");
+
+  // Compute memory space of this allocation.
+  unsigned memorySpace;
+  if (getMemorySpace().has_value()) {
+    memorySpace = *getMemorySpace();
+  } else if (getCopy()) {
+    auto copyBufferType = bufferization::getBufferType(getCopy(), options);
+    if (failed(copyBufferType))
+      return failure();
+    memorySpace = copyBufferType->getMemorySpaceAsInt();
+  } else if (options.defaultMemorySpace.has_value()) {
+    memorySpace = *options.defaultMemorySpace;
+  } else {
+    return getOperation()->emitError("could not infer memory space");
+  }
+
+  return getMemRefTypeWithStaticIdentityLayout(getType(), memorySpace);
+}
+
 LogicalResult AllocTensorOp::verify() {
   if (getCopy() && !getDynamicSizes().empty())
     return emitError("dynamic sizes not needed when copying a tensor");

diff  --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index 9459640dad02b..0b5939e60c1bd 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -472,9 +472,13 @@ struct ForOpInterface
   }
 
   FailureOr<BaseMemRefType>
-  getBufferType(Operation *op, BlockArgument bbArg,
+  getBufferType(Operation *op, Value value,
                 const BufferizationOptions &options) const {
     auto forOp = cast<scf::ForOp>(op);
+    // TODO: Only block arguments supported at the moment.
+    if (value.isa<OpResult>())
+      return failure();
+    auto bbArg = value.cast<BlockArgument>();
     return bufferization::getBufferType(
         forOp.getOpOperandForRegionIterArg(bbArg).get(), options);
   }

diff  --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index cf9e5c89dfaa5..2e23a167cd0b2 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -290,21 +290,38 @@ struct ExtractSliceOpInterface
         getBuffer(rewriter, extractSliceOp.getSource(), options);
     if (failed(srcMemref))
       return failure();
-    auto srcMemrefType = srcMemref->getType().cast<MemRefType>();
 
     // Take a subview of the source buffer.
-    auto subviewMemRefType =
-        memref::SubViewOp::inferRankReducedResultType(
-            extractSliceOp.getType().getShape(), srcMemrefType, mixedOffsets,
-            mixedSizes, mixedStrides)
-            .cast<MemRefType>();
+    auto resultMemrefType =
+        getBufferType(op, extractSliceOp.getResult(), options);
+    if (failed(resultMemrefType))
+      return failure();
     Value subView = rewriter.create<memref::SubViewOp>(
-        loc, subviewMemRefType, *srcMemref, mixedOffsets, mixedSizes,
-        mixedStrides);
+        loc, resultMemrefType->cast<MemRefType>(), *srcMemref, mixedOffsets,
+        mixedSizes, mixedStrides);
 
     replaceOpWithBufferizedValues(rewriter, op, subView);
     return success();
   }
+
+  FailureOr<BaseMemRefType>
+  getBufferType(Operation *op, Value value,
+                const BufferizationOptions &options) const {
+    auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
+    assert(value == extractSliceOp.getResult() && "invalid value");
+    auto srcMemrefType =
+        bufferization::getBufferType(extractSliceOp.getSource(), options);
+    if (failed(srcMemrefType))
+      return failure();
+    SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets();
+    SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
+    SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
+    return memref::SubViewOp::inferRankReducedResultType(
+               extractSliceOp.getType().getShape(),
+               srcMemrefType->cast<MemRefType>(), mixedOffsets, mixedSizes,
+               mixedStrides)
+        .cast<BaseMemRefType>();
+  }
 };
 
 /// Bufferization of tensor.extract. Replace with memref.load.


        


More information about the Mlir-commits mailing list