[llvm-branch-commits] [mlir] caf4f2e - [mlir] Handle unknown ops in dynamic_tensor_from_elements bufferization
Sean Silva via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Tue Dec 15 12:55:51 PST 2020
Author: Sean Silva
Date: 2020-12-15T12:50:56-08:00
New Revision: caf4f2e372a7a4d5d8b5a8733e44f002c6dee0d5
URL: https://github.com/llvm/llvm-project/commit/caf4f2e372a7a4d5d8b5a8733e44f002c6dee0d5
DIFF: https://github.com/llvm/llvm-project/commit/caf4f2e372a7a4d5d8b5a8733e44f002c6dee0d5.diff
LOG: [mlir] Handle unknown ops in dynamic_tensor_from_elements bufferization
Due to how the conversion infra works, the "clone" call that this
pattern was using required all the cloned ops to be immediately
legalized as part of this dialect conversion invocation.
That was previously working due to a couple factors:
- In the test case, there was scf.if, which we happen to mark as legal
as part of marking the entire SCF dialect as legal for the scf.parallel
we generate here.
- Originally, this test case had std.extract_element in the body, which
we happened to have a pattern for in this pass. After I migrated that to
`tensor.extract` (which removed the tensor.extract bufferization from
here), I hacked this up to use `std.dim` which we still have patterns
for in this pass.
This patch updates the test case to use a truly opaque op `test.source`
that properly stresses this aspect of the pattern.
(this also removes a stray dependency on the `tensor` dialect that I
must have left behind as part of my hacking this pass up when migrating
to `tensor.extract`)
Differential Revision: https://reviews.llvm.org/D93262
Added:
Modified:
mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp
mlir/test/Dialect/Standard/bufferize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp b/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp
index 6691355d232c..a84934b0ebb8 100644
--- a/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp
@@ -15,7 +15,6 @@
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/StandardOps/Transforms/Passes.h"
-#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/Transforms/DialectConversion.h"
@@ -70,18 +69,29 @@ class BufferizeDynamicTensorFromElementsOp
upperBounds.push_back(upperBound);
}
- // Generate tensor elements with a parallel loop.
- rewriter.create<scf::ParallelOp>(
- loc, lowerBounds, upperBounds, steps,
- [&](OpBuilder &b, Location loc, ValueRange ivs) {
- BlockAndValueMapping mapping;
- mapping.map(op.body().getArguments(), ivs);
- for (auto &nestedOp : op.getBody()->without_terminator())
- b.clone(nestedOp, mapping);
- auto yieldOp = cast<YieldOp>(op.getBody()->getTerminator());
- b.create<StoreOp>(loc, mapping.lookup(yieldOp.value()), result, ivs);
- b.create<scf::YieldOp>(loc);
- });
+ // Generate tensor elements with a parallel loop that stores into
+ // each element of the resulting memref.
+ //
+ // This is a bit tricky. We cannot simply clone the ops because when an op
+ // is cloned, it must be legalized. However, we want to allow arbitrary ops
+ // in the body that we don't necessarily have legalization patterns for as
+ // part of this dialect conversion invocation.
+ //
+ // To accomplish this, we use mergeBlockBefore to "move" this op's body
+ // into the scf.parallel's body.
+ auto parallel =
+ rewriter.create<scf::ParallelOp>(loc, lowerBounds, upperBounds, steps);
+ Block *parallelBody = parallel.getBody();
+ rewriter.mergeBlockBefore(op.getBody(), parallelBody->getTerminator(),
+ parallelBody->getArguments());
+ // Replace the inlined yield op with a store op. The scf.parallel's builder
+ // already populated an scf.yield at the end, so we don't need to worry
+ // about creating that.
+ Operation *elementYield = parallelBody->getTerminator()->getPrevNode();
+ rewriter.setInsertionPointAfter(elementYield);
+ rewriter.replaceOpWithNewOp<StoreOp>(elementYield,
+ elementYield->getOperands()[0], result,
+ parallelBody->getArguments());
rewriter.replaceOp(op, {result});
return success();
@@ -168,7 +178,6 @@ struct StdBufferizePass : public StdBufferizeBase<StdBufferizePass> {
target.addLegalDialect<StandardOpsDialect>();
target.addLegalDialect<scf::SCFDialect>();
- target.addLegalDialect<tensor::TensorDialect>();
populateStdBufferizePatterns(context, typeConverter, patterns);
target.addIllegalOp<DynamicTensorFromElementsOp, TensorCastOp,
diff --git a/mlir/test/Dialect/Standard/bufferize.mlir b/mlir/test/Dialect/Standard/bufferize.mlir
index 75ff2a9d78f0..8ae10ccf0f3b 100644
--- a/mlir/test/Dialect/Standard/bufferize.mlir
+++ b/mlir/test/Dialect/Standard/bufferize.mlir
@@ -123,20 +123,20 @@ func @tensor_from_elements(%arg0: index, %arg1: index) -> tensor<2xindex> {
return %0 : tensor<2xindex>
}
-// The dynamic_tensor_from_elements op clones each op in its body.
-// Make sure that regions nested within such ops are recursively converted.
-// CHECK-LABEL: func @recursively_convert_cloned_regions
-func @recursively_convert_cloned_regions(%arg0: tensor<*xf32>, %arg1: index, %arg2: i1) -> tensor<?xindex> {
- %tensor = dynamic_tensor_from_elements %arg1 {
+// The dynamic_tensor_from_elements op needs to put its body into the
+// resulting scf.parallel. To handle unknown ops in the body, it cannot clone
+// the body because that would require the cloned ops to be legalized
+// immediately, which is usually not possible since they might be from various
+// other dialects.
+//
+// CHECK-LABEL: func @unknown_ops_in_body
+func @unknown_ops_in_body(%arg0: index) -> tensor<?xindex> {
+ // CHECK-NOT: dynamic_tensor_from_elements
+ %tensor = dynamic_tensor_from_elements %arg0 {
^bb0(%iv: index):
- %48 = scf.if %arg2 -> (index) {
- scf.yield %iv : index
- } else {
- // CHECK-NOT: dim{{.*}}tensor
- %50 = dim %arg0, %iv : tensor<*xf32>
- scf.yield %50 : index
- }
- yield %48 : index
+ // CHECK: test.source
+ %0 = "test.source"() : () -> index
+ yield %0 : index
} : tensor<?xindex>
return %tensor : tensor<?xindex>
}
More information about the llvm-branch-commits
mailing list