[Mlir-commits] [mlir] d0ee094 - [mlir][Bufferize] Fix incorrect bufferization of rank-reducing tensor ops.

Nicolas Vasilache llvmlistbot at llvm.org
Mon Jan 10 07:14:59 PST 2022


Author: Nicolas Vasilache
Date: 2022-01-10T10:14:55-05:00
New Revision: d0ee094d6acf72608e927bf2e9ba69c57da59a96

URL: https://github.com/llvm/llvm-project/commit/d0ee094d6acf72608e927bf2e9ba69c57da59a96
DIFF: https://github.com/llvm/llvm-project/commit/d0ee094d6acf72608e927bf2e9ba69c57da59a96.diff

LOG: [mlir][Bufferize] Fix incorrect bufferization of rank-reducing tensor ops.

This revision fixes SubviewOp, InsertSliceOp, ExtractSliceOp construction during bufferization
where not all offset/size/stride operands were properly specified.

A test that exhibited problematic behaviors related to incorrect memref casts is introduced.
Init tensor optimization is disabled in teh testing func bufferize pass.

Differential Revision: https://reviews.llvm.org/D116899

Added: 
    

Modified: 
    mlir/include/mlir/Interfaces/ViewLikeInterface.td
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
    mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
    mlir/lib/Interfaces/ViewLikeInterface.cpp
    mlir/test/Dialect/Linalg/comprehensive-function-bufferize.mlir
    mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir
    mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
    mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.td b/mlir/include/mlir/Interfaces/ViewLikeInterface.td
index 50ebeaa44a5c3..b829760ba4591 100644
--- a/mlir/include/mlir/Interfaces/ViewLikeInterface.td
+++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.td
@@ -483,6 +483,19 @@ def OffsetSizeAndStrideOpInterface : OpInterface<"OffsetSizeAndStrideOpInterface
         ::mlir::OpTrait::AttrSizedOperandSegments<void>::getOperandSegmentSizeAttr()};
       return names;
    }
+   /// Assume target is a shaped type and offsets/sizes/strides are vectors of
+   /// the same length and lower than target's rank.
+   /// Complete missing dims `i` with offset=0, size=dim(target, i), stride=1
+   /// until all vectors have size rank. The commpletion occurs for the most
+   /// minor dimensions (i.e. fastest varying).
+   /// Take a `createDim` lambda that knows how to build the size of a
+   /// particular dimension of `target` (to avoid dialect dependencies).
+   static void expandToRank(
+     Value target,
+     SmallVector<OpFoldResult> &offsets,
+     SmallVector<OpFoldResult> &sizes,
+     SmallVector<OpFoldResult> &strides,
+     llvm::function_ref<OpFoldResult(Value, int64_t)> createDim);
   }];
 
   let verify = [{

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
index e64d5ae3dda61..2d4f4b2a6500c 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
@@ -347,6 +347,14 @@ mlir::linalg::comprehensive_bufferize::BufferizationState::BufferizationState(
   });
 }
 
+// bufferization.to_memref is not allowed to change the rank.
+static void ensureToMemrefOpIsValid(Value tensor, Type memrefType) {
+  auto rankedTensorType = tensor.getType().dyn_cast<RankedTensorType>();
+  assert((!rankedTensorType || memrefType.cast<MemRefType>().getRank() ==
+                                   rankedTensorType.getRank()) &&
+         "to_memref would be invalid: mismatching ranks");
+}
+
 static Value lookupBuffer(RewriterBase &rewriter, Value tensor) {
   assert(tensor.getType().isa<TensorType>() && "unexpected non-tensor type");
 
@@ -364,6 +372,7 @@ static Value lookupBuffer(RewriterBase &rewriter, Value tensor) {
     memrefType = getUnrankedMemRefType(
         tensor.getType().cast<TensorType>().getElementType());
   }
+  ensureToMemrefOpIsValid(tensor, memrefType);
   return rewriter.create<bufferization::ToMemrefOp>(tensor.getLoc(), memrefType,
                                                     tensor);
 }

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
index bf08155076b22..64bc6920da07c 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
@@ -563,10 +563,26 @@ LogicalResult mlir::linalg::comprehensive_bufferize::linalg_ext::
       },
       /*rewriteFunc=*/
       [](OpBuilder &b, Location loc, OpOperand &operand) {
-        auto insertSliceOp = cast<tensor::InsertSliceOp>(operand.getOwner());
+        auto insertOp = cast<tensor::InsertSliceOp>(operand.getOwner());
+        // Expand offsets, sizes and strides to the full rank to handle the
+        // rank-reducing case.
+        SmallVector<OpFoldResult> mixedOffsets = insertOp.getMixedOffsets();
+        SmallVector<OpFoldResult> mixedSizes = insertOp.getMixedSizes();
+        SmallVector<OpFoldResult> mixedStrides = insertOp.getMixedStrides();
+        OffsetSizeAndStrideOpInterface::expandToRank(
+            insertOp.dest(), mixedOffsets, mixedSizes, mixedStrides,
+            [&](Value target, int64_t dim) -> OpFoldResult {
+              auto shapedType = target.getType().cast<ShapedType>();
+              if (shapedType.isDynamicDim(dim))
+                return b.create<tensor::DimOp>(loc, target, dim).result();
+              return b.getIndexAttr(shapedType.getDimSize(dim));
+            });
+        auto t = tensor::ExtractSliceOp::inferRankReducedResultType(
+            insertOp.getSourceType().getRank(),
+            insertOp.dest().getType().cast<RankedTensorType>(), mixedOffsets,
+            mixedSizes, mixedStrides);
         auto extractOp = b.create<tensor::ExtractSliceOp>(
-            loc, insertSliceOp.dest(), insertSliceOp.getMixedOffsets(),
-            insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides());
+            loc, t, insertOp.dest(), mixedOffsets, mixedSizes, mixedStrides);
         return extractOp.result();
       },
       newOps);

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
index 01308088bab8e..2eaed56669d7e 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
@@ -19,6 +19,14 @@ namespace linalg {
 namespace comprehensive_bufferize {
 namespace scf_ext {
 
+// bufferization.to_memref is not allowed to change the rank.
+static void ensureToMemrefOpIsValid(Value tensor, Type memrefType) {
+  auto rankedTensorType = tensor.getType().dyn_cast<RankedTensorType>();
+  assert((!rankedTensorType || (memrefType.cast<MemRefType>().getRank() ==
+                                rankedTensorType.getRank())) &&
+         "to_memref would be invalid: mismatching ranks");
+}
+
 /// Bufferization of scf.execute_region. Can be analyzed, but bufferization not
 /// fully implemented at the moment.
 struct ExecuteRegionOpInterface
@@ -159,6 +167,8 @@ struct IfOpInterface
     SmallVector<Value> thenYieldValues;
     for (OpOperand &operand : thenYieldOp->getOpOperands()) {
       if (operand.get().getType().isa<TensorType>()) {
+        ensureToMemrefOpIsValid(operand.get(),
+                                newTypes[operand.getOperandNumber()]);
         Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>(
             operand.get().getLoc(), newTypes[operand.getOperandNumber()],
             operand.get());
@@ -172,6 +182,8 @@ struct IfOpInterface
     SmallVector<Value> elseYieldValues;
     for (OpOperand &operand : elseYieldOp->getOpOperands()) {
       if (operand.get().getType().isa<TensorType>()) {
+        ensureToMemrefOpIsValid(operand.get(),
+                                newTypes[operand.getOperandNumber()]);
         Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>(
             operand.get().getLoc(), newTypes[operand.getOperandNumber()],
             operand.get());
@@ -317,6 +329,7 @@ struct ForOpInterface
     rewriter.setInsertionPoint(yieldOp);
     SmallVector<Value> yieldValues =
         convert(yieldOp.getResults(), [&](Value val, int64_t index) {
+          ensureToMemrefOpIsValid(val, initArgs[index].getType());
           return rewriter.create<bufferization::ToMemrefOp>(
               val.getLoc(), initArgs[index].getType(), val);
         });

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
index f0f20b433937e..620328799712a 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
@@ -68,7 +68,7 @@ struct CastOpInterface
 
     // Compute the new memref type.
     Type resultMemRefType;
-    if (auto rankedTensorType = resultTensorType.isa<RankedTensorType>()) {
+    if (resultTensorType.isa<RankedTensorType>()) {
       resultMemRefType =
           getContiguousMemRefType(resultTensorType, layout, memorySpace);
     } else {
@@ -165,16 +165,27 @@ struct ExtractSliceOpInterface
       alloc = *allocOrFailure;
     }
 
+    // Expand offsets, sizes and strides to the full rank to handle the
+    // rank-reducing case.
+    SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets();
+    SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
+    SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
+    OffsetSizeAndStrideOpInterface::expandToRank(
+        srcMemref, mixedOffsets, mixedSizes, mixedStrides,
+        [&](Value target, int64_t dim) -> OpFoldResult {
+          auto shapedType = target.getType().cast<ShapedType>();
+          if (shapedType.isDynamicDim(dim))
+            return rewriter.create<memref::DimOp>(loc, target, dim).result();
+          return rewriter.getIndexAttr(shapedType.getDimSize(dim));
+        });
     // Bufferize to subview.
-    auto subviewMemRefType =
-        memref::SubViewOp::inferRankReducedResultType(
-            dstTensorType.getRank(), srcMemrefType,
-            extractSliceOp.getMixedOffsets(), extractSliceOp.getMixedSizes(),
-            extractSliceOp.getMixedStrides())
-            .cast<MemRefType>();
+    auto subviewMemRefType = memref::SubViewOp::inferRankReducedResultType(
+                                 dstTensorType.getRank(), srcMemrefType,
+                                 mixedOffsets, mixedSizes, mixedStrides)
+                                 .cast<MemRefType>();
     Value subView = rewriter.create<memref::SubViewOp>(
-        loc, subviewMemRefType, srcMemref, extractSliceOp.getMixedOffsets(),
-        extractSliceOp.getMixedSizes(), extractSliceOp.getMixedStrides());
+        loc, subviewMemRefType, srcMemref, mixedOffsets, mixedSizes,
+        mixedStrides);
 
     // If not inplaceable, copy.
     if (!inplace) {
@@ -422,17 +433,29 @@ struct InsertSliceOpInterface
     if (failed(dstMemref))
       return failure();
 
+    // Expand offsets, sizes and strides to the full rank to handle the
+    // rank-reducing case.
+    SmallVector<OpFoldResult> mixedOffsets = insertSliceOp.getMixedOffsets();
+    SmallVector<OpFoldResult> mixedSizes = insertSliceOp.getMixedSizes();
+    SmallVector<OpFoldResult> mixedStrides = insertSliceOp.getMixedStrides();
+    OffsetSizeAndStrideOpInterface::expandToRank(
+        *dstMemref, mixedOffsets, mixedSizes, mixedStrides,
+        [&](Value target, int64_t dim) -> OpFoldResult {
+          auto shapedType = target.getType().cast<ShapedType>();
+          if (shapedType.isDynamicDim(dim))
+            return rewriter.create<memref::DimOp>(loc, target, dim).result();
+          return rewriter.getIndexAttr(shapedType.getDimSize(dim));
+        });
     // Take a subview of the dst.
     auto dstMemrefType = dstMemref->getType().cast<MemRefType>();
     auto subviewMemRefType =
         memref::SubViewOp::inferRankReducedResultType(
             insertSliceOp.getSourceType().getRank(), dstMemrefType,
-            insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
-            insertSliceOp.getMixedStrides())
+            mixedOffsets, mixedSizes, mixedStrides)
             .cast<MemRefType>();
     Value subView = rewriter.create<memref::SubViewOp>(
-        loc, subviewMemRefType, *dstMemref, insertSliceOp.getMixedOffsets(),
-        insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides());
+        loc, subviewMemRefType, *dstMemref, mixedOffsets, mixedSizes,
+        mixedStrides);
 
     // Copy tensor. If this tensor.insert_slice has a matching
     // tensor.extract_slice, the copy operation will eventually fold away.

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
index 6bc3ece8693e3..a368f4e1653c7 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
@@ -96,6 +96,7 @@ void LinalgComprehensiveModuleBufferize::runOnOperation() {
     options->addPostAnalysisStep<
         linalg_ext::InsertSliceAnchoredInitTensorEliminationStep>();
   }
+
   if (!allowReturnMemref)
     options->addPostAnalysisStep<scf_ext::AssertDestinationPassingStyle>();
 

diff  --git a/mlir/lib/Interfaces/ViewLikeInterface.cpp b/mlir/lib/Interfaces/ViewLikeInterface.cpp
index 6394895370e25..cccc8339a6446 100644
--- a/mlir/lib/Interfaces/ViewLikeInterface.cpp
+++ b/mlir/lib/Interfaces/ViewLikeInterface.cpp
@@ -176,3 +176,22 @@ bool mlir::detail::sameOffsetsSizesAndStrides(
       return false;
   return true;
 }
+
+void OffsetSizeAndStrideOpInterface::expandToRank(
+    Value target, SmallVector<OpFoldResult> &offsets,
+    SmallVector<OpFoldResult> &sizes, SmallVector<OpFoldResult> &strides,
+    llvm::function_ref<OpFoldResult(Value, int64_t)> createOrFoldDim) {
+  auto shapedType = target.getType().cast<ShapedType>();
+  unsigned rank = shapedType.getRank();
+  assert(offsets.size() == sizes.size() && "mismatched lengths");
+  assert(offsets.size() == strides.size() && "mismatched lengths");
+  assert(offsets.size() <= rank && "rank overflow");
+  MLIRContext *ctx = target.getContext();
+  Attribute zero = IntegerAttr::get(IndexType::get(ctx), APInt(64, 0));
+  Attribute one = IntegerAttr::get(IndexType::get(ctx), APInt(64, 1));
+  for (unsigned i = offsets.size(); i < rank; ++i) {
+    offsets.push_back(zero);
+    sizes.push_back(createOrFoldDim(target, i));
+    strides.push_back(one);
+  }
+}

diff  --git a/mlir/test/Dialect/Linalg/comprehensive-function-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-function-bufferize.mlir
index 971d6c6e88a2f..fd32430f4e332 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-function-bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-function-bufferize.mlir
@@ -67,3 +67,32 @@ func private @private_func(tensor<?xf32>) -> ()
 func @empty_func() -> () {
   return
 }
+
+// -----
+
+// CHECK-LABEL: func @rank_reducing
+func @rank_reducing(
+    %i: index, %j: index,
+    %arg0: tensor<8x18x32xf32>) 
+      -> tensor<?x1x6x8xf32> {
+  %c1 = arith.constant 1 : index
+  %c6 = arith.constant 6 : index
+  %c8 = arith.constant 8 : index
+  %c32 = arith.constant 32 : index
+  %c0 = arith.constant 0 : index
+  %0 = linalg.init_tensor [4, 1, 6, 8] : tensor<4x1x6x8xf32>
+  %1 = tensor.cast %0 : tensor<4x1x6x8xf32> to tensor<?x1x6x8xf32>
+  %2 = linalg.init_tensor [1, 6, 8] : tensor<1x6x8xf32>
+  %5 = scf.for %arg7 = %c0 to %c32 step %c8 iter_args(%arg8 = %1) -> (tensor<?x1x6x8xf32>) {
+    %7 = affine.apply affine_map<(d0) -> (d0 ceildiv 8)>(%arg7)
+    %8 = tensor.extract_slice %arg0[%i, %j, %arg7] [1, 6, 8] [1, 1, 1] : tensor<8x18x32xf32> to tensor<1x6x8xf32>
+    %9 = scf.for %arg9 = %c0 to %c6 step %c1 iter_args(%arg10 = %2) -> (tensor<1x6x8xf32>) {
+      %11 = tensor.extract_slice %8[0, %arg9, 0] [1, 1, 8] [1, 1, 1] : tensor<1x6x8xf32> to tensor<1x1x8xf32>
+      %12 = tensor.insert_slice %11 into %arg10[0, %arg9, 0] [1, 1, 8] [1, 1, 1] : tensor<1x1x8xf32> into tensor<1x6x8xf32>
+      scf.yield %12 : tensor<1x6x8xf32>
+    }
+    %10 = tensor.insert_slice %9 into %arg8[%7, 0, 0, 0] [1, 1, 6, 8] [1, 1, 1, 1] : tensor<1x6x8xf32> into tensor<?x1x6x8xf32>
+    scf.yield %10 : tensor<?x1x6x8xf32>
+  }
+  return %5: tensor<?x1x6x8xf32>
+}
\ No newline at end of file

diff  --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir
index 46f442b8d297c..96725d16bd16c 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir
@@ -1710,26 +1710,3 @@ func @equivalent_func_arg_2(%c0: index, %c10: index, %c1: index, %t0: tensor<?xf
   }
   return %1: tensor<?xf32>
 }
-
-// -----
-
-//===----------------------------------------------------------------------===//
-// InitTensorOp elimination would produce SSA violations for the example below.
-//===----------------------------------------------------------------------===//
-
-func @depthwise_conv_1d_nwc_wc(%arg0: index, %arg1: index, %arg2: tensor<8x18x32xf32>) 
-    -> tensor<?x1x6x8xf32> {
-  %c0 = arith.constant 0 : index
-  %c32 = arith.constant 32 : index
-  %c8 = arith.constant 8 : index
-  %0 = linalg.init_tensor [4, 1, 6, 8] : tensor<4x1x6x8xf32>
-  %1 = tensor.cast %0 : tensor<4x1x6x8xf32> to tensor<?x1x6x8xf32>
-  %2 = linalg.init_tensor [1, 6, 8] : tensor<1x6x8xf32>
-  %3 = scf.for %arg3 = %c0 to %c32 step %c8 iter_args(%arg4 = %1) -> (tensor<?x1x6x8xf32>) {
-    %4 = affine.apply affine_map<(d0) -> (d0 ceildiv 8)>(%arg3)
-    %5 = tensor.insert_slice %2 into %arg4[%4,0, 0, 0] [1, 1, 6, 8] [1, 1, 1, 1] :
-      tensor<1x6x8xf32> into tensor<?x1x6x8xf32>
-    scf.yield %5 : tensor<?x1x6x8xf32>
-  }
-  return %3 : tensor<?x1x6x8xf32>
-}
\ No newline at end of file

diff  --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
index 4501a3a075dd2..05c120bcf557d 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
@@ -1199,3 +1199,26 @@ func @op_is_reading_but_following_ops_are_not(
   // CHECK: return %[[ALLOC]]
   return %r1 : tensor<?xf32>
 }
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// InitTensorOp elimination would produce SSA violations for the example below.
+//===----------------------------------------------------------------------===//
+
+func @depthwise_conv_1d_nwc_wc(%arg0: index, %arg1: index, %arg2: tensor<8x18x32xf32>) 
+    -> tensor<?x1x6x8xf32> {
+  %c0 = arith.constant 0 : index
+  %c32 = arith.constant 32 : index
+  %c8 = arith.constant 8 : index
+  %0 = linalg.init_tensor [4, 1, 6, 8] : tensor<4x1x6x8xf32>
+  %1 = tensor.cast %0 : tensor<4x1x6x8xf32> to tensor<?x1x6x8xf32>
+  %2 = linalg.init_tensor [1, 6, 8] : tensor<1x6x8xf32>
+  %3 = scf.for %arg3 = %c0 to %c32 step %c8 iter_args(%arg4 = %1) -> (tensor<?x1x6x8xf32>) {
+    %4 = affine.apply affine_map<(d0) -> (d0 ceildiv 8)>(%arg3)
+    %5 = tensor.insert_slice %2 into %arg4[%4,0, 0, 0] [1, 1, 6, 8] [1, 1, 1, 1] :
+      tensor<1x6x8xf32> into tensor<?x1x6x8xf32>
+    scf.yield %5 : tensor<?x1x6x8xf32>
+  }
+  return %3 : tensor<?x1x6x8xf32>
+}

diff  --git a/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp b/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp
index ae5252b7c3c1a..5ae4efba9e1ac 100644
--- a/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp
@@ -96,9 +96,6 @@ struct TestComprehensiveFunctionBufferize
 void TestComprehensiveFunctionBufferize::runOnFunction() {
   auto options = std::make_unique<BufferizationOptions>();
 
-  // Enable InitTensorOp elimination.
-  options->addPostAnalysisStep<
-      linalg_ext::InsertSliceAnchoredInitTensorEliminationStep>();
   if (!allowReturnMemref)
     options->addPostAnalysisStep<scf_ext::AssertDestinationPassingStyle>();
 


        


More information about the Mlir-commits mailing list