[Mlir-commits] [mlir] 1ce7040 - [mlir] Properly handle recursive bufferization for scf.for/scf.if
Sean Silva
llvmlistbot at llvm.org
Wed Oct 28 14:21:07 PDT 2020
Author: Sean Silva
Date: 2020-10-28T14:16:56-07:00
New Revision: 1ce7040359a92a63b7b8cda2b6635627a4428399
URL: https://github.com/llvm/llvm-project/commit/1ce7040359a92a63b7b8cda2b6635627a4428399
DIFF: https://github.com/llvm/llvm-project/commit/1ce7040359a92a63b7b8cda2b6635627a4428399.diff
LOG: [mlir] Properly handle recursive bufferization for scf.for/scf.if
This fixes a subtle issue, described in the comment starting with
"Clone the op without the regions and inline the regions from the old op",
which prevented this conversion from working on non-trivial examples.
Differential Revision: https://reviews.llvm.org/D90203
Added:
Modified:
mlir/lib/Dialect/SCF/Transforms/Bufferize.cpp
mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
mlir/test/Dialect/SCF/bufferize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SCF/Transforms/Bufferize.cpp b/mlir/lib/Dialect/SCF/Transforms/Bufferize.cpp
index 7cf0dfabd917..57d605b3491f 100644
--- a/mlir/lib/Dialect/SCF/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/Bufferize.cpp
@@ -27,6 +27,21 @@ struct SCFBufferizePass : public SCFBufferizeBase<SCFBufferizePass> {
OwningRewritePatternList patterns;
ConversionTarget target(*context);
+ // TODO: Move this to BufferizeTypeConverter's constructor.
+ //
+ // This doesn't currently play well with "finalizing" bufferizations (ones
+ // that expect all materializations to be gone). In particular, there seems
+ // to at least be a double-free in the dialect conversion framework
+ // when this materialization gets inserted and then folded away because
+ // it is marked as illegal.
+ typeConverter.addArgumentMaterialization(
+ [](OpBuilder &builder, RankedTensorType type, ValueRange inputs,
+ Location loc) -> Value {
+ assert(inputs.size() == 1);
+ assert(inputs[0].getType().isa<BaseMemRefType>());
+ return builder.create<TensorLoadOp>(loc, type, inputs[0]);
+ });
+
populateBufferizeMaterializationLegality(target);
populateSCFStructuralTypeConversionsAndLegality(context, typeConverter,
patterns, target);
diff --git a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
index 30a2272f39a2..4ad6d4116a76 100644
--- a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
@@ -31,16 +31,44 @@ class ConvertForOpTypes : public OpConversionPattern<ForOp> {
newResultTypes.push_back(newType);
}
- // Clone and replace.
- ForOp newOp = cast<ForOp>(rewriter.clone(*op.getOperation()));
+ // Clone the op without the regions and inline the regions from the old op.
+ //
+ // This is a little bit tricky. We have two concerns here:
+ //
+ // 1. We cannot update the op in place because the dialect conversion
+ // framework does not track type changes for ops updated in place, so it
+ // won't insert appropriate materializations on the changed result types.
+ // PR47938 tracks this issue, but it seems hard to fix. Instead, we need to
+ // clone the op.
+ //
+ // 2. We cannot simply call `op.clone()` to get the cloned op. Besides being
+ // inefficient to recursively clone the regions, there is a correctness
+ // issue: if we clone with the regions, then the dialect conversion
+ // framework thinks that we just inserted all the cloned child ops. But what
+ // we want is to "take" the child regions and let the dialect conversion
+ // framework continue recursively into ops inside those regions (which are
+ // already in its worklist; inlining them into the new op's regions doesn't
+ // remove the child ops from the worklist).
+ ForOp newOp = cast<ForOp>(rewriter.cloneWithoutRegions(*op.getOperation()));
+ // Take the region from the old op and put it in the new op.
+ rewriter.inlineRegionBefore(op.getLoopBody(), newOp.getLoopBody(),
+ newOp.getLoopBody().end());
+
+ // Now, update all the types.
+
+ // Convert the type of the entry block of the ForOp's body.
+ if (failed(rewriter.convertRegionTypes(&newOp.getLoopBody(),
+ *getTypeConverter()))) {
+ return rewriter.notifyMatchFailure(op, "could not convert body types");
+ }
+ // Change the clone to use the updated operands. We could have cloned with
+ // a BlockAndValueMapping, but this seems a bit more direct.
newOp.getOperation()->setOperands(operands);
+ // Update the result types to the new converted types.
for (auto t : llvm::zip(newOp.getResults(), newResultTypes))
std::get<0>(t).setType(std::get<1>(t));
- auto bodyArgs = newOp.getBody()->getArguments();
- for (auto t : llvm::zip(llvm::drop_begin(bodyArgs, 1), newResultTypes))
- std::get<0>(t).setType(std::get<1>(t));
- rewriter.replaceOp(op, newOp.getResults());
+ rewriter.replaceOp(op, newOp.getResults());
return success();
}
};
@@ -71,9 +99,15 @@ class ConvertIfOpTypes : public OpConversionPattern<IfOp> {
newResultTypes.push_back(newType);
}
- // TODO: Write this with updateRootInPlace once the conversion infra
- // supports source materializations on ops updated in place.
- IfOp newOp = cast<IfOp>(rewriter.clone(*op.getOperation()));
+ // See comments in the ForOp pattern for why we clone without regions and
+ // then inline.
+ IfOp newOp = cast<IfOp>(rewriter.cloneWithoutRegions(*op.getOperation()));
+ rewriter.inlineRegionBefore(op.thenRegion(), newOp.thenRegion(),
+ newOp.thenRegion().end());
+ rewriter.inlineRegionBefore(op.elseRegion(), newOp.elseRegion(),
+ newOp.elseRegion().end());
+
+ // Update the operands and types.
newOp.getOperation()->setOperands(operands);
for (auto t : llvm::zip(newOp.getResults(), newResultTypes))
std::get<0>(t).setType(std::get<1>(t));
diff --git a/mlir/test/Dialect/SCF/bufferize.mlir b/mlir/test/Dialect/SCF/bufferize.mlir
index 01b353da83ed..7f42e15107d6 100644
--- a/mlir/test/Dialect/SCF/bufferize.mlir
+++ b/mlir/test/Dialect/SCF/bufferize.mlir
@@ -29,7 +29,9 @@ func @if(%pred: i1, %true_val: tensor<?xf32>, %false_val: tensor<?xf32>) -> tens
// CHECK-SAME: %[[STEP:.*]]: index) -> tensor<f32> {
// CHECK: %[[MEMREF:.*]] = tensor_to_memref %[[TENSOR]] : memref<f32>
// CHECK: %[[RESULT_MEMREF:.*]] = scf.for %[[VAL_6:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args(%[[ITER:.*]] = %[[MEMREF]]) -> (memref<f32>) {
-// CHECK: scf.yield %[[ITER]] : memref<f32>
+// CHECK: %[[TENSOR_ITER:.*]] = tensor_load %[[ITER]] : memref<f32>
+// CHECK: %[[MEMREF_YIELDED:.*]] = tensor_to_memref %[[TENSOR_ITER]] : memref<f32>
+// CHECK: scf.yield %[[MEMREF_YIELDED]] : memref<f32>
// CHECK: }
// CHECK: %[[VAL_8:.*]] = tensor_load %[[VAL_9:.*]] : memref<f32>
// CHECK: return %[[VAL_8]] : tensor<f32>
@@ -40,3 +42,40 @@ func @for(%arg0: tensor<f32>, %lb: index, %ub: index, %step: index) -> tensor<f3
}
return %ret : tensor<f32>
}
+
+// Check whether this converts at all.
+//
+// It would previously fail altogether.
+// CHECK-LABEL: func @if_correct_recursive_legalization_behavior
+// CHECK: "test.munge_tensor"
+func @if_correct_recursive_legalization_behavior(%pred: i1, %tensor: tensor<f32>) -> tensor<f32> {
+ %0 = scf.if %pred -> (tensor<f32>) {
+ %1 = "test.munge_tensor"(%tensor) : (tensor<f32>) -> (tensor<f32>)
+ scf.yield %1: tensor<f32>
+ } else {
+ %1 = "test.munge_tensor"(%tensor) : (tensor<f32>) -> (tensor<f32>)
+ scf.yield %1 : tensor<f32>
+ }
+ return %0 : tensor<f32>
+}
+
+// CHECK-LABEL: func @for_correct_recursive_legalization_behavior(
+// CHECK-SAME: %[[TENSOR:.*]]: tensor<f32>,
+// CHECK-SAME: %[[INDEX:.*]]: index) -> tensor<f32> {
+// CHECK: %[[MEMREF:.*]] = tensor_to_memref %[[TENSOR]] : memref<f32>
+// CHECK: %[[RESULT:.*]] = scf.for %[[IV:.*]] = %[[INDEX]] to %[[INDEX]] step %[[INDEX]] iter_args(%[[MEMREF_ITER:.*]] = %[[MEMREF]]) -> (memref<f32>) {
+// CHECK: %[[TENSOR_ITER:.*]] = tensor_load %[[MEMREF_ITER]] : memref<f32>
+// CHECK: %[[TENSOR_MUNGED:.*]] = "test.munge_tensor"(%[[TENSOR_ITER]]) : (tensor<f32>) -> tensor<f32>
+// CHECK: %[[MEMREF_MUNGED:.*]] = tensor_to_memref %[[TENSOR_MUNGED]] : memref<f32>
+// CHECK: scf.yield %[[MEMREF_MUNGED]] : memref<f32>
+// CHECK: }
+// CHECK: %[[TENSOR:.*]] = tensor_load %[[RESULT:.*]] : memref<f32>
+// CHECK: return %[[TENSOR]] : tensor<f32>
+// CHECK: }
+func @for_correct_recursive_legalization_behavior(%arg0: tensor<f32>, %index: index) -> tensor<f32> {
+ %ret = scf.for %iv = %index to %index step %index iter_args(%iter = %arg0) -> tensor<f32> {
+ %0 = "test.munge_tensor"(%iter) : (tensor<f32>) -> (tensor<f32>)
+ scf.yield %0 : tensor<f32>
+ }
+ return %ret : tensor<f32>
+}
More information about the Mlir-commits
mailing list