[Mlir-commits] [mlir] 8e691e1 - [mlir][SCF][bufferize] Bufferize scf.if/execute_region terminators separately

Matthias Springer llvmlistbot at llvm.org
Mon Jun 27 04:22:36 PDT 2022


Author: Matthias Springer
Date: 2022-06-27T13:22:19+02:00
New Revision: 8e691e1f245ad4b275983749fd6a78120a0bf263

URL: https://github.com/llvm/llvm-project/commit/8e691e1f245ad4b275983749fd6a78120a0bf263
DIFF: https://github.com/llvm/llvm-project/commit/8e691e1f245ad4b275983749fd6a78120a0bf263.diff

LOG: [mlir][SCF][bufferize] Bufferize scf.if/execute_region terminators separately

This allows for better type inference during bufferization and is in preparation of supporting memory spaces.

Differential Revision: https://reviews.llvm.org/D128581

Added: 
    

Modified: 
    mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
    mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-allow-return-allocs.mlir
    mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-partial.mlir
    mlir/test/Dialect/SCF/one-shot-bufferize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index fb514a2f2b08..b9236b3573dd 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -75,41 +75,17 @@ struct ExecuteRegionOpInterface
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
                           const BufferizationOptions &options) const {
     auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
-
-    // Compute new result types.
-    SmallVector<Type> newResultTypes;
-    for (Type type : executeRegionOp->getResultTypes()) {
-      if (auto tensorType = type.dyn_cast<TensorType>()) {
-        // TODO: Infer the result type instead of computing it.
-        newResultTypes.push_back(getMemRefType(tensorType, options));
-      } else {
-        newResultTypes.push_back(type);
-      }
-    }
+    assert(executeRegionOp.getRegion().getBlocks().size() == 1 &&
+           "only 1 block supported");
+    auto yieldOp =
+        cast<scf::YieldOp>(executeRegionOp.getRegion().front().getTerminator());
+    TypeRange newResultTypes(yieldOp.getResults());
 
     // Create new op and move over region.
     auto newOp =
         rewriter.create<scf::ExecuteRegionOp>(op->getLoc(), newResultTypes);
     newOp.getRegion().takeBody(executeRegionOp.getRegion());
 
-    // Update terminator.
-    assert(newOp.getRegion().getBlocks().size() == 1 &&
-           "only 1 block supported");
-    Block *newBlock = &newOp.getRegion().front();
-    auto yieldOp = cast<scf::YieldOp>(newBlock->getTerminator());
-    rewriter.setInsertionPoint(yieldOp);
-    SmallVector<Value> newYieldValues;
-    for (const auto &it : llvm::enumerate(yieldOp.getResults())) {
-      Value val = it.value();
-      if (val.getType().isa<TensorType>()) {
-        newYieldValues.push_back(rewriter.create<bufferization::ToMemrefOp>(
-            yieldOp.getLoc(), newResultTypes[it.index()], val));
-      } else {
-        newYieldValues.push_back(val);
-      }
-    }
-    rewriter.replaceOpWithNewOp<scf::YieldOp>(yieldOp, newYieldValues);
-
     // Update all uses of the old op.
     rewriter.setInsertionPointAfter(newOp);
     SmallVector<Value> newResults;
@@ -184,64 +160,62 @@ struct IfOpInterface
 
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
                           const BufferizationOptions &options) const {
+    OpBuilder::InsertionGuard g(rewriter);
     auto ifOp = cast<scf::IfOp>(op);
-
-    // Compute new types of the bufferized scf.if op.
-    SmallVector<Type> newTypes;
-    for (Type returnType : ifOp->getResultTypes()) {
-      if (auto tensorType = returnType.dyn_cast<TensorType>()) {
-        // TODO: Infer the result type instead of computing it.
-        newTypes.push_back(getMemRefType(tensorType, options));
-      } else {
-        newTypes.push_back(returnType);
+    auto thenYieldOp = cast<scf::YieldOp>(ifOp.thenBlock()->getTerminator());
+    auto elseYieldOp = cast<scf::YieldOp>(ifOp.elseBlock()->getTerminator());
+
+    // Reconcile type mismatches between then/else branches by inserting memref
+    // casts.
+    SmallVector<Value> thenResults, elseResults;
+    bool insertedCast = false;
+    for (unsigned i = 0; i < thenYieldOp.getResults().size(); ++i) {
+      Value thenValue = thenYieldOp.getResults()[i];
+      Value elseValue = elseYieldOp.getResults()[i];
+      if (thenValue.getType() == elseValue.getType()) {
+        thenResults.push_back(thenValue);
+        elseResults.push_back(elseValue);
+        continue;
       }
+
+      // Type mismatch between then/else yield value. Cast both to a memref type
+      // with a fully dynamic layout map.
+      auto thenMemrefType = thenValue.getType().cast<BaseMemRefType>();
+      auto elseMemrefType = elseValue.getType().cast<BaseMemRefType>();
+      if (thenMemrefType.getMemorySpaceAsInt() !=
+          elseMemrefType.getMemorySpaceAsInt())
+        return op->emitError("inconsistent memory space on then/else branches");
+      rewriter.setInsertionPoint(thenYieldOp);
+      BaseMemRefType memrefType = getMemRefTypeWithFullyDynamicLayout(
+          ifOp.getResultTypes()[i].cast<TensorType>(),
+          thenMemrefType.getMemorySpaceAsInt());
+      thenResults.push_back(rewriter.create<memref::CastOp>(
+          thenYieldOp.getLoc(), memrefType, thenValue));
+      rewriter.setInsertionPoint(elseYieldOp);
+      elseResults.push_back(rewriter.create<memref::CastOp>(
+          elseYieldOp.getLoc(), memrefType, elseValue));
+      insertedCast = true;
+    }
+
+    if (insertedCast) {
+      rewriter.setInsertionPoint(thenYieldOp);
+      rewriter.replaceOpWithNewOp<scf::YieldOp>(thenYieldOp, thenResults);
+      rewriter.setInsertionPoint(elseYieldOp);
+      rewriter.replaceOpWithNewOp<scf::YieldOp>(elseYieldOp, elseResults);
     }
 
     // Create new op.
+    rewriter.setInsertionPoint(ifOp);
+    ValueRange resultsValueRange(thenResults);
+    TypeRange newTypes(resultsValueRange);
     auto newIfOp =
         rewriter.create<scf::IfOp>(ifOp.getLoc(), newTypes, ifOp.getCondition(),
                                    /*withElseRegion=*/true);
 
-    // Remove terminators.
-    if (!newIfOp.thenBlock()->empty()) {
-      rewriter.eraseOp(newIfOp.thenBlock()->getTerminator());
-      rewriter.eraseOp(newIfOp.elseBlock()->getTerminator());
-    }
-
     // Move over then/else blocks.
     rewriter.mergeBlocks(ifOp.thenBlock(), newIfOp.thenBlock());
     rewriter.mergeBlocks(ifOp.elseBlock(), newIfOp.elseBlock());
 
-    // Update scf.yield of new then-block.
-    auto thenYieldOp = cast<scf::YieldOp>(newIfOp.thenBlock()->getTerminator());
-    rewriter.setInsertionPoint(thenYieldOp);
-    SmallVector<Value> thenYieldValues;
-    for (OpOperand &operand : thenYieldOp->getOpOperands()) {
-      if (operand.get().getType().isa<TensorType>()) {
-        ensureToMemrefOpIsValid(operand.get(),
-                                newTypes[operand.getOperandNumber()]);
-        Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>(
-            operand.get().getLoc(), newTypes[operand.getOperandNumber()],
-            operand.get());
-        operand.set(toMemrefOp);
-      }
-    }
-
-    // Update scf.yield of new else-block.
-    auto elseYieldOp = cast<scf::YieldOp>(newIfOp.elseBlock()->getTerminator());
-    rewriter.setInsertionPoint(elseYieldOp);
-    SmallVector<Value> elseYieldValues;
-    for (OpOperand &operand : elseYieldOp->getOpOperands()) {
-      if (operand.get().getType().isa<TensorType>()) {
-        ensureToMemrefOpIsValid(operand.get(),
-                                newTypes[operand.getOperandNumber()]);
-        Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>(
-            operand.get().getLoc(), newTypes[operand.getOperandNumber()],
-            operand.get());
-        operand.set(toMemrefOp);
-      }
-    }
-
     // Replace op results.
     replaceOpWithBufferizedValues(rewriter, op, newIfOp->getResults());
 
@@ -869,6 +843,24 @@ struct YieldOpInterface
     if (!isa<scf::ExecuteRegionOp, scf::IfOp, scf::ForOp, scf::WhileOp>(
             yieldOp->getParentOp()))
       return yieldOp->emitError("unsupported scf::YieldOp parent");
+
+    // TODO: Bufferize scf.yield inside scf.while/scf.for here.
+    // (Currently bufferized together with scf.while/scf.for.)
+    if (isa<scf::ForOp, scf::WhileOp>(yieldOp->getParentOp()))
+      return success();
+
+    SmallVector<Value> newResults;
+    for (const auto &it : llvm::enumerate(yieldOp.getResults())) {
+      Value value = it.value();
+      if (value.getType().isa<TensorType>()) {
+        Value buffer = getBuffer(rewriter, value, options);
+        newResults.push_back(buffer);
+      } else {
+        newResults.push_back(value);
+      }
+    }
+
+    replaceOpWithNewBufferizedOp<scf::YieldOp>(rewriter, op, newResults);
     return success();
   }
 };

diff  --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-allow-return-allocs.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-allow-return-allocs.mlir
index f6da6dc8a5ae..ab9360e45c65 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-allow-return-allocs.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-allow-return-allocs.mlir
@@ -8,6 +8,7 @@
 // CHECK-LABEL: func @buffer_not_deallocated(
 //  CHECK-SAME:     %[[t:.*]]: tensor<?xf32>
 func.func @buffer_not_deallocated(%t : tensor<?xf32>, %c : i1) -> tensor<?xf32> {
+  // CHECK: %[[m:.*]] = bufferization.to_memref %[[t]]
   // CHECK: %[[r:.*]] = scf.if %{{.*}} {
   %r = scf.if %c -> tensor<?xf32> {
     // CHECK: %[[some_op:.*]] = "test.some_op"
@@ -20,7 +21,6 @@ func.func @buffer_not_deallocated(%t : tensor<?xf32>, %c : i1) -> tensor<?xf32>
     scf.yield %0 : tensor<?xf32>
   } else {
   // CHECK: } else {
-    // CHECK: %[[m:.*]] = bufferization.to_memref %[[t]]
     // CHECK: %[[cloned:.*]] = bufferization.clone %[[m]]
     // CHECK: scf.yield %[[cloned]]
     scf.yield %t : tensor<?xf32>
@@ -40,8 +40,8 @@ func.func @write_to_alloc_tensor_or_readonly_tensor(%arg0: tensor<i32>,
                                                     %cond: i1, %val: i32)
   -> tensor<i32>
 {
+  // CHECK: %[[arg0_m:.*]] = bufferization.to_memref %[[arg0]]
   // CHECK: %[[r:.*]] = scf.if {{.*}} {
-  // CHECK:   %[[arg0_m:.*]] = bufferization.to_memref %[[arg0]]
   // CHECK:   %[[clone:.*]] = bufferization.clone %[[arg0_m]]
   // CHECK:   scf.yield %[[clone]]
   // CHECK: } else {

diff  --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-partial.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-partial.mlir
index 0874912323c5..090e7c61239f 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-partial.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-partial.mlir
@@ -206,9 +206,9 @@ func.func @simple_tensor_test(%t1 : tensor<?xf32>, %f : f32) -> tensor<?xf32> {
 //  CHECK-SCF-SAME:     %[[t1:.*]]: tensor<?xf32> {bufferization.writable = true}, %[[c:.*]]: i1, %[[pos:.*]]: index
 func.func @simple_scf_if(%t1: tensor<?xf32> {bufferization.writable = true}, %c: i1, %pos: index, %f: f32)
     -> (tensor<?xf32>, index) {
+  // CHECK-SCF: %[[t1_memref:.*]] = bufferization.to_memref %[[t1]]
   // CHECK-SCF: %[[r:.*]] = scf.if %[[c]] -> (memref<?xf32, #{{.*}}>) {
   %r1, %r2 = scf.if %c -> (tensor<?xf32>, index) {
-    // CHECK-SCF: %[[t1_memref:.*]] = bufferization.to_memref %[[t1]]
     // CHECK-SCF: scf.yield %[[t1_memref]]
     scf.yield %t1, %pos : tensor<?xf32>, index
   // CHECK-SCF: } else {

diff  --git a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
index cc0357a055d7..5ea23b8e65b4 100644
--- a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
+++ b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
@@ -124,11 +124,10 @@ func.func @execute_region_with_conflict(
     scf.yield %f1, %t2, %f1 : f32, tensor<?xf32>, f32
   }
 
-  // CHECK: %[[casted:.*]] = memref.cast %[[alloc]]
   // CHECK: %[[load:.*]] = memref.load %[[m1]]
   %3 = tensor.extract %t1[%idx] : tensor<?xf32>
 
-  // CHECK: return %{{.*}}, %[[casted]], %[[load]] : f32, memref<?xf32, #{{.*}}>, f32
+  // CHECK: return %{{.*}}, %[[alloc]], %[[load]] : f32, memref<?xf32>, f32
   return %0, %1, %3 : f32, tensor<?xf32>, f32
 }
 


        


More information about the Mlir-commits mailing list