[Mlir-commits] [mlir] 774f1d3 - [mlir] Small cleanups to func-bufferize/finalizing-bufferize

Sean Silva llvmlistbot at llvm.org
Mon Nov 30 17:05:24 PST 2020


Author: Sean Silva
Date: 2020-11-30T17:04:14-08:00
New Revision: 774f1d3ffd458d6cb82d5039758ef1cf6370957f

URL: https://github.com/llvm/llvm-project/commit/774f1d3ffd458d6cb82d5039758ef1cf6370957f
DIFF: https://github.com/llvm/llvm-project/commit/774f1d3ffd458d6cb82d5039758ef1cf6370957f.diff

LOG: [mlir] Small cleanups to func-bufferize/finalizing-bufferize

- Address TODO in scf-bufferize: the argument materialization issue is
  now fixed and the code is now in Transforms/Bufferize.cpp
- Tighten up finalizing-bufferize to avoid creating invalid IR when
  operand types potentially change
- Tidy up the testing of func-bufferize, and move appropriate tests
  to a new finalizing-bufferize.mlir
- The new stricter checking in finalizing-bufferize revealed that we
  needed a DimOp conversion pattern (found when integrating into npcomp).
  Previously, the converion infrastructure was blindly changing the
  operand type during finalization, which happened to work due to
  DimOp's tensor/memref polymorphism, but is generally not encouraged
  (the new pattern is the way to tell the conversion infrastructure that
  it is legal to change that type).

Added: 
    mlir/test/Transforms/finalizing-bufferize.mlir

Modified: 
    mlir/lib/Dialect/SCF/Transforms/Bufferize.cpp
    mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp
    mlir/lib/Transforms/Bufferize.cpp
    mlir/test/Dialect/Standard/bufferize.mlir
    mlir/test/Dialect/Standard/func-bufferize.mlir

Removed: 
    mlir/test/Dialect/Standard/func-bufferize-partial.mlir


################################################################################
diff  --git a/mlir/lib/Dialect/SCF/Transforms/Bufferize.cpp b/mlir/lib/Dialect/SCF/Transforms/Bufferize.cpp
index 57d605b3491f..7cf0dfabd917 100644
--- a/mlir/lib/Dialect/SCF/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/Bufferize.cpp
@@ -27,21 +27,6 @@ struct SCFBufferizePass : public SCFBufferizeBase<SCFBufferizePass> {
     OwningRewritePatternList patterns;
     ConversionTarget target(*context);
 
-    // TODO: Move this to BufferizeTypeConverter's constructor.
-    //
-    // This doesn't currently play well with "finalizing" bufferizations (ones
-    // that expect all materializations to be gone). In particular, there seems
-    // to at least be a double-free in the dialect conversion framework
-    // when this materialization gets inserted and then folded away because
-    // it is marked as illegal.
-    typeConverter.addArgumentMaterialization(
-        [](OpBuilder &builder, RankedTensorType type, ValueRange inputs,
-           Location loc) -> Value {
-          assert(inputs.size() == 1);
-          assert(inputs[0].getType().isa<BaseMemRefType>());
-          return builder.create<TensorLoadOp>(loc, type, inputs[0]);
-        });
-
     populateBufferizeMaterializationLegality(target);
     populateSCFStructuralTypeConversionsAndLegality(context, typeConverter,
                                                     patterns, target);

diff  --git a/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp b/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp
index 9056fbc25e14..8b47e88677e2 100644
--- a/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp
@@ -20,6 +20,21 @@
 
 using namespace mlir;
 
+namespace {
+class BufferizeDimOp : public OpConversionPattern<DimOp> {
+public:
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(DimOp op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    DimOp::Adaptor adaptor(operands);
+    rewriter.replaceOpWithNewOp<DimOp>(op, adaptor.memrefOrTensor(),
+                                       adaptor.index());
+    return success();
+  }
+};
+} // namespace
+
 namespace {
 class BufferizeDynamicTensorFromElementsOp
     : public OpConversionPattern<DynamicTensorFromElementsOp> {
@@ -148,6 +163,7 @@ void mlir::populateStdBufferizePatterns(MLIRContext *context,
                                         OwningRewritePatternList &patterns) {
   patterns.insert<
       // clang-format off
+      BufferizeDimOp,
       BufferizeDynamicTensorFromElementsOp,
       BufferizeExtractElementOp,
       BufferizeSelectOp,
@@ -178,6 +194,8 @@ struct StdBufferizePass : public StdBufferizeBase<StdBufferizePass> {
       return typeConverter.isLegal(op.getType()) ||
              !op.condition().getType().isa<IntegerType>();
     });
+    target.addDynamicallyLegalOp<DimOp>(
+        [&](DimOp op) { return typeConverter.isLegal(op); });
     if (failed(
             applyPartialConversion(getFunction(), target, std::move(patterns))))
       signalPassFailure();

diff  --git a/mlir/lib/Transforms/Bufferize.cpp b/mlir/lib/Transforms/Bufferize.cpp
index 1811ac8bdfbc..66b1cc65646c 100644
--- a/mlir/lib/Transforms/Bufferize.cpp
+++ b/mlir/lib/Transforms/Bufferize.cpp
@@ -105,13 +105,17 @@ struct FinalizingBufferizePass
 
     populateEliminateBufferizeMaterializationsPatterns(context, typeConverter,
                                                        patterns);
-    target.addIllegalOp<TensorLoadOp, TensorToMemrefOp>();
 
     // If all result types are legal, and all block arguments are legal (ensured
     // by func conversion above), then all types in the program are legal.
-    target.markUnknownOpDynamicallyLegal([&](Operation *op) {
-      return typeConverter.isLegal(op->getResultTypes());
-    });
+    //
+    // We also check that the operand types are legal to avoid creating invalid
+    // IR. For example, this prevents
+    // populateEliminateBufferizeMaterializationsPatterns from updating the
+    // types of the operands to a return op without updating the enclosing
+    // function.
+    target.markUnknownOpDynamicallyLegal(
+        [&](Operation *op) { return typeConverter.isLegal(op); });
 
     if (failed(applyFullConversion(func, target, std::move(patterns))))
       signalPassFailure();

diff  --git a/mlir/test/Dialect/Standard/bufferize.mlir b/mlir/test/Dialect/Standard/bufferize.mlir
index 8cc05ff20644..27769c52d9ea 100644
--- a/mlir/test/Dialect/Standard/bufferize.mlir
+++ b/mlir/test/Dialect/Standard/bufferize.mlir
@@ -1,5 +1,16 @@
 // RUN: mlir-opt %s -std-bufferize | FileCheck %s
 
+// CHECK-LABEL:   func @dim(
+// CHECK-SAME:              %[[TENSOR:.*]]: tensor<f32>,
+// CHECK-SAME:              %[[INDEX:.*]]: index) -> index {
+// CHECK:           %[[MEMREF:.*]] = tensor_to_memref %[[TENSOR]] : memref<f32>
+// CHECK:           %[[EXTENT:.*]] = dim %[[MEMREF]], %[[INDEX]] : memref<f32>
+// CHECK:           return %[[EXTENT]] : index
+func @dim(%arg0: tensor<f32>, %arg1: index) -> index {
+  %0 = dim %arg0, %arg1 : tensor<f32>
+  return %0 : index
+}
+
 // CHECK-LABEL:   func @dynamic_tensor_from_elements(
 // CHECK-SAME:                                       %[[ARG:.*]]: tensor<*xf32>,
 // CHECK-SAME:                                       %[[DYNAMIC_EXTENT:.*]]: index) -> tensor<?xindex> {
@@ -7,7 +18,8 @@
 // CHECK:           %[[C0:.*]] = constant 0 : index
 // CHECK:           %[[C1:.*]] = constant 1 : index
 // CHECK:           scf.parallel (%[[I:.*]]) = (%[[C0]]) to (%[[DYNAMIC_EXTENT]]) step (%[[C1]]) {
-// CHECK:             %[[ELEM:.*]] = dim %[[ARG]], %[[I]] : tensor<*xf32>
+// CHECK:             %[[ARG_MEMREF:.*]] = tensor_to_memref %[[ARG]] : memref<*xf32>
+// CHECK:             %[[ELEM:.*]] = dim %[[ARG_MEMREF]], %[[I]] : memref<*xf32>
 // CHECK:             store %[[ELEM]], %[[MEMREF]][%[[I]]] : memref<?xindex>
 // CHECK:             scf.yield
 // CHECK:           }

diff  --git a/mlir/test/Dialect/Standard/func-bufferize-partial.mlir b/mlir/test/Dialect/Standard/func-bufferize-partial.mlir
deleted file mode 100644
index 43ea4591e4e3..000000000000
--- a/mlir/test/Dialect/Standard/func-bufferize-partial.mlir
+++ /dev/null
@@ -1,59 +0,0 @@
-// RUN: mlir-opt %s -func-bufferize -split-input-file -verify-diagnostics | FileCheck %s
-
-// CHECK-LABEL:   func @block_arguments(
-// CHECK-SAME:        %[[ARG:.*]]: memref<f32>) -> memref<f32> {
-// CHECK:           %[[T1:.*]] = tensor_load %[[ARG]] : memref<f32>
-// CHECK:           %[[M1:.*]] = tensor_to_memref %[[T1]] : memref<f32>
-// CHECK:           br ^bb1(%[[M1]] : memref<f32>)
-// CHECK:         ^bb1(%[[BBARG:.*]]: memref<f32>):
-// CHECK:           %[[T2:.*]] = tensor_load %[[BBARG]] : memref<f32>
-// CHECK:           %[[M2:.*]] = tensor_to_memref %[[T2]] : memref<f32>
-// CHECK:           return %[[M2]] : memref<f32>
-func @block_arguments(%arg0: tensor<f32>) -> tensor<f32> {
-  br ^bb1(%arg0: tensor<f32>)
-^bb1(%bbarg: tensor<f32>):
-  return %bbarg : tensor<f32>
-}
-
-// CHECK-LABEL: func @partial()
-// CHECK-SAME: memref<f32>
-func @partial() -> tensor<f32> {
-  // CHECK-NEXT: %[[SRC:.*]] = "test.source"() : () -> tensor<f32>
-  // CHECK-NEXT: %[[MEM:.*]] = tensor_to_memref %[[SRC]] : memref<f32>
-  %0 = "test.source"() : () -> tensor<f32>
-  // CHECK-NEXT: return %[[MEM]] : memref<f32>
-  return %0 : tensor<f32>
-}
-
-// CHECK-LABEL: func @region_op
-// CHECK-SAME: (%[[ARG0:.*]]: i1) -> memref<f32>
-func @region_op(%arg0: i1) -> tensor<f32> {
-  // CHECK-NEXT: %[[IF:.*]] = scf.if %[[ARG0]] -> (tensor<f32>)
-  %0 = scf.if %arg0 -> (tensor<f32>) {
-    // CHECK-NEXT: %[[SRC:.*]] = "test.source"() : () -> tensor<f32>
-    %1 = "test.source"() : () -> tensor<f32>
-    // CHECK-NEXT: scf.yield %[[SRC]] : tensor<f32>
-    scf.yield %1 : tensor<f32>
-  // CHECK-NEXT: else
-  } else {
-    // CHECK-NEXT: %[[OSRC:.*]] = "test.other_source"() : () -> tensor<f32>
-    %1 = "test.other_source"() : () -> tensor<f32>
-    // CHECK-NEXT: scf.yield %[[OSRC]] : tensor<f32>
-    scf.yield %1 : tensor<f32>
-  }
-  // CHECK: %[[MEM:.*]] = tensor_to_memref %[[IF]] : memref<f32>
-  // CHECK: return %[[MEM]] : memref<f32>
-  return %0 : tensor<f32>
-}
-
-// -----
-
-func @failed_to_legalize(%arg0: tensor<f32>) -> tensor<f32> {
-    %0 = constant true
-    cond_br %0, ^bb1(%arg0: tensor<f32>), ^bb2(%arg0: tensor<f32>)
-  ^bb1(%bbarg0: tensor<f32>):
-    // expected-error @+1 {{failed to legalize operation 'test.terminator'}}
-    "test.terminator"() : () -> ()
-  ^bb2(%bbarg1: tensor<f32>):
-    return %bbarg1 : tensor<f32>
-}

diff  --git a/mlir/test/Dialect/Standard/func-bufferize.mlir b/mlir/test/Dialect/Standard/func-bufferize.mlir
index d02db99aecd8..de2f75c4a293 100644
--- a/mlir/test/Dialect/Standard/func-bufferize.mlir
+++ b/mlir/test/Dialect/Standard/func-bufferize.mlir
@@ -1,39 +1,29 @@
-// RUN: mlir-opt %s -func-bufferize -finalizing-bufferize -split-input-file -verify-diagnostics | FileCheck %s
+// RUN: mlir-opt %s -func-bufferize -split-input-file -verify-diagnostics | FileCheck %s
 
 // CHECK-LABEL:   func @identity(
-// CHECK-SAME:        %[[ARG:.*]]: memref<f32>) -> memref<f32> {
-// CHECK:           return %[[ARG]] : memref<f32>
+// CHECK-SAME:                   %[[ARG:.*]]: memref<f32>) -> memref<f32> {
+// CHECK:           %[[TENSOR:.*]] = tensor_load %[[ARG]] : memref<f32>
+// CHECK:           %[[MEMREF:.*]] = tensor_to_memref %[[TENSOR]] : memref<f32>
+// CHECK:           return %[[MEMREF]] : memref<f32>
 func @identity(%arg0: tensor<f32>) -> tensor<f32> {
   return %arg0 : tensor<f32>
 }
 
 // CHECK-LABEL:   func @block_arguments(
 // CHECK-SAME:        %[[ARG:.*]]: memref<f32>) -> memref<f32> {
-// CHECK:           br ^bb1(%[[ARG]] : memref<f32>)
+// CHECK:           %[[T1:.*]] = tensor_load %[[ARG]] : memref<f32>
+// CHECK:           %[[M1:.*]] = tensor_to_memref %[[T1]] : memref<f32>
+// CHECK:           br ^bb1(%[[M1]] : memref<f32>)
 // CHECK:         ^bb1(%[[BBARG:.*]]: memref<f32>):
-// CHECK:           return %[[BBARG]] : memref<f32>
+// CHECK:           %[[T2:.*]] = tensor_load %[[BBARG]] : memref<f32>
+// CHECK:           %[[M2:.*]] = tensor_to_memref %[[T2]] : memref<f32>
+// CHECK:           return %[[M2]] : memref<f32>
 func @block_arguments(%arg0: tensor<f32>) -> tensor<f32> {
   br ^bb1(%arg0: tensor<f32>)
 ^bb1(%bbarg: tensor<f32>):
   return %bbarg : tensor<f32>
 }
 
-// CHECK-LABEL:   func @eliminate_target_materialization(
-// CHECK-SAME:        %[[ARG:.*]]: memref<f32>) -> memref<f32> {
-// CHECK:           return %[[ARG]] : memref<f32>
-func @eliminate_target_materialization(%arg0: tensor<f32>) -> memref<f32> {
-  %0 = tensor_to_memref %arg0 : memref<f32>
-  return %0 : memref<f32>
-}
-
-// CHECK-LABEL:   func @eliminate_source_materialization(
-// CHECK-SAME:        %[[ARG:.*]]: memref<f32>) -> memref<f32> {
-// CHECK:           return %[[ARG]] : memref<f32>
-func @eliminate_source_materialization(%arg0: memref<f32>) -> tensor<f32> {
-  %0 = tensor_load %arg0 : memref<f32>
-  return %0 : tensor<f32>
-}
-
 // CHECK-LABEL:   func private @source() -> memref<f32>
 // CHECK-LABEL:   func @call_source() -> memref<f32> {
 // CHECK:           %[[RET:.*]] = call @source() : () -> memref<f32>
@@ -43,11 +33,11 @@ func @call_source() -> tensor<f32> {
   %0 = call @source() : () -> tensor<f32>
   return %0 : tensor<f32>
 }
-
-// CHECK-LABEL:   func private @sink(memref<f32>)
 // CHECK-LABEL:   func @call_sink(
-// CHECK-SAME:        %[[ARG:.*]]: memref<f32>) {
-// CHECK:           call @sink(%[[ARG]]) : (memref<f32>) -> ()
+// CHECK-SAME:                    %[[ARG:.*]]: memref<f32>) {
+// CHECK:           %[[TENSOR:.*]] = tensor_load %[[ARG]] : memref<f32>
+// CHECK:           %[[MEMREF:.*]] = tensor_to_memref %[[TENSOR]] : memref<f32>
+// CHECK:           call @sink(%[[MEMREF]]) : (memref<f32>) -> ()
 // CHECK:           return
 func private @sink(tensor<f32>)
 func @call_sink(%arg0: tensor<f32>) {
@@ -55,10 +45,25 @@ func @call_sink(%arg0: tensor<f32>) {
   return
 }
 
+// CHECK-LABEL:   func @unconverted_op_in_body() -> memref<f32> {
+// CHECK:           %[[TENSOR:.*]] = "test.source"() : () -> tensor<f32>
+// CHECK:           %[[MEMREF:.*]] = tensor_to_memref %[[TENSOR]] : memref<f32>
+// CHECK:           return %[[MEMREF]] : memref<f32>
+func @unconverted_op_in_body() -> tensor<f32> {
+  %0 = "test.source"() : () -> tensor<f32>
+  return %0 : tensor<f32>
+}
+
 // -----
 
-func @failed_to_legalize() -> tensor<f32> {
-  // expected-error @+1 {{failed to legalize operation 'test.source'}}
-  %0 = "test.source"() : () -> (tensor<f32>)
-  return %0 : tensor<f32>
+// Because this pass updates block arguments, it needs to also atomically
+// update all terminators and issue an error if that is not possible.
+func @unable_to_update_terminator(%arg0: tensor<f32>) -> tensor<f32> {
+    %0 = constant true
+    cond_br %0, ^bb1(%arg0: tensor<f32>), ^bb2(%arg0: tensor<f32>)
+  ^bb1(%bbarg0: tensor<f32>):
+    // expected-error @+1 {{failed to legalize operation 'test.terminator'}}
+    "test.terminator"() : () -> ()
+  ^bb2(%bbarg1: tensor<f32>):
+    return %bbarg1 : tensor<f32>
 }

diff  --git a/mlir/test/Transforms/finalizing-bufferize.mlir b/mlir/test/Transforms/finalizing-bufferize.mlir
new file mode 100644
index 000000000000..5c09664776ea
--- /dev/null
+++ b/mlir/test/Transforms/finalizing-bufferize.mlir
@@ -0,0 +1,28 @@
+// RUN: mlir-opt %s -finalizing-bufferize -split-input-file -verify-diagnostics | FileCheck %s
+
+// CHECK-LABEL:   func @eliminate_materializations(
+// CHECK-SAME:                                     %[[ARG:.*]]: memref<f32>) -> memref<f32> {
+// CHECK:           return %[[ARG]] : memref<f32>
+func @eliminate_materializations(%arg0: memref<f32>) -> memref<f32> {
+  %0 = tensor_load %arg0 : memref<f32>
+  %1 = tensor_to_memref %0 : memref<f32>
+  return %1 : memref<f32>
+}
+
+// -----
+
+func @unable_to_convert_lone_tensor_to_memref() -> memref<f32> {
+  // expected-error @+1 {{failed to legalize operation 'test.source'}}
+  %0 = "test.source"() : () -> tensor<f32>
+  %1 = tensor_to_memref %0 : memref<f32>
+  return %1 : memref<f32>
+}
+
+// -----
+
+func @unable_to_convert_lone_tensor_load(%arg0: memref<f32>) {
+  %0 = tensor_load %arg0 : memref<f32>
+  // expected-error @+1 {{failed to legalize operation 'test.sink'}}
+  "test.sink"(%0) : (tensor<f32>) -> ()
+  return
+}


        


More information about the Mlir-commits mailing list