[Mlir-commits] [mlir] fa7c8cb - [mlir][bufferize] Support memrefs with non-standard layout in `finalizing-bufferize`
Matthias Springer
llvmlistbot at llvm.org
Fri Feb 18 02:34:16 PST 2022
Author: Matthias Springer
Date: 2022-02-18T19:34:04+09:00
New Revision: fa7c8cb4d01e9f24816c43d5c44a7fb62564ebc5
URL: https://github.com/llvm/llvm-project/commit/fa7c8cb4d01e9f24816c43d5c44a7fb62564ebc5
DIFF: https://github.com/llvm/llvm-project/commit/fa7c8cb4d01e9f24816c43d5c44a7fb62564ebc5.diff
LOG: [mlir][bufferize] Support memrefs with non-standard layout in `finalizing-bufferize`
Differential Revision: https://reviews.llvm.org/D119935
Added:
Modified:
mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h
mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
mlir/test/Dialect/Bufferization/Transforms/finalizing-bufferize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h b/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h
index 2cbfc901f239..d3e6a5c7f5e3 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h
@@ -27,4 +27,26 @@
#define GET_OP_CLASSES
#include "mlir/Dialect/Bufferization/IR/BufferizationOps.h.inc"
+//===----------------------------------------------------------------------===//
+// Helper functions
+//===----------------------------------------------------------------------===//
+
+namespace mlir {
+namespace bufferization {
+/// Try to cast the given ranked MemRef-typed value to the given ranked MemRef
+/// type. Insert a reallocation + copy if it cannot be statically guaranteed
+/// that a direct cast would be valid.
+///
+/// E.g., when casting from a ranked MemRef type with dynamic layout to a ranked
+/// MemRef type with static layout, it is not statically known whether the cast
+/// will succeed or not. Such `memref.cast` ops may fail at runtime. This
+/// function never generates such casts and conservatively inserts a copy.
+///
+/// This function returns `failure()` in case of unsupported casts. E.g., casts
+/// with
diff ering element types or memory spaces.
+FailureOr<Value> castOrReallocMemRefValue(OpBuilder &b, Value value,
+ MemRefType type);
+} // namespace bufferization
+} // namespace mlir
+
#endif // MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZATION_H_
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index f1ec7bbdead2..c5a99d820bc9 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -1,4 +1,3 @@
-
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
@@ -13,6 +12,73 @@
using namespace mlir;
using namespace mlir::bufferization;
+//===----------------------------------------------------------------------===//
+// Helper functions
+//===----------------------------------------------------------------------===//
+
+FailureOr<Value>
+mlir::bufferization::castOrReallocMemRefValue(OpBuilder &b, Value value,
+ MemRefType destType) {
+ auto srcType = value.getType().cast<MemRefType>();
+
+ // Casting to the same type, nothing to do.
+ if (srcType == destType)
+ return value;
+
+ // Element type, rank and memory space must match.
+ if (srcType.getElementType() != destType.getElementType())
+ return failure();
+ if (srcType.getMemorySpaceAsInt() != destType.getMemorySpaceAsInt())
+ return failure();
+ if (srcType.getRank() != destType.getRank())
+ return failure();
+
+ // In case the affine maps are
diff erent, we may need to use a copy if we go
+ // from dynamic to static offset or stride (the canonicalization cannot know
+ // at this point that it is really cast compatible).
+ auto isGuaranteedCastCompatible = [](MemRefType source, MemRefType target) {
+ int64_t sourceOffset, targetOffset;
+ SmallVector<int64_t, 4> sourceStrides, targetStrides;
+ if (failed(getStridesAndOffset(source, sourceStrides, sourceOffset)) ||
+ failed(getStridesAndOffset(target, targetStrides, targetOffset)))
+ return false;
+ auto dynamicToStatic = [](int64_t a, int64_t b) {
+ return a == MemRefType::getDynamicStrideOrOffset() &&
+ b != MemRefType::getDynamicStrideOrOffset();
+ };
+ if (dynamicToStatic(sourceOffset, targetOffset))
+ return false;
+ for (auto it : zip(sourceStrides, targetStrides))
+ if (dynamicToStatic(std::get<0>(it), std::get<1>(it)))
+ return false;
+ return true;
+ };
+
+ // Note: If `areCastCompatible`, a cast is valid, but may fail at runtime. To
+ // ensure that we only generate casts that always succeed at runtime, we check
+ // a fix extra conditions in `isGuaranteedCastCompatible`.
+ if (memref::CastOp::areCastCompatible(srcType, destType) &&
+ isGuaranteedCastCompatible(srcType, destType)) {
+ Value casted = b.create<memref::CastOp>(value.getLoc(), destType, value);
+ return casted;
+ }
+
+ auto loc = value.getLoc();
+ SmallVector<Value, 4> dynamicOperands;
+ for (int i = 0; i < destType.getRank(); ++i) {
+ if (destType.getShape()[i] != ShapedType::kDynamicSize)
+ continue;
+ auto index = b.createOrFold<arith::ConstantIndexOp>(loc, i);
+ Value size = b.create<memref::DimOp>(loc, value, index);
+ dynamicOperands.push_back(size);
+ }
+ // TODO: Use alloc/memcpy callback from BufferizationOptions if called via
+ // BufferizableOpInterface impl of ToMemrefOp.
+ Value copy = b.create<memref::AllocOp>(loc, destType, dynamicOperands);
+ b.create<memref::CopyOp>(loc, value, copy);
+ return copy;
+}
+
//===----------------------------------------------------------------------===//
// CloneOp
//===----------------------------------------------------------------------===//
@@ -191,67 +257,39 @@ static LogicalResult foldToMemrefToTensorPair(RewriterBase &rewriter,
if (!memrefToTensor)
return failure();
- // A memref_to_tensor + tensor_to_memref with same types can be folded without
- // inserting a cast.
- if (memrefToTensor.memref().getType() == toMemref.getType()) {
- if (!allowSameType)
- // Function can be configured to only handle cases where a cast is needed.
+ Type srcType = memrefToTensor.memref().getType();
+ Type destType = toMemref.getType();
+
+ // Function can be configured to only handle cases where a cast is needed.
+ if (!allowSameType && srcType == destType)
+ return failure();
+
+ auto rankedSrcType = srcType.dyn_cast<MemRefType>();
+ auto rankedDestType = destType.dyn_cast<MemRefType>();
+ auto unrankedSrcType = srcType.dyn_cast<UnrankedMemRefType>();
+
+ // Ranked memref -> Ranked memref cast.
+ if (rankedSrcType && rankedDestType) {
+ FailureOr<Value> replacement = castOrReallocMemRefValue(
+ rewriter, memrefToTensor.memref(), rankedDestType);
+ if (failed(replacement))
return failure();
- rewriter.replaceOp(toMemref, memrefToTensor.memref());
+
+ rewriter.replaceOp(toMemref, *replacement);
return success();
}
- // If types are definitely not cast-compatible, bail.
- if (!memref::CastOp::areCastCompatible(memrefToTensor.memref().getType(),
- toMemref.getType()))
+ // Unranked memref -> Ranked memref cast: May require a copy.
+ // TODO: Not implemented at the moment.
+ if (unrankedSrcType && rankedDestType)
return failure();
- // We already know that the types are potentially cast-compatible. However
- // in case the affine maps are
diff erent, we may need to use a copy if we go
- // from dynamic to static offset or stride (the canonicalization cannot know
- // at this point that it is really cast compatible).
- auto isGuaranteedCastCompatible = [](MemRefType source, MemRefType target) {
- int64_t sourceOffset, targetOffset;
- SmallVector<int64_t, 4> sourceStrides, targetStrides;
- if (failed(getStridesAndOffset(source, sourceStrides, sourceOffset)) ||
- failed(getStridesAndOffset(target, targetStrides, targetOffset)))
- return false;
- auto dynamicToStatic = [](int64_t a, int64_t b) {
- return a == MemRefType::getDynamicStrideOrOffset() &&
- b != MemRefType::getDynamicStrideOrOffset();
- };
- if (dynamicToStatic(sourceOffset, targetOffset))
- return false;
- for (auto it : zip(sourceStrides, targetStrides))
- if (dynamicToStatic(std::get<0>(it), std::get<1>(it)))
- return false;
- return true;
- };
-
- auto memrefToTensorType =
- memrefToTensor.memref().getType().dyn_cast<MemRefType>();
- auto toMemrefType = toMemref.getType().dyn_cast<MemRefType>();
- if (memrefToTensorType && toMemrefType &&
- !isGuaranteedCastCompatible(memrefToTensorType, toMemrefType)) {
- MemRefType resultType = toMemrefType;
- auto loc = toMemref.getLoc();
- SmallVector<Value, 4> dynamicOperands;
- for (int i = 0; i < resultType.getRank(); ++i) {
- if (resultType.getShape()[i] != ShapedType::kDynamicSize)
- continue;
- auto index = rewriter.createOrFold<arith::ConstantIndexOp>(loc, i);
- Value size = rewriter.create<tensor::DimOp>(loc, memrefToTensor, index);
- dynamicOperands.push_back(size);
- }
- // TODO: Use alloc/memcpy callback from BufferizationOptions if called via
- // BufferizableOpInterface impl of ToMemrefOp.
- auto copy =
- rewriter.create<memref::AllocOp>(loc, resultType, dynamicOperands);
- rewriter.create<memref::CopyOp>(loc, memrefToTensor.memref(), copy);
- rewriter.replaceOp(toMemref, {copy});
- } else
- rewriter.replaceOpWithNewOp<memref::CastOp>(toMemref, toMemref.getType(),
- memrefToTensor.memref());
+ // Unranked memref -> unranked memref cast
+ // Ranked memref -> unranked memref cast: No copy needed.
+ assert(memref::CastOp::areCastCompatible(srcType, destType) &&
+ "expected that types are cast compatible");
+ rewriter.replaceOpWithNewOp<memref::CastOp>(toMemref, destType,
+ memrefToTensor.memref());
return success();
}
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
index c7468da6132f..01b22264e5ba 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
@@ -45,9 +45,28 @@ BufferizeTypeConverter::BufferizeTypeConverter() {
addSourceMaterialization(materializeToTensor);
addTargetMaterialization([](OpBuilder &builder, BaseMemRefType type,
ValueRange inputs, Location loc) -> Value {
- assert(inputs.size() == 1);
- assert(inputs[0].getType().isa<TensorType>());
- return builder.create<bufferization::ToMemrefOp>(loc, type, inputs[0]);
+ assert(inputs.size() == 1 && "expected exactly one input");
+
+ if (auto inputType = inputs[0].getType().dyn_cast<MemRefType>()) {
+ // MemRef to MemRef cast.
+ assert(inputType != type && "expected
diff erent types");
+ // Unranked to ranked and ranked to unranked casts must be explicit.
+ auto rankedDestType = type.dyn_cast<MemRefType>();
+ if (!rankedDestType)
+ return nullptr;
+ FailureOr<Value> replacement =
+ castOrReallocMemRefValue(builder, inputs[0], rankedDestType);
+ if (failed(replacement))
+ return nullptr;
+ return *replacement;
+ }
+
+ if (inputs[0].getType().isa<TensorType>()) {
+ // Tensor to MemRef cast.
+ return builder.create<bufferization::ToMemrefOp>(loc, type, inputs[0]);
+ }
+
+ llvm_unreachable("only tensor/memref input types supported");
});
}
diff --git a/mlir/test/Dialect/Bufferization/Transforms/finalizing-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/finalizing-bufferize.mlir
index fac685ae7e72..5d70e90b7540 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/finalizing-bufferize.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/finalizing-bufferize.mlir
@@ -26,3 +26,77 @@ func @unable_to_convert_lone_tensor_load(%arg0: memref<f32>) {
"test.sink"(%0) : (tensor<f32>) -> ()
return
}
+
+// -----
+
+// CHECK: #[[$map1:.*]] = affine_map<(d0)[s0] -> (d0 + s0)>
+// CHECK-LABEL: func @dyn_layout_to_no_layout_cast(
+// CHECK-SAME: %[[arg:.*]]: memref<?xf32, #[[$map1]]>)
+// CHECK: %[[c0:.*]] = arith.constant 0 : index
+// CHECK: %[[dim:.*]] = memref.dim %[[arg]], %[[c0]]
+// CHECK: %[[alloc:.*]] = memref.alloc(%[[dim]]) : memref<?xf32>
+// CHECK: memref.copy %[[arg]], %[[alloc]]
+// CHECK: return %[[alloc]]
+#map1 = affine_map<(d0)[s0] -> (d0 + s0)>
+func @dyn_layout_to_no_layout_cast(%m: memref<?xf32, #map1>) -> memref<?xf32> {
+ %0 = bufferization.to_tensor %m : memref<?xf32, #map1>
+ %1 = bufferization.to_memref %0 : memref<?xf32>
+ return %1 : memref<?xf32>
+}
+
+// -----
+
+// CHECK: #[[$map2:.*]] = affine_map<(d0)[s0] -> (d0 * 100 + s0)>
+// CHECK-LABEL: func @fancy_layout_to_no_layout_cast(
+// CHECK-SAME: %[[arg:.*]]: memref<?xf32, #[[$map2]]>)
+// CHECK: %[[c0:.*]] = arith.constant 0 : index
+// CHECK: %[[dim:.*]] = memref.dim %[[arg]], %[[c0]]
+// CHECK: %[[alloc:.*]] = memref.alloc(%[[dim]]) : memref<?xf32>
+// CHECK: memref.copy %[[arg]], %[[alloc]]
+// CHECK: return %[[alloc]]
+#map2 = affine_map<(d0)[s0] -> (d0 * 100 + s0)>
+func @fancy_layout_to_no_layout_cast(%m: memref<?xf32, #map2>) -> memref<?xf32> {
+ %0 = bufferization.to_tensor %m : memref<?xf32, #map2>
+ %1 = bufferization.to_memref %0 : memref<?xf32>
+ return %1 : memref<?xf32>
+}
+
+// -----
+
+// CHECK: #[[$map3:.*]] = affine_map<(d0)[s0] -> (d0 + 25)>
+// CHECK-LABEL: func @static_layout_to_no_layout_cast(
+// CHECK-SAME: %[[arg:.*]]: memref<?xf32, #[[$map3]]>)
+// CHECK: %[[c0:.*]] = arith.constant 0 : index
+// CHECK: %[[dim:.*]] = memref.dim %[[arg]], %[[c0]]
+// CHECK: %[[alloc:.*]] = memref.alloc(%[[dim]]) : memref<?xf32>
+// CHECK: memref.copy %[[arg]], %[[alloc]]
+// CHECK: return %[[alloc]]
+#map3 = affine_map<(d0)[s0] -> (d0 + 25)>
+func @static_layout_to_no_layout_cast(%m: memref<?xf32, #map3>) -> memref<?xf32> {
+ %0 = bufferization.to_tensor %m : memref<?xf32, #map3>
+ %1 = bufferization.to_memref %0 : memref<?xf32>
+ return %1 : memref<?xf32>
+}
+
+// -----
+
+// TODO: to_memref with layout maps not supported yet. This should fold to a
+// memref.cast.
+#map4 = affine_map<(d0)[s0] -> (d0 + s0)>
+func @no_layout_to_dyn_layout_cast(%m: memref<?xf32>) -> memref<?xf32, #map4> {
+ %0 = bufferization.to_tensor %m : memref<?xf32>
+ // expected-error @+1 {{failed to materialize conversion for result #0 of operation 'bufferization.to_memref' that remained live after conversion}}
+ %1 = bufferization.to_memref %0 : memref<?xf32, #map4>
+ // expected-note @+1 {{see existing live user here}}
+ return %1 : memref<?xf32, #map4>
+}
+
+// -----
+
+func @illegal_unranked_to_rank(%m: memref<*xf32>) -> memref<?xf32> {
+ // expected-note @+1 {{prior use here}}
+ %0 = bufferization.to_tensor %m : memref<*xf32>
+ // expected-error @+1 {{expects
diff erent type than prior uses: 'tensor<?xf32>' vs 'tensor<*xf32>'}}
+ %1 = bufferization.to_memref %0 : memref<?xf32>
+ return %1 : memref<?xf32>
+}
More information about the Mlir-commits
mailing list