[Mlir-commits] [mlir] d0ee094 - [mlir][Bufferize] Fix incorrect bufferization of rank-reducing tensor ops.
Nicolas Vasilache
llvmlistbot at llvm.org
Mon Jan 10 07:14:59 PST 2022
Author: Nicolas Vasilache
Date: 2022-01-10T10:14:55-05:00
New Revision: d0ee094d6acf72608e927bf2e9ba69c57da59a96
URL: https://github.com/llvm/llvm-project/commit/d0ee094d6acf72608e927bf2e9ba69c57da59a96
DIFF: https://github.com/llvm/llvm-project/commit/d0ee094d6acf72608e927bf2e9ba69c57da59a96.diff
LOG: [mlir][Bufferize] Fix incorrect bufferization of rank-reducing tensor ops.
This revision fixes SubviewOp, InsertSliceOp, ExtractSliceOp construction during bufferization
where not all offset/size/stride operands were properly specified.
A test that exhibited problematic behaviors related to incorrect memref casts is introduced.
Init tensor optimization is disabled in teh testing func bufferize pass.
Differential Revision: https://reviews.llvm.org/D116899
Added:
Modified:
mlir/include/mlir/Interfaces/ViewLikeInterface.td
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
mlir/lib/Interfaces/ViewLikeInterface.cpp
mlir/test/Dialect/Linalg/comprehensive-function-bufferize.mlir
mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir
mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.td b/mlir/include/mlir/Interfaces/ViewLikeInterface.td
index 50ebeaa44a5c3..b829760ba4591 100644
--- a/mlir/include/mlir/Interfaces/ViewLikeInterface.td
+++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.td
@@ -483,6 +483,19 @@ def OffsetSizeAndStrideOpInterface : OpInterface<"OffsetSizeAndStrideOpInterface
::mlir::OpTrait::AttrSizedOperandSegments<void>::getOperandSegmentSizeAttr()};
return names;
}
+ /// Assume target is a shaped type and offsets/sizes/strides are vectors of
+ /// the same length and lower than target's rank.
+ /// Complete missing dims `i` with offset=0, size=dim(target, i), stride=1
+ /// until all vectors have size rank. The commpletion occurs for the most
+ /// minor dimensions (i.e. fastest varying).
+ /// Take a `createDim` lambda that knows how to build the size of a
+ /// particular dimension of `target` (to avoid dialect dependencies).
+ static void expandToRank(
+ Value target,
+ SmallVector<OpFoldResult> &offsets,
+ SmallVector<OpFoldResult> &sizes,
+ SmallVector<OpFoldResult> &strides,
+ llvm::function_ref<OpFoldResult(Value, int64_t)> createDim);
}];
let verify = [{
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
index e64d5ae3dda61..2d4f4b2a6500c 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
@@ -347,6 +347,14 @@ mlir::linalg::comprehensive_bufferize::BufferizationState::BufferizationState(
});
}
+// bufferization.to_memref is not allowed to change the rank.
+static void ensureToMemrefOpIsValid(Value tensor, Type memrefType) {
+ auto rankedTensorType = tensor.getType().dyn_cast<RankedTensorType>();
+ assert((!rankedTensorType || memrefType.cast<MemRefType>().getRank() ==
+ rankedTensorType.getRank()) &&
+ "to_memref would be invalid: mismatching ranks");
+}
+
static Value lookupBuffer(RewriterBase &rewriter, Value tensor) {
assert(tensor.getType().isa<TensorType>() && "unexpected non-tensor type");
@@ -364,6 +372,7 @@ static Value lookupBuffer(RewriterBase &rewriter, Value tensor) {
memrefType = getUnrankedMemRefType(
tensor.getType().cast<TensorType>().getElementType());
}
+ ensureToMemrefOpIsValid(tensor, memrefType);
return rewriter.create<bufferization::ToMemrefOp>(tensor.getLoc(), memrefType,
tensor);
}
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
index bf08155076b22..64bc6920da07c 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
@@ -563,10 +563,26 @@ LogicalResult mlir::linalg::comprehensive_bufferize::linalg_ext::
},
/*rewriteFunc=*/
[](OpBuilder &b, Location loc, OpOperand &operand) {
- auto insertSliceOp = cast<tensor::InsertSliceOp>(operand.getOwner());
+ auto insertOp = cast<tensor::InsertSliceOp>(operand.getOwner());
+ // Expand offsets, sizes and strides to the full rank to handle the
+ // rank-reducing case.
+ SmallVector<OpFoldResult> mixedOffsets = insertOp.getMixedOffsets();
+ SmallVector<OpFoldResult> mixedSizes = insertOp.getMixedSizes();
+ SmallVector<OpFoldResult> mixedStrides = insertOp.getMixedStrides();
+ OffsetSizeAndStrideOpInterface::expandToRank(
+ insertOp.dest(), mixedOffsets, mixedSizes, mixedStrides,
+ [&](Value target, int64_t dim) -> OpFoldResult {
+ auto shapedType = target.getType().cast<ShapedType>();
+ if (shapedType.isDynamicDim(dim))
+ return b.create<tensor::DimOp>(loc, target, dim).result();
+ return b.getIndexAttr(shapedType.getDimSize(dim));
+ });
+ auto t = tensor::ExtractSliceOp::inferRankReducedResultType(
+ insertOp.getSourceType().getRank(),
+ insertOp.dest().getType().cast<RankedTensorType>(), mixedOffsets,
+ mixedSizes, mixedStrides);
auto extractOp = b.create<tensor::ExtractSliceOp>(
- loc, insertSliceOp.dest(), insertSliceOp.getMixedOffsets(),
- insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides());
+ loc, t, insertOp.dest(), mixedOffsets, mixedSizes, mixedStrides);
return extractOp.result();
},
newOps);
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
index 01308088bab8e..2eaed56669d7e 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
@@ -19,6 +19,14 @@ namespace linalg {
namespace comprehensive_bufferize {
namespace scf_ext {
+// bufferization.to_memref is not allowed to change the rank.
+static void ensureToMemrefOpIsValid(Value tensor, Type memrefType) {
+ auto rankedTensorType = tensor.getType().dyn_cast<RankedTensorType>();
+ assert((!rankedTensorType || (memrefType.cast<MemRefType>().getRank() ==
+ rankedTensorType.getRank())) &&
+ "to_memref would be invalid: mismatching ranks");
+}
+
/// Bufferization of scf.execute_region. Can be analyzed, but bufferization not
/// fully implemented at the moment.
struct ExecuteRegionOpInterface
@@ -159,6 +167,8 @@ struct IfOpInterface
SmallVector<Value> thenYieldValues;
for (OpOperand &operand : thenYieldOp->getOpOperands()) {
if (operand.get().getType().isa<TensorType>()) {
+ ensureToMemrefOpIsValid(operand.get(),
+ newTypes[operand.getOperandNumber()]);
Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>(
operand.get().getLoc(), newTypes[operand.getOperandNumber()],
operand.get());
@@ -172,6 +182,8 @@ struct IfOpInterface
SmallVector<Value> elseYieldValues;
for (OpOperand &operand : elseYieldOp->getOpOperands()) {
if (operand.get().getType().isa<TensorType>()) {
+ ensureToMemrefOpIsValid(operand.get(),
+ newTypes[operand.getOperandNumber()]);
Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>(
operand.get().getLoc(), newTypes[operand.getOperandNumber()],
operand.get());
@@ -317,6 +329,7 @@ struct ForOpInterface
rewriter.setInsertionPoint(yieldOp);
SmallVector<Value> yieldValues =
convert(yieldOp.getResults(), [&](Value val, int64_t index) {
+ ensureToMemrefOpIsValid(val, initArgs[index].getType());
return rewriter.create<bufferization::ToMemrefOp>(
val.getLoc(), initArgs[index].getType(), val);
});
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
index f0f20b433937e..620328799712a 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
@@ -68,7 +68,7 @@ struct CastOpInterface
// Compute the new memref type.
Type resultMemRefType;
- if (auto rankedTensorType = resultTensorType.isa<RankedTensorType>()) {
+ if (resultTensorType.isa<RankedTensorType>()) {
resultMemRefType =
getContiguousMemRefType(resultTensorType, layout, memorySpace);
} else {
@@ -165,16 +165,27 @@ struct ExtractSliceOpInterface
alloc = *allocOrFailure;
}
+ // Expand offsets, sizes and strides to the full rank to handle the
+ // rank-reducing case.
+ SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets();
+ SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
+ SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
+ OffsetSizeAndStrideOpInterface::expandToRank(
+ srcMemref, mixedOffsets, mixedSizes, mixedStrides,
+ [&](Value target, int64_t dim) -> OpFoldResult {
+ auto shapedType = target.getType().cast<ShapedType>();
+ if (shapedType.isDynamicDim(dim))
+ return rewriter.create<memref::DimOp>(loc, target, dim).result();
+ return rewriter.getIndexAttr(shapedType.getDimSize(dim));
+ });
// Bufferize to subview.
- auto subviewMemRefType =
- memref::SubViewOp::inferRankReducedResultType(
- dstTensorType.getRank(), srcMemrefType,
- extractSliceOp.getMixedOffsets(), extractSliceOp.getMixedSizes(),
- extractSliceOp.getMixedStrides())
- .cast<MemRefType>();
+ auto subviewMemRefType = memref::SubViewOp::inferRankReducedResultType(
+ dstTensorType.getRank(), srcMemrefType,
+ mixedOffsets, mixedSizes, mixedStrides)
+ .cast<MemRefType>();
Value subView = rewriter.create<memref::SubViewOp>(
- loc, subviewMemRefType, srcMemref, extractSliceOp.getMixedOffsets(),
- extractSliceOp.getMixedSizes(), extractSliceOp.getMixedStrides());
+ loc, subviewMemRefType, srcMemref, mixedOffsets, mixedSizes,
+ mixedStrides);
// If not inplaceable, copy.
if (!inplace) {
@@ -422,17 +433,29 @@ struct InsertSliceOpInterface
if (failed(dstMemref))
return failure();
+ // Expand offsets, sizes and strides to the full rank to handle the
+ // rank-reducing case.
+ SmallVector<OpFoldResult> mixedOffsets = insertSliceOp.getMixedOffsets();
+ SmallVector<OpFoldResult> mixedSizes = insertSliceOp.getMixedSizes();
+ SmallVector<OpFoldResult> mixedStrides = insertSliceOp.getMixedStrides();
+ OffsetSizeAndStrideOpInterface::expandToRank(
+ *dstMemref, mixedOffsets, mixedSizes, mixedStrides,
+ [&](Value target, int64_t dim) -> OpFoldResult {
+ auto shapedType = target.getType().cast<ShapedType>();
+ if (shapedType.isDynamicDim(dim))
+ return rewriter.create<memref::DimOp>(loc, target, dim).result();
+ return rewriter.getIndexAttr(shapedType.getDimSize(dim));
+ });
// Take a subview of the dst.
auto dstMemrefType = dstMemref->getType().cast<MemRefType>();
auto subviewMemRefType =
memref::SubViewOp::inferRankReducedResultType(
insertSliceOp.getSourceType().getRank(), dstMemrefType,
- insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
- insertSliceOp.getMixedStrides())
+ mixedOffsets, mixedSizes, mixedStrides)
.cast<MemRefType>();
Value subView = rewriter.create<memref::SubViewOp>(
- loc, subviewMemRefType, *dstMemref, insertSliceOp.getMixedOffsets(),
- insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides());
+ loc, subviewMemRefType, *dstMemref, mixedOffsets, mixedSizes,
+ mixedStrides);
// Copy tensor. If this tensor.insert_slice has a matching
// tensor.extract_slice, the copy operation will eventually fold away.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
index 6bc3ece8693e3..a368f4e1653c7 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
@@ -96,6 +96,7 @@ void LinalgComprehensiveModuleBufferize::runOnOperation() {
options->addPostAnalysisStep<
linalg_ext::InsertSliceAnchoredInitTensorEliminationStep>();
}
+
if (!allowReturnMemref)
options->addPostAnalysisStep<scf_ext::AssertDestinationPassingStyle>();
diff --git a/mlir/lib/Interfaces/ViewLikeInterface.cpp b/mlir/lib/Interfaces/ViewLikeInterface.cpp
index 6394895370e25..cccc8339a6446 100644
--- a/mlir/lib/Interfaces/ViewLikeInterface.cpp
+++ b/mlir/lib/Interfaces/ViewLikeInterface.cpp
@@ -176,3 +176,22 @@ bool mlir::detail::sameOffsetsSizesAndStrides(
return false;
return true;
}
+
+void OffsetSizeAndStrideOpInterface::expandToRank(
+ Value target, SmallVector<OpFoldResult> &offsets,
+ SmallVector<OpFoldResult> &sizes, SmallVector<OpFoldResult> &strides,
+ llvm::function_ref<OpFoldResult(Value, int64_t)> createOrFoldDim) {
+ auto shapedType = target.getType().cast<ShapedType>();
+ unsigned rank = shapedType.getRank();
+ assert(offsets.size() == sizes.size() && "mismatched lengths");
+ assert(offsets.size() == strides.size() && "mismatched lengths");
+ assert(offsets.size() <= rank && "rank overflow");
+ MLIRContext *ctx = target.getContext();
+ Attribute zero = IntegerAttr::get(IndexType::get(ctx), APInt(64, 0));
+ Attribute one = IntegerAttr::get(IndexType::get(ctx), APInt(64, 1));
+ for (unsigned i = offsets.size(); i < rank; ++i) {
+ offsets.push_back(zero);
+ sizes.push_back(createOrFoldDim(target, i));
+ strides.push_back(one);
+ }
+}
diff --git a/mlir/test/Dialect/Linalg/comprehensive-function-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-function-bufferize.mlir
index 971d6c6e88a2f..fd32430f4e332 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-function-bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-function-bufferize.mlir
@@ -67,3 +67,32 @@ func private @private_func(tensor<?xf32>) -> ()
func @empty_func() -> () {
return
}
+
+// -----
+
+// CHECK-LABEL: func @rank_reducing
+func @rank_reducing(
+ %i: index, %j: index,
+ %arg0: tensor<8x18x32xf32>)
+ -> tensor<?x1x6x8xf32> {
+ %c1 = arith.constant 1 : index
+ %c6 = arith.constant 6 : index
+ %c8 = arith.constant 8 : index
+ %c32 = arith.constant 32 : index
+ %c0 = arith.constant 0 : index
+ %0 = linalg.init_tensor [4, 1, 6, 8] : tensor<4x1x6x8xf32>
+ %1 = tensor.cast %0 : tensor<4x1x6x8xf32> to tensor<?x1x6x8xf32>
+ %2 = linalg.init_tensor [1, 6, 8] : tensor<1x6x8xf32>
+ %5 = scf.for %arg7 = %c0 to %c32 step %c8 iter_args(%arg8 = %1) -> (tensor<?x1x6x8xf32>) {
+ %7 = affine.apply affine_map<(d0) -> (d0 ceildiv 8)>(%arg7)
+ %8 = tensor.extract_slice %arg0[%i, %j, %arg7] [1, 6, 8] [1, 1, 1] : tensor<8x18x32xf32> to tensor<1x6x8xf32>
+ %9 = scf.for %arg9 = %c0 to %c6 step %c1 iter_args(%arg10 = %2) -> (tensor<1x6x8xf32>) {
+ %11 = tensor.extract_slice %8[0, %arg9, 0] [1, 1, 8] [1, 1, 1] : tensor<1x6x8xf32> to tensor<1x1x8xf32>
+ %12 = tensor.insert_slice %11 into %arg10[0, %arg9, 0] [1, 1, 8] [1, 1, 1] : tensor<1x1x8xf32> into tensor<1x6x8xf32>
+ scf.yield %12 : tensor<1x6x8xf32>
+ }
+ %10 = tensor.insert_slice %9 into %arg8[%7, 0, 0, 0] [1, 1, 6, 8] [1, 1, 1, 1] : tensor<1x6x8xf32> into tensor<?x1x6x8xf32>
+ scf.yield %10 : tensor<?x1x6x8xf32>
+ }
+ return %5: tensor<?x1x6x8xf32>
+}
\ No newline at end of file
diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir
index 46f442b8d297c..96725d16bd16c 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir
@@ -1710,26 +1710,3 @@ func @equivalent_func_arg_2(%c0: index, %c10: index, %c1: index, %t0: tensor<?xf
}
return %1: tensor<?xf32>
}
-
-// -----
-
-//===----------------------------------------------------------------------===//
-// InitTensorOp elimination would produce SSA violations for the example below.
-//===----------------------------------------------------------------------===//
-
-func @depthwise_conv_1d_nwc_wc(%arg0: index, %arg1: index, %arg2: tensor<8x18x32xf32>)
- -> tensor<?x1x6x8xf32> {
- %c0 = arith.constant 0 : index
- %c32 = arith.constant 32 : index
- %c8 = arith.constant 8 : index
- %0 = linalg.init_tensor [4, 1, 6, 8] : tensor<4x1x6x8xf32>
- %1 = tensor.cast %0 : tensor<4x1x6x8xf32> to tensor<?x1x6x8xf32>
- %2 = linalg.init_tensor [1, 6, 8] : tensor<1x6x8xf32>
- %3 = scf.for %arg3 = %c0 to %c32 step %c8 iter_args(%arg4 = %1) -> (tensor<?x1x6x8xf32>) {
- %4 = affine.apply affine_map<(d0) -> (d0 ceildiv 8)>(%arg3)
- %5 = tensor.insert_slice %2 into %arg4[%4,0, 0, 0] [1, 1, 6, 8] [1, 1, 1, 1] :
- tensor<1x6x8xf32> into tensor<?x1x6x8xf32>
- scf.yield %5 : tensor<?x1x6x8xf32>
- }
- return %3 : tensor<?x1x6x8xf32>
-}
\ No newline at end of file
diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
index 4501a3a075dd2..05c120bcf557d 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
@@ -1199,3 +1199,26 @@ func @op_is_reading_but_following_ops_are_not(
// CHECK: return %[[ALLOC]]
return %r1 : tensor<?xf32>
}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// InitTensorOp elimination would produce SSA violations for the example below.
+//===----------------------------------------------------------------------===//
+
+func @depthwise_conv_1d_nwc_wc(%arg0: index, %arg1: index, %arg2: tensor<8x18x32xf32>)
+ -> tensor<?x1x6x8xf32> {
+ %c0 = arith.constant 0 : index
+ %c32 = arith.constant 32 : index
+ %c8 = arith.constant 8 : index
+ %0 = linalg.init_tensor [4, 1, 6, 8] : tensor<4x1x6x8xf32>
+ %1 = tensor.cast %0 : tensor<4x1x6x8xf32> to tensor<?x1x6x8xf32>
+ %2 = linalg.init_tensor [1, 6, 8] : tensor<1x6x8xf32>
+ %3 = scf.for %arg3 = %c0 to %c32 step %c8 iter_args(%arg4 = %1) -> (tensor<?x1x6x8xf32>) {
+ %4 = affine.apply affine_map<(d0) -> (d0 ceildiv 8)>(%arg3)
+ %5 = tensor.insert_slice %2 into %arg4[%4,0, 0, 0] [1, 1, 6, 8] [1, 1, 1, 1] :
+ tensor<1x6x8xf32> into tensor<?x1x6x8xf32>
+ scf.yield %5 : tensor<?x1x6x8xf32>
+ }
+ return %3 : tensor<?x1x6x8xf32>
+}
diff --git a/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp b/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp
index ae5252b7c3c1a..5ae4efba9e1ac 100644
--- a/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp
@@ -96,9 +96,6 @@ struct TestComprehensiveFunctionBufferize
void TestComprehensiveFunctionBufferize::runOnFunction() {
auto options = std::make_unique<BufferizationOptions>();
- // Enable InitTensorOp elimination.
- options->addPostAnalysisStep<
- linalg_ext::InsertSliceAnchoredInitTensorEliminationStep>();
if (!allowReturnMemref)
options->addPostAnalysisStep<scf_ext::AssertDestinationPassingStyle>();
More information about the Mlir-commits
mailing list