[Mlir-commits] [mlir] d0b282e - [mlir][Linalg] Rewrite PadTensorOp to enable its comprehensive bufferization.
Nicolas Vasilache
llvmlistbot at llvm.org
Wed Jul 7 05:40:56 PDT 2021
Author: Nicolas Vasilache
Date: 2021-07-07T12:39:22Z
New Revision: d0b282e10bc91ea19a9b1a0ca4ed81d0c65e7cd3
URL: https://github.com/llvm/llvm-project/commit/d0b282e10bc91ea19a9b1a0ca4ed81d0c65e7cd3
DIFF: https://github.com/llvm/llvm-project/commit/d0b282e10bc91ea19a9b1a0ca4ed81d0c65e7cd3.diff
LOG: [mlir][Linalg] Rewrite PadTensorOp to enable its comprehensive bufferization.
Add the rewrite of PadTensorOp to InitTensor + InsertSlice before the
bufferization analysis starts.
This is exercised via a more advanced integration test.
Since the new behavior triggers folding, 2 tests need to be updated.
One of those seems to exhibit a folding issue with `switch` and is modified.
Differential Revision: https://reviews.llvm.org/D105549
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir
mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir
mlir/test/Integration/Dialect/Linalg/CPU/test-comprehensive-bufferize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
index cf4ec5228dd5..8c37bebe1000 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
@@ -108,6 +108,7 @@
#include "PassDetail.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Linalg/Passes.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/SCF.h"
@@ -117,6 +118,7 @@
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/BufferUtils.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/Passes.h"
#include "llvm/ADT/DenseSet.h"
@@ -1491,9 +1493,7 @@ bufferize(OpBuilder &b, CallOpInterface callOp, BlockAndValueMapping &bvm,
<< "cannot bufferize bodiless function that returns a tensor";
} else {
ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
- if (!returnOp)
- return funcOp->emitError() << "cannot bufferize a FuncOp with tensors "
- "and without a unique ReturnOp";
+ assert(returnOp && "expected func with single return op");
// For each FuncOp result, keep track of which inplace argument it reuses.
for (OpOperand &returnOperand : returnOp->getOpOperands()) {
@@ -2474,9 +2474,7 @@ static LogicalResult bufferizeFuncOpBoundary(
// Support only single return-terminated block in the function.
ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
- if (!returnOp)
- return funcOp->emitError() << "cannot bufferize a FuncOp with tensors and "
- "without a unique ReturnOp";
+ assert(returnOp && "expected func with single return op");
// 1. For each FuncOp result, keep track of which inplace argument it reuses.
SmallVector<Value> returnValues;
@@ -2574,7 +2572,15 @@ getFuncOpsOrderedByCalls(ModuleOp moduleOp,
DenseMap<FuncOp, DenseSet<FuncOp>> calledBy;
// For each FuncOp, the number of CallOpInterface it contains.
DenseMap<FuncOp, unsigned> numberCallOpsContainedInFuncOp;
- WalkResult res = moduleOp.walk([&](FuncOp funcOp) {
+ WalkResult res = moduleOp.walk([&](FuncOp funcOp) -> WalkResult {
+ if (!funcOp.body().empty()) {
+ ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
+ if (!returnOp)
+ return funcOp->emitError()
+ << "cannot bufferize a FuncOp with tensors and "
+ "without a unique ReturnOp";
+ }
+
numberCallOpsContainedInFuncOp[funcOp] = 0;
return funcOp.walk([&](CallOpInterface callOp) -> WalkResult {
// Only support CallOp for now.
@@ -2622,8 +2628,15 @@ struct LinalgComprehensiveModuleBufferize
};
} // end namespace
+static void applyEnablingTransformations(ModuleOp moduleOp) {
+ RewritePatternSet patterns(moduleOp.getContext());
+ patterns.add<GeneralizePadTensorOpPattern>(moduleOp.getContext());
+ (void)applyPatternsAndFoldGreedily(moduleOp, std::move(patterns));
+}
+
void LinalgComprehensiveModuleBufferize::runOnOperation() {
ModuleOp moduleOp = getOperation();
+ applyEnablingTransformations(moduleOp);
SmallVector<FuncOp> orderedFuncOps;
DenseMap<FuncOp, DenseSet<Operation *>> callerMap;
diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir
index 2dea6fde4f34..2c2a14ead0f3 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir
@@ -485,9 +485,11 @@ func @scf_for_deps(%A : tensor<?xf32> {linalg.inplaceable = true},
// %r0 must be out of place because one use of %t in the subsequent production
// of %r1 is read.
// CHECK: scf.for
+ // CHECK-NEXT: call
// CHECK-NEXT: scf.yield
// CHECK-NEXT: {__inplace_results_attr__ = ["false"]}
%r0 = scf.for %i = %lb to %ub step %step iter_args(%t = %A) -> (tensor<?xf32>) {
+ call @some_use(%t) : (tensor<?xf32>) -> ()
scf.yield %t : tensor<?xf32>
}
@@ -504,11 +506,13 @@ func @scf_for_deps(%A : tensor<?xf32> {linalg.inplaceable = true},
// %r2 must be out of place because one use of %t in the subsequent production
// of %r3 is read.
// CHECK: linalg.tiled_loop
+ // CHECK-NEXT: call
// CHECK-NEXT: linalg.yield
// CHECK-NEXT: {__inplace_results_attr__ = ["false"]}
%r2 = linalg.tiled_loop (%i) = (%lb) to (%ub) step (%step)
ins()
outs(%t = %B: tensor<?xf32>) {
+ call @some_use(%t) : (tensor<?xf32>) -> ()
linalg.yield %t : tensor<?xf32>
}
diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir
index d8257dd172c6..15be096dd86a 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir
@@ -17,18 +17,20 @@ func private @foo() -> tensor<?xf32>
// -----
// expected-error @+1 {{cannot bufferize a FuncOp with tensors and without a unique ReturnOp}}
-func @switch(%flag : i32, %caseOperand : i32, %t1 : tensor<f32>, %t2 : tensor<f32>)
- -> (tensor<f32>)
+func @swappy(%cond1 : i1, %cond2 : i1, %t1 : tensor<f32>, %t2 : tensor<f32>)
+ -> (tensor<f32>, tensor<f32>)
{
- switch %flag : i32, [
- default: ^bb1(%caseOperand : i32),
- 42: ^bb2(%caseOperand : i32)
- ]
-
- ^bb1(%bb1arg : i32):
- return %t1 : tensor<f32>
- ^bb2(%bb2arg : i32):
- return %t2 : tensor<f32>
+ cond_br %cond1, ^bb1, ^bb2
+
+ ^bb1:
+ %T:2 = scf.if %cond2 -> (tensor<f32>, tensor<f32>) {
+ scf.yield %t1, %t2 : tensor<f32>, tensor<f32>
+ } else {
+ scf.yield %t2, %t1 : tensor<f32>, tensor<f32>
+ }
+ return %T#0, %T#1 : tensor<f32>, tensor<f32>
+ ^bb2:
+ return %t2, %t1 : tensor<f32>, tensor<f32>
}
// -----
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/test-comprehensive-bufferize.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/test-comprehensive-bufferize.mlir
index 7a4e134e498f..b0c2baafd94a 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/test-comprehensive-bufferize.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/test-comprehensive-bufferize.mlir
@@ -6,15 +6,73 @@
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext |\
// RUN: FileCheck %s
-func @init_and_dot(%a: tensor<64xf32>, %b: tensor<64xf32>, %c: tensor<f32>) -> tensor<f32> {
- %v0 = constant 0.0 : f32
+#map0 = affine_map<(d0, d1)[s0] -> ((d1 - d0) ceildiv s0)>
+#map1 = affine_map<(d0, d1)[s0] -> ((d0 - d1) ceildiv s0)>
+
+func @init_and_dot(%arg0: tensor<64xf32>, %arg1: tensor<64xf32>, %arg2: tensor<f32> {linalg.inplaceable = true}) -> tensor<f32> {
+ %c64 = constant 64 : index
+ %cst = constant 0.000000e+00 : f32
+ %c2 = constant 2 : index
+ %c0 = constant 0 : index
+ %0 = linalg.fill(%cst, %arg2) : f32, tensor<f32> -> tensor<f32>
+ %1 = affine.apply #map0(%c0, %c64)[%c2]
+ %2 = linalg.init_tensor [%1, 2] : tensor<?x2xf32>
+ %3 = scf.for %arg3 = %c0 to %c64 step %c2 iter_args(%arg4 = %2) -> (tensor<?x2xf32>) {
+ %8 = affine.apply #map1(%arg3, %c0)[%c2]
+ %9 = tensor.extract_slice %arg1[%arg3] [2] [1] : tensor<64xf32> to tensor<2xf32>
+ %10 = tensor.cast %9 : tensor<2xf32> to tensor<?xf32>
+ %11 = linalg.pad_tensor %10 low[%c0] high[%c0] {
+ ^bb0(%arg5: index): // no predecessors
+ linalg.yield %cst : f32
+ } : tensor<?xf32> to tensor<2xf32>
+ %12 = tensor.insert_slice %11 into %arg4[%8, 0] [1, 2] [1, 1] : tensor<2xf32> into tensor<?x2xf32>
+ scf.yield %12 : tensor<?x2xf32>
+ }
+
+ // %B = tensor.cast %3 : tensor<?x2xf32> to tensor<*xf32>
+ // call @print_memref_f32(%B) : (tensor<*xf32>) -> ()
+
+ %4 = affine.apply #map0(%c0, %c64)[%c2]
+ %5 = linalg.init_tensor [%4, 2] : tensor<?x2xf32>
+ %6 = scf.for %arg3 = %c0 to %c64 step %c2 iter_args(%arg4 = %5) -> (tensor<?x2xf32>) {
+ %8 = affine.apply #map1(%arg3, %c0)[%c2]
+ %9 = tensor.extract_slice %arg0[%arg3] [2] [1] : tensor<64xf32> to tensor<2xf32>
+ %10 = tensor.cast %9 : tensor<2xf32> to tensor<?xf32>
+ %11 = linalg.pad_tensor %10 low[%c0] high[%c0] {
+ ^bb0(%arg5: index): // no predecessors
+ linalg.yield %cst : f32
+ } : tensor<?xf32> to tensor<2xf32>
+ %12 = tensor.insert_slice %11 into %arg4[%8, 0] [1, 2] [1, 1] : tensor<2xf32> into tensor<?x2xf32>
+ scf.yield %12 : tensor<?x2xf32>
+ }
+
+ // %A = tensor.cast %6 : tensor<?x2xf32> to tensor<*xf32>
+ // call @print_memref_f32(%A) : (tensor<*xf32>) -> ()
+
+ // %C = tensor.cast %0 : tensor<f32> to tensor<*xf32>
+ // call @print_memref_f32(%C) : (tensor<*xf32>) -> ()
- %d = linalg.fill(%v0, %c) : f32, tensor<f32> -> tensor<f32>
+ %7 = scf.for %arg3 = %c0 to %c64 step %c2 iter_args(%arg4 = %0) -> (tensor<f32>) {
+ %8 = tensor.extract_slice %arg0[%arg3] [2] [1] : tensor<64xf32> to tensor<2xf32>
+ %9 = tensor.cast %8 : tensor<2xf32> to tensor<?xf32>
+ %10 = tensor.extract_slice %arg1[%arg3] [2] [1] : tensor<64xf32> to tensor<2xf32>
+ %11 = tensor.cast %10 : tensor<2xf32> to tensor<?xf32>
+ %12 = affine.apply #map1(%arg3, %c0)[%c2]
+ %13 = tensor.extract_slice %6[%12, 0] [1, 2] [1, 1] : tensor<?x2xf32> to tensor<2xf32>
+ %14 = affine.apply #map1(%arg3, %c0)[%c2]
+ %15 = tensor.extract_slice %3[%14, 0] [1, 2] [1, 1] : tensor<?x2xf32> to tensor<2xf32>
+ %16 = linalg.dot ins(%13, %15 : tensor<2xf32>, tensor<2xf32>) outs(%arg4 : tensor<f32>) -> tensor<f32>
- %e = linalg.dot ins(%a, %b : tensor<64xf32>,tensor<64xf32>)
- outs(%d: tensor<f32>) -> tensor<f32>
+ // %AA = tensor.cast %13 : tensor<2xf32> to tensor<*xf32>
+ // call @print_memref_f32(%AA) : (tensor<*xf32>) -> ()
+ // %BB = tensor.cast %15 : tensor<2xf32> to tensor<*xf32>
+ // call @print_memref_f32(%BB) : (tensor<*xf32>) -> ()
+ // %CC = tensor.cast %16 : tensor<f32> to tensor<*xf32>
+ // call @print_memref_f32(%CC) : (tensor<*xf32>) -> ()
- return %e : tensor<f32>
+ scf.yield %16 : tensor<f32>
+ }
+ return %7 : tensor<f32>
}
func @main() {
More information about the Mlir-commits
mailing list