[Mlir-commits] [mlir] 1ce7040 - [mlir] Properly handle recursive bufferization for scf.for/scf.if

Sean Silva llvmlistbot at llvm.org
Wed Oct 28 14:21:07 PDT 2020


Author: Sean Silva
Date: 2020-10-28T14:16:56-07:00
New Revision: 1ce7040359a92a63b7b8cda2b6635627a4428399

URL: https://github.com/llvm/llvm-project/commit/1ce7040359a92a63b7b8cda2b6635627a4428399
DIFF: https://github.com/llvm/llvm-project/commit/1ce7040359a92a63b7b8cda2b6635627a4428399.diff

LOG: [mlir] Properly handle recursive bufferization for scf.for/scf.if

This fixes a subtle issue, described in the comment starting with
"Clone the op without the regions and inline the regions from the old op",
which prevented this conversion from working on non-trivial examples.

Differential Revision: https://reviews.llvm.org/D90203

Added: 
    

Modified: 
    mlir/lib/Dialect/SCF/Transforms/Bufferize.cpp
    mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
    mlir/test/Dialect/SCF/bufferize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SCF/Transforms/Bufferize.cpp b/mlir/lib/Dialect/SCF/Transforms/Bufferize.cpp
index 7cf0dfabd917..57d605b3491f 100644
--- a/mlir/lib/Dialect/SCF/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/Bufferize.cpp
@@ -27,6 +27,21 @@ struct SCFBufferizePass : public SCFBufferizeBase<SCFBufferizePass> {
     OwningRewritePatternList patterns;
     ConversionTarget target(*context);
 
+    // TODO: Move this to BufferizeTypeConverter's constructor.
+    //
+    // This doesn't currently play well with "finalizing" bufferizations (ones
+    // that expect all materializations to be gone). In particular, there seems
+    // to at least be a double-free in the dialect conversion framework
+    // when this materialization gets inserted and then folded away because
+    // it is marked as illegal.
+    typeConverter.addArgumentMaterialization(
+        [](OpBuilder &builder, RankedTensorType type, ValueRange inputs,
+           Location loc) -> Value {
+          assert(inputs.size() == 1);
+          assert(inputs[0].getType().isa<BaseMemRefType>());
+          return builder.create<TensorLoadOp>(loc, type, inputs[0]);
+        });
+
     populateBufferizeMaterializationLegality(target);
     populateSCFStructuralTypeConversionsAndLegality(context, typeConverter,
                                                     patterns, target);

diff  --git a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
index 30a2272f39a2..4ad6d4116a76 100644
--- a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
@@ -31,16 +31,44 @@ class ConvertForOpTypes : public OpConversionPattern<ForOp> {
       newResultTypes.push_back(newType);
     }
 
-    // Clone and replace.
-    ForOp newOp = cast<ForOp>(rewriter.clone(*op.getOperation()));
+    // Clone the op without the regions and inline the regions from the old op.
+    //
+    // This is a little bit tricky. We have two concerns here:
+    //
+    // 1. We cannot update the op in place because the dialect conversion
+    // framework does not track type changes for ops updated in place, so it
+    // won't insert appropriate materializations on the changed result types.
+    // PR47938 tracks this issue, but it seems hard to fix. Instead, we need to
+    // clone the op.
+    //
+    // 2. We cannot simply call `op.clone()` to get the cloned op. Besides being
+    // inefficient to recursively clone the regions, there is a correctness
+    // issue: if we clone with the regions, then the dialect conversion
+    // framework thinks that we just inserted all the cloned child ops. But what
+    // we want is to "take" the child regions and let the dialect conversion
+    // framework continue recursively into ops inside those regions (which are
+    // already in its worklist; inlining them into the new op's regions doesn't
+    // remove the child ops from the worklist).
+    ForOp newOp = cast<ForOp>(rewriter.cloneWithoutRegions(*op.getOperation()));
+    // Take the region from the old op and put it in the new op.
+    rewriter.inlineRegionBefore(op.getLoopBody(), newOp.getLoopBody(),
+                                newOp.getLoopBody().end());
+
+    // Now, update all the types.
+
+    // Convert the type of the entry block of the ForOp's body.
+    if (failed(rewriter.convertRegionTypes(&newOp.getLoopBody(),
+                                           *getTypeConverter()))) {
+      return rewriter.notifyMatchFailure(op, "could not convert body types");
+    }
+    // Change the clone to use the updated operands. We could have cloned with
+    // a BlockAndValueMapping, but this seems a bit more direct.
     newOp.getOperation()->setOperands(operands);
+    // Update the result types to the new converted types.
     for (auto t : llvm::zip(newOp.getResults(), newResultTypes))
       std::get<0>(t).setType(std::get<1>(t));
-    auto bodyArgs = newOp.getBody()->getArguments();
-    for (auto t : llvm::zip(llvm::drop_begin(bodyArgs, 1), newResultTypes))
-      std::get<0>(t).setType(std::get<1>(t));
-    rewriter.replaceOp(op, newOp.getResults());
 
+    rewriter.replaceOp(op, newOp.getResults());
     return success();
   }
 };
@@ -71,9 +99,15 @@ class ConvertIfOpTypes : public OpConversionPattern<IfOp> {
       newResultTypes.push_back(newType);
     }
 
-    // TODO: Write this with updateRootInPlace once the conversion infra
-    // supports source materializations on ops updated in place.
-    IfOp newOp = cast<IfOp>(rewriter.clone(*op.getOperation()));
+    // See comments in the ForOp pattern for why we clone without regions and
+    // then inline.
+    IfOp newOp = cast<IfOp>(rewriter.cloneWithoutRegions(*op.getOperation()));
+    rewriter.inlineRegionBefore(op.thenRegion(), newOp.thenRegion(),
+                                newOp.thenRegion().end());
+    rewriter.inlineRegionBefore(op.elseRegion(), newOp.elseRegion(),
+                                newOp.elseRegion().end());
+
+    // Update the operands and types.
     newOp.getOperation()->setOperands(operands);
     for (auto t : llvm::zip(newOp.getResults(), newResultTypes))
       std::get<0>(t).setType(std::get<1>(t));

diff  --git a/mlir/test/Dialect/SCF/bufferize.mlir b/mlir/test/Dialect/SCF/bufferize.mlir
index 01b353da83ed..7f42e15107d6 100644
--- a/mlir/test/Dialect/SCF/bufferize.mlir
+++ b/mlir/test/Dialect/SCF/bufferize.mlir
@@ -29,7 +29,9 @@ func @if(%pred: i1, %true_val: tensor<?xf32>, %false_val: tensor<?xf32>) -> tens
 // CHECK-SAME:              %[[STEP:.*]]: index) -> tensor<f32> {
 // CHECK:           %[[MEMREF:.*]] = tensor_to_memref %[[TENSOR]] : memref<f32>
 // CHECK:           %[[RESULT_MEMREF:.*]] = scf.for %[[VAL_6:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args(%[[ITER:.*]] = %[[MEMREF]]) -> (memref<f32>) {
-// CHECK:             scf.yield %[[ITER]] : memref<f32>
+// CHECK:             %[[TENSOR_ITER:.*]] = tensor_load %[[ITER]] : memref<f32>
+// CHECK:             %[[MEMREF_YIELDED:.*]] = tensor_to_memref %[[TENSOR_ITER]] : memref<f32>
+// CHECK:             scf.yield %[[MEMREF_YIELDED]] : memref<f32>
 // CHECK:           }
 // CHECK:           %[[VAL_8:.*]] = tensor_load %[[VAL_9:.*]] : memref<f32>
 // CHECK:           return %[[VAL_8]] : tensor<f32>
@@ -40,3 +42,40 @@ func @for(%arg0: tensor<f32>, %lb: index, %ub: index, %step: index) -> tensor<f3
   }
   return %ret : tensor<f32>
 }
+
+// Check whether this converts at all.
+//
+// It would previously fail altogether.
+// CHECK-LABEL:   func @if_correct_recursive_legalization_behavior
+// CHECK: "test.munge_tensor"
+func @if_correct_recursive_legalization_behavior(%pred: i1, %tensor: tensor<f32>) -> tensor<f32> {
+  %0 = scf.if %pred -> (tensor<f32>) {
+    %1 = "test.munge_tensor"(%tensor) : (tensor<f32>) -> (tensor<f32>)
+    scf.yield %1: tensor<f32>
+  } else {
+    %1 = "test.munge_tensor"(%tensor) : (tensor<f32>) -> (tensor<f32>)
+    scf.yield %1 : tensor<f32>
+  }
+  return %0 : tensor<f32>
+}
+
+// CHECK-LABEL:   func @for_correct_recursive_legalization_behavior(
+// CHECK-SAME:                                                      %[[TENSOR:.*]]: tensor<f32>,
+// CHECK-SAME:                                                      %[[INDEX:.*]]: index) -> tensor<f32> {
+// CHECK:           %[[MEMREF:.*]] = tensor_to_memref %[[TENSOR]] : memref<f32>
+// CHECK:           %[[RESULT:.*]] = scf.for %[[IV:.*]] = %[[INDEX]] to %[[INDEX]] step %[[INDEX]] iter_args(%[[MEMREF_ITER:.*]] = %[[MEMREF]]) -> (memref<f32>) {
+// CHECK:             %[[TENSOR_ITER:.*]] = tensor_load %[[MEMREF_ITER]] : memref<f32>
+// CHECK:             %[[TENSOR_MUNGED:.*]] = "test.munge_tensor"(%[[TENSOR_ITER]]) : (tensor<f32>) -> tensor<f32>
+// CHECK:             %[[MEMREF_MUNGED:.*]] = tensor_to_memref %[[TENSOR_MUNGED]] : memref<f32>
+// CHECK:             scf.yield %[[MEMREF_MUNGED]] : memref<f32>
+// CHECK:           }
+// CHECK:           %[[TENSOR:.*]] = tensor_load %[[RESULT:.*]] : memref<f32>
+// CHECK:           return %[[TENSOR]] : tensor<f32>
+// CHECK:         }
+func @for_correct_recursive_legalization_behavior(%arg0: tensor<f32>, %index: index) -> tensor<f32> {
+  %ret = scf.for %iv = %index to %index step %index iter_args(%iter = %arg0) -> tensor<f32> {
+    %0 = "test.munge_tensor"(%iter) : (tensor<f32>) -> (tensor<f32>)
+    scf.yield %0 : tensor<f32>
+  }
+  return %ret : tensor<f32>
+}


        


More information about the Mlir-commits mailing list