[Mlir-commits] [mlir] fa7c8cb - [mlir][bufferize] Support memrefs with non-standard layout in `finalizing-bufferize`

Matthias Springer llvmlistbot at llvm.org
Fri Feb 18 02:34:16 PST 2022


Author: Matthias Springer
Date: 2022-02-18T19:34:04+09:00
New Revision: fa7c8cb4d01e9f24816c43d5c44a7fb62564ebc5

URL: https://github.com/llvm/llvm-project/commit/fa7c8cb4d01e9f24816c43d5c44a7fb62564ebc5
DIFF: https://github.com/llvm/llvm-project/commit/fa7c8cb4d01e9f24816c43d5c44a7fb62564ebc5.diff

LOG: [mlir][bufferize] Support memrefs with non-standard layout in `finalizing-bufferize`

Differential Revision: https://reviews.llvm.org/D119935

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h
    mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
    mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
    mlir/test/Dialect/Bufferization/Transforms/finalizing-bufferize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h b/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h
index 2cbfc901f239..d3e6a5c7f5e3 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h
@@ -27,4 +27,26 @@
 #define GET_OP_CLASSES
 #include "mlir/Dialect/Bufferization/IR/BufferizationOps.h.inc"
 
+//===----------------------------------------------------------------------===//
+// Helper functions
+//===----------------------------------------------------------------------===//
+
+namespace mlir {
+namespace bufferization {
+/// Try to cast the given ranked MemRef-typed value to the given ranked MemRef
+/// type. Insert a reallocation + copy if it cannot be statically guaranteed
+/// that a direct cast would be valid.
+///
+/// E.g., when casting from a ranked MemRef type with dynamic layout to a ranked
+/// MemRef type with static layout, it is not statically known whether the cast
+/// will succeed or not. Such `memref.cast` ops may fail at runtime. This
+/// function never generates such casts and conservatively inserts a copy.
+///
+/// This function returns `failure()` in case of unsupported casts. E.g., casts
+/// with 
diff ering element types or memory spaces.
+FailureOr<Value> castOrReallocMemRefValue(OpBuilder &b, Value value,
+                                          MemRefType type);
+} // namespace bufferization
+} // namespace mlir
+
 #endif // MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZATION_H_

diff  --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index f1ec7bbdead2..c5a99d820bc9 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -1,4 +1,3 @@
-
 //===----------------------------------------------------------------------===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
@@ -13,6 +12,73 @@
 using namespace mlir;
 using namespace mlir::bufferization;
 
+//===----------------------------------------------------------------------===//
+// Helper functions
+//===----------------------------------------------------------------------===//
+
+FailureOr<Value>
+mlir::bufferization::castOrReallocMemRefValue(OpBuilder &b, Value value,
+                                              MemRefType destType) {
+  auto srcType = value.getType().cast<MemRefType>();
+
+  // Casting to the same type, nothing to do.
+  if (srcType == destType)
+    return value;
+
+  // Element type, rank and memory space must match.
+  if (srcType.getElementType() != destType.getElementType())
+    return failure();
+  if (srcType.getMemorySpaceAsInt() != destType.getMemorySpaceAsInt())
+    return failure();
+  if (srcType.getRank() != destType.getRank())
+    return failure();
+
+  // In case the affine maps are 
diff erent, we may need to use a copy if we go
+  // from dynamic to static offset or stride (the canonicalization cannot know
+  // at this point that it is really cast compatible).
+  auto isGuaranteedCastCompatible = [](MemRefType source, MemRefType target) {
+    int64_t sourceOffset, targetOffset;
+    SmallVector<int64_t, 4> sourceStrides, targetStrides;
+    if (failed(getStridesAndOffset(source, sourceStrides, sourceOffset)) ||
+        failed(getStridesAndOffset(target, targetStrides, targetOffset)))
+      return false;
+    auto dynamicToStatic = [](int64_t a, int64_t b) {
+      return a == MemRefType::getDynamicStrideOrOffset() &&
+             b != MemRefType::getDynamicStrideOrOffset();
+    };
+    if (dynamicToStatic(sourceOffset, targetOffset))
+      return false;
+    for (auto it : zip(sourceStrides, targetStrides))
+      if (dynamicToStatic(std::get<0>(it), std::get<1>(it)))
+        return false;
+    return true;
+  };
+
+  // Note: If `areCastCompatible`, a cast is valid, but may fail at runtime. To
+  // ensure that we only generate casts that always succeed at runtime, we check
+  // a fix extra conditions in `isGuaranteedCastCompatible`.
+  if (memref::CastOp::areCastCompatible(srcType, destType) &&
+      isGuaranteedCastCompatible(srcType, destType)) {
+    Value casted = b.create<memref::CastOp>(value.getLoc(), destType, value);
+    return casted;
+  }
+
+  auto loc = value.getLoc();
+  SmallVector<Value, 4> dynamicOperands;
+  for (int i = 0; i < destType.getRank(); ++i) {
+    if (destType.getShape()[i] != ShapedType::kDynamicSize)
+      continue;
+    auto index = b.createOrFold<arith::ConstantIndexOp>(loc, i);
+    Value size = b.create<memref::DimOp>(loc, value, index);
+    dynamicOperands.push_back(size);
+  }
+  // TODO: Use alloc/memcpy callback from BufferizationOptions if called via
+  // BufferizableOpInterface impl of ToMemrefOp.
+  Value copy = b.create<memref::AllocOp>(loc, destType, dynamicOperands);
+  b.create<memref::CopyOp>(loc, value, copy);
+  return copy;
+}
+
 //===----------------------------------------------------------------------===//
 // CloneOp
 //===----------------------------------------------------------------------===//
@@ -191,67 +257,39 @@ static LogicalResult foldToMemrefToTensorPair(RewriterBase &rewriter,
   if (!memrefToTensor)
     return failure();
 
-  // A memref_to_tensor + tensor_to_memref with same types can be folded without
-  // inserting a cast.
-  if (memrefToTensor.memref().getType() == toMemref.getType()) {
-    if (!allowSameType)
-      // Function can be configured to only handle cases where a cast is needed.
+  Type srcType = memrefToTensor.memref().getType();
+  Type destType = toMemref.getType();
+
+  // Function can be configured to only handle cases where a cast is needed.
+  if (!allowSameType && srcType == destType)
+    return failure();
+
+  auto rankedSrcType = srcType.dyn_cast<MemRefType>();
+  auto rankedDestType = destType.dyn_cast<MemRefType>();
+  auto unrankedSrcType = srcType.dyn_cast<UnrankedMemRefType>();
+
+  // Ranked memref -> Ranked memref cast.
+  if (rankedSrcType && rankedDestType) {
+    FailureOr<Value> replacement = castOrReallocMemRefValue(
+        rewriter, memrefToTensor.memref(), rankedDestType);
+    if (failed(replacement))
       return failure();
-    rewriter.replaceOp(toMemref, memrefToTensor.memref());
+
+    rewriter.replaceOp(toMemref, *replacement);
     return success();
   }
 
-  // If types are definitely not cast-compatible, bail.
-  if (!memref::CastOp::areCastCompatible(memrefToTensor.memref().getType(),
-                                         toMemref.getType()))
+  // Unranked memref -> Ranked memref cast: May require a copy.
+  // TODO: Not implemented at the moment.
+  if (unrankedSrcType && rankedDestType)
     return failure();
 
-  // We already know that the types are potentially cast-compatible. However
-  // in case the affine maps are 
diff erent, we may need to use a copy if we go
-  // from dynamic to static offset or stride (the canonicalization cannot know
-  // at this point that it is really cast compatible).
-  auto isGuaranteedCastCompatible = [](MemRefType source, MemRefType target) {
-    int64_t sourceOffset, targetOffset;
-    SmallVector<int64_t, 4> sourceStrides, targetStrides;
-    if (failed(getStridesAndOffset(source, sourceStrides, sourceOffset)) ||
-        failed(getStridesAndOffset(target, targetStrides, targetOffset)))
-      return false;
-    auto dynamicToStatic = [](int64_t a, int64_t b) {
-      return a == MemRefType::getDynamicStrideOrOffset() &&
-             b != MemRefType::getDynamicStrideOrOffset();
-    };
-    if (dynamicToStatic(sourceOffset, targetOffset))
-      return false;
-    for (auto it : zip(sourceStrides, targetStrides))
-      if (dynamicToStatic(std::get<0>(it), std::get<1>(it)))
-        return false;
-    return true;
-  };
-
-  auto memrefToTensorType =
-      memrefToTensor.memref().getType().dyn_cast<MemRefType>();
-  auto toMemrefType = toMemref.getType().dyn_cast<MemRefType>();
-  if (memrefToTensorType && toMemrefType &&
-      !isGuaranteedCastCompatible(memrefToTensorType, toMemrefType)) {
-    MemRefType resultType = toMemrefType;
-    auto loc = toMemref.getLoc();
-    SmallVector<Value, 4> dynamicOperands;
-    for (int i = 0; i < resultType.getRank(); ++i) {
-      if (resultType.getShape()[i] != ShapedType::kDynamicSize)
-        continue;
-      auto index = rewriter.createOrFold<arith::ConstantIndexOp>(loc, i);
-      Value size = rewriter.create<tensor::DimOp>(loc, memrefToTensor, index);
-      dynamicOperands.push_back(size);
-    }
-    // TODO: Use alloc/memcpy callback from BufferizationOptions if called via
-    // BufferizableOpInterface impl of ToMemrefOp.
-    auto copy =
-        rewriter.create<memref::AllocOp>(loc, resultType, dynamicOperands);
-    rewriter.create<memref::CopyOp>(loc, memrefToTensor.memref(), copy);
-    rewriter.replaceOp(toMemref, {copy});
-  } else
-    rewriter.replaceOpWithNewOp<memref::CastOp>(toMemref, toMemref.getType(),
-                                                memrefToTensor.memref());
+  // Unranked memref -> unranked memref cast
+  // Ranked memref -> unranked memref cast: No copy needed.
+  assert(memref::CastOp::areCastCompatible(srcType, destType) &&
+         "expected that types are cast compatible");
+  rewriter.replaceOpWithNewOp<memref::CastOp>(toMemref, destType,
+                                              memrefToTensor.memref());
   return success();
 }
 

diff  --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
index c7468da6132f..01b22264e5ba 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
@@ -45,9 +45,28 @@ BufferizeTypeConverter::BufferizeTypeConverter() {
   addSourceMaterialization(materializeToTensor);
   addTargetMaterialization([](OpBuilder &builder, BaseMemRefType type,
                               ValueRange inputs, Location loc) -> Value {
-    assert(inputs.size() == 1);
-    assert(inputs[0].getType().isa<TensorType>());
-    return builder.create<bufferization::ToMemrefOp>(loc, type, inputs[0]);
+    assert(inputs.size() == 1 && "expected exactly one input");
+
+    if (auto inputType = inputs[0].getType().dyn_cast<MemRefType>()) {
+      // MemRef to MemRef cast.
+      assert(inputType != type && "expected 
diff erent types");
+      // Unranked to ranked and ranked to unranked casts must be explicit.
+      auto rankedDestType = type.dyn_cast<MemRefType>();
+      if (!rankedDestType)
+        return nullptr;
+      FailureOr<Value> replacement =
+          castOrReallocMemRefValue(builder, inputs[0], rankedDestType);
+      if (failed(replacement))
+        return nullptr;
+      return *replacement;
+    }
+
+    if (inputs[0].getType().isa<TensorType>()) {
+      // Tensor to MemRef cast.
+      return builder.create<bufferization::ToMemrefOp>(loc, type, inputs[0]);
+    }
+
+    llvm_unreachable("only tensor/memref input types supported");
   });
 }
 

diff  --git a/mlir/test/Dialect/Bufferization/Transforms/finalizing-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/finalizing-bufferize.mlir
index fac685ae7e72..5d70e90b7540 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/finalizing-bufferize.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/finalizing-bufferize.mlir
@@ -26,3 +26,77 @@ func @unable_to_convert_lone_tensor_load(%arg0: memref<f32>) {
   "test.sink"(%0) : (tensor<f32>) -> ()
   return
 }
+
+// -----
+
+//       CHECK: #[[$map1:.*]] = affine_map<(d0)[s0] -> (d0 + s0)>
+// CHECK-LABEL: func @dyn_layout_to_no_layout_cast(
+//  CHECK-SAME:     %[[arg:.*]]: memref<?xf32, #[[$map1]]>)
+//       CHECK:   %[[c0:.*]] = arith.constant 0 : index
+//       CHECK:   %[[dim:.*]] = memref.dim %[[arg]], %[[c0]]
+//       CHECK:   %[[alloc:.*]] = memref.alloc(%[[dim]]) : memref<?xf32>
+//       CHECK:   memref.copy %[[arg]], %[[alloc]]
+//       CHECK:   return %[[alloc]]
+#map1 = affine_map<(d0)[s0] -> (d0 + s0)>
+func @dyn_layout_to_no_layout_cast(%m: memref<?xf32, #map1>) -> memref<?xf32> {
+  %0 = bufferization.to_tensor %m : memref<?xf32, #map1>
+  %1 = bufferization.to_memref %0 : memref<?xf32>
+  return %1 : memref<?xf32>
+}
+
+// -----
+
+//       CHECK: #[[$map2:.*]] = affine_map<(d0)[s0] -> (d0 * 100 + s0)>
+// CHECK-LABEL: func @fancy_layout_to_no_layout_cast(
+//  CHECK-SAME:     %[[arg:.*]]: memref<?xf32, #[[$map2]]>)
+//       CHECK:   %[[c0:.*]] = arith.constant 0 : index
+//       CHECK:   %[[dim:.*]] = memref.dim %[[arg]], %[[c0]]
+//       CHECK:   %[[alloc:.*]] = memref.alloc(%[[dim]]) : memref<?xf32>
+//       CHECK:   memref.copy %[[arg]], %[[alloc]]
+//       CHECK:   return %[[alloc]]
+#map2 = affine_map<(d0)[s0] -> (d0 * 100 + s0)>
+func @fancy_layout_to_no_layout_cast(%m: memref<?xf32, #map2>) -> memref<?xf32> {
+  %0 = bufferization.to_tensor %m : memref<?xf32, #map2>
+  %1 = bufferization.to_memref %0 : memref<?xf32>
+  return %1 : memref<?xf32>
+}
+
+// -----
+
+//       CHECK: #[[$map3:.*]] = affine_map<(d0)[s0] -> (d0 + 25)>
+// CHECK-LABEL: func @static_layout_to_no_layout_cast(
+//  CHECK-SAME:     %[[arg:.*]]: memref<?xf32, #[[$map3]]>)
+//       CHECK:   %[[c0:.*]] = arith.constant 0 : index
+//       CHECK:   %[[dim:.*]] = memref.dim %[[arg]], %[[c0]]
+//       CHECK:   %[[alloc:.*]] = memref.alloc(%[[dim]]) : memref<?xf32>
+//       CHECK:   memref.copy %[[arg]], %[[alloc]]
+//       CHECK:   return %[[alloc]]
+#map3 = affine_map<(d0)[s0] -> (d0 + 25)>
+func @static_layout_to_no_layout_cast(%m: memref<?xf32, #map3>) -> memref<?xf32> {
+  %0 = bufferization.to_tensor %m : memref<?xf32, #map3>
+  %1 = bufferization.to_memref %0 : memref<?xf32>
+  return %1 : memref<?xf32>
+}
+
+// -----
+
+// TODO: to_memref with layout maps not supported yet. This should fold to a
+// memref.cast.
+#map4 = affine_map<(d0)[s0] -> (d0 + s0)>
+func @no_layout_to_dyn_layout_cast(%m: memref<?xf32>) -> memref<?xf32, #map4> {
+  %0 = bufferization.to_tensor %m : memref<?xf32>
+  // expected-error @+1 {{failed to materialize conversion for result #0 of operation 'bufferization.to_memref' that remained live after conversion}}
+  %1 = bufferization.to_memref %0 : memref<?xf32, #map4>
+  // expected-note @+1 {{see existing live user here}}
+  return %1 : memref<?xf32, #map4>
+}
+
+// -----
+
+func @illegal_unranked_to_rank(%m: memref<*xf32>) -> memref<?xf32> {
+  // expected-note @+1 {{prior use here}}
+  %0 = bufferization.to_tensor %m : memref<*xf32>
+  // expected-error @+1 {{expects 
diff erent type than prior uses: 'tensor<?xf32>' vs 'tensor<*xf32>'}}
+  %1 = bufferization.to_memref %0 : memref<?xf32>
+  return %1 : memref<?xf32>
+}


        


More information about the Mlir-commits mailing list