[Mlir-commits] [mlir] b83c67d - [mlir][linalg][bufferize] Support scf.execute_region bufferization
Matthias Springer
llvmlistbot at llvm.org
Wed Jan 19 01:17:26 PST 2022
Author: Matthias Springer
Date: 2022-01-19T18:17:09+09:00
New Revision: b83c67d978942320b1fb7d814ae43c98f95ebe44
URL: https://github.com/llvm/llvm-project/commit/b83c67d978942320b1fb7d814ae43c98f95ebe44
DIFF: https://github.com/llvm/llvm-project/commit/b83c67d978942320b1fb7d814ae43c98f95ebe44.diff
LOG: [mlir][linalg][bufferize] Support scf.execute_region bufferization
This op is needed for unit testing in a subsequent revision. (This is the first op that has a block that yields equivalent values via the op's results.)
Note: Bufferization of scf.execute_region ops with multiple blocks is not yet supported.
Differential Revision: https://reviews.llvm.org/D117424
Added:
Modified:
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir
mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
index be55eca51986d..2fd89de1e4aa0 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
@@ -44,6 +44,7 @@ struct ExecuteRegionOpInterface
auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
size_t resultNum = std::distance(op->getOpResults().begin(),
llvm::find(op->getOpResults(), opResult));
+ // TODO: Support multiple blocks.
assert(executeRegionOp.getRegion().getBlocks().size() == 1 &&
"expected exactly 1 block");
auto yieldOp = dyn_cast<scf::YieldOp>(
@@ -66,13 +67,59 @@ struct ExecuteRegionOpInterface
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationState &state) const {
- // TODO: Add bufferization support when needed. scf.execute_region should be
- // bufferized similar to scf.if.
- bool hasTensorReturnType = any_of(
- op->getResultTypes(), [](Type t) { return t.isa<TensorType>(); });
- if (hasTensorReturnType)
- return op->emitError(
- "scf.execute_region with tensor result not supported");
+ auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
+
+ // Compute new result types.
+ SmallVector<Type> newResultTypes;
+ for (Type type : executeRegionOp->getResultTypes()) {
+ if (auto rankedTensorType = type.dyn_cast<RankedTensorType>()) {
+ newResultTypes.push_back(getDynamicMemRefType(rankedTensorType));
+ } else if (auto tensorType = type.dyn_cast<TensorType>()) {
+ newResultTypes.push_back(
+ getUnrankedMemRefType(tensorType.getElementType()));
+ } else {
+ newResultTypes.push_back(type);
+ }
+ }
+
+ // Create new op and move over region.
+ auto newOp =
+ rewriter.create<scf::ExecuteRegionOp>(op->getLoc(), newResultTypes);
+ newOp.getRegion().takeBody(executeRegionOp.getRegion());
+
+ // Update terminator.
+ assert(newOp.getRegion().getBlocks().size() == 1 &&
+ "only 1 block supported");
+ Block *newBlock = &newOp.getRegion().front();
+ auto yieldOp = cast<scf::YieldOp>(newBlock->getTerminator());
+ rewriter.setInsertionPoint(yieldOp);
+ SmallVector<Value> newYieldValues;
+ for (auto it : llvm::enumerate(yieldOp.getResults())) {
+ Value val = it.value();
+ if (val.getType().isa<TensorType>()) {
+ newYieldValues.push_back(rewriter.create<bufferization::ToMemrefOp>(
+ yieldOp.getLoc(), newResultTypes[it.index()], val));
+ } else {
+ newYieldValues.push_back(val);
+ }
+ }
+ rewriter.replaceOpWithNewOp<scf::YieldOp>(yieldOp, newYieldValues);
+
+ // Update all uses of the old op.
+ rewriter.setInsertionPointAfter(newOp);
+ SmallVector<Value> newResults;
+ for (auto it : llvm::enumerate(executeRegionOp->getResultTypes())) {
+ if (it.value().isa<TensorType>()) {
+ newResults.push_back(rewriter.create<bufferization::ToTensorOp>(
+ executeRegionOp.getLoc(), newOp->getResult(it.index())));
+ } else {
+ newResults.push_back(newOp->getResult(it.index()));
+ }
+ }
+
+ // Replace old op.
+ rewriter.replaceOp(executeRegionOp, newResults);
+
return success();
}
diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir
index 5cf7612b58a68..f9a809ea15784 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir
@@ -159,8 +159,8 @@ func @mini_test_case1() -> tensor<10x20xf32> {
// -----
+// expected-error @+1 {{memref return type is unsupported}}
func @main() -> tensor<4xi32> {
- // expected-error @+1 {{scf.execute_region with tensor result not supported}}
%r = scf.execute_region -> tensor<4xi32> {
%A = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32>
scf.yield %A: tensor<4xi32>
diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
index b37c3066a0d65..e6c521ddfdc24 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
@@ -446,6 +446,59 @@ func @main() {
// -----
+// CHECK-LABEL: func @execute_region_test(
+// CHECK-SAME: %[[m1:.*]]: memref<?xf32
+func @execute_region_test(%t1 : tensor<?xf32> {linalg.inplaceable = "true"})
+ -> (f32, tensor<?xf32>, f32)
+{
+ %f1 = arith.constant 0.0 : f32
+ %f2 = arith.constant 1.0 : f32
+ %idx = arith.constant 7 : index
+
+ // scf.execute_region is canonicalized away after bufferization. So just the
+ // memref.store is left over.
+
+ // CHECK: memref.store %{{.*}}, %[[m1]][%{{.*}}]
+ %0, %1, %2 = scf.execute_region -> (f32, tensor<?xf32>, f32) {
+ %t2 = tensor.insert %f2 into %t1[%idx] : tensor<?xf32>
+ scf.yield %f1, %t2, %f2 : f32, tensor<?xf32>, f32
+ }
+
+ // CHECK: return %{{.*}}, %{{.*}} : f32, f32
+ return %0, %1, %2 : f32, tensor<?xf32>, f32
+}
+
+// -----
+
+// CHECK-LABEL: func @execute_region_with_conflict(
+// CHECK-SAME: %[[m1:.*]]: memref<?xf32
+func @execute_region_with_conflict(%t1 : tensor<?xf32> {linalg.inplaceable = "true"})
+ -> (f32, tensor<?xf32>, f32)
+{
+ %f1 = arith.constant 0.0 : f32
+ %idx = arith.constant 7 : index
+
+ // scf.execute_region is canonicalized away after bufferization. So just the
+ // memref.store is left over.
+
+ // CHECK: %[[alloc:.*]] = memref.alloc
+ // CHECK: %[[casted:.*]] = memref.cast %[[alloc]]
+ // CHECK: memref.copy %[[m1]], %[[alloc]]
+ // CHECK: memref.store %{{.*}}, %[[alloc]][%{{.*}}]
+ %0, %1, %2 = scf.execute_region -> (f32, tensor<?xf32>, f32) {
+ %t2 = tensor.insert %f1 into %t1[%idx] : tensor<?xf32>
+ scf.yield %f1, %t2, %f1 : f32, tensor<?xf32>, f32
+ }
+
+ // CHECK: %[[load:.*]] = memref.load %[[m1]]
+ %3 = tensor.extract %t1[%idx] : tensor<?xf32>
+
+ // CHECK: return %{{.*}}, %[[casted]], %[[load]] : f32, memref<?xf32, #{{.*}}>, f32
+ return %0, %1, %3 : f32, tensor<?xf32>, f32
+}
+
+// -----
+
// CHECK: #[[$DYN_1D_MAP:.*]] = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)>
// CHECK: func private @some_external_func(memref<?xf32, #[[$DYN_1D_MAP]]>)
More information about the Mlir-commits
mailing list