[Mlir-commits] [mlir] d2608ad - [mlir][bufferize] Do not insert useless casts for newly allocated buffers

Matthias Springer llvmlistbot at llvm.org
Fri Apr 8 02:13:34 PDT 2022


Author: Matthias Springer
Date: 2022-04-08T18:12:02+09:00
New Revision: d2608adf490c10afc71d57141d61a9df5464fd82

URL: https://github.com/llvm/llvm-project/commit/d2608adf490c10afc71d57141d61a9df5464fd82
DIFF: https://github.com/llvm/llvm-project/commit/d2608adf490c10afc71d57141d61a9df5464fd82.diff

LOG: [mlir][bufferize] Do not insert useless casts for newly allocated buffers

Differential Revision: https://reviews.llvm.org/D123369

Added: 
    

Modified: 
    mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
    mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
    mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-partial.mlir
    mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir
    mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
    mlir/test/Dialect/Linalg/one-shot-module-bufferize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 8c68db885de63..091462f1ed73a 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -466,7 +466,6 @@ FailureOr<Value> BufferizationState::createAlloc(OpBuilder &b, Location loc,
 
   // Compute allocation memref type.
   assert(shapedValue.getType().isa<ShapedType>());
-  MemRefType memRefType = shapedValue.getType().dyn_cast<MemRefType>();
   SmallVector<Value> dynShape;
   MemRefType allocMemRefType =
       getAllocationTypeAndShape(b, loc, shapedValue, dynShape);
@@ -485,17 +484,7 @@ FailureOr<Value> BufferizationState::createAlloc(OpBuilder &b, Location loc,
   }
 
   // Create the buffer allocation.
-  Value alloc =
-      createBufferAllocation(b, loc, allocMemRefType, dynShape, skipDealloc);
-
-  // Insert a cast if a 
diff erent type was requested.
-  if (memRefType && memRefType != allocMemRefType) {
-    assert(memref::CastOp::areCastCompatible(allocMemRefType, memRefType) &&
-           "createAlloc: cast incompatible");
-    alloc = b.create<memref::CastOp>(loc, memRefType, alloc);
-  }
-
-  return alloc;
+  return createBufferAllocation(b, loc, allocMemRefType, dynShape, skipDealloc);
 }
 
 /// Create a memory copy between two memref buffers.

diff  --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index fde24fd8b7ab4..7c6bbcd414d71 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -411,7 +411,22 @@ struct ForOpInterface
                                           *yieldedAlloc, state.getOptions());
           (void)copyStatus;
           assert(succeeded(copyStatus) && "could not create memcpy");
-          return *yieldedAlloc;
+
+          if (yieldedVal.getType() == yieldedAlloc->getType())
+            return *yieldedAlloc;
+
+          // The iter_arg memref type has a layout map. Cast the new buffer to
+          // the same type.
+          // TODO: In case the iter_arg has a layout map that is not the fully
+          // dynamic one, we cannot cast the new buffer. In that case, the
+          // iter_arg must be changed to the fully dynamic layout map. (And then
+          // the new buffer can be casted.)
+          assert(memref::CastOp::areCastCompatible(yieldedAlloc->getType(),
+                                                   yieldedVal.getType()) &&
+                 "scf.for op bufferization: cast incompatible");
+          Value casted = rewriter.create<memref::CastOp>(
+              val.getLoc(), yieldedVal.getType(), *yieldedAlloc);
+          return casted;
         });
     yieldOp.getResultsMutable().assign(yieldValues);
 

diff  --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-partial.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-partial.mlir
index 2fdb0a45e9d34..efc3038820ace 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-partial.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-partial.mlir
@@ -193,12 +193,11 @@ func @simple_tensor_test(%t1 : tensor<?xf32>, %f : f32) -> tensor<?xf32> {
   // CHECK-TENSOR: %[[t1_memref:.*]] = bufferization.to_memref %[[t1]]
   %c0 = arith.constant 0 : index
   // CHECK-TENSOR: %[[alloc:.*]] = memref.alloc
-  // CHECK-TENSOR: %[[casted:.*]] = memref.cast %[[alloc]]
-  // CHECK-TENSOR: %[[casted_tensor:.*]] = bufferization.to_tensor %[[casted]]
+  // CHECK-TENSOR: %[[casted_alloc:.*]] = bufferization.to_tensor %[[alloc]]
   // CHECK-TENSOR: memref.copy %[[t1_memref]], %[[alloc]]
   // CHECK-TENSOR: memref.store %{{.*}}, %[[alloc]]
   %0 = tensor.insert %f into %t1[%c0] : tensor<?xf32>
-  // CHECK-TENSOR: return %[[casted_tensor]]
+  // CHECK-TENSOR: return %[[casted_alloc]]
   return %0 : tensor<?xf32>
 }
 

diff  --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir
index ac2249da4282c..3d8d09460484a 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir
@@ -29,10 +29,9 @@ func @return_tensor(%A : tensor<?xf32>, %v : vector<4xf32>) -> (tensor<?xf32>) {
   // CHECK: %[[A_memref:.*]] = bufferization.to_memref %[[A]]
   // CHECK: %[[dim:.*]] = tensor.dim %[[A]]
   // CHECK: %[[alloc:.*]] = memref.alloc(%[[dim]])
-  // CHECK: %[[casted:.*]] = memref.cast %[[alloc]]
   // CHECK: memref.copy %[[A_memref]], %[[alloc]]
   // CHECK: vector.transfer_write %{{.*}}, %[[alloc]]
-  // CHECK: %[[res_tensor:.*]] = bufferization.to_tensor %[[casted]]
+  // CHECK: %[[res_tensor:.*]] = bufferization.to_tensor %[[alloc]]
   %0 = vector.transfer_write %v, %A[%c0] : vector<4xf32>, tensor<?xf32>
 
   // CHECK: memref.dealloc %[[alloc]]

diff  --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
index b8f9dcf0149b8..7ed7bd13321b1 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
@@ -308,16 +308,16 @@ func @insert_slice_fun_not_inplace(
 
 // CHECK-DAG: #[[$map_1d_dyn:.*]] = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)>
 
-// CHECK-LABEL: func @scf_for_yield_only
-//  CHECK-SAME:   %[[A:[a-zA-Z0-9]*]]: memref<?xf32, #[[$map_1d_dyn]]>
+// CHECK-LABEL: func @scf_for_yield_only(
+//  CHECK-SAME:   %[[A:[a-zA-Z0-9]*]]: memref<?xf32, #[[$map_1d_dyn]]>,
 //  CHECK-SAME:   %[[t:[a-zA-Z0-9]*]]: memref<?xf32, #[[$map_1d_dyn]]>
+//  CHECK-SAME:   ) -> memref<?xf32> {
 func @scf_for_yield_only(%A : tensor<?xf32> {linalg.inplaceable = false},
                          %B : tensor<?xf32> {linalg.inplaceable = true},
                          %lb : index, %ub : index, %step : index)
   -> (tensor<?xf32>, tensor<?xf32>)
 {
   //     CHECK:   %[[ALLOC_FOR_A:.*]] = memref.alloc
-  //     CHECK:   %[[CASTED:.*]] = memref.cast %[[ALLOC_FOR_A]]
   //     CHECK:   memref.copy %[[A]], %[[ALLOC_FOR_A]]
 
   // The first scf.for remains but just turns into dead code.
@@ -330,7 +330,7 @@ func @scf_for_yield_only(%A : tensor<?xf32> {linalg.inplaceable = false},
     scf.yield %t : tensor<?xf32>
   }
 
-  //     CHECK:   return %[[CASTED]] : memref<?xf32, #[[$map_1d_dyn]]>
+  //     CHECK:   return %[[ALLOC_FOR_A]] : memref<?xf32>
   // CHECK-NOT:   dealloc
   return %r0, %r1: tensor<?xf32>, tensor<?xf32>
 }
@@ -373,7 +373,6 @@ func @scf_for_with_tensor.insert_slice(
   -> (tensor<?xf32>, tensor<?xf32>)
 {
   //     CHECK:   %[[ALLOC_FOR_A:.*]] = memref.alloc
-  //     CHECK:   %[[CASTED:.*]] = memref.cast %[[ALLOC_FOR_A]]
   //     CHECK:   memref.copy %[[A]], %[[ALLOC_FOR_A]]
 
   //     CHECK: %[[svA:.*]] = memref.subview %[[ALLOC_FOR_A]][0] [4] [1]
@@ -396,7 +395,7 @@ func @scf_for_with_tensor.insert_slice(
     scf.yield %ttA, %ttB : tensor<?xf32>, tensor<?xf32>
   }
 
-  //     CHECK:  return %[[CASTED]] : memref<?xf32, #[[$map_1d_dyn]]>
+  //     CHECK:  return %[[ALLOC_FOR_A]] : memref<?xf32>
   return %r0#0, %r0#1: tensor<?xf32>, tensor<?xf32>
 }
 
@@ -418,7 +417,6 @@ func @execute_region_with_conflict(%t1 : tensor<?xf32> {linalg.inplaceable = "tr
   // memref.store is left over.
 
   // CHECK: %[[alloc:.*]] = memref.alloc
-  // CHECK: %[[casted:.*]] = memref.cast %[[alloc]]
   // CHECK: memref.copy %[[m1]], %[[alloc]]
   // CHECK: memref.store %{{.*}}, %[[alloc]][%{{.*}}]
   %0, %1, %2 = scf.execute_region -> (f32, tensor<?xf32>, f32) {
@@ -426,6 +424,7 @@ func @execute_region_with_conflict(%t1 : tensor<?xf32> {linalg.inplaceable = "tr
     scf.yield %f1, %t2, %f1 : f32, tensor<?xf32>, f32
   }
 
+  // CHECK: %[[casted:.*]] = memref.cast %[[alloc]]
   // CHECK: %[[load:.*]] = memref.load %[[m1]]
   %3 = tensor.extract %t1[%idx] : tensor<?xf32>
 
@@ -783,8 +782,8 @@ func @write_after_select_read_one(
   %idx = arith.constant 0 : index
 
   // CHECK: %[[alloc:.*]] = memref.alloc
-  // CHECK: %[[casted:.*]] = memref.cast %[[alloc]]
-  // CHECK: memref.copy %[[t1]], %[[alloc]]
+  // CHECK-DAG: %[[casted:.*]] = memref.cast %[[alloc]]
+  // CHECK-DAG: memref.copy %[[t1]], %[[alloc]]
   // CHECK: %[[select:.*]] = arith.select %{{.*}}, %[[casted]], %[[t2]]
   %s = arith.select %c, %t1, %t2 : tensor<?xf32>
 
@@ -859,13 +858,11 @@ func @scf_execute_region_yield_non_equivalent(%i: index, %j: index) -> f32 {
 // CHECK-LABEL: func @scf_for_yield_non_equivalent(
 //  CHECK-SAME:     %[[t:.*]]: memref<?xf32
 //       CHECK:   %[[alloc:.*]] = memref.alloc(%{{.*}})
-//       CHECK:   %[[casted:.*]] = memref.cast %[[alloc]]
-//       CHECK:   %[[for:.*]] = scf.for {{.*}} iter_args(%[[iter:.*]] = %[[casted]])
+//       CHECK:   %[[for:.*]] = scf.for {{.*}} iter_args(%[[iter:.*]] = %[[alloc]])
 //       CHECK:     memref.dealloc %[[iter]]
 //       CHECK:     %[[alloc2:.*]] = memref.alloc(%{{.*}})
 //       CHECK:     memref.copy %[[t]], %[[alloc2]]
-//       CHECK:     %[[casted2:.*]] = memref.cast %[[alloc2]]
-//       CHECK:     scf.yield %[[casted2]]
+//       CHECK:     scf.yield %[[alloc2]]
 //       CHECK:   return %[[for]]
 func @scf_for_yield_non_equivalent(
     %t: tensor<?xf32>, %lb : index, %ub : index, %step : index) -> tensor<?xf32> {

diff  --git a/mlir/test/Dialect/Linalg/one-shot-module-bufferize.mlir b/mlir/test/Dialect/Linalg/one-shot-module-bufferize.mlir
index 72661713a5b44..24fe44ca58a34 100644
--- a/mlir/test/Dialect/Linalg/one-shot-module-bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/one-shot-module-bufferize.mlir
@@ -358,7 +358,7 @@ func @bar(
   // %r0#0 requires a copy because we have no idea what the function is doing.
 //  CHECK-DAG:   %[[alloc:.*]] = memref.alloc
 //  CHECK-DAG:   %[[casted:.*]] = memref.cast %[[alloc]]
-//      CHECK:   memref.copy %[[B]], %[[alloc]]
+//  CHECK-DAG:   memref.copy %[[B]], %[[alloc]]
 // CHECK-NEXT:   call @some_external_func(%[[casted]]) : (memref<?xf32, #[[$DYN_1D_MAP]]>) -> ()
   call @some_external_func(%r0#0) : (tensor<?xf32>) -> ()
 
@@ -475,15 +475,15 @@ func @entry(%A : tensor<?xf32> {linalg.buffer_layout = affine_map<(i)[s0, s1] ->
 // conflict. However, inside `entry`, the writes do cause a conflict because
 // %A, %B and %C are not inplaceable. This test case shows that this kind of
 // conflict detection has a "transitive" nature.
-//      CHECK: %[[ALLOC_C:.*]] = memref.alloc
-//      CHECK: %[[CASTED_C:.*]] = memref.cast %[[ALLOC_C]]
-//      CHECK: %[[ALLOC_B:.*]] = memref.alloc
-//      CHECK: %[[CASTED_B:.*]] = memref.cast %[[ALLOC_B]]
-//      CHECK: %[[ALLOC_A:.*]] = memref.alloc
-//      CHECK: memref.copy %[[A]], %[[ALLOC_A]]
-//      CHECK: memref.copy %[[B]], %[[ALLOC_B]]
-//      CHECK: memref.copy %[[C]], %[[ALLOC_C]]
-//      CHECK: %[[CASTED_A:.*]] = memref.cast %[[ALLOC_A]]
+//  CHECK-DAG: %[[ALLOC_C:.*]] = memref.alloc
+//  CHECK-DAG: %[[CASTED_C:.*]] = memref.cast %[[ALLOC_C]]
+//  CHECK-DAG: %[[ALLOC_B:.*]] = memref.alloc
+//  CHECK-DAG: %[[CASTED_B:.*]] = memref.cast %[[ALLOC_B]]
+//  CHECK-DAG: %[[ALLOC_A:.*]] = memref.alloc
+//  CHECK-DAG: %[[CASTED_A:.*]] = memref.cast %[[ALLOC_A]]
+//  CHECK-DAG: memref.copy %[[A]], %[[ALLOC_A]]
+//  CHECK-DAG: memref.copy %[[B]], %[[ALLOC_B]]
+//  CHECK-DAG: memref.copy %[[C]], %[[ALLOC_C]]
 // CHECK-NEXT: call @callee(%[[CASTED_A]], %[[CASTED_B]], %[[CASTED_C]])
   call @callee(%A, %B, %C) : (tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> ()
   return
@@ -539,8 +539,8 @@ func @equivalent_func_arg_2(%t0: tensor<?xf32> {linalg.inplaceable = true},
   // CHECK: scf.for {{.*}} {
   %1 = scf.for %iv = %c0 to %c10 step %c1 iter_args(%t1 = %t0) -> (tensor<?xf32>) {
     // CHECK: %[[alloc:.*]] = memref.alloc
-    // CHECK: %[[casted:.*]] = memref.cast %[[alloc]]
-    // CHECK: memref.copy %[[arg0]], %[[alloc]]
+    // CHECK-DAG: %[[casted:.*]] = memref.cast %[[alloc]]
+    // CHECK-DAG: memref.copy %[[arg0]], %[[alloc]]
     // CHECK: call @inner_func_2(%[[casted]])
     // CHECK: memref.dealloc %[[alloc]]
     // CHECK-NOT: scf.yield


        


More information about the Mlir-commits mailing list