[Mlir-commits] [mlir] 231b9dd - [mlir][Linalg] Add comprehensive bufferization support for linalg::InitTensor and tensor::CastOp (11/n)
Nicolas Vasilache
llvmlistbot at llvm.org
Thu Jul 1 04:30:39 PDT 2021
Author: Nicolas Vasilache
Date: 2021-07-01T11:26:01Z
New Revision: 231b9dd9de87f15170850e7d752dd6bd19799449
URL: https://github.com/llvm/llvm-project/commit/231b9dd9de87f15170850e7d752dd6bd19799449
DIFF: https://github.com/llvm/llvm-project/commit/231b9dd9de87f15170850e7d752dd6bd19799449.diff
LOG: [mlir][Linalg] Add comprehensive bufferization support for linalg::InitTensor and tensor::CastOp (11/n)
Also add an integration test that connects all the dots end to end, including with cast to unranked tensor for external library calls.
Differential Revision: https://reviews.llvm.org/D105106
Added:
mlir/test/Integration/Dialect/Linalg/CPU/test-comprehensive-bufferize.mlir
Modified:
mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
index 824092df292ca..03191a85e506c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
@@ -357,7 +357,9 @@ static bool hasKnownBufferizationAliasingBehavior(Operation *op) {
return
// clang-format off
isa<CallOpInterface,
+ tensor::CastOp,
scf::ForOp,
+ InitTensorOp,
LinalgOp,
ReturnOp,
ExtractSliceOp,
@@ -418,6 +420,14 @@ static OpResult getInplaceableOpResult(InsertSliceOp op, OpOperand &opOperand) {
return op->getResult(0);
}
+/// Return the OpResult that may bufferize into the same buffer as `opOperand`
+/// when the op is bufferized inplace.
+/// Return null if no such result exists.
+static OpResult getInplaceableOpResult(tensor::CastOp op,
+ OpOperand &opOperand) {
+ return op->getResult(0);
+}
+
/// Return the OpResult that may bufferize into the same buffer as `opOperand`
/// when the op is bufferized inplace.
/// The inplace analysis uses this information along with interfering read
@@ -428,7 +438,8 @@ static OpResult getInplaceableOpResult(OpOperand &opOperand) {
// clang-format off
// Ops that perform destructive updates on operand(s) to produce
// result(s).
- .Case<scf::ForOp,
+ .Case<tensor::CastOp,
+ scf::ForOp,
LinalgOp,
InsertSliceOp,
VectorTransferOpInterface>(
@@ -455,6 +466,7 @@ static Optional<OpOperand *> getAliasingOpOperand(OpResult result) {
if (!hasKnownBufferizationAliasingBehavior(result.getDefiningOp()))
return None;
return TypeSwitch<Operation *, OpOperand *>(result.getDefiningOp())
+ .Case([&](tensor::CastOp op) { return &op->getOpOperand(0); })
.Case([&](LinalgOp op) {
return op.getOutputTensorOperands()[result.getResultNumber()];
})
@@ -1559,6 +1571,35 @@ bufferize(OpBuilder &b, CallOpInterface callOp, BlockAndValueMapping &bvm,
return success();
}
+/// tensor::CastOp bufferizes to memref::CastOp.
+static LogicalResult bufferize(OpBuilder &b, tensor::CastOp castOp,
+ BlockAndValueMapping &bvm,
+ BufferizationAliasInfo &aliasInfo) {
+ // Take a guard before anything else.
+ OpBuilder::InsertionGuard g(b);
+ b.setInsertionPoint(castOp);
+
+ Type sourceType = lookup(bvm, castOp.source()).getType();
+ auto rankedMemRefType = sourceType.dyn_cast<MemRefType>();
+ auto unrankedMemRefType = sourceType.dyn_cast<UnrankedMemRefType>();
+ assert(rankedMemRefType || unrankedMemRefType);
+ unsigned memorySpace = rankedMemRefType
+ ? rankedMemRefType.getMemorySpaceAsInt()
+ : unrankedMemRefType.getMemorySpaceAsInt();
+ TensorType tensorType = castOp.getResult().getType().cast<TensorType>();
+ ArrayRef<AffineMap> affineMaps =
+ rankedMemRefType && tensorType.isa<RankedTensorType>()
+ ? rankedMemRefType.getAffineMaps()
+ : ArrayRef<AffineMap>{};
+ Type memRefType = getContiguousOrUnrankedMemRefType(
+ castOp.getResult().getType(), {}, memorySpace);
+ Value res = b.create<memref::CastOp>(castOp.getLoc(), memRefType,
+ lookup(bvm, castOp.source()));
+ aliasInfo.insertNewBufferEquivalence(res, castOp.getResult());
+ map(bvm, castOp.getResult(), res);
+ return success();
+}
+
/// DimOp tensor operand is modified inplace. This allows leaving dead
/// tensors behind that will get DCE'd.
static LogicalResult bufferize(OpBuilder &b, tensor::DimOp dimOp,
@@ -1635,6 +1676,21 @@ static LogicalResult bufferize(OpBuilder &b, FuncOp funcOp,
return success();
}
+/// InitTensor always allocates.
+/// TODO: consider hoisting across function boundaries prior to bufferization.
+static LogicalResult bufferize(OpBuilder &b, InitTensorOp initTensorOp,
+ BlockAndValueMapping &bvm,
+ BufferizationAliasInfo &aliasInfo) {
+ // Take a guard before anything else.
+ OpBuilder::InsertionGuard g(b);
+ b.setInsertionPoint(initTensorOp);
+
+ Value alloc = createNewAllocDeallocPairForShapedValue(
+ b, initTensorOp->getLoc(), initTensorOp.result(), aliasInfo);
+ map(bvm, initTensorOp.result(), alloc);
+ return success();
+}
+
/// ReturnOp always creates memref::TensorLoadOp.
static LogicalResult bufferize(OpBuilder &b, ReturnOp returnOp,
BlockAndValueMapping &bvm,
@@ -2070,16 +2126,18 @@ static LogicalResult bufferizeFuncOpInternals(
// Since walk has to be PreOrder, we need to erase ops that require it
// separately: this is the case for CallOp
SmallVector<Operation *> toErase;
- WalkResult result = funcOp.walk<WalkOrder::PreOrder>([&](Operation *op)
- -> WalkResult {
- // clang-format off
+ WalkResult result =
+ funcOp.walk<WalkOrder::PreOrder>([&](Operation *op) -> WalkResult {
+ // clang-format off
WalkResult result =
TypeSwitch<Operation *, LogicalResult>(op)
// Skip BufferCast and TensorLoad ops.
.Case<memref::BufferCastOp,
memref::TensorLoadOp>([&](auto) { return success(); })
- .Case<tensor::DimOp,
+ .Case<tensor::CastOp,
+ tensor::DimOp,
scf::ForOp,
+ InitTensorOp,
LinalgOp,
ReturnOp,
ExtractSliceOp,
@@ -2100,16 +2158,16 @@ static LogicalResult bufferizeFuncOpInternals(
return failure();
return success();
});
- // clang-format on
+ // clang-format on
- // Register post-walk erasure, if necessary.
- if (isa<CallOpInterface>(op))
- if (llvm::any_of(op->getOperandTypes(), isaTensor) ||
- llvm::any_of(op->getResultTypes(), isaTensor))
- toErase.push_back(op);
+ // Register post-walk erasure, if necessary.
+ if (isa<CallOpInterface>(op))
+ if (llvm::any_of(op->getOperandTypes(), isaTensor) ||
+ llvm::any_of(op->getResultTypes(), isaTensor))
+ toErase.push_back(op);
- return result;
- });
+ return result;
+ });
LDBG("End BufferizeFuncOpInternals:\n" << funcOp << '\n');
for (Operation *op : toErase)
diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
index 7756587560ead..b71f6f92d51ed 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
@@ -58,3 +58,73 @@ func @bar(
// CHECK-NEXT: return
return %r0#0, %r0#1: tensor<?xf32>, tensor<?xf32>
}
+
+// -----
+
+// CHECK-DAG: #[[$DYN_0D_MAP:.*]] = affine_map<()[s0] -> (s0)>
+// CHECK-DAG: #[[$DYN_1D_MAP:.*]] = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)>
+
+// CHECK: func @init_and_dot(
+// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: memref<64xf32, #[[$DYN_1D_MAP]]>
+// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: memref<64xf32, #[[$DYN_1D_MAP]]>
+// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: memref<f32, #[[$DYN_0D_MAP]]>
+func @init_and_dot(%a: tensor<64xf32>, %b: tensor<64xf32>, %c: tensor<f32>) -> tensor<f32> {
+ // CHECK-NEXT: %[[C0:.*]] = constant 0{{.*}} : f32
+ %v0 = constant 0.0 : f32
+
+ // CHECK-NEXT: linalg.fill(%[[C0]], %[[C]]) : f32, memref<f32, #[[$DYN_0D_MAP]]>
+ %d = linalg.fill(%v0, %c) : f32, tensor<f32> -> tensor<f32>
+
+ // CHECK-NEXT: linalg.dot ins(%[[A]], %[[B]] : memref<64xf32, #[[$DYN_1D_MAP]]>, memref<64xf32, #[[$DYN_1D_MAP]]>) outs(%[[C]] : memref<f32, #[[$DYN_0D_MAP]]>)
+ %e = linalg.dot ins(%a, %b : tensor<64xf32>,tensor<64xf32>)
+ outs(%d: tensor<f32>) -> tensor<f32>
+
+ // CHECK-NEXT: return
+ return %e : tensor<f32>
+}
+
+// CHECK: func @main()
+func @main() {
+ // CHECK-DAG: %[[C0:.*]] = constant 0{{.*}} : f32
+ // CHECK-DAG: %[[C1:.*]] = constant 1{{.*}} : f32
+ // CHECK-DAG: %[[C2:.*]] = constant 2{{.*}} : f32
+ %v0 = constant 0.0 : f32
+ %v1 = constant 1.0 : f32
+ %v2 = constant 2.0 : f32
+
+ // CHECK-NEXT: %[[A:.*]] = memref.alloc() : memref<64xf32>
+ // CHECK-NEXT: %[[B:.*]] = memref.alloc() : memref<64xf32>
+ // CHECK-NEXT: %[[C:.*]] = memref.alloc() : memref<f32>
+ %A = linalg.init_tensor [64] : tensor<64xf32>
+ %B = linalg.init_tensor [64] : tensor<64xf32>
+ %C = linalg.init_tensor [] : tensor<f32>
+
+ // CHECK-NEXT: linalg.fill(%[[C1]], %[[A]]) : f32, memref<64xf32>
+ // CHECK-NEXT: linalg.fill(%[[C2]], %[[B]]) : f32, memref<64xf32>
+ // CHECK-NEXT: linalg.fill(%[[C0]], %[[C]]) : f32, memref<f32>
+ %AA = linalg.fill(%v1, %A) : f32, tensor<64xf32> -> tensor<64xf32>
+ %BB = linalg.fill(%v2, %B) : f32, tensor<64xf32> -> tensor<64xf32>
+ %CC = linalg.fill(%v0, %C) : f32, tensor<f32> -> tensor<f32>
+
+ // CHECK-NEXT: %[[cA:.*]] = memref.cast %[[A]] : memref<64xf32> to memref<64xf32, #[[$DYN_1D_MAP]]>
+ // CHECK-NEXT: %[[cB:.*]] = memref.cast %[[B]] : memref<64xf32> to memref<64xf32, #[[$DYN_1D_MAP]]>
+ // CHECK-NEXT: %[[cC:.*]] = memref.cast %[[C]] : memref<f32> to memref<f32, #[[$DYN_0D_MAP]]>
+ // CHECK-NEXT: call @init_and_dot(%[[cA]], %[[cB]], %[[cC]])
+ %res = call @init_and_dot(%AA, %BB, %CC) :
+ (tensor<64xf32>, tensor<64xf32>, tensor<f32>) -> tensor<f32>
+
+ // CHECK-NEXT: %[[dC:.*]] = memref.cast %[[C]] : memref<f32> to memref<*xf32>
+ %res2 = tensor.cast %res: tensor<f32> to tensor<*xf32>
+
+ // CHECK-NEXT: call @print_memref_f32(%[[dC]]) : (memref<*xf32>) -> ()
+ call @print_memref_f32(%res2) : (tensor<*xf32>) -> ()
+
+ // CHECK-DAG: memref.dealloc %[[A]] : memref<64xf32>
+ // CHECK-DAG: memref.dealloc %[[B]] : memref<64xf32>
+ // CHECK-DAG: memref.dealloc %[[C]] : memref<f32>
+ // CHECK-NEXT: return
+ return
+}
+
+// CHECK: func private @print_memref_f32(memref<*xf32>)
+func private @print_memref_f32(tensor<*xf32>)
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/test-comprehensive-bufferize.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/test-comprehensive-bufferize.mlir
new file mode 100644
index 0000000000000..7a4e134e498f8
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/test-comprehensive-bufferize.mlir
@@ -0,0 +1,44 @@
+// RUN: mlir-opt %s -canonicalize -cse -linalg-comprehensive-module-bufferize |\
+// RUN: mlir-opt -convert-vector-to-scf -lower-affine -convert-linalg-to-loops |\
+// RUN: mlir-opt -canonicalize -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \
+
+// RUN: mlir-cpu-runner -O3 -e main -entry-point-result=void \
+// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext |\
+// RUN: FileCheck %s
+
+func @init_and_dot(%a: tensor<64xf32>, %b: tensor<64xf32>, %c: tensor<f32>) -> tensor<f32> {
+ %v0 = constant 0.0 : f32
+
+ %d = linalg.fill(%v0, %c) : f32, tensor<f32> -> tensor<f32>
+
+ %e = linalg.dot ins(%a, %b : tensor<64xf32>,tensor<64xf32>)
+ outs(%d: tensor<f32>) -> tensor<f32>
+
+ return %e : tensor<f32>
+}
+
+func @main() {
+ %v0 = constant 0.0 : f32
+ %v1 = constant 1.0 : f32
+ %v2 = constant 2.0 : f32
+
+ %A = linalg.init_tensor [64] : tensor<64xf32>
+ %B = linalg.init_tensor [64] : tensor<64xf32>
+ %C = linalg.init_tensor [] : tensor<f32>
+ %AA = linalg.fill(%v1, %A) : f32, tensor<64xf32> -> tensor<64xf32>
+ %BB = linalg.fill(%v2, %B) : f32, tensor<64xf32> -> tensor<64xf32>
+ %CC = linalg.fill(%v0, %C) : f32, tensor<f32> -> tensor<f32>
+
+ %res = call @init_and_dot(%AA, %BB, %CC) :
+ (tensor<64xf32>, tensor<64xf32>, tensor<f32>) -> tensor<f32>
+
+ %res2 = tensor.cast %res: tensor<f32> to tensor<*xf32>
+
+// CHECK: Unranked Memref base@ = {{.*}} rank = 0 offset = 0 sizes = [] strides = [] data =
+// CHECK-NEXT: [128]
+ call @print_memref_f32(%res2) : (tensor<*xf32>) -> ()
+
+ return
+}
+
+func private @print_memref_f32(tensor<*xf32>) attributes { llvm.emit_c_interface }
More information about the Mlir-commits
mailing list