[Mlir-commits] [mlir] a592773 - [mlir][linalg][bufferize] Reimplementation of scf.if bufferization
Matthias Springer
llvmlistbot at llvm.org
Wed Dec 15 01:41:08 PST 2021
Author: Matthias Springer
Date: 2021-12-15T18:40:54+09:00
New Revision: a5927737daeb1d1a6e954fbac16f4d570c3d7496
URL: https://github.com/llvm/llvm-project/commit/a5927737daeb1d1a6e954fbac16f4d570c3d7496
DIFF: https://github.com/llvm/llvm-project/commit/a5927737daeb1d1a6e954fbac16f4d570c3d7496.diff
LOG: [mlir][linalg][bufferize] Reimplementation of scf.if bufferization
Instead of modifying the existing scf.if op, create a new op with memref OpOperands/OpResults and delete the old op.
New allocations / other memrefs can now be yielded from the op. This functionality is deactivated by default and guarded against by AssertDestinationPassingStyle.
Differential Revision: https://reviews.llvm.org/D115491
Added:
Modified:
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir
mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir
mlir/test/Dialect/Linalg/comprehensive-module-bufferize-partial.mlir
mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
index d6c36f3b98a53..028be806236c4 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
@@ -461,8 +461,6 @@ wouldCreateWriteToNonWritableBuffer(OpOperand &opOperand, OpResult opResult,
// Certain buffers are not writeable:
// 1. A function bbArg that is not inplaceable or
// 2. A constant op.
- assert(!aliasesNonWritableBuffer(opResult, aliasInfo, state) &&
- "expected that opResult does not alias non-writable buffer");
bool nonWritable =
aliasesNonWritableBuffer(opOperand.get(), aliasInfo, state);
if (!nonWritable)
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
index ec9a315a93718..edded005a1ee4 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
@@ -131,27 +131,74 @@ struct IfOpInterface
BufferizationState &state) const {
auto ifOp = cast<scf::IfOp>(op);
- // Bufferize then/else blocks.
- if (failed(comprehensive_bufferize::bufferize(ifOp.thenBlock(), state)))
- return failure();
- if (failed(comprehensive_bufferize::bufferize(ifOp.elseBlock(), state)))
- return failure();
+ // Use IRRewriter instead of OpBuilder because it has additional helper
+ // functions.
+ IRRewriter rewriter(op->getContext());
+ rewriter.setInsertionPoint(ifOp);
+
+ // Compute new types of the bufferized scf.if op.
+ SmallVector<Type> newTypes;
+ for (Type returnType : ifOp->getResultTypes()) {
+ if (returnType.isa<TensorType>()) {
+ assert(returnType.isa<RankedTensorType>() &&
+ "unsupported unranked tensor");
+ newTypes.push_back(
+ getDynamicMemRefType(returnType.cast<RankedTensorType>()));
+ } else {
+ newTypes.push_back(returnType);
+ }
+ }
- for (OpResult opResult : ifOp->getResults()) {
- if (!opResult.getType().isa<TensorType>())
- continue;
- // TODO: Atm we bail on unranked TensorType because we don't know how to
- // alloc an UnrankedMemRefType + its underlying ranked MemRefType.
- assert(opResult.getType().isa<RankedTensorType>() &&
- "unsupported unranked tensor");
+ // Create new op.
+ auto newIfOp =
+ rewriter.create<scf::IfOp>(ifOp.getLoc(), newTypes, ifOp.condition(),
+ /*withElseRegion=*/true);
- Value resultBuffer = state.getResultBuffer(opResult);
- if (!resultBuffer)
- return failure();
+ // Remove terminators.
+ if (!newIfOp.thenBlock()->empty()) {
+ rewriter.eraseOp(newIfOp.thenBlock()->getTerminator());
+ rewriter.eraseOp(newIfOp.elseBlock()->getTerminator());
+ }
- state.mapBuffer(opResult, resultBuffer);
+ // Move over then/else blocks.
+ rewriter.mergeBlocks(ifOp.thenBlock(), newIfOp.thenBlock());
+ rewriter.mergeBlocks(ifOp.elseBlock(), newIfOp.elseBlock());
+
+ // Update scf.yield of new then-block.
+ auto thenYieldOp = cast<scf::YieldOp>(newIfOp.thenBlock()->getTerminator());
+ rewriter.setInsertionPoint(thenYieldOp);
+ SmallVector<Value> thenYieldValues;
+ for (OpOperand &operand : thenYieldOp->getOpOperands()) {
+ if (operand.get().getType().isa<TensorType>()) {
+ Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>(
+ operand.get().getLoc(), newTypes[operand.getOperandNumber()],
+ operand.get());
+ operand.set(toMemrefOp);
+ }
}
+ // Update scf.yield of new else-block.
+ auto elseYieldOp = cast<scf::YieldOp>(newIfOp.elseBlock()->getTerminator());
+ rewriter.setInsertionPoint(elseYieldOp);
+ SmallVector<Value> elseYieldValues;
+ for (OpOperand &operand : elseYieldOp->getOpOperands()) {
+ if (operand.get().getType().isa<TensorType>()) {
+ Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>(
+ operand.get().getLoc(), newTypes[operand.getOperandNumber()],
+ operand.get());
+ operand.set(toMemrefOp);
+ }
+ }
+
+ // Replace op results.
+ state.replaceOp(op, newIfOp->getResults());
+
+ // Bufferize then/else blocks.
+ if (failed(comprehensive_bufferize::bufferize(newIfOp.thenBlock(), state)))
+ return failure();
+ if (failed(comprehensive_bufferize::bufferize(newIfOp.elseBlock(), state)))
+ return failure();
+
return success();
}
@@ -293,33 +340,65 @@ struct ForOpInterface
}
};
+// TODO: Evolve toward matching ReturnLike ops. Check for aliasing values that
+// do not bufferize inplace. (Requires a few more changes for ConstantOp,
+// InitTensorOp, CallOp.)
LogicalResult mlir::linalg::comprehensive_bufferize::scf_ext::
AssertDestinationPassingStyle::run(Operation *op, BufferizationState &state,
BufferizationAliasInfo &aliasInfo,
SmallVector<Operation *> &newOps) {
LogicalResult status = success();
op->walk([&](scf::YieldOp yieldOp) {
- auto forOp = dyn_cast<scf::ForOp>(yieldOp->getParentOp());
- if (!forOp)
- return WalkResult::advance();
-
- for (OpOperand &operand : yieldOp->getOpOperands()) {
- auto tensorType = operand.get().getType().dyn_cast<TensorType>();
- if (!tensorType)
- continue;
-
- OpOperand &forOperand = forOp.getOpOperandForResult(
- forOp->getResult(operand.getOperandNumber()));
- auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand);
- if (!aliasInfo.areEquivalentBufferizedValues(operand.get(), bbArg)) {
- // TODO: this could get resolved with copies but it can also turn into
- // swaps so we need to be careful about order of copies.
- status =
- yieldOp->emitError()
- << "Yield operand #" << operand.getOperandNumber()
- << " does not bufferize to an equivalent buffer to the matching"
- << " enclosing scf::for operand";
- return WalkResult::interrupt();
+ if (auto forOp = dyn_cast<scf::ForOp>(yieldOp->getParentOp())) {
+ for (OpOperand &operand : yieldOp->getOpOperands()) {
+ auto tensorType = operand.get().getType().dyn_cast<TensorType>();
+ if (!tensorType)
+ continue;
+
+ OpOperand &forOperand = forOp.getOpOperandForResult(
+ forOp->getResult(operand.getOperandNumber()));
+ auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand);
+ if (!aliasInfo.areEquivalentBufferizedValues(operand.get(), bbArg)) {
+ // TODO: this could get resolved with copies but it can also turn into
+ // swaps so we need to be careful about order of copies.
+ status =
+ yieldOp->emitError()
+ << "Yield operand #" << operand.getOperandNumber()
+ << " does not bufferize to an equivalent buffer to the matching"
+ << " enclosing scf::for operand";
+ return WalkResult::interrupt();
+ }
+ }
+ }
+
+ if (auto ifOp = dyn_cast<scf::IfOp>(yieldOp->getParentOp())) {
+ // IfOps are in destination passing style if all yielded tensors are
+ // a value or equivalent to a value that is defined outside of the IfOp.
+ for (OpOperand &operand : yieldOp->getOpOperands()) {
+ auto tensorType = operand.get().getType().dyn_cast<TensorType>();
+ if (!tensorType)
+ continue;
+
+ bool foundOutsideEquivalent = false;
+ aliasInfo.applyOnEquivalenceClass(operand.get(), [&](Value value) {
+ Operation *valueOp = value.getDefiningOp();
+ if (value.isa<BlockArgument>())
+ valueOp = value.cast<BlockArgument>().getOwner()->getParentOp();
+
+ bool inThenBlock = ifOp.thenBlock()->findAncestorOpInBlock(*valueOp);
+ bool inElseBlock = ifOp.elseBlock()->findAncestorOpInBlock(*valueOp);
+
+ if (!inThenBlock && !inElseBlock)
+ foundOutsideEquivalent = true;
+ });
+
+ if (!foundOutsideEquivalent) {
+ status = yieldOp->emitError()
+ << "Yield operand #" << operand.getOperandNumber()
+ << " does not bufferize to a buffer that is equivalent to a"
+ << " buffer defined outside of the scf::if op";
+ return WalkResult::interrupt();
+ }
}
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
index 2ce198b86bcec..5255bd2d7b000 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
@@ -97,7 +97,8 @@ void LinalgComprehensiveModuleBufferize::runOnOperation() {
// TODO: Find a way to enable this step automatically when bufferizing tensor
// dialect ops.
options.addPostAnalysisStep<tensor_ext::InplaceInsertSliceOpAnalysis>();
- options.addPostAnalysisStep<scf_ext::AssertDestinationPassingStyle>();
+ if (!allowReturnMemref)
+ options.addPostAnalysisStep<scf_ext::AssertDestinationPassingStyle>();
ModuleOp moduleOp = getOperation();
applyEnablingTransformations(moduleOp);
diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir
index f2fa7ce3e4bf4..8705cd1f1b1e0 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir
@@ -1,9 +1,9 @@
-// RUN: mlir-opt %s -linalg-comprehensive-module-bufferize=test-analysis-only -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -linalg-comprehensive-module-bufferize="test-analysis-only allow-return-memref" -split-input-file | FileCheck %s
// Run fuzzer with
diff erent seeds.
-// RUN: mlir-opt %s -linalg-comprehensive-module-bufferize="test-analysis-only analysis-fuzzer-seed=23" -split-input-file -o /dev/null
-// RUN: mlir-opt %s -linalg-comprehensive-module-bufferize="test-analysis-only analysis-fuzzer-seed=59" -split-input-file -o /dev/null
-// RUN: mlir-opt %s -linalg-comprehensive-module-bufferize="test-analysis-only analysis-fuzzer-seed=91" -split-input-file -o /dev/null
+// RUN: mlir-opt %s -linalg-comprehensive-module-bufferize="test-analysis-only allow-return-memref analysis-fuzzer-seed=23" -split-input-file -o /dev/null
+// RUN: mlir-opt %s -linalg-comprehensive-module-bufferize="test-analysis-only allow-return-memref analysis-fuzzer-seed=59" -split-input-file -o /dev/null
+// RUN: mlir-opt %s -linalg-comprehensive-module-bufferize="test-analysis-only allow-return-memref analysis-fuzzer-seed=91" -split-input-file -o /dev/null
//===----------------------------------------------------------------------===//
// Simple cases
diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir
index edeb0c07da0f2..02431d9175a9c 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir
@@ -38,12 +38,12 @@ func @swappy(%cond1 : i1, %cond2 : i1, %t1 : tensor<f32>, %t2 : tensor<f32>)
func @scf_if_not_equivalent(
%cond: i1, %t1: tensor<?xf32> {linalg.inplaceable = true},
%idx: index) -> tensor<?xf32> {
- // expected-error @+1 {{result buffer is ambiguous}}
%r = scf.if %cond -> (tensor<?xf32>) {
scf.yield %t1 : tensor<?xf32>
} else {
// This buffer aliases, but is not equivalent.
%t2 = tensor.extract_slice %t1 [%idx] [%idx] [1] : tensor<?xf32> to tensor<?xf32>
+ // expected-error @+1 {{Yield operand #0 does not bufferize to a buffer that is equivalent to a buffer defined outside of the scf::if op}}
scf.yield %t2 : tensor<?xf32>
}
return %r : tensor<?xf32>
@@ -127,9 +127,9 @@ func @extract_slice_fun(%A : tensor<?xf32> {linalg.inplaceable = true})
// -----
+// expected-error @+1 {{memref return type is unsupported}}
func @scf_yield(%b : i1, %A : tensor<4xf32>, %B : tensor<4xf32>) -> tensor<4xf32>
{
- // expected-error @+1 {{result buffer is ambiguous}}
%r = scf.if %b -> (tensor<4xf32>) {
scf.yield %A : tensor<4xf32>
} else {
diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-partial.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-partial.mlir
index ec2c33f3c949a..de8717f22d72d 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-partial.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-partial.mlir
@@ -194,3 +194,28 @@ func @simple_scf_for(
// CHECK-SCF: return %[[scf_for_tensor]]
return %0 : tensor<?xf32>
}
+
+// -----
+
+// CHECK-SCF-LABEL: func @simple_scf_if(
+// CHECK-SCF-SAME: %[[t1:.*]]: tensor<?xf32> {linalg.inplaceable = true}, %[[c:.*]]: i1, %[[pos:.*]]: index
+func @simple_scf_if(%t1: tensor<?xf32> {linalg.inplaceable = true}, %c: i1, %pos: index, %f: f32)
+ -> (tensor<?xf32>, index) {
+ // CHECK-SCF: %[[r:.*]] = scf.if %[[c]] -> (memref<?xf32, #{{.*}}>) {
+ %r1, %r2 = scf.if %c -> (tensor<?xf32>, index) {
+ // CHECK-SCF: %[[t1_memref:.*]] = bufferization.to_memref %[[t1]]
+ // CHECK-SCF: scf.yield %[[t1_memref]]
+ scf.yield %t1, %pos : tensor<?xf32>, index
+ // CHECK-SCF: } else {
+ } else {
+ // CHECK-SCF: %[[insert:.*]] = tensor.insert %{{.*}} into %[[t1]][{{.*}}]
+ // CHECK-SCF: %[[insert_memref:.*]] = bufferization.to_memref %[[insert]]
+ %1 = tensor.insert %f into %t1[%pos] : tensor<?xf32>
+ // CHECK-SCF: scf.yield %[[insert_memref]]
+ scf.yield %1, %pos : tensor<?xf32>, index
+ }
+
+ // CHECK-SCF: %[[r_tensor:.*]] = bufferization.to_tensor %[[r]]
+ // CHECK-SCF: return %[[r_tensor]], %[[pos]]
+ return %r1, %r2 : tensor<?xf32>, index
+}
diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
index 970a9b54b2883..1094c21ed0537 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
@@ -921,6 +921,22 @@ func @scf_if_inside_scf_for(%t1: tensor<?xf32> {linalg.inplaceable = true},
// -----
+// CHECK-LABEL: func @scf_if_non_equiv_yields(
+// CHECK-SAME: %[[cond:.*]]: i1, %[[A:.*]]: memref<{{.*}}>, %[[B:.*]]: memref<{{.*}}>) -> memref<{{.*}}>
+func @scf_if_non_equiv_yields(%b : i1, %A : tensor<4xf32>, %B : tensor<4xf32>) -> tensor<4xf32>
+{
+ // CHECK: %[[r:.*]] = select %[[cond]], %[[A]], %[[B]]
+ %r = scf.if %b -> (tensor<4xf32>) {
+ scf.yield %A : tensor<4xf32>
+ } else {
+ scf.yield %B : tensor<4xf32>
+ }
+ // CHECK: return %[[r]]
+ return %r: tensor<4xf32>
+}
+
+// -----
+
// CHECK-LABEL: func @insert_op
// CHECK-SAME: %[[t1:.*]]: memref<?xf32, {{.*}}>, %[[s:.*]]: f32, %[[i:.*]]: index
func @insert_op(%t1 : tensor<?xf32> {linalg.inplaceable = true},
diff --git a/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp b/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp
index d5140eaf91652..b69e8a3738aea 100644
--- a/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp
@@ -101,6 +101,8 @@ void TestComprehensiveFunctionBufferize::runOnFunction() {
// TODO: Find a way to enable this step automatically when bufferizing
// tensor dialect ops.
options.addPostAnalysisStep<tensor_ext::InplaceInsertSliceOpAnalysis>();
+ if (!allowReturnMemref)
+ options.addPostAnalysisStep<scf_ext::AssertDestinationPassingStyle>();
options.allowReturnMemref = allowReturnMemref;
options.allowUnknownOps = allowUnknownOps;
More information about the Mlir-commits
mailing list