[Mlir-commits] [mlir] ad0050c - [mlir][Linalg] Add comprehensive bufferization support for TiledLoopOp (14/n)

Nicolas Vasilache llvmlistbot at llvm.org
Fri Jul 2 07:21:13 PDT 2021


Author: Nicolas Vasilache
Date: 2021-07-02T14:21:08Z
New Revision: ad0050c6073d8b9a6cbc9ab94c75fc5ba30051fd

URL: https://github.com/llvm/llvm-project/commit/ad0050c6073d8b9a6cbc9ab94c75fc5ba30051fd
DIFF: https://github.com/llvm/llvm-project/commit/ad0050c6073d8b9a6cbc9ab94c75fc5ba30051fd.diff

LOG: [mlir][Linalg] Add comprehensive bufferization support for TiledLoopOp (14/n)

Differential Revision: https://reviews.llvm.org/D105335

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
    mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir
    mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
index 178676c5e4b7b..ad296ff8c199e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
@@ -296,13 +296,13 @@ static InPlaceSpec getInPlace(BlockArgument bbArg) {
       return InPlaceSpec::None;
     return inplaceAttr.getValue() ? InPlaceSpec::True : InPlaceSpec::False;
   }
-  // Interestingly, scf::ForOp's bbArg can **always** be viewed inplace from the
-  // perspective of ops nested under it:
+  // Interestingly, scf::ForOp's and TiledLoop's bbArg can **always** be viewed
+  // inplace from the perspective of ops nested under:
   //   1. Either the matching iter operand is not bufferized inplace and an
   //      alloc + optional copy makes the bbArg itself inplaceable.
   //   2. Or the matching iter operand is bufferized inplace and bbArg just
   //      bufferizes to that too.
-  if (auto forOp = dyn_cast<scf::ForOp>(bbArg.getOwner()->getParentOp()))
+  if (isa<scf::ForOp, TiledLoopOp>(bbArg.getOwner()->getParentOp()))
     return InPlaceSpec::True;
   // Unknown cases.
   return InPlaceSpec::None;
@@ -359,19 +359,28 @@ static bool hasKnownBufferizationAliasingBehavior(Operation *op) {
       isa<CallOpInterface,
           tensor::CastOp,
           ConstantOp,
+          ExtractSliceOp,
           scf::ForOp,
+          InsertSliceOp,
           InitTensorOp,
           LinalgOp,
           ReturnOp,
-          ExtractSliceOp,
-          InsertSliceOp,
+          TiledLoopOp,
           VectorTransferOpInterface,
+          linalg::YieldOp,
           scf::YieldOp>(op)
       // clang-format on
       || (none_of(op->getResultTypes(), isaTensor) &&
           none_of(op->getOperandTypes(), isaTensor));
 }
 
+/// Return the OpResult that may bufferize into the same buffer as `opOperand`
+/// when the op is bufferized inplace.
+/// Return null if no such result exists.
+static OpResult getInplaceableOpResult(TiledLoopOp op, OpOperand &opOperand) {
+  return op.getTiedOpResult(opOperand);
+}
+
 /// Return the OpResult that may bufferize into the same buffer as `opOperand`
 /// when the op is bufferized inplace.
 /// Return null if no such result exists.
@@ -441,8 +450,9 @@ static OpResult getInplaceableOpResult(OpOperand &opOperand) {
         // result(s).
         .Case<tensor::CastOp,
               scf::ForOp,
-              LinalgOp,
               InsertSliceOp,
+              LinalgOp,
+              TiledLoopOp,
               VectorTransferOpInterface>(
             [&](auto op) { return getInplaceableOpResult(op, opOperand); })
         // ExtractSliceOp is special, when bufferized inplace it just returns an
@@ -469,18 +479,23 @@ static Optional<OpOperand *> getAliasingOpOperand(OpResult result) {
   return TypeSwitch<Operation *, OpOperand *>(result.getDefiningOp())
       .Case([&](tensor::CastOp op) { return &op->getOpOperand(0); })
       .Case([&](ConstantOp op) { return &op->getOpOperand(0); })
-      .Case([&](LinalgOp op) {
-        return op.getOutputTensorOperands()[result.getResultNumber()];
-      })
       .Case([&](ExtractSliceOp op) { return &op->getOpOperand(0); })
-      .Case([&](InsertSliceOp op) { return &op->getOpOperand(1); })
-      .Case([&](vector::TransferWriteOp op) { return &op->getOpOperand(1); })
       // In the case of scf::ForOp, this currently assumes the iter_args / yield
       // are 1-1. This may fail and is verified at the end.
       // TODO: update this.
       .Case([&](scf::ForOp op) {
         return &op.getIterOpOperands()[result.getResultNumber()];
       })
+      .Case([&](InsertSliceOp op) { return &op->getOpOperand(1); })
+      .Case([&](LinalgOp op) {
+        return op.getOutputTensorOperands()[result.getResultNumber()];
+      })
+      .Case([&](TiledLoopOp op) {
+        // TODO: TiledLoopOp helper method to avoid leaking impl details.
+        return &op->getOpOperand(op.getNumControlOperands() +
+                                 op.getNumInputs() + result.getResultNumber());
+      })
+      .Case([&](vector::TransferWriteOp op) { return &op->getOpOperand(1); })
       .Default([&](Operation *op) {
         op->dump();
         llvm_unreachable("unexpected defining op");
@@ -528,6 +543,10 @@ static bool bufferizesToMemoryRead(OpOperand &opOperand) {
   // matching bbArg may.
   if (isa<scf::ForOp>(opOperand.getOwner()))
     return false;
+  // TiledLoop alone doesn't bufferize to a memory read, one of the uses of its
+  // matching bbArg may.
+  if (isa<TiledLoopOp>(opOperand.getOwner()))
+    return false;
   // CallOpInterface alone doesn't bufferize to a memory read, one of the uses
   // of the matching bbArg may. It is the responsibility of the caller to
   // inspect bbArgs. In the absence of a BufferizationAliasInfo, we need to be
@@ -1340,11 +1359,10 @@ createNewAllocDeallocPairForShapedValue(OpBuilder &b, Location loc,
 /// When allocating a new buffer, analyze whether `op` want to read form that
 /// buffer. In such a case, insert a copy to ensure the newly allocated buffer
 /// is properly initialiazed.
-static LogicalResult
-allocateBuffersForResults(OpBuilder &b, Location loc, LinalgOp op,
-                          SmallVectorImpl<Value> &resultBuffers,
-                          BlockAndValueMapping &bvm,
-                          BufferizationAliasInfo &aliasInfo) {
+static void allocateBuffersForResults(OpBuilder &b, Location loc, LinalgOp op,
+                                      SmallVectorImpl<Value> &resultBuffers,
+                                      BlockAndValueMapping &bvm,
+                                      BufferizationAliasInfo &aliasInfo) {
   // Take a guard before anything else.
   OpBuilder::InsertionGuard g(b);
 
@@ -1360,8 +1378,7 @@ allocateBuffersForResults(OpBuilder &b, Location loc, LinalgOp op,
     OpResult opResult = getInplaceableOpResult(*opOperand);
     if (getInPlace(opResult) == InPlaceSpec::True) {
       Value v = lookup(bvm, output);
-      if (!v)
-        return failure();
+      assert(v && "missing buffer");
       resultBuffers.push_back(v);
       continue;
     }
@@ -1375,17 +1392,13 @@ allocateBuffersForResults(OpBuilder &b, Location loc, LinalgOp op,
 
     // Additionally, if the output buffer is used, clone its value for now.
     if (op.payloadUsesValueFromOperand(opOperand)) {
-      if (Value v = lookup(bvm, output))
-        b.create<CopyOp>(loc, v, alloc);
-      else
-        return failure();
+      Value v = lookup(bvm, output);
+      b.create<CopyOp>(loc, v, alloc);
     }
   }
 
   if (op->getNumResults())
     map(bvm, op->getResults(), resultBuffers);
-
-  return success();
 }
 
 /// Generic conversion for any LinalgOp on tensors.
@@ -1398,7 +1411,7 @@ static LogicalResult bufferize(OpBuilder &b, LinalgOp op,
   // Ensure op has only tensors. Allow mixed tensor-buffer mode on a per-need
   // basis.
   if (!op.hasTensorSemantics())
-    return failure();
+    return op->emitError() << "op does not have tensor semantics";
 
   b.setInsertionPoint(op);
   Location loc = op.getLoc();
@@ -1410,14 +1423,11 @@ static LogicalResult bufferize(OpBuilder &b, LinalgOp op,
       continue;
     }
     newInputBuffers.push_back(lookup(bvm, opOperand->get()));
-    if (!newInputBuffers.back())
-      return failure();
+    assert(newInputBuffers.back() && "missing buffer");
   }
   SmallVector<Value> newOutputBuffers;
   // Try to allocate new buffers depending on op's inplace semantics.
-  if (failed(allocateBuffersForResults(b, loc, op, newOutputBuffers, bvm,
-                                       aliasInfo)))
-    return failure();
+  allocateBuffersForResults(b, loc, op, newOutputBuffers, bvm, aliasInfo);
 
   // Clone the newly bufferized op.
   SmallVector<Value> newOperands = newInputBuffers;
@@ -1608,8 +1618,8 @@ static LogicalResult bufferize(OpBuilder &b, ConstantOp constantOp,
                                BlockAndValueMapping &bvm,
                                BufferizationAliasInfo &aliasInfo,
                                GlobalCreator &globalCreator) {
-  if (!constantOp.getType().dyn_cast<RankedTensorType>())
-    return failure();
+  assert(constantOp.getType().dyn_cast<RankedTensorType>() &&
+         "not a constant ranked tensor");
 
   // Take a guard before anything else.
   OpBuilder::InsertionGuard g(b);
@@ -1629,11 +1639,15 @@ static LogicalResult bufferize(OpBuilder &b, ConstantOp constantOp,
 static LogicalResult bufferize(OpBuilder &b, tensor::DimOp dimOp,
                                BlockAndValueMapping &bvm,
                                BufferizationAliasInfo &aliasInfo) {
+  // Take a guard before anything else.
+  OpBuilder::InsertionGuard g(b);
+  b.setInsertionPoint(dimOp);
+
   if (dimOp.source().getType().isa<RankedTensorType>()) {
     Value v = lookup(bvm, dimOp.source());
-    if (!v)
-      return failure();
-    dimOp.sourceMutable().assign(v);
+    assert(v && "missing buffer");
+    dimOp.result().replaceAllUsesWith(
+        b.create<memref::DimOp>(dimOp.getLoc(), v, dimOp.index()));
   }
   return success();
 }
@@ -1649,10 +1663,12 @@ static LogicalResult bufferize(OpBuilder &b, scf::ForOp forOp,
   // Otherwise alloc and copy.
   b.setInsertionPoint(forOp);
   for (OpResult opResult : forOp->getResults()) {
+    if (!opResult.getType().isa<TensorType>())
+      continue;
     // TODO: Atm we bail on unranked TensorType because we don't know how to
     // alloc an UnrankedMemRefType + its underlying ranked MemRefType.
-    if (!opResult.getType().isa<RankedTensorType>())
-      return failure();
+    assert(opResult.getType().isa<RankedTensorType>() &&
+           "unsupported unranked tensor");
     OpOperand &opOperand = forOp.getOpOperandForResult(opResult);
     Value operand = opOperand.get();
     Value operandBuffer = lookup(bvm, operand);
@@ -1730,8 +1746,7 @@ static LogicalResult bufferize(OpBuilder &b, ReturnOp returnOp,
     if (!tensorType)
       continue;
     Value v = lookup(bvm, operand.get());
-    if (!v)
-      return failure();
+    assert(v && "missing buffer for result");
     Value returnTensor = b.create<memref::TensorLoadOp>(returnOp.getLoc(), v);
     operand.set(returnTensor);
     aliasInfo.insertNewBufferEquivalence(returnTensor, v);
@@ -1740,6 +1755,135 @@ static LogicalResult bufferize(OpBuilder &b, ReturnOp returnOp,
   return success();
 }
 
+/// Bufferization for TiledLoopOp..
+static LogicalResult bufferize(OpBuilder &b, TiledLoopOp tiledLoopOp,
+                               BlockAndValueMapping &bvm,
+                               BufferizationAliasInfo &aliasInfo) {
+  // Allocate output buffers if needed, forward output tensor args to the
+  // terminator.
+  Operation *yieldOp = tiledLoopOp.getBody()->getTerminator();
+  Block *body = tiledLoopOp.getBody();
+
+  // Take copies of the old input and output operands, so we can insert inplace
+  // easily.
+  auto oldInputs = llvm::to_vector<4>(tiledLoopOp.inputs());
+  auto oldOutputs = llvm::to_vector<4>(tiledLoopOp.outputs());
+
+  int numLoops = tiledLoopOp.getNumLoops();
+  int numControlOperands = tiledLoopOp.getNumControlOperands();
+
+  // Add buffers for outputs and the corresponding block arguments.
+  // Keep separate iterators to increment without further leaking impl. details.
+  // Start with outputs to avoid interference from new input buffers.
+  int numNewOutputBuffers = 0;
+  int resultIndex = 0;
+  int oldOutputBBArgIndex = numLoops + oldInputs.size();
+  int nextOutputBBArgIndex = numLoops + oldInputs.size() + oldOutputs.size();
+  int nextOutputOperandIndex =
+      numControlOperands + oldInputs.size() + oldOutputs.size();
+  for (Value oldOutputTensor : oldOutputs) {
+    if (!oldOutputTensor.getType().isa<TensorType>()) {
+      // Skip and increment the old bbarg index only.
+      ++oldOutputBBArgIndex;
+      // Do not increment resultIndex as only tensors are returned.
+      // TODO: better interface to avoid leaking such impl details.
+      continue;
+    }
+
+    assert(oldOutputTensor.getType().isa<RankedTensorType>() &&
+           "bufferizable output must be a ranked tensor");
+
+    Value outputBuffer = lookup(bvm, oldOutputTensor);
+    const OpResult &opResult = tiledLoopOp->getResult(resultIndex);
+    OpOperand &yieldOperand = yieldOp->getOpOperand(resultIndex);
+    // If the result is not inplaceable, need to allocate a copy for it.
+    if (getInPlace(opResult) != InPlaceSpec::True) {
+      auto loc = tiledLoopOp.getLoc();
+      Value alloc = createNewAllocDeallocPairForShapedValue(
+          b, loc, oldOutputTensor, aliasInfo);
+      // If the tensor comes from `linalg::InitTensorOp`, the value is
+      // unitialized and we do not need to copy.
+      // TODO: "matching bbArg does not bufferize to a read" is a more general
+      // check.
+      if (!oldOutputTensor.getDefiningOp<linalg::InitTensorOp>()) {
+        b.setInsertionPointAfter(alloc.getDefiningOp());
+        b.create<linalg::CopyOp>(loc, outputBuffer, alloc);
+      }
+      outputBuffer = alloc;
+    }
+    // Insert mapping and aliasing info.
+    aliasInfo.createAliasInfoEntry(outputBuffer);
+    aliasInfo.insertNewBufferEquivalence(opResult, outputBuffer);
+    map(bvm, opResult, outputBuffer);
+
+    // Insert new operand and bbArg.
+    tiledLoopOp->insertOperands(nextOutputOperandIndex, outputBuffer);
+    BlockArgument newBufferBBArg =
+        body->insertArgument(nextOutputBBArgIndex, outputBuffer.getType());
+    BlockArgument oldTensorBBArg = body->getArgument(oldOutputBBArgIndex);
+    // Insert mapping and aliasing info.
+    aliasInfo.createAliasInfoEntry(newBufferBBArg);
+    aliasInfo.insertNewBufferEquivalence(oldTensorBBArg, newBufferBBArg);
+    map(bvm, oldTensorBBArg, newBufferBBArg);
+
+    // Set operand of `linalg.yield` to the bbArg so it just canonicalizes away
+    // later.
+    yieldOperand.set(oldTensorBBArg);
+
+    // Increment indices.
+    ++numNewOutputBuffers;
+    ++resultIndex;
+    ++oldOutputBBArgIndex;
+    ++nextOutputBBArgIndex;
+    ++nextOutputOperandIndex;
+  }
+
+  // Add buffers for inputs and the corresponding block arguments.
+  // Keep separate iterators to increment without further leaking impl. details.
+  int numNewInputBuffers = 0;
+  int oldInputBBArgIndex = numLoops;
+  int nextInputBBArgIndex = numLoops + oldInputs.size();
+  int nextInputOperandIndex = numControlOperands + oldInputs.size();
+  for (Value oldInputTensor : oldInputs) {
+    if (!oldInputTensor.getType().isa<TensorType>()) {
+      // Skip and increment the old bbarg index only.
+      ++oldInputBBArgIndex;
+      continue;
+    }
+
+    Value inputBuffer = lookup(bvm, oldInputTensor);
+    assert(inputBuffer && " missing buffer for operand");
+
+    // Insert new operand and bbArg.
+    tiledLoopOp->insertOperands(nextInputOperandIndex, inputBuffer);
+    BlockArgument newBufferBBArg =
+        body->insertArgument(nextInputBBArgIndex, inputBuffer.getType());
+    BlockArgument oldTensorBBArg = body->getArgument(oldInputBBArgIndex);
+
+    // Insert mapping and aliasing info.
+    aliasInfo.createAliasInfoEntry(newBufferBBArg);
+    aliasInfo.insertNewBufferEquivalence(oldTensorBBArg, newBufferBBArg);
+    map(bvm, oldTensorBBArg, newBufferBBArg);
+
+    // Increment indices.
+    ++numNewInputBuffers;
+    ++oldInputBBArgIndex;
+    ++nextInputBBArgIndex;
+    ++nextInputOperandIndex;
+  }
+
+  // Update segment sizes.
+  // TODO: Helper method to avoid leaking impl details.
+  tiledLoopOp->setAttr(
+      TiledLoopOp::getOperandSegmentSizeAttr(),
+      b.getI32VectorAttr(
+          {numLoops, numLoops, numLoops,
+           static_cast<int>(oldInputs.size()) + numNewInputBuffers,
+           static_cast<int>(oldOutputs.size()) + numNewOutputBuffers}));
+
+  return success();
+}
+
 /// Bufferize ExtractSliceOp to subview with optional alloc + copy depending on
 /// whether or not it is marked inplaceable.
 /// Note that `getInplaceableOpResult` on a ExtractSliceOp always returns null.
@@ -1871,8 +2015,7 @@ static LogicalResult bufferize(OpBuilder &b, VectorTransferOpInterface op,
   /// op.source().
   if (auto readOp = dyn_cast<vector::TransferReadOp>(op.getOperation())) {
     Value v = lookup(bvm, op.source());
-    if (!v)
-      return failure();
+    assert(v && "missing buffer");
     readOp.sourceMutable().assign(v);
     return success();
   }
@@ -1891,8 +2034,7 @@ static LogicalResult bufferize(OpBuilder &b, VectorTransferOpInterface op,
     // InPlace write will result in memref.tensor_load(x) which must
     // canonicalize away with one of it uses.
     newInputBuffer = lookup(bvm, writeOp.source());
-    if (!newInputBuffer)
-      return failure();
+    assert(newInputBuffer && "missing buffer");
   }
 
   // Create a new transfer_write on buffer that doesn't have a return value.
@@ -1933,6 +2075,22 @@ static LogicalResult bufferize(OpBuilder &b, scf::YieldOp yieldOp,
   return success();
 }
 
+/// Bufferization for linalg::YieldOp either does not involve tensors or just
+/// results in later canonicalization. In either case it does nothing.
+static LogicalResult bufferize(OpBuilder &b, linalg::YieldOp yieldOp,
+                               BlockAndValueMapping &bvm,
+                               BufferizationAliasInfo &aliasInfo) {
+  // Take a guard before anything else.
+  OpBuilder::InsertionGuard g(b);
+  b.setInsertionPoint(yieldOp);
+  // No tensors -> success.
+  if (!llvm::any_of(yieldOp.getOperandTypes(), isaTensor))
+    return success();
+  // linalg::YieldOp nested under TiledLoop must just canonicalize.
+  if (yieldOp->getParentOfType<TiledLoopOp>())
+    return success();
+  llvm_unreachable("unexpected yieldOp");
+}
 //===----------------------------------------------------------------------===//
 // Bufferization analyses.
 //===----------------------------------------------------------------------===//
@@ -2043,7 +2201,7 @@ bufferizationSanityCheck(scf::YieldOp yieldOp,
                          const BufferizationAliasInfo &aliasInfo) {
   auto parentForOp = yieldOp->getParentOfType<scf::ForOp>();
   if (!parentForOp)
-    return failure();
+    return yieldOp->emitError() << "not nested under ForOp";
 
   for (OpOperand &operand : yieldOp->getOpOperands()) {
     OpResult matchingForOpResult =
@@ -2057,11 +2215,10 @@ bufferizationSanityCheck(scf::YieldOp yieldOp,
         parentForOp.getRegionIterArgForOpOperand(machingForOpOperand);
     if (!aliasInfo.areEquivalentBufferizedValues(matchingForOpIterArg,
                                                  operand.get())) {
-      yieldOp->emitError()
-          << "Yield operand #" << operand.getOperandNumber()
-          << " does not bufferize to an equivalent buffer to the matching"
-          << " enclosing scf::for operand -> Fail the pass\n";
-      return failure();
+      return yieldOp->emitError()
+             << "Yield operand #" << operand.getOperandNumber()
+             << " does not bufferize to an equivalent buffer to the matching"
+             << " enclosing scf::for operand -> Fail the pass\n";
     }
   }
 
@@ -2150,10 +2307,10 @@ static LogicalResult bufferizeFuncOpInternals(
   // Walk in PreOrder to ensure ops with regions are handled before their body.
   // Since walk has to be PreOrder, we need to erase ops that require it
   // separately: this is the case for CallOp
+  // clang-format off
   SmallVector<Operation *> toErase;
-  WalkResult result =
-      funcOp.walk<WalkOrder::PreOrder>([&](Operation *op) -> WalkResult {
-        // clang-format off
+  WalkResult result = funcOp.walk<WalkOrder::PreOrder>([&](Operation *op)
+                                                          -> WalkResult {
     WalkResult result =
       TypeSwitch<Operation *, LogicalResult>(op)
       // Skip BufferCast and TensorLoad ops.
@@ -2161,13 +2318,15 @@ static LogicalResult bufferizeFuncOpInternals(
             memref::TensorLoadOp>([&](auto) { return success(); })
       .Case<tensor::CastOp,
             tensor::DimOp,
+            ExtractSliceOp,
             scf::ForOp,
             InitTensorOp,
+            InsertSliceOp,
             LinalgOp,
             ReturnOp,
-            ExtractSliceOp,
-            InsertSliceOp,
+            TiledLoopOp,
             VectorTransferOpInterface,
+            linalg::YieldOp,
             scf::YieldOp>([&](auto op) {
         LDBG("Begin bufferize:\n" << op << '\n');
         return bufferize(b, op, bvm, aliasInfo);
@@ -2182,23 +2341,23 @@ static LogicalResult bufferizeFuncOpInternals(
         LDBG("Begin bufferize:\n" << op << '\n');
         return bufferize(b, op, bvm, aliasInfo, globalCreator);
       })
-      .Default([&](Operation *op) {
+      .Default([&](Operation *op) -> LogicalResult {
         auto isaTensor = [](Type t) { return t.isa<TensorType>(); };
         if (any_of(op->getOperandTypes(), isaTensor) ||
             any_of(op->getResultTypes(), isaTensor))
-          return failure();
+          return op->emitError() << "unsupported op with tensors";
         return success();
       });
-        // clang-format on
 
-        // Register post-walk erasure, if necessary.
-        if (isa<CallOpInterface>(op))
-          if (llvm::any_of(op->getOperandTypes(), isaTensor) ||
-              llvm::any_of(op->getResultTypes(), isaTensor))
-            toErase.push_back(op);
+    // Register post-walk erasure, if necessary.
+    if (isa<CallOpInterface>(op))
+      if (llvm::any_of(op->getOperandTypes(), isaTensor) ||
+          llvm::any_of(op->getResultTypes(), isaTensor))
+        toErase.push_back(op);
 
-        return result;
-      });
+    return result;
+  });
+  // clang-format on
   LDBG("End BufferizeFuncOpInternals:\n" << funcOp << '\n');
 
   for (Operation *op : toErase)

diff  --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir
index 78f84cc8540c4..d8257dd172c63 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -linalg-comprehensive-module-bufferize -split-input-file -verify-diagnostics
+// RUN: mlir-opt %s -allow-unregistered-dialect -linalg-comprehensive-module-bufferize -split-input-file -verify-diagnostics
 
 func private @foo() -> tensor<?xf32>
 
@@ -85,3 +85,25 @@ func @extract_slice_fun(%A : tensor<?xf32> {linalg.inplaceable = true})
   // expected-error @+1 {{buffer result #0 not produced by an alloc}}
   return %r0: tensor<4xf32>
 }
+
+// -----
+
+func @scf_yield(%b : i1, %A : tensor<4xf32>, %B : tensor<4xf32>) -> tensor<4xf32>
+{
+  %r = scf.if %b -> (tensor<4xf32>) { 
+    // expected-error @+1 {{not nested under ForOp}}
+    scf.yield %A : tensor<4xf32>
+  } else {
+    scf.yield %B : tensor<4xf32>
+  }
+  return %r: tensor<4xf32>
+}
+
+// -----
+
+func @unknown_op(%A : tensor<4xf32>) -> tensor<4xf32>
+{
+  // expected-error @+1 {{unsupported op with tensors}}
+  %r = "marklar"(%A) : (tensor<4xf32>) -> (tensor<4xf32>)
+  return %r: tensor<4xf32>
+}

diff  --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
index f7f221b2b77fb..b29cf6e81f92c 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
@@ -498,3 +498,60 @@ func @main() {
 
 //     CHECK:   func private @print_memref_f32(memref<*xf32>)
 func private @print_memref_f32(tensor<*xf32>)
+
+// -----
+
+func private @some_use(memref<?xf32>)
+
+#TILE_MAP = affine_map<(d0)[s0] -> (3, -d0 + s0)>
+
+//  CHECK-DAG: #[[$DYN_0D_MAP:.*]] = affine_map<()[s0] -> (s0)>
+//  CHECK-DAG: #[[$DYN_1D_MAP:.*]] = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)>
+//  CHECK-DAG: #[[$TILE_MAP:.*]] = affine_map<(d0)[s0] -> (3, -d0 + s0)>
+
+//      CHECK:  func @tiled_dot(
+// CHECK-SAME:    %[[A:[a-zA-Z0-9]*]]: memref<?xf32, #[[$DYN_1D_MAP]]>
+// CHECK-SAME:    %[[B:[a-zA-Z0-9]*]]: memref<?xf32, #[[$DYN_1D_MAP]]>
+// CHECK-SAME:    %[[c:[a-zA-Z0-9]*]]: memref<f32, #[[$DYN_0D_MAP]]>
+func @tiled_dot(%A: tensor<?xf32>, %B: tensor<?xf32>, %c: tensor<f32> {linalg.inplaceable = true},
+                %effecting: memref<?xf32>) -> tensor<f32> {
+  %c3 = constant 3 : index
+  %c0 = constant 0 : index
+
+  //     CHECK: %[[M:.*]] = memref.dim %[[A]], {{.*}} : memref<?xf32, #[[$DYN_1D_MAP:.*]]>
+  %0 = tensor.dim %A, %c0 : tensor<?xf32>
+
+  //     CHECK: linalg.tiled_loop {{.*}} to (%[[M]]) {{.*}} %[[A]]{{.*}}%[[B]]{{.*}}outs{{.*}}%[[c]]
+  %1 = linalg.tiled_loop (%arg3) = (%c0) to (%0) step (%c3)
+       ins (%arg4 = %A: tensor<?xf32>, %use = %effecting : memref<?xf32>, %arg5 = %B: tensor<?xf32>)
+      outs (%arg6 = %c: tensor<f32>)
+      iterators["reduction"]
+  {
+    // CHECK-NOT:   alloc
+
+    %2 = tensor.dim %arg4, %c0 : tensor<?xf32>
+    %3 = affine.min #TILE_MAP(%arg3)[%2]
+
+    //     CHECK:   %[[SV_A:.*]] = memref.subview {{.*}}
+    %4 = tensor.extract_slice %arg4[%arg3] [%3] [1] : tensor<?xf32> to tensor<?xf32>
+    %5 = tensor.dim %arg5, %c0 : tensor<?xf32>
+    %6 = affine.min #TILE_MAP(%arg3)[%5]
+
+    //     CHECK:   %[[SV_B:.*]] = memref.subview {{.*}}
+    %7 = tensor.extract_slice %arg5[%arg3] [%6] [1] : tensor<?xf32> to tensor<?xf32>
+
+    //     CHECK:   linalg.dot ins(%[[SV_A]], %[[SV_B]] : memref<?xf32, #[[$DYN_1D_MAP:.*]]>, memref<?xf32, #[[$DYN_1D_MAP:.*]]>) outs(%{{.*}} : memref<f32, #[[$DYN_0D_MAP]]>)
+    %8 = linalg.dot ins(%4, %7 : tensor<?xf32>, tensor<?xf32>) outs(%arg6 : tensor<f32>) -> tensor<f32>
+
+    //     CHECK:   call @some_use(%{{.*}}) : (memref<?xf32>) -> ()
+    call @some_use(%use) : (memref<?xf32>) -> ()
+
+    linalg.yield %8 : tensor<f32>
+    //     CHECK:   linalg.yield
+    // CHECK-NOT:   tensor
+  }
+
+  //     CHECK: return
+  // CHECK-NOT: tensor
+  return %1 : tensor<f32>
+}


        


More information about the Mlir-commits mailing list