[Mlir-commits] [mlir] 9d6096c - [mlir][SCF][bufferize][NFC] Move scf.if buffer type computation to getBufferType

Matthias Springer llvmlistbot at llvm.org
Tue Aug 30 07:49:43 PDT 2022


Author: Matthias Springer
Date: 2022-08-30T16:48:10+02:00
New Revision: 9d6096c56fcafbd882d5f688cbd8d62ec2f2ac71

URL: https://github.com/llvm/llvm-project/commit/9d6096c56fcafbd882d5f688cbd8d62ec2f2ac71
DIFF: https://github.com/llvm/llvm-project/commit/9d6096c56fcafbd882d5f688cbd8d62ec2f2ac71.diff

LOG: [mlir][SCF][bufferize][NFC] Move scf.if buffer type computation to getBufferType

A part of the functionality of `bufferize` is extracted into `getBufferType`. Also, bufferized scf.yields inside scf.if are now created with the correct bufferized type from the get-to.

Differential Revision: https://reviews.llvm.org/D132862

Added: 
    

Modified: 
    mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
    mlir/test/Dialect/SCF/one-shot-bufferize-invalid.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index b92b131616fbc..7e458a1f37626 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -163,52 +163,22 @@ struct IfOpInterface
                           const BufferizationOptions &options) const {
     OpBuilder::InsertionGuard g(rewriter);
     auto ifOp = cast<scf::IfOp>(op);
-    auto thenYieldOp = cast<scf::YieldOp>(ifOp.thenBlock()->getTerminator());
-    auto elseYieldOp = cast<scf::YieldOp>(ifOp.elseBlock()->getTerminator());
 
-    // Reconcile type mismatches between then/else branches by inserting memref
-    // casts.
-    SmallVector<Value> thenResults, elseResults;
-    bool insertedCast = false;
-    for (unsigned i = 0; i < thenYieldOp.getResults().size(); ++i) {
-      Value thenValue = thenYieldOp.getResults()[i];
-      Value elseValue = elseYieldOp.getResults()[i];
-      if (thenValue.getType() == elseValue.getType()) {
-        thenResults.push_back(thenValue);
-        elseResults.push_back(elseValue);
+    // Compute bufferized result types.
+    SmallVector<Type> newTypes;
+    for (Value result : ifOp.getResults()) {
+      if (!result.getType().isa<TensorType>()) {
+        newTypes.push_back(result.getType());
         continue;
       }
-
-      // Type mismatch between then/else yield value. Cast both to a memref type
-      // with a fully dynamic layout map.
-      auto thenMemrefType = thenValue.getType().cast<BaseMemRefType>();
-      auto elseMemrefType = elseValue.getType().cast<BaseMemRefType>();
-      if (thenMemrefType.getMemorySpaceAsInt() !=
-          elseMemrefType.getMemorySpaceAsInt())
-        return op->emitError("inconsistent memory space on then/else branches");
-      rewriter.setInsertionPoint(thenYieldOp);
-      BaseMemRefType memrefType = getMemRefTypeWithFullyDynamicLayout(
-          ifOp.getResultTypes()[i].cast<TensorType>(),
-          thenMemrefType.getMemorySpaceAsInt());
-      thenResults.push_back(rewriter.create<memref::CastOp>(
-          thenYieldOp.getLoc(), memrefType, thenValue));
-      rewriter.setInsertionPoint(elseYieldOp);
-      elseResults.push_back(rewriter.create<memref::CastOp>(
-          elseYieldOp.getLoc(), memrefType, elseValue));
-      insertedCast = true;
-    }
-
-    if (insertedCast) {
-      rewriter.setInsertionPoint(thenYieldOp);
-      rewriter.replaceOpWithNewOp<scf::YieldOp>(thenYieldOp, thenResults);
-      rewriter.setInsertionPoint(elseYieldOp);
-      rewriter.replaceOpWithNewOp<scf::YieldOp>(elseYieldOp, elseResults);
+      auto bufferType = bufferization::getBufferType(result, options);
+      if (failed(bufferType))
+        return failure();
+      newTypes.push_back(*bufferType);
     }
 
     // Create new op.
     rewriter.setInsertionPoint(ifOp);
-    ValueRange resultsValueRange(thenResults);
-    TypeRange newTypes(resultsValueRange);
     auto newIfOp =
         rewriter.create<scf::IfOp>(ifOp.getLoc(), newTypes, ifOp.getCondition(),
                                    /*withElseRegion=*/true);
@@ -223,6 +193,55 @@ struct IfOpInterface
     return success();
   }
 
+  FailureOr<BaseMemRefType>
+  getBufferType(Operation *op, Value value, const BufferizationOptions &options,
+                const DenseMap<Value, BaseMemRefType> &fixedTypes) const {
+    auto ifOp = cast<scf::IfOp>(op);
+    auto thenYieldOp = cast<scf::YieldOp>(ifOp.thenBlock()->getTerminator());
+    auto elseYieldOp = cast<scf::YieldOp>(ifOp.elseBlock()->getTerminator());
+    assert(value.getDefiningOp() == op && "invalid valid");
+
+    // Determine buffer types of the true/false branches.
+    auto opResult = value.cast<OpResult>();
+    auto thenValue = thenYieldOp.getOperand(opResult.getResultNumber());
+    auto elseValue = elseYieldOp.getOperand(opResult.getResultNumber());
+    BaseMemRefType thenBufferType, elseBufferType;
+    if (thenValue.getType().isa<BaseMemRefType>()) {
+      // True branch was already bufferized.
+      thenBufferType = thenValue.getType().cast<BaseMemRefType>();
+    } else {
+      auto maybeBufferType =
+          bufferization::getBufferType(thenValue, options, fixedTypes);
+      if (failed(maybeBufferType))
+        return failure();
+      thenBufferType = *maybeBufferType;
+    }
+    if (elseValue.getType().isa<BaseMemRefType>()) {
+      // False branch was already bufferized.
+      elseBufferType = elseValue.getType().cast<BaseMemRefType>();
+    } else {
+      auto maybeBufferType =
+          bufferization::getBufferType(elseValue, options, fixedTypes);
+      if (failed(maybeBufferType))
+        return failure();
+      elseBufferType = *maybeBufferType;
+    }
+
+    // Best case: Both branches have the exact same buffer type.
+    if (thenBufferType == elseBufferType)
+      return thenBufferType;
+
+    // Memory space mismatch.
+    if (thenBufferType.getMemorySpaceAsInt() !=
+        elseBufferType.getMemorySpaceAsInt())
+      return op->emitError("inconsistent memory space on then/else branches");
+
+    // Layout maps are 
diff erent: Promote to fully dynamic layout map.
+    return getMemRefTypeWithFullyDynamicLayout(
+        opResult.getType().cast<TensorType>(),
+        thenBufferType.getMemorySpaceAsInt());
+  }
+
   BufferRelation bufferRelation(Operation *op, OpResult opResult,
                                 const AnalysisState &state) const {
     // IfOp results are equivalent to their corresponding yield values if both
@@ -973,9 +992,12 @@ struct YieldOpInterface
         if (failed(maybeBuffer))
           return failure();
         Value buffer = *maybeBuffer;
-        if (auto forOp = dyn_cast<scf::ForOp>(yieldOp->getParentOp())) {
+        // In case of scf::ForOp / scf::IfOp, we may have to cast the value
+        // before yielding it.
+        // TODO: Do the same for scf::WhileOp.
+        if (isa<scf::ForOp, scf::IfOp>(yieldOp->getParentOp())) {
           FailureOr<BaseMemRefType> resultType = bufferization::getBufferType(
-              forOp.getRegionIterArgs()[it.index()], options);
+              yieldOp->getParentOp()->getResult(it.index()), options);
           if (failed(resultType))
             return failure();
           buffer = castBuffer(rewriter, buffer, *resultType);

diff  --git a/mlir/test/Dialect/SCF/one-shot-bufferize-invalid.mlir b/mlir/test/Dialect/SCF/one-shot-bufferize-invalid.mlir
index 52338d0701be3..66a9807fb1086 100644
--- a/mlir/test/Dialect/SCF/one-shot-bufferize-invalid.mlir
+++ b/mlir/test/Dialect/SCF/one-shot-bufferize-invalid.mlir
@@ -5,9 +5,9 @@ func.func @inconsistent_memory_space_scf_if(%c: i1) -> tensor<10xf32> {
   // bufferized.
   %0 = bufferization.alloc_tensor() {memory_space = 0 : ui64} : tensor<10xf32>
   %1 = bufferization.alloc_tensor() {memory_space = 1 : ui64} : tensor<10xf32>
-  // expected-error @+2 {{inconsistent memory space on then/else branches}}
-  // expected-error @+1 {{failed to bufferize op}}
+  // expected-error @+1 {{inconsistent memory space on then/else branches}}
   %r = scf.if %c -> tensor<10xf32> {
+    // expected-error @+1 {{failed to bufferize op}}
     scf.yield %0 : tensor<10xf32>
   } else {
     scf.yield %1 : tensor<10xf32>


        


More information about the Mlir-commits mailing list