[llvm-branch-commits] [mlir] 774f1d3 - [mlir] Small cleanups to func-bufferize/finalizing-bufferize
Sean Silva via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Mon Nov 30 17:09:49 PST 2020
Author: Sean Silva
Date: 2020-11-30T17:04:14-08:00
New Revision: 774f1d3ffd458d6cb82d5039758ef1cf6370957f
URL: https://github.com/llvm/llvm-project/commit/774f1d3ffd458d6cb82d5039758ef1cf6370957f
DIFF: https://github.com/llvm/llvm-project/commit/774f1d3ffd458d6cb82d5039758ef1cf6370957f.diff
LOG: [mlir] Small cleanups to func-bufferize/finalizing-bufferize
- Address TODO in scf-bufferize: the argument materialization issue is
now fixed and the code is now in Transforms/Bufferize.cpp
- Tighten up finalizing-bufferize to avoid creating invalid IR when
operand types potentially change
- Tidy up the testing of func-bufferize, and move appropriate tests
to a new finalizing-bufferize.mlir
- The new stricter checking in finalizing-bufferize revealed that we
needed a DimOp conversion pattern (found when integrating into npcomp).
Previously, the converion infrastructure was blindly changing the
operand type during finalization, which happened to work due to
DimOp's tensor/memref polymorphism, but is generally not encouraged
(the new pattern is the way to tell the conversion infrastructure that
it is legal to change that type).
Added:
mlir/test/Transforms/finalizing-bufferize.mlir
Modified:
mlir/lib/Dialect/SCF/Transforms/Bufferize.cpp
mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp
mlir/lib/Transforms/Bufferize.cpp
mlir/test/Dialect/Standard/bufferize.mlir
mlir/test/Dialect/Standard/func-bufferize.mlir
Removed:
mlir/test/Dialect/Standard/func-bufferize-partial.mlir
################################################################################
diff --git a/mlir/lib/Dialect/SCF/Transforms/Bufferize.cpp b/mlir/lib/Dialect/SCF/Transforms/Bufferize.cpp
index 57d605b3491f7..7cf0dfabd9174 100644
--- a/mlir/lib/Dialect/SCF/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/Bufferize.cpp
@@ -27,21 +27,6 @@ struct SCFBufferizePass : public SCFBufferizeBase<SCFBufferizePass> {
OwningRewritePatternList patterns;
ConversionTarget target(*context);
- // TODO: Move this to BufferizeTypeConverter's constructor.
- //
- // This doesn't currently play well with "finalizing" bufferizations (ones
- // that expect all materializations to be gone). In particular, there seems
- // to at least be a double-free in the dialect conversion framework
- // when this materialization gets inserted and then folded away because
- // it is marked as illegal.
- typeConverter.addArgumentMaterialization(
- [](OpBuilder &builder, RankedTensorType type, ValueRange inputs,
- Location loc) -> Value {
- assert(inputs.size() == 1);
- assert(inputs[0].getType().isa<BaseMemRefType>());
- return builder.create<TensorLoadOp>(loc, type, inputs[0]);
- });
-
populateBufferizeMaterializationLegality(target);
populateSCFStructuralTypeConversionsAndLegality(context, typeConverter,
patterns, target);
diff --git a/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp b/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp
index 9056fbc25e14d..8b47e88677e2d 100644
--- a/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp
@@ -20,6 +20,21 @@
using namespace mlir;
+namespace {
+class BufferizeDimOp : public OpConversionPattern<DimOp> {
+public:
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(DimOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ DimOp::Adaptor adaptor(operands);
+ rewriter.replaceOpWithNewOp<DimOp>(op, adaptor.memrefOrTensor(),
+ adaptor.index());
+ return success();
+ }
+};
+} // namespace
+
namespace {
class BufferizeDynamicTensorFromElementsOp
: public OpConversionPattern<DynamicTensorFromElementsOp> {
@@ -148,6 +163,7 @@ void mlir::populateStdBufferizePatterns(MLIRContext *context,
OwningRewritePatternList &patterns) {
patterns.insert<
// clang-format off
+ BufferizeDimOp,
BufferizeDynamicTensorFromElementsOp,
BufferizeExtractElementOp,
BufferizeSelectOp,
@@ -178,6 +194,8 @@ struct StdBufferizePass : public StdBufferizeBase<StdBufferizePass> {
return typeConverter.isLegal(op.getType()) ||
!op.condition().getType().isa<IntegerType>();
});
+ target.addDynamicallyLegalOp<DimOp>(
+ [&](DimOp op) { return typeConverter.isLegal(op); });
if (failed(
applyPartialConversion(getFunction(), target, std::move(patterns))))
signalPassFailure();
diff --git a/mlir/lib/Transforms/Bufferize.cpp b/mlir/lib/Transforms/Bufferize.cpp
index 1811ac8bdfbca..66b1cc65646c1 100644
--- a/mlir/lib/Transforms/Bufferize.cpp
+++ b/mlir/lib/Transforms/Bufferize.cpp
@@ -105,13 +105,17 @@ struct FinalizingBufferizePass
populateEliminateBufferizeMaterializationsPatterns(context, typeConverter,
patterns);
- target.addIllegalOp<TensorLoadOp, TensorToMemrefOp>();
// If all result types are legal, and all block arguments are legal (ensured
// by func conversion above), then all types in the program are legal.
- target.markUnknownOpDynamicallyLegal([&](Operation *op) {
- return typeConverter.isLegal(op->getResultTypes());
- });
+ //
+ // We also check that the operand types are legal to avoid creating invalid
+ // IR. For example, this prevents
+ // populateEliminateBufferizeMaterializationsPatterns from updating the
+ // types of the operands to a return op without updating the enclosing
+ // function.
+ target.markUnknownOpDynamicallyLegal(
+ [&](Operation *op) { return typeConverter.isLegal(op); });
if (failed(applyFullConversion(func, target, std::move(patterns))))
signalPassFailure();
diff --git a/mlir/test/Dialect/Standard/bufferize.mlir b/mlir/test/Dialect/Standard/bufferize.mlir
index 8cc05ff20644b..27769c52d9ea4 100644
--- a/mlir/test/Dialect/Standard/bufferize.mlir
+++ b/mlir/test/Dialect/Standard/bufferize.mlir
@@ -1,5 +1,16 @@
// RUN: mlir-opt %s -std-bufferize | FileCheck %s
+// CHECK-LABEL: func @dim(
+// CHECK-SAME: %[[TENSOR:.*]]: tensor<f32>,
+// CHECK-SAME: %[[INDEX:.*]]: index) -> index {
+// CHECK: %[[MEMREF:.*]] = tensor_to_memref %[[TENSOR]] : memref<f32>
+// CHECK: %[[EXTENT:.*]] = dim %[[MEMREF]], %[[INDEX]] : memref<f32>
+// CHECK: return %[[EXTENT]] : index
+func @dim(%arg0: tensor<f32>, %arg1: index) -> index {
+ %0 = dim %arg0, %arg1 : tensor<f32>
+ return %0 : index
+}
+
// CHECK-LABEL: func @dynamic_tensor_from_elements(
// CHECK-SAME: %[[ARG:.*]]: tensor<*xf32>,
// CHECK-SAME: %[[DYNAMIC_EXTENT:.*]]: index) -> tensor<?xindex> {
@@ -7,7 +18,8 @@
// CHECK: %[[C0:.*]] = constant 0 : index
// CHECK: %[[C1:.*]] = constant 1 : index
// CHECK: scf.parallel (%[[I:.*]]) = (%[[C0]]) to (%[[DYNAMIC_EXTENT]]) step (%[[C1]]) {
-// CHECK: %[[ELEM:.*]] = dim %[[ARG]], %[[I]] : tensor<*xf32>
+// CHECK: %[[ARG_MEMREF:.*]] = tensor_to_memref %[[ARG]] : memref<*xf32>
+// CHECK: %[[ELEM:.*]] = dim %[[ARG_MEMREF]], %[[I]] : memref<*xf32>
// CHECK: store %[[ELEM]], %[[MEMREF]][%[[I]]] : memref<?xindex>
// CHECK: scf.yield
// CHECK: }
diff --git a/mlir/test/Dialect/Standard/func-bufferize-partial.mlir b/mlir/test/Dialect/Standard/func-bufferize-partial.mlir
deleted file mode 100644
index 43ea4591e4e35..0000000000000
--- a/mlir/test/Dialect/Standard/func-bufferize-partial.mlir
+++ /dev/null
@@ -1,59 +0,0 @@
-// RUN: mlir-opt %s -func-bufferize -split-input-file -verify-diagnostics | FileCheck %s
-
-// CHECK-LABEL: func @block_arguments(
-// CHECK-SAME: %[[ARG:.*]]: memref<f32>) -> memref<f32> {
-// CHECK: %[[T1:.*]] = tensor_load %[[ARG]] : memref<f32>
-// CHECK: %[[M1:.*]] = tensor_to_memref %[[T1]] : memref<f32>
-// CHECK: br ^bb1(%[[M1]] : memref<f32>)
-// CHECK: ^bb1(%[[BBARG:.*]]: memref<f32>):
-// CHECK: %[[T2:.*]] = tensor_load %[[BBARG]] : memref<f32>
-// CHECK: %[[M2:.*]] = tensor_to_memref %[[T2]] : memref<f32>
-// CHECK: return %[[M2]] : memref<f32>
-func @block_arguments(%arg0: tensor<f32>) -> tensor<f32> {
- br ^bb1(%arg0: tensor<f32>)
-^bb1(%bbarg: tensor<f32>):
- return %bbarg : tensor<f32>
-}
-
-// CHECK-LABEL: func @partial()
-// CHECK-SAME: memref<f32>
-func @partial() -> tensor<f32> {
- // CHECK-NEXT: %[[SRC:.*]] = "test.source"() : () -> tensor<f32>
- // CHECK-NEXT: %[[MEM:.*]] = tensor_to_memref %[[SRC]] : memref<f32>
- %0 = "test.source"() : () -> tensor<f32>
- // CHECK-NEXT: return %[[MEM]] : memref<f32>
- return %0 : tensor<f32>
-}
-
-// CHECK-LABEL: func @region_op
-// CHECK-SAME: (%[[ARG0:.*]]: i1) -> memref<f32>
-func @region_op(%arg0: i1) -> tensor<f32> {
- // CHECK-NEXT: %[[IF:.*]] = scf.if %[[ARG0]] -> (tensor<f32>)
- %0 = scf.if %arg0 -> (tensor<f32>) {
- // CHECK-NEXT: %[[SRC:.*]] = "test.source"() : () -> tensor<f32>
- %1 = "test.source"() : () -> tensor<f32>
- // CHECK-NEXT: scf.yield %[[SRC]] : tensor<f32>
- scf.yield %1 : tensor<f32>
- // CHECK-NEXT: else
- } else {
- // CHECK-NEXT: %[[OSRC:.*]] = "test.other_source"() : () -> tensor<f32>
- %1 = "test.other_source"() : () -> tensor<f32>
- // CHECK-NEXT: scf.yield %[[OSRC]] : tensor<f32>
- scf.yield %1 : tensor<f32>
- }
- // CHECK: %[[MEM:.*]] = tensor_to_memref %[[IF]] : memref<f32>
- // CHECK: return %[[MEM]] : memref<f32>
- return %0 : tensor<f32>
-}
-
-// -----
-
-func @failed_to_legalize(%arg0: tensor<f32>) -> tensor<f32> {
- %0 = constant true
- cond_br %0, ^bb1(%arg0: tensor<f32>), ^bb2(%arg0: tensor<f32>)
- ^bb1(%bbarg0: tensor<f32>):
- // expected-error @+1 {{failed to legalize operation 'test.terminator'}}
- "test.terminator"() : () -> ()
- ^bb2(%bbarg1: tensor<f32>):
- return %bbarg1 : tensor<f32>
-}
diff --git a/mlir/test/Dialect/Standard/func-bufferize.mlir b/mlir/test/Dialect/Standard/func-bufferize.mlir
index d02db99aecd83..de2f75c4a293b 100644
--- a/mlir/test/Dialect/Standard/func-bufferize.mlir
+++ b/mlir/test/Dialect/Standard/func-bufferize.mlir
@@ -1,39 +1,29 @@
-// RUN: mlir-opt %s -func-bufferize -finalizing-bufferize -split-input-file -verify-diagnostics | FileCheck %s
+// RUN: mlir-opt %s -func-bufferize -split-input-file -verify-diagnostics | FileCheck %s
// CHECK-LABEL: func @identity(
-// CHECK-SAME: %[[ARG:.*]]: memref<f32>) -> memref<f32> {
-// CHECK: return %[[ARG]] : memref<f32>
+// CHECK-SAME: %[[ARG:.*]]: memref<f32>) -> memref<f32> {
+// CHECK: %[[TENSOR:.*]] = tensor_load %[[ARG]] : memref<f32>
+// CHECK: %[[MEMREF:.*]] = tensor_to_memref %[[TENSOR]] : memref<f32>
+// CHECK: return %[[MEMREF]] : memref<f32>
func @identity(%arg0: tensor<f32>) -> tensor<f32> {
return %arg0 : tensor<f32>
}
// CHECK-LABEL: func @block_arguments(
// CHECK-SAME: %[[ARG:.*]]: memref<f32>) -> memref<f32> {
-// CHECK: br ^bb1(%[[ARG]] : memref<f32>)
+// CHECK: %[[T1:.*]] = tensor_load %[[ARG]] : memref<f32>
+// CHECK: %[[M1:.*]] = tensor_to_memref %[[T1]] : memref<f32>
+// CHECK: br ^bb1(%[[M1]] : memref<f32>)
// CHECK: ^bb1(%[[BBARG:.*]]: memref<f32>):
-// CHECK: return %[[BBARG]] : memref<f32>
+// CHECK: %[[T2:.*]] = tensor_load %[[BBARG]] : memref<f32>
+// CHECK: %[[M2:.*]] = tensor_to_memref %[[T2]] : memref<f32>
+// CHECK: return %[[M2]] : memref<f32>
func @block_arguments(%arg0: tensor<f32>) -> tensor<f32> {
br ^bb1(%arg0: tensor<f32>)
^bb1(%bbarg: tensor<f32>):
return %bbarg : tensor<f32>
}
-// CHECK-LABEL: func @eliminate_target_materialization(
-// CHECK-SAME: %[[ARG:.*]]: memref<f32>) -> memref<f32> {
-// CHECK: return %[[ARG]] : memref<f32>
-func @eliminate_target_materialization(%arg0: tensor<f32>) -> memref<f32> {
- %0 = tensor_to_memref %arg0 : memref<f32>
- return %0 : memref<f32>
-}
-
-// CHECK-LABEL: func @eliminate_source_materialization(
-// CHECK-SAME: %[[ARG:.*]]: memref<f32>) -> memref<f32> {
-// CHECK: return %[[ARG]] : memref<f32>
-func @eliminate_source_materialization(%arg0: memref<f32>) -> tensor<f32> {
- %0 = tensor_load %arg0 : memref<f32>
- return %0 : tensor<f32>
-}
-
// CHECK-LABEL: func private @source() -> memref<f32>
// CHECK-LABEL: func @call_source() -> memref<f32> {
// CHECK: %[[RET:.*]] = call @source() : () -> memref<f32>
@@ -43,11 +33,11 @@ func @call_source() -> tensor<f32> {
%0 = call @source() : () -> tensor<f32>
return %0 : tensor<f32>
}
-
-// CHECK-LABEL: func private @sink(memref<f32>)
// CHECK-LABEL: func @call_sink(
-// CHECK-SAME: %[[ARG:.*]]: memref<f32>) {
-// CHECK: call @sink(%[[ARG]]) : (memref<f32>) -> ()
+// CHECK-SAME: %[[ARG:.*]]: memref<f32>) {
+// CHECK: %[[TENSOR:.*]] = tensor_load %[[ARG]] : memref<f32>
+// CHECK: %[[MEMREF:.*]] = tensor_to_memref %[[TENSOR]] : memref<f32>
+// CHECK: call @sink(%[[MEMREF]]) : (memref<f32>) -> ()
// CHECK: return
func private @sink(tensor<f32>)
func @call_sink(%arg0: tensor<f32>) {
@@ -55,10 +45,25 @@ func @call_sink(%arg0: tensor<f32>) {
return
}
+// CHECK-LABEL: func @unconverted_op_in_body() -> memref<f32> {
+// CHECK: %[[TENSOR:.*]] = "test.source"() : () -> tensor<f32>
+// CHECK: %[[MEMREF:.*]] = tensor_to_memref %[[TENSOR]] : memref<f32>
+// CHECK: return %[[MEMREF]] : memref<f32>
+func @unconverted_op_in_body() -> tensor<f32> {
+ %0 = "test.source"() : () -> tensor<f32>
+ return %0 : tensor<f32>
+}
+
// -----
-func @failed_to_legalize() -> tensor<f32> {
- // expected-error @+1 {{failed to legalize operation 'test.source'}}
- %0 = "test.source"() : () -> (tensor<f32>)
- return %0 : tensor<f32>
+// Because this pass updates block arguments, it needs to also atomically
+// update all terminators and issue an error if that is not possible.
+func @unable_to_update_terminator(%arg0: tensor<f32>) -> tensor<f32> {
+ %0 = constant true
+ cond_br %0, ^bb1(%arg0: tensor<f32>), ^bb2(%arg0: tensor<f32>)
+ ^bb1(%bbarg0: tensor<f32>):
+ // expected-error @+1 {{failed to legalize operation 'test.terminator'}}
+ "test.terminator"() : () -> ()
+ ^bb2(%bbarg1: tensor<f32>):
+ return %bbarg1 : tensor<f32>
}
diff --git a/mlir/test/Transforms/finalizing-bufferize.mlir b/mlir/test/Transforms/finalizing-bufferize.mlir
new file mode 100644
index 0000000000000..5c09664776ead
--- /dev/null
+++ b/mlir/test/Transforms/finalizing-bufferize.mlir
@@ -0,0 +1,28 @@
+// RUN: mlir-opt %s -finalizing-bufferize -split-input-file -verify-diagnostics | FileCheck %s
+
+// CHECK-LABEL: func @eliminate_materializations(
+// CHECK-SAME: %[[ARG:.*]]: memref<f32>) -> memref<f32> {
+// CHECK: return %[[ARG]] : memref<f32>
+func @eliminate_materializations(%arg0: memref<f32>) -> memref<f32> {
+ %0 = tensor_load %arg0 : memref<f32>
+ %1 = tensor_to_memref %0 : memref<f32>
+ return %1 : memref<f32>
+}
+
+// -----
+
+func @unable_to_convert_lone_tensor_to_memref() -> memref<f32> {
+ // expected-error @+1 {{failed to legalize operation 'test.source'}}
+ %0 = "test.source"() : () -> tensor<f32>
+ %1 = tensor_to_memref %0 : memref<f32>
+ return %1 : memref<f32>
+}
+
+// -----
+
+func @unable_to_convert_lone_tensor_load(%arg0: memref<f32>) {
+ %0 = tensor_load %arg0 : memref<f32>
+ // expected-error @+1 {{failed to legalize operation 'test.sink'}}
+ "test.sink"(%0) : (tensor<f32>) -> ()
+ return
+}
More information about the llvm-branch-commits
mailing list