[Mlir-commits] [mlir] a592773 - [mlir][linalg][bufferize] Reimplementation of scf.if bufferization

Matthias Springer llvmlistbot at llvm.org
Wed Dec 15 01:41:08 PST 2021


Author: Matthias Springer
Date: 2021-12-15T18:40:54+09:00
New Revision: a5927737daeb1d1a6e954fbac16f4d570c3d7496

URL: https://github.com/llvm/llvm-project/commit/a5927737daeb1d1a6e954fbac16f4d570c3d7496
DIFF: https://github.com/llvm/llvm-project/commit/a5927737daeb1d1a6e954fbac16f4d570c3d7496.diff

LOG: [mlir][linalg][bufferize] Reimplementation of scf.if bufferization

Instead of modifying the existing scf.if op, create a new op with memref OpOperands/OpResults and delete the old op.

New allocations / other memrefs can now be yielded from the op. This functionality is deactivated by default and guarded against by AssertDestinationPassingStyle.

Differential Revision: https://reviews.llvm.org/D115491

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
    mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
    mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir
    mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir
    mlir/test/Dialect/Linalg/comprehensive-module-bufferize-partial.mlir
    mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
    mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
index d6c36f3b98a53..028be806236c4 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
@@ -461,8 +461,6 @@ wouldCreateWriteToNonWritableBuffer(OpOperand &opOperand, OpResult opResult,
   // Certain buffers are not writeable:
   //   1. A function bbArg that is not inplaceable or
   //   2. A constant op.
-  assert(!aliasesNonWritableBuffer(opResult, aliasInfo, state) &&
-         "expected that opResult does not alias non-writable buffer");
   bool nonWritable =
       aliasesNonWritableBuffer(opOperand.get(), aliasInfo, state);
   if (!nonWritable)

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
index ec9a315a93718..edded005a1ee4 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
@@ -131,27 +131,74 @@ struct IfOpInterface
                           BufferizationState &state) const {
     auto ifOp = cast<scf::IfOp>(op);
 
-    // Bufferize then/else blocks.
-    if (failed(comprehensive_bufferize::bufferize(ifOp.thenBlock(), state)))
-      return failure();
-    if (failed(comprehensive_bufferize::bufferize(ifOp.elseBlock(), state)))
-      return failure();
+    // Use IRRewriter instead of OpBuilder because it has additional helper
+    // functions.
+    IRRewriter rewriter(op->getContext());
+    rewriter.setInsertionPoint(ifOp);
+
+    // Compute new types of the bufferized scf.if op.
+    SmallVector<Type> newTypes;
+    for (Type returnType : ifOp->getResultTypes()) {
+      if (returnType.isa<TensorType>()) {
+        assert(returnType.isa<RankedTensorType>() &&
+               "unsupported unranked tensor");
+        newTypes.push_back(
+            getDynamicMemRefType(returnType.cast<RankedTensorType>()));
+      } else {
+        newTypes.push_back(returnType);
+      }
+    }
 
-    for (OpResult opResult : ifOp->getResults()) {
-      if (!opResult.getType().isa<TensorType>())
-        continue;
-      // TODO: Atm we bail on unranked TensorType because we don't know how to
-      // alloc an UnrankedMemRefType + its underlying ranked MemRefType.
-      assert(opResult.getType().isa<RankedTensorType>() &&
-             "unsupported unranked tensor");
+    // Create new op.
+    auto newIfOp =
+        rewriter.create<scf::IfOp>(ifOp.getLoc(), newTypes, ifOp.condition(),
+                                   /*withElseRegion=*/true);
 
-      Value resultBuffer = state.getResultBuffer(opResult);
-      if (!resultBuffer)
-        return failure();
+    // Remove terminators.
+    if (!newIfOp.thenBlock()->empty()) {
+      rewriter.eraseOp(newIfOp.thenBlock()->getTerminator());
+      rewriter.eraseOp(newIfOp.elseBlock()->getTerminator());
+    }
 
-      state.mapBuffer(opResult, resultBuffer);
+    // Move over then/else blocks.
+    rewriter.mergeBlocks(ifOp.thenBlock(), newIfOp.thenBlock());
+    rewriter.mergeBlocks(ifOp.elseBlock(), newIfOp.elseBlock());
+
+    // Update scf.yield of new then-block.
+    auto thenYieldOp = cast<scf::YieldOp>(newIfOp.thenBlock()->getTerminator());
+    rewriter.setInsertionPoint(thenYieldOp);
+    SmallVector<Value> thenYieldValues;
+    for (OpOperand &operand : thenYieldOp->getOpOperands()) {
+      if (operand.get().getType().isa<TensorType>()) {
+        Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>(
+            operand.get().getLoc(), newTypes[operand.getOperandNumber()],
+            operand.get());
+        operand.set(toMemrefOp);
+      }
     }
 
+    // Update scf.yield of new else-block.
+    auto elseYieldOp = cast<scf::YieldOp>(newIfOp.elseBlock()->getTerminator());
+    rewriter.setInsertionPoint(elseYieldOp);
+    SmallVector<Value> elseYieldValues;
+    for (OpOperand &operand : elseYieldOp->getOpOperands()) {
+      if (operand.get().getType().isa<TensorType>()) {
+        Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>(
+            operand.get().getLoc(), newTypes[operand.getOperandNumber()],
+            operand.get());
+        operand.set(toMemrefOp);
+      }
+    }
+
+    // Replace op results.
+    state.replaceOp(op, newIfOp->getResults());
+
+    // Bufferize then/else blocks.
+    if (failed(comprehensive_bufferize::bufferize(newIfOp.thenBlock(), state)))
+      return failure();
+    if (failed(comprehensive_bufferize::bufferize(newIfOp.elseBlock(), state)))
+      return failure();
+
     return success();
   }
 
@@ -293,33 +340,65 @@ struct ForOpInterface
   }
 };
 
+// TODO: Evolve toward matching ReturnLike ops. Check for aliasing values that
+// do not bufferize inplace. (Requires a few more changes for ConstantOp,
+// InitTensorOp, CallOp.)
 LogicalResult mlir::linalg::comprehensive_bufferize::scf_ext::
     AssertDestinationPassingStyle::run(Operation *op, BufferizationState &state,
                                        BufferizationAliasInfo &aliasInfo,
                                        SmallVector<Operation *> &newOps) {
   LogicalResult status = success();
   op->walk([&](scf::YieldOp yieldOp) {
-    auto forOp = dyn_cast<scf::ForOp>(yieldOp->getParentOp());
-    if (!forOp)
-      return WalkResult::advance();
-
-    for (OpOperand &operand : yieldOp->getOpOperands()) {
-      auto tensorType = operand.get().getType().dyn_cast<TensorType>();
-      if (!tensorType)
-        continue;
-
-      OpOperand &forOperand = forOp.getOpOperandForResult(
-          forOp->getResult(operand.getOperandNumber()));
-      auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand);
-      if (!aliasInfo.areEquivalentBufferizedValues(operand.get(), bbArg)) {
-        // TODO: this could get resolved with copies but it can also turn into
-        // swaps so we need to be careful about order of copies.
-        status =
-            yieldOp->emitError()
-            << "Yield operand #" << operand.getOperandNumber()
-            << " does not bufferize to an equivalent buffer to the matching"
-            << " enclosing scf::for operand";
-        return WalkResult::interrupt();
+    if (auto forOp = dyn_cast<scf::ForOp>(yieldOp->getParentOp())) {
+      for (OpOperand &operand : yieldOp->getOpOperands()) {
+        auto tensorType = operand.get().getType().dyn_cast<TensorType>();
+        if (!tensorType)
+          continue;
+
+        OpOperand &forOperand = forOp.getOpOperandForResult(
+            forOp->getResult(operand.getOperandNumber()));
+        auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand);
+        if (!aliasInfo.areEquivalentBufferizedValues(operand.get(), bbArg)) {
+          // TODO: this could get resolved with copies but it can also turn into
+          // swaps so we need to be careful about order of copies.
+          status =
+              yieldOp->emitError()
+              << "Yield operand #" << operand.getOperandNumber()
+              << " does not bufferize to an equivalent buffer to the matching"
+              << " enclosing scf::for operand";
+          return WalkResult::interrupt();
+        }
+      }
+    }
+
+    if (auto ifOp = dyn_cast<scf::IfOp>(yieldOp->getParentOp())) {
+      // IfOps are in destination passing style if all yielded tensors are
+      // a value or equivalent to a value that is defined outside of the IfOp.
+      for (OpOperand &operand : yieldOp->getOpOperands()) {
+        auto tensorType = operand.get().getType().dyn_cast<TensorType>();
+        if (!tensorType)
+          continue;
+
+        bool foundOutsideEquivalent = false;
+        aliasInfo.applyOnEquivalenceClass(operand.get(), [&](Value value) {
+          Operation *valueOp = value.getDefiningOp();
+          if (value.isa<BlockArgument>())
+            valueOp = value.cast<BlockArgument>().getOwner()->getParentOp();
+
+          bool inThenBlock = ifOp.thenBlock()->findAncestorOpInBlock(*valueOp);
+          bool inElseBlock = ifOp.elseBlock()->findAncestorOpInBlock(*valueOp);
+
+          if (!inThenBlock && !inElseBlock)
+            foundOutsideEquivalent = true;
+        });
+
+        if (!foundOutsideEquivalent) {
+          status = yieldOp->emitError()
+                   << "Yield operand #" << operand.getOperandNumber()
+                   << " does not bufferize to a buffer that is equivalent to a"
+                   << " buffer defined outside of the scf::if op";
+          return WalkResult::interrupt();
+        }
       }
     }
 

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
index 2ce198b86bcec..5255bd2d7b000 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
@@ -97,7 +97,8 @@ void LinalgComprehensiveModuleBufferize::runOnOperation() {
   // TODO: Find a way to enable this step automatically when bufferizing tensor
   // dialect ops.
   options.addPostAnalysisStep<tensor_ext::InplaceInsertSliceOpAnalysis>();
-  options.addPostAnalysisStep<scf_ext::AssertDestinationPassingStyle>();
+  if (!allowReturnMemref)
+    options.addPostAnalysisStep<scf_ext::AssertDestinationPassingStyle>();
 
   ModuleOp moduleOp = getOperation();
   applyEnablingTransformations(moduleOp);

diff  --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir
index f2fa7ce3e4bf4..8705cd1f1b1e0 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir
@@ -1,9 +1,9 @@
-// RUN: mlir-opt %s -linalg-comprehensive-module-bufferize=test-analysis-only -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -linalg-comprehensive-module-bufferize="test-analysis-only allow-return-memref" -split-input-file | FileCheck %s
 
 // Run fuzzer with 
diff erent seeds.
-// RUN: mlir-opt %s -linalg-comprehensive-module-bufferize="test-analysis-only analysis-fuzzer-seed=23" -split-input-file -o /dev/null
-// RUN: mlir-opt %s -linalg-comprehensive-module-bufferize="test-analysis-only analysis-fuzzer-seed=59" -split-input-file -o /dev/null
-// RUN: mlir-opt %s -linalg-comprehensive-module-bufferize="test-analysis-only analysis-fuzzer-seed=91" -split-input-file -o /dev/null
+// RUN: mlir-opt %s -linalg-comprehensive-module-bufferize="test-analysis-only allow-return-memref analysis-fuzzer-seed=23" -split-input-file -o /dev/null
+// RUN: mlir-opt %s -linalg-comprehensive-module-bufferize="test-analysis-only allow-return-memref analysis-fuzzer-seed=59" -split-input-file -o /dev/null
+// RUN: mlir-opt %s -linalg-comprehensive-module-bufferize="test-analysis-only allow-return-memref analysis-fuzzer-seed=91" -split-input-file -o /dev/null
 
 //===----------------------------------------------------------------------===//
 // Simple cases

diff  --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir
index edeb0c07da0f2..02431d9175a9c 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir
@@ -38,12 +38,12 @@ func @swappy(%cond1 : i1, %cond2 : i1, %t1 : tensor<f32>, %t2 : tensor<f32>)
 func @scf_if_not_equivalent(
     %cond: i1, %t1: tensor<?xf32> {linalg.inplaceable = true},
     %idx: index) -> tensor<?xf32> {
-  // expected-error @+1 {{result buffer is ambiguous}}
   %r = scf.if %cond -> (tensor<?xf32>) {
     scf.yield %t1 : tensor<?xf32>
   } else {
     // This buffer aliases, but is not equivalent.
     %t2 = tensor.extract_slice %t1 [%idx] [%idx] [1] : tensor<?xf32> to tensor<?xf32>
+    // expected-error @+1 {{Yield operand #0 does not bufferize to a buffer that is equivalent to a buffer defined outside of the scf::if op}}
     scf.yield %t2 : tensor<?xf32>
   }
   return %r : tensor<?xf32>
@@ -127,9 +127,9 @@ func @extract_slice_fun(%A : tensor<?xf32> {linalg.inplaceable = true})
 
 // -----
 
+// expected-error @+1 {{memref return type is unsupported}}
 func @scf_yield(%b : i1, %A : tensor<4xf32>, %B : tensor<4xf32>) -> tensor<4xf32>
 {
-  // expected-error @+1 {{result buffer is ambiguous}}
   %r = scf.if %b -> (tensor<4xf32>) {
     scf.yield %A : tensor<4xf32>
   } else {

diff  --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-partial.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-partial.mlir
index ec2c33f3c949a..de8717f22d72d 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-partial.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-partial.mlir
@@ -194,3 +194,28 @@ func @simple_scf_for(
   // CHECK-SCF: return %[[scf_for_tensor]]
   return %0 : tensor<?xf32>
 }
+
+// -----
+
+// CHECK-SCF-LABEL: func @simple_scf_if(
+//  CHECK-SCF-SAME:     %[[t1:.*]]: tensor<?xf32> {linalg.inplaceable = true}, %[[c:.*]]: i1, %[[pos:.*]]: index
+func @simple_scf_if(%t1: tensor<?xf32> {linalg.inplaceable = true}, %c: i1, %pos: index, %f: f32)
+    -> (tensor<?xf32>, index) {
+  // CHECK-SCF: %[[r:.*]] = scf.if %[[c]] -> (memref<?xf32, #{{.*}}>) {
+  %r1, %r2 = scf.if %c -> (tensor<?xf32>, index) {
+    // CHECK-SCF: %[[t1_memref:.*]] = bufferization.to_memref %[[t1]]
+    // CHECK-SCF: scf.yield %[[t1_memref]]
+    scf.yield %t1, %pos : tensor<?xf32>, index
+  // CHECK-SCF: } else {
+  } else {
+    // CHECK-SCF: %[[insert:.*]] = tensor.insert %{{.*}} into %[[t1]][{{.*}}]
+    // CHECK-SCF: %[[insert_memref:.*]] = bufferization.to_memref %[[insert]]
+    %1 = tensor.insert %f into %t1[%pos] : tensor<?xf32>
+    // CHECK-SCF: scf.yield %[[insert_memref]]
+    scf.yield %1, %pos : tensor<?xf32>, index
+  }
+
+  // CHECK-SCF: %[[r_tensor:.*]] = bufferization.to_tensor %[[r]]
+  // CHECK-SCF: return %[[r_tensor]], %[[pos]]
+  return %r1, %r2 : tensor<?xf32>, index
+}

diff  --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
index 970a9b54b2883..1094c21ed0537 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
@@ -921,6 +921,22 @@ func @scf_if_inside_scf_for(%t1: tensor<?xf32> {linalg.inplaceable = true},
 
 // -----
 
+// CHECK-LABEL: func @scf_if_non_equiv_yields(
+//  CHECK-SAME:     %[[cond:.*]]: i1, %[[A:.*]]: memref<{{.*}}>, %[[B:.*]]: memref<{{.*}}>) -> memref<{{.*}}>
+func @scf_if_non_equiv_yields(%b : i1, %A : tensor<4xf32>, %B : tensor<4xf32>) -> tensor<4xf32>
+{
+  // CHECK: %[[r:.*]] = select %[[cond]], %[[A]], %[[B]]
+  %r = scf.if %b -> (tensor<4xf32>) {
+    scf.yield %A : tensor<4xf32>
+  } else {
+    scf.yield %B : tensor<4xf32>
+  }
+  // CHECK: return %[[r]]
+  return %r: tensor<4xf32>
+}
+
+// -----
+
 // CHECK-LABEL: func @insert_op
 //  CHECK-SAME:     %[[t1:.*]]: memref<?xf32, {{.*}}>, %[[s:.*]]: f32, %[[i:.*]]: index
 func @insert_op(%t1 : tensor<?xf32> {linalg.inplaceable = true},

diff  --git a/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp b/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp
index d5140eaf91652..b69e8a3738aea 100644
--- a/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp
@@ -101,6 +101,8 @@ void TestComprehensiveFunctionBufferize::runOnFunction() {
   // TODO: Find a way to enable this step automatically when bufferizing
   // tensor dialect ops.
   options.addPostAnalysisStep<tensor_ext::InplaceInsertSliceOpAnalysis>();
+  if (!allowReturnMemref)
+    options.addPostAnalysisStep<scf_ext::AssertDestinationPassingStyle>();
 
   options.allowReturnMemref = allowReturnMemref;
   options.allowUnknownOps = allowUnknownOps;


        


More information about the Mlir-commits mailing list