[Mlir-commits] [mlir] 97a2bd8 - Revert "[mlir][loops] Reland Refactor LoopFuseSiblingOp and support parallel fusion #94391 (#97607)"
Alexander Belyaev
llvmlistbot at llvm.org
Thu Jul 4 00:26:45 PDT 2024
Author: Alexander Belyaev
Date: 2024-07-04T09:24:23+02:00
New Revision: 97a2bd8415dc6792b99ec0f091ad7570673c3f37
URL: https://github.com/llvm/llvm-project/commit/97a2bd8415dc6792b99ec0f091ad7570673c3f37
DIFF: https://github.com/llvm/llvm-project/commit/97a2bd8415dc6792b99ec0f091ad7570673c3f37.diff
LOG: Revert "[mlir][loops] Reland Refactor LoopFuseSiblingOp and support parallel fusion #94391 (#97607)"
This reverts commit edbc0e30a9e587cee1189be023b9385adc2f239a.
Reason for rollback. ASAN complains about this PR:
==4320==ERROR: AddressSanitizer: heap-use-after-free on address 0x502000006cd8 at pc 0x55e2978d63cf bp 0x7ffe6431c2b0 sp 0x7ffe6431c2a8
READ of size 8 at 0x502000006cd8 thread T0
#0 0x55e2978d63ce in map<llvm::MutableArrayRef<mlir::BlockArgument> &, llvm::MutableArrayRef<mlir::BlockArgument>, nullptr> mlir/include/mlir/IR/IRMapping.h:40:11
#1 0x55e2978d63ce in mlir::createFused(mlir::LoopLikeOpInterface, mlir::LoopLikeOpInterface, mlir::RewriterBase&, std::__u::function<llvm::SmallVector<mlir::Value, 6u> (mlir::OpBuilder&, mlir::Location, llvm::ArrayRef<mlir::BlockArgument>)>, llvm::function_ref<void (mlir::RewriterBase&, mlir::LoopLikeOpInterface, mlir::LoopLikeOpInterface&, mlir::IRMapping)>) mlir/lib/Interfaces/LoopLikeInterface.cpp:156:11
#2 0x55e2952a614b in mlir::fuseIndependentSiblingForLoops(mlir::scf::ForOp, mlir::scf::ForOp, mlir::RewriterBase&) mlir/lib/Dialect/SCF/Utils/Utils.cpp:1398:43
#3 0x55e291480c6f in mlir::transform::LoopFuseSiblingOp::apply(mlir::transform::TransformRewriter&, mlir::transform::TransformResults&, mlir::transform::TransformState&) mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp:482:17
#4 0x55e29149ed5e in mlir::transform::detail::TransformOpInterfaceInterfaceTraits::Model<mlir::transform::LoopFuseSiblingOp>::apply(mlir::transform::detail::TransformOpInterfaceInterfaceTraits::Concept const*, mlir::Operation*, mlir::transform::TransformRewriter&, mlir::transform::TransformResults&, mlir::transform::TransformState&) blaze-out/k8-opt-asan/bin/mlir/include/mlir/Dialect/Transform/Interfaces/TransformInterfaces.h.inc:477:56
#5 0x55e297494a60 in apply blaze-out/k8-opt-asan/bin/mlir/include/mlir/Dialect/Transform/Interfaces/TransformInterfaces.cpp.inc:61:14
#6 0x55e297494a60 in mlir::transform::TransformState::applyTransform(mlir::transform::TransformOpInterface) mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp:953:48
#7 0x55e294646a8d in applySequenceBlock(mlir::Block&, mlir::transform::FailurePropagationMode, mlir::transform::TransformState&, mlir::transform::TransformResults&) mlir/lib/Dialect/Transform/IR/TransformOps.cpp:1788:15
#8 0x55e29464f927 in mlir::transform::NamedSequenceOp::apply(mlir::transform::TransformRewriter&, mlir::transform::TransformResults&, mlir::transform::TransformState&) mlir/lib/Dialect/Transform/IR/TransformOps.cpp:2155:10
#9 0x55e2945d28ee in mlir::transform::detail::TransformOpInterfaceInterfaceTraits::Model<mlir::transform::NamedSequenceOp>::apply(mlir::transform::detail::TransformOpInterfaceInterfaceTraits::Concept const*, mlir::Operation*, mlir::transform::TransformRewriter&, mlir::transform::TransformResults&, mlir::transform::TransformState&) blaze-out/k8-opt-asan/bin/mlir/include/mlir/Dialect/Transform/Interfaces/TransformInterfaces.h.inc:477:56
#10 0x55e297494a60 in apply blaze-out/k8-opt-asan/bin/mlir/include/mlir/Dialect/Transform/Interfaces/TransformInterfaces.cpp.inc:61:14
#11 0x55e297494a60 in mlir::transform::TransformState::applyTransform(mlir::transform::TransformOpInterface) mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp:953:48
#12 0x55e2974a5fe2 in mlir::transform::applyTransforms(mlir::Operation*, mlir::transform::TransformOpInterface, mlir::RaggedArray<llvm::PointerUnion<mlir::Operation*, mlir::Attribute, mlir::Value>> const&, mlir::transform::TransformOptions const&, bool) mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp:2016:16
#13 0x55e2945888d7 in mlir::transform::applyTransformNamedSequence(mlir::RaggedArray<llvm::PointerUnion<mlir::Operation*, mlir::Attribute, mlir::Value>>, mlir::transform::TransformOpInterface, mlir::ModuleOp, mlir::transform::TransformOptions const&) mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp:234:10
#14 0x55e294582446 in (anonymous namespace)::InterpreterPass::runOnOperation() mlir/lib/Dialect/Transform/Transforms/InterpreterPass.cpp:147:16
#15 0x55e2978e93c6 in operator() mlir/lib/Pass/Pass.cpp:527:17
#16 0x55e2978e93c6 in void llvm::function_ref<void ()>::callback_fn<mlir::detail::OpToOpPassAdaptor::run(mlir::Pass*, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int)::$_1>(long) llvm/include/llvm/ADT/STLFunctionalExtras.h:45:12
#17 0x55e2978e207a in operator() llvm/include/llvm/ADT/STLFunctionalExtras.h:68:12
#18 0x55e2978e207a in executeAction<mlir::PassExecutionAction, mlir::Pass &> mlir/include/mlir/IR/MLIRContext.h:275:7
#19 0x55e2978e207a in mlir::detail::OpToOpPassAdaptor::run(mlir::Pass*, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int) mlir/lib/Pass/Pass.cpp:521:21
#20 0x55e2978e5fbf in runPipeline mlir/lib/Pass/Pass.cpp:593:16
#21 0x55e2978e5fbf in mlir::PassManager::runPasses(mlir::Operation*, mlir::AnalysisManager) mlir/lib/Pass/Pass.cpp:904:10
#22 0x55e2978e5b65 in mlir::PassManager::run(mlir::Operation*) mlir/lib/Pass/Pass.cpp:884:60
#23 0x55e291ebb460 in performActions(llvm::raw_ostream&, std::__u::shared_ptr<llvm::SourceMgr> const&, mlir::MLIRContext*, mlir::MlirOptMainConfig const&) mlir/lib/Tools/mlir-opt/MlirOptMain.cpp:408:17
#24 0x55e291ebabd9 in processBuffer mlir/lib/Tools/mlir-opt/MlirOptMain.cpp:481:9
#25 0x55e291ebabd9 in operator() mlir/lib/Tools/mlir-opt/MlirOptMain.cpp:548:12
#26 0x55e291ebabd9 in llvm::LogicalResult llvm::function_ref<llvm::LogicalResult (std::__u::unique_ptr<llvm::MemoryBuffer, std::__u::default_delete<llvm::MemoryBuffer>>, llvm::raw_ostream&)>::callback_fn<mlir::MlirOptMain(llvm::raw_ostream&, std::__u::unique_ptr<llvm::MemoryBuffer, std::__u::default_delete<llvm::MemoryBuffer>>, mlir::DialectRegistry&, mlir::MlirOptMainConfig const&)::$_0>(long, std::__u::unique_ptr<llvm::MemoryBuffer, std::__u::default_delete<llvm::MemoryBuffer>>, llvm::raw_ostream&) llvm/include/llvm/ADT/STLFunctionalExtras.h:45:12
#27 0x55e297b1cffe in operator() llvm/include/llvm/ADT/STLFunctionalExtras.h:68:12
#28 0x55e297b1cffe in mlir::splitAndProcessBuffer(std::__u::unique_ptr<llvm::MemoryBuffer, std::__u::default_delete<llvm::MemoryBuffer>>, llvm::function_ref<llvm::LogicalResult (std::__u::unique_ptr<llvm::MemoryBuffer, std::__u::default_delete<llvm::MemoryBuffer>>, llvm::raw_ostream&)>, llvm::raw_ostream&, llvm::StringRef, llvm::StringRef)::$_0::operator()(llvm::StringRef) const mlir/lib/Support/ToolUtilities.cpp:86:16
#29 0x55e297b1c9c5 in interleave<const llvm::StringRef *, (lambda at mlir/lib/Support/ToolUtilities.cpp:79:23), (lambda at llvm/include/llvm/ADT/STLExtras.h:2147:49), void> llvm/include/llvm/ADT/STLExtras.h:2125:3
#30 0x55e297b1c9c5 in interleave<llvm::SmallVector<llvm::StringRef, 8U>, (lambda at mlir/lib/Support/ToolUtilities.cpp:79:23), llvm::raw_ostream, llvm::StringRef> llvm/include/llvm/ADT/STLExtras.h:2147:3
#31 0x55e297b1c9c5 in mlir::splitAndProcessBuffer(std::__u::unique_ptr<llvm::MemoryBuffer, std::__u::default_delete<llvm::MemoryBuffer>>, llvm::function_ref<llvm::LogicalResult (std::__u::unique_ptr<llvm::MemoryBuffer, std::__u::default_delete<llvm::MemoryBuffer>>, llvm::raw_ostream&)>, llvm::raw_ostream&, llvm::StringRef, llvm::StringRef) mlir/lib/Support/ToolUtilities.cpp:89:3
#32 0x55e291eb0cf0 in mlir::MlirOptMain(llvm::raw_ostream&, std::__u::unique_ptr<llvm::MemoryBuffer, std::__u::default_delete<llvm::MemoryBuffer>>, mlir::DialectRegistry&, mlir::MlirOptMainConfig const&) mlir/lib/Tools/mlir-opt/MlirOptMain.cpp:551:10
#33 0x55e291eb115c in mlir::MlirOptMain(int, char**, llvm::StringRef, llvm::StringRef, mlir::DialectRegistry&) mlir/lib/Tools/mlir-opt/MlirOptMain.cpp:589:14
#34 0x55e291eb15f8 in mlir::MlirOptMain(int, char**, llvm::StringRef, mlir::DialectRegistry&) mlir/lib/Tools/mlir-opt/MlirOptMain.cpp:605:10
#35 0x55e29130d1be in main mlir/tools/mlir-opt/mlir-opt.cpp:311:33
#36 0x7fbcf3fff3d3 in __libc_start_main (/usr/grte/v5/lib64/libc.so.6+0x613d3) (BuildId: 9a996398ce14a94560b0c642eb4f6e94)
#37 0x55e2912365a9 in _start /usr/grte/v5/debug-src/src/csu/../sysdeps/x86_64/start.S:120
0x502000006cd8 is located 8 bytes inside of 16-byte region [0x502000006cd0,0x502000006ce0)
freed by thread T0 here:
#0 0x55e29130b7e2 in operator delete(void*, unsigned long) compiler-rt/lib/asan/asan_new_delete.cpp:155:3
#1 0x55e2979eb657 in __libcpp_operator_delete<void *, unsigned long>
#2 0x55e2979eb657 in __do_deallocate_handle_size<>
#3 0x55e2979eb657 in __libcpp_deallocate
#4 0x55e2979eb657 in deallocate
#5 0x55e2979eb657 in deallocate
#6 0x55e2979eb657 in operator()
#7 0x55e2979eb657 in ~vector
#8 0x55e2979eb657 in mlir::Block::~Block() mlir/lib/IR/Block.cpp:24:1
#9 0x55e2979ebc17 in deleteNode llvm/include/llvm/ADT/ilist.h:42:39
#10 0x55e2979ebc17 in erase llvm/include/llvm/ADT/ilist.h:205:5
#11 0x55e2979ebc17 in erase llvm/include/llvm/ADT/ilist.h:209:39
#12 0x55e2979ebc17 in mlir::Block::erase() mlir/lib/IR/Block.cpp:67:28
#13 0x55e297aef978 in mlir::RewriterBase::eraseBlock(mlir::Block*) mlir/lib/IR/PatternMatch.cpp:245:10
#14 0x55e297af0563 in mlir::RewriterBase::inlineBlockBefore(mlir::Block*, mlir::Block*, llvm::ilist_iterator<llvm::ilist_detail::node_options<mlir::Operation, false, false, void, false, void>, false, false>, mlir::ValueRange) mlir/lib/IR/PatternMatch.cpp:331:3
#15 0x55e297af06d8 in mlir::RewriterBase::mergeBlocks(mlir::Block*, mlir::Block*, mlir::ValueRange) mlir/lib/IR/PatternMatch.cpp:341:3
#16 0x55e297036608 in mlir::scf::ForOp::replaceWithAdditionalYields(mlir::RewriterBase&, mlir::ValueRange, bool, std::__u::function<llvm::SmallVector<mlir::Value, 6u> (mlir::OpBuilder&, mlir::Location, llvm::ArrayRef<mlir::BlockArgument>)> const&) mlir/lib/Dialect/SCF/IR/SCF.cpp:575:12
#17 0x55e2970673ca in mlir::detail::LoopLikeOpInterfaceInterfaceTraits::Model<mlir::scf::ForOp>::replaceWithAdditionalYields(mlir::detail::LoopLikeOpInterfaceInterfaceTraits::Concept const*, mlir::Operation*, mlir::RewriterBase&, mlir::ValueRange, bool, std::__u::function<llvm::SmallVector<mlir::Value, 6u> (mlir::OpBuilder&, mlir::Location, llvm::ArrayRef<mlir::BlockArgument>)> const&) blaze-out/k8-opt-asan/bin/mlir/include/mlir/Interfaces/LoopLikeInterface.h.inc:658:56
#18 0x55e2978d5feb in replaceWithAdditionalYields blaze-out/k8-opt-asan/bin/mlir/include/mlir/Interfaces/LoopLikeInterface.cpp.inc:105:14
#19 0x55e2978d5feb in mlir::createFused(mlir::LoopLikeOpInterface, mlir::LoopLikeOpInterface, mlir::RewriterBase&, std::__u::function<llvm::SmallVector<mlir::Value, 6u> (mlir::OpBuilder&, mlir::Location, llvm::ArrayRef<mlir::BlockArgument>)>, llvm::function_ref<void (mlir::RewriterBase&, mlir::LoopLikeOpInterface, mlir::LoopLikeOpInterface&, mlir::IRMapping)>) mlir/lib/Interfaces/LoopLikeInterface.cpp:135:14
#20 0x55e2952a614b in mlir::fuseIndependentSiblingForLoops(mlir::scf::ForOp, mlir::scf::ForOp, mlir::RewriterBase&) mlir/lib/Dialect/SCF/Utils/Utils.cpp:1398:43
#21 0x55e291480c6f in mlir::transform::LoopFuseSiblingOp::apply(mlir::transform::TransformRewriter&, mlir::transform::TransformResults&, mlir::transform::TransformState&) mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp:482:17
#22 0x55e29149ed5e in mlir::transform::detail::TransformOpInterfaceInterfaceTraits::Model<mlir::transform::LoopFuseSiblingOp>::apply(mlir::transform::detail::TransformOpInterfaceInterfaceTraits::Concept const*, mlir::Operation*, mlir::transform::TransformRewriter&, mlir::transform::TransformResults&, mlir::transform::TransformState&) blaze-out/k8-opt-asan/bin/mlir/include/mlir/Dialect/Transform/Interfaces/TransformInterfaces.h.inc:477:56
#23 0x55e297494a60 in apply blaze-out/k8-opt-asan/bin/mlir/include/mlir/Dialect/Transform/Interfaces/TransformInterfaces.cpp.inc:61:14
#24 0x55e297494a60 in mlir::transform::TransformState::applyTransform(mlir::transform::TransformOpInterface) mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp:953:48
#25 0x55e294646a8d in applySequenceBlock(mlir::Block&, mlir::transform::FailurePropagationMode, mlir::transform::TransformState&, mlir::transform::TransformResults&) mlir/lib/Dialect/Transform/IR/TransformOps.cpp:1788:15
#26 0x55e29464f927 in mlir::transform::NamedSequenceOp::apply(mlir::transform::TransformRewriter&, mlir::transform::TransformResults&, mlir::transform::TransformState&) mlir/lib/Dialect/Transform/IR/TransformOps.cpp:2155:10
#27 0x55e2945d28ee in mlir::transform::detail::TransformOpInterfaceInterfaceTraits::Model<mlir::transform::NamedSequenceOp>::apply(mlir::transform::detail::TransformOpInterfaceInterfaceTraits::Concept const*, mlir::Operation*, mlir::transform::TransformRewriter&, mlir::transform::TransformResults&, mlir::transform::TransformState&) blaze-out/k8-opt-asan/bin/mlir/include/mlir/Dialect/Transform/Interfaces/TransformInterfaces.h.inc:477:56
#28 0x55e297494a60 in apply blaze-out/k8-opt-asan/bin/mlir/include/mlir/Dialect/Transform/Interfaces/TransformInterfaces.cpp.inc:61:14
#29 0x55e297494a60 in mlir::transform::TransformState::applyTransform(mlir::transform::TransformOpInterface) mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp:953:48
#30 0x55e2974a5fe2 in mlir::transform::applyTransforms(mlir::Operation*, mlir::transform::TransformOpInterface, mlir::RaggedArray<llvm::PointerUnion<mlir::Operation*, mlir::Attribute, mlir::Value>> const&, mlir::transform::TransformOptions const&, bool) mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp:2016:16
#31 0x55e2945888d7 in mlir::transform::applyTransformNamedSequence(mlir::RaggedArray<llvm::PointerUnion<mlir::Operation*, mlir::Attribute, mlir::Value>>, mlir::transform::TransformOpInterface, mlir::ModuleOp, mlir::transform::TransformOptions const&) mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp:234:10
#32 0x55e294582446 in (anonymous namespace)::InterpreterPass::runOnOperation() mlir/lib/Dialect/Transform/Transforms/InterpreterPass.cpp:147:16
#33 0x55e2978e93c6 in operator() mlir/lib/Pass/Pass.cpp:527:17
#34 0x55e2978e93c6 in void llvm::function_ref<void ()>::callback_fn<mlir::detail::OpToOpPassAdaptor::run(mlir::Pass*, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int)::$_1>(long) llvm/include/llvm/ADT/STLFunctionalExtras.h:45:12
#35 0x55e2978e207a in operator() llvm/include/llvm/ADT/STLFunctionalExtras.h:68:12
#36 0x55e2978e207a in executeAction<mlir::PassExecutionAction, mlir::Pass &> mlir/include/mlir/IR/MLIRContext.h:275:7
#37 0x55e2978e207a in mlir::detail::OpToOpPassAdaptor::run(mlir::Pass*, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int) mlir/lib/Pass/Pass.cpp:521:21
#38 0x55e2978e5fbf in runPipeline mlir/lib/Pass/Pass.cpp:593:16
#39 0x55e2978e5fbf in mlir::PassManager::runPasses(mlir::Operation*, mlir::AnalysisManager) mlir/lib/Pass/Pass.cpp:904:10
#40 0x55e2978e5b65 in mlir::PassManager::run(mlir::Operation*) mlir/lib/Pass/Pass.cpp:884:60
#41 0x55e291ebb460 in performActions(llvm::raw_ostream&, std::__u::shared_ptr<llvm::SourceMgr> const&, mlir::MLIRContext*, mlir::MlirOptMainConfig const&) mlir/lib/Tools/mlir-opt/MlirOptMain.cpp:408:17
#42 0x55e291ebabd9 in processBuffer mlir/lib/Tools/mlir-opt/MlirOptMain.cpp:481:9
#43 0x55e291ebabd9 in operator() mlir/lib/Tools/mlir-opt/MlirOptMain.cpp:548:12
#44 0x55e291ebabd9 in llvm::LogicalResult llvm::function_ref<llvm::LogicalResult (std::__u::unique_ptr<llvm::MemoryBuffer, std::__u::default_delete<llvm::MemoryBuffer>>, llvm::raw_ostream&)>::callback_fn<mlir::MlirOptMain(llvm::raw_ostream&, std::__u::unique_ptr<llvm::MemoryBuffer, std::__u::default_delete<llvm::MemoryBuffer>>, mlir::DialectRegistry&, mlir::MlirOptMainConfig const&)::$_0>(long, std::__u::unique_ptr<llvm::MemoryBuffer, std::__u::default_delete<llvm::MemoryBuffer>>, llvm::raw_ostream&) llvm/include/llvm/ADT/STLFunctionalExtras.h:45:12
#45 0x55e297b1cffe in operator() llvm/include/llvm/ADT/STLFunctionalExtras.h:68:12
#46 0x55e297b1cffe in mlir::splitAndProcessBuffer(std::__u::unique_ptr<llvm::MemoryBuffer, std::__u::default_delete<llvm::MemoryBuffer>>, llvm::function_ref<llvm::LogicalResult (std::__u::unique_ptr<llvm::MemoryBuffer, std::__u::default_delete<llvm::MemoryBuffer>>, llvm::raw_ostream&)>, llvm::raw_ostream&, llvm::StringRef, llvm::StringRef)::$_0::operator()(llvm::StringRef) const mlir/lib/Support/ToolUtilities.cpp:86:16
#47 0x55e297b1c9c5 in interleave<const llvm::StringRef *, (lambda at mlir/lib/Support/ToolUtilities.cpp:79:23), (lambda at llvm/include/llvm/ADT/STLExtras.h:2147:49), void> llvm/include/llvm/ADT/STLExtras.h:2125:3
#48 0x55e297b1c9c5 in interleave<llvm::SmallVector<llvm::StringRef, 8U>, (lambda at mlir/lib/Support/ToolUtilities.cpp:79:23), llvm::raw_ostream, llvm::StringRef> llvm/include/llvm/ADT/STLExtras.h:2147:3
#49 0x55e297b1c9c5 in mlir::splitAndProcessBuffer(std::__u::unique_ptr<llvm::MemoryBuffer, std::__u::default_delete<llvm::MemoryBuffer>>, llvm::function_ref<llvm::LogicalResult (std::__u::unique_ptr<llvm::MemoryBuffer, std::__u::default_delete<llvm::MemoryBuffer>>, llvm::raw_ostream&)>, llvm::raw_ostream&, llvm::StringRef, llvm::StringRef) mlir/lib/Support/ToolUtilities.cpp:89:3
#50 0x55e291eb0cf0 in mlir::MlirOptMain(llvm::raw_ostream&, std::__u::unique_ptr<llvm::MemoryBuffer, std::__u::default_delete<llvm::MemoryBuffer>>, mlir::DialectRegistry&, mlir::MlirOptMainConfig const&) mlir/lib/Tools/mlir-opt/MlirOptMain.cpp:551:10
#51 0x55e291eb115c in mlir::MlirOptMain(int, char**, llvm::StringRef, llvm::StringRef, mlir::DialectRegistry&) mlir/lib/Tools/mlir-opt/MlirOptMain.cpp:589:14
previously allocated by thread T0 here:
#0 0x55e29130ab5d in operator new(unsigned long) compiler-rt/lib/asan/asan_new_delete.cpp:86:3
#1 0x55e2979ed5d4 in __libcpp_operator_new<unsigned long>
#2 0x55e2979ed5d4 in __libcpp_allocate
#3 0x55e2979ed5d4 in allocate
#4 0x55e2979ed5d4 in __allocate_at_least<std::__u::allocator<mlir::BlockArgument> >
#5 0x55e2979ed5d4 in __split_buffer
#6 0x55e2979ed5d4 in mlir::BlockArgument* std::__u::vector<mlir::BlockArgument, std::__u::allocator<mlir::BlockArgument>>::__push_back_slow_path<mlir::BlockArgument const&>(mlir::BlockArgument const&)
#7 0x55e2979ec0f2 in push_back
#8 0x55e2979ec0f2 in mlir::Block::addArgument(mlir::Type, mlir::Location) mlir/lib/IR/Block.cpp:154:13
#9 0x55e29796e457 in parseRegionBody mlir/lib/AsmParser/Parser.cpp:2172:34
#10 0x55e29796e457 in (anonymous namespace)::OperationParser::parseRegion(mlir::Region&, llvm::ArrayRef<mlir::OpAsmParser::Argument>, bool) mlir/lib/AsmParser/Parser.cpp:2121:7
#11 0x55e29796b25e in (anonymous namespace)::CustomOpAsmParser::parseRegion(mlir::Region&, llvm::ArrayRef<mlir::OpAsmParser::Argument>, bool) mlir/lib/AsmParser/Parser.cpp:1785:16
#12 0x55e297035742 in mlir::scf::ForOp::parse(mlir::OpAsmParser&, mlir::OperationState&) mlir/lib/Dialect/SCF/IR/SCF.cpp:521:14
#13 0x55e291322c18 in llvm::ParseResult llvm::detail::UniqueFunctionBase<llvm::ParseResult, mlir::OpAsmParser&, mlir::OperationState&>::CallImpl<llvm::ParseResult (*)(mlir::OpAsmParser&, mlir::OperationState&)>(void*, mlir::OpAsmParser&, mlir::OperationState&) llvm/include/llvm/ADT/FunctionExtras.h:220:12
#14 0x55e29795bea3 in operator() llvm/include/llvm/ADT/FunctionExtras.h:384:12
#15 0x55e29795bea3 in callback_fn<llvm::unique_function<llvm::ParseResult (mlir::OpAsmParser &, mlir::OperationState &)> > llvm/include/llvm/ADT/STLFunctionalExtras.h:45:12
#16 0x55e29795bea3 in operator() llvm/include/llvm/ADT/STLFunctionalExtras.h:68:12
#17 0x55e29795bea3 in parseOperation mlir/lib/AsmParser/Parser.cpp:1521:9
#18 0x55e29795bea3 in parseCustomOperation mlir/lib/AsmParser/Parser.cpp:2017:19
#19 0x55e29795bea3 in (anonymous namespace)::OperationParser::parseOperation() mlir/lib/AsmParser/Parser.cpp:1174:10
#20 0x55e297971d20 in parseBlockBody mlir/lib/AsmParser/Parser.cpp:2296:9
#21 0x55e297971d20 in (anonymous namespace)::OperationParser::parseBlock(mlir::Block*&) mlir/lib/AsmParser/Parser.cpp:2226:12
#22 0x55e29796e4f5 in parseRegionBody mlir/lib/AsmParser/Parser.cpp:2184:7
#23 0x55e29796e4f5 in (anonymous namespace)::OperationParser::parseRegion(mlir::Region&, llvm::ArrayRef<mlir::OpAsmParser::Argument>, bool) mlir/lib/AsmParser/Parser.cpp:2121:7
#24 0x55e29796b25e in (anonymous namespace)::CustomOpAsmParser::parseRegion(mlir::Region&, llvm::ArrayRef<mlir::OpAsmParser::Argument>, bool) mlir/lib/AsmParser/Parser.cpp:1785:16
#25 0x55e29796b2cf in (anonymous namespace)::CustomOpAsmParser::parseOptionalRegion(mlir::Region&, llvm::ArrayRef<mlir::OpAsmParser::Argument>, bool) mlir/lib/AsmParser/Parser.cpp:1796:12
#26 0x55e2978d89ff in mlir::function_interface_impl::parseFunctionOp(mlir::OpAsmParser&, mlir::OperationState&, bool, mlir::StringAttr, llvm::function_ref<mlir::Type (mlir::Builder&, llvm::ArrayRef<mlir::Type>, llvm::ArrayRef<mlir::Type>, mlir::function_interface_impl::VariadicFlag, std::__u::basic_string<char, std::__u::char_traits<char>, std::__u::allocator<char>>&)>, mlir::StringAttr, mlir::StringAttr) mlir/lib/Interfaces/FunctionImplementation.cpp:232:14
#27 0x55e2969ba41d in mlir::func::FuncOp::parse(mlir::OpAsmParser&, mlir::OperationState&) mlir/lib/Dialect/Func/IR/FuncOps.cpp:203:10
#28 0x55e291322c18 in llvm::ParseResult llvm::detail::UniqueFunctionBase<llvm::ParseResult, mlir::OpAsmParser&, mlir::OperationState&>::CallImpl<llvm::ParseResult (*)(mlir::OpAsmParser&, mlir::OperationState&)>(void*, mlir::OpAsmParser&, mlir::OperationState&) llvm/include/llvm/ADT/FunctionExtras.h:220:12
#29 0x55e29795bea3 in operator() llvm/include/llvm/ADT/FunctionExtras.h:384:12
#30 0x55e29795bea3 in callback_fn<llvm::unique_function<llvm::ParseResult (mlir::OpAsmParser &, mlir::OperationState &)> > llvm/include/llvm/ADT/STLFunctionalExtras.h:45:12
#31 0x55e29795bea3 in operator() llvm/include/llvm/ADT/STLFunctionalExtras.h:68:12
#32 0x55e29795bea3 in parseOperation mlir/lib/AsmParser/Parser.cpp:1521:9
#33 0x55e29795bea3 in parseCustomOperation mlir/lib/AsmParser/Parser.cpp:2017:19
#34 0x55e29795bea3 in (anonymous namespace)::OperationParser::parseOperation() mlir/lib/AsmParser/Parser.cpp:1174:10
#35 0x55e297959b78 in parse mlir/lib/AsmParser/Parser.cpp:2725:20
#36 0x55e297959b78 in mlir::parseAsmSourceFile(llvm::SourceMgr const&, mlir::Block*, mlir::ParserConfig const&, mlir::AsmParserState*, mlir::AsmParserCodeCompleteContext*) mlir/lib/AsmParser/Parser.cpp:2785:41
#37 0x55e29790d5c2 in mlir::parseSourceFile(std::__u::shared_ptr<llvm::SourceMgr> const&, mlir::Block*, mlir::ParserConfig const&, mlir::LocationAttr*) mlir/lib/Parser/Parser.cpp:46:10
#38 0x55e291ebbfe2 in parseSourceFile<mlir::ModuleOp, const std::__u::shared_ptr<llvm::SourceMgr> &> mlir/include/mlir/Parser/Parser.h:159:14
#39 0x55e291ebbfe2 in parseSourceFile<mlir::ModuleOp> mlir/include/mlir/Parser/Parser.h:189:10
#40 0x55e291ebbfe2 in mlir::parseSourceFileForTool(std::__u::shared_ptr<llvm::SourceMgr> const&, mlir::ParserConfig const&, bool) mlir/include/mlir/Tools/ParseUtilities.h:31:12
#41 0x55e291ebb263 in performActions(llvm::raw_ostream&, std::__u::shared_ptr<llvm::SourceMgr> const&, mlir::MLIRContext*, mlir::MlirOptMainConfig const&) mlir/lib/Tools/mlir-opt/MlirOptMain.cpp:383:33
#42 0x55e291ebabd9 in processBuffer mlir/lib/Tools/mlir-opt/MlirOptMain.cpp:481:9
#43 0x55e291ebabd9 in operator() mlir/lib/Tools/mlir-opt/MlirOptMain.cpp:548:12
#44 0x55e291ebabd9 in llvm::LogicalResult llvm::function_ref<llvm::LogicalResult (std::__u::unique_ptr<llvm::MemoryBuffer, std::__u::default_delete<llvm::MemoryBuffer>>, llvm::raw_ostream&)>::callback_fn<mlir::MlirOptMain(llvm::raw_ostream&, std::__u::unique_ptr<llvm::MemoryBuffer, std::__u::default_delete<llvm::MemoryBuffer>>, mlir::DialectRegistry&, mlir::MlirOptMainConfig const&)::$_0>(long, std::__u::unique_ptr<llvm::MemoryBuffer, std::__u::default_delete<llvm::MemoryBuffer>>, llvm::raw_ostream&) llvm/include/llvm/ADT/STLFunctionalExtras.h:45:12
#45 0x55e297b1cffe in operator() llvm/include/llvm/ADT/STLFunctionalExtras.h:68:12
#46 0x55e297b1cffe in mlir::splitAndProcessBuffer(std::__u::unique_ptr<llvm::MemoryBuffer, std::__u::default_delete<llvm::MemoryBuffer>>, llvm::function_ref<llvm::LogicalResult (std::__u::unique_ptr<llvm::MemoryBuffer, std::__u::default_delete<llvm::MemoryBuffer>>, llvm::raw_ostream&)>, llvm::raw_ostream&, llvm::StringRef, llvm::StringRef)::$_0::operator()(llvm::StringRef) const mlir/lib/Support/ToolUtilities.cpp:86:16
#47 0x55e297b1c9c5 in interleave<const llvm::StringRef *, (lambda at mlir/lib/Support/ToolUtilities.cpp:79:23), (lambda at llvm/include/llvm/ADT/STLExtras.h:2147:49), void> llvm/include/llvm/ADT/STLExtras.h:2125:3
#48 0x55e297b1c9c5 in interleave<llvm::SmallVector<llvm::StringRef, 8U>, (lambda at mlir/lib/Support/ToolUtilities.cpp:79:23), llvm::raw_ostream, llvm::StringRef> llvm/include/llvm/ADT/STLExtras.h:2147:3
#49 0x55e297b1c9c5 in mlir::splitAndProcessBuffer(std::__u::unique_ptr<llvm::MemoryBuffer, std::__u::default_delete<llvm::MemoryBuffer>>, llvm::function_ref<llvm::LogicalResult (std::__u::unique_ptr<llvm::MemoryBuffer, std::__u::default_delete<llvm::MemoryBuffer>>, llvm::raw_ostream&)>, llvm::raw_ostream&, llvm::StringRef, llvm::StringRef) mlir/lib/Support/ToolUtilities.cpp:89:3
#50 0x55e291eb0cf0 in mlir::MlirOptMain(llvm::raw_ostream&, std::__u::unique_ptr<llvm::MemoryBuffer, std::__u::default_delete<llvm::MemoryBuffer>>, mlir::DialectRegistry&, mlir::MlirOptMainConfig const&) mlir/lib/Tools/mlir-opt/MlirOptMain.cpp:551:10
#51 0x55e291eb115c in mlir::MlirOptMain(int, char**, llvm::StringRef, llvm::StringRef, mlir::DialectRegistry&) mlir/lib/Tools/mlir-opt/MlirOptMain.cpp:589:14
#52 0x55e291eb15f8 in mlir::MlirOptMain(int, char**, llvm::StringRef, mlir::DialectRegistry&) mlir/lib/Tools/mlir-opt/MlirOptMain.cpp:605:10
#53 0x55e29130d1be in main mlir/tools/mlir-opt/mlir-opt.cpp:311:33
#54 0x7fbcf3fff3d3 in __libc_start_main (/usr/grte/v5/lib64/libc.so.6+0x613d3) (BuildId: 9a996398ce14a94560b0c642eb4f6e94)
#55 0x55e2912365a9 in _start /usr/grte/v5/debug-src/src/csu/../sysdeps/x86_64/start.S:120
SUMMARY: AddressSanitizer: heap-use-after-free mlir/include/mlir/IR/IRMapping.h:40:11 in map<llvm::MutableArrayRef<mlir::BlockArgument> &, llvm::MutableArrayRef<mlir::BlockArgument>, nullptr>
Shadow bytes around the buggy address:
0x502000006a00: fa fa 00 fa fa fa 00 00 fa fa 00 fa fa fa 00 fa
0x502000006a80: fa fa 00 fa fa fa 00 00 fa fa 00 00 fa fa 00 00
0x502000006b00: fa fa 00 00 fa fa 00 00 fa fa 00 fa fa fa 00 fa
0x502000006b80: fa fa 00 fa fa fa 00 fa fa fa 00 00 fa fa 00 00
0x502000006c00: fa fa 00 00 fa fa 00 00 fa fa 00 00 fa fa fd fa
=>0x502000006c80: fa fa fd fa fa fa fd fd fa fa fd[fd]fa fa fd fd
0x502000006d00: fa fa 00 fa fa fa 00 fa fa fa 00 fa fa fa 00 fa
0x502000006d80: fa fa 00 fa fa fa 00 fa fa fa 00 fa fa fa 00 fa
0x502000006e00: fa fa 00 fa fa fa 00 fa fa fa 00 00 fa fa 00 fa
0x502000006e80: fa fa 00 fa fa fa 00 00 fa fa 00 fa fa fa 00 fa
0x502000006f00: fa fa 00 fa fa fa 00 fa fa fa 00 fa fa fa 00 fa
Shadow byte legend (one shadow byte represents 8 application bytes):
Addressable: 00
Partially addressable: 01 02 03 04 05 06 07
Heap left redzone: fa
Freed heap region: fd
Stack left redzone: f1
Stack mid redzone: f2
Stack right redzone: f3
Stack after return: f5
Stack use after scope: f8
Global redzone: f9
Global init order: f6
Poisoned by user: f7
Container overflow: fc
Array cookie: ac
Intra object redzone: bb
ASan internal: fe
Left alloca redzone: ca
Right alloca redzone: cb
==4320==ABORTING
Added:
Modified:
mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
mlir/include/mlir/Dialect/SCF/Utils/Utils.h
mlir/include/mlir/Interfaces/LoopLikeInterface.h
mlir/lib/Dialect/SCF/IR/SCF.cpp
mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
mlir/lib/Dialect/SCF/Utils/Utils.cpp
mlir/lib/Interfaces/LoopLikeInterface.cpp
mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index bf95fbe6721cf..f35ea962bea16 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -303,8 +303,7 @@ def ForallOp : SCF_Op<"forall", [
DeclareOpInterfaceMethods<LoopLikeOpInterface,
["getInitsMutable", "getRegionIterArgs", "getLoopInductionVars",
"getLoopLowerBounds", "getLoopUpperBounds", "getLoopSteps",
- "replaceWithAdditionalYields", "promoteIfSingleIteration",
- "yieldTiledValuesAndReplace"]>,
+ "promoteIfSingleIteration", "yieldTiledValuesAndReplace"]>,
RecursiveMemoryEffects,
SingleBlockImplicitTerminator<"scf::InParallelOp">,
DeclareOpInterfaceMethods<RegionBranchOpInterface>,
diff --git a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
index 6a40304e2eeba..de807c3e4e1f8 100644
--- a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
@@ -181,16 +181,6 @@ Loops tilePerfectlyNested(scf::ForOp rootForOp, ArrayRef<Value> sizes);
void getPerfectlyNestedLoops(SmallVectorImpl<scf::ForOp> &nestedLoops,
scf::ForOp root);
-//===----------------------------------------------------------------------===//
-// Fusion related helpers
-//===----------------------------------------------------------------------===//
-
-/// Check structural compatibility between two loops such as iteration space
-/// and dominance.
-bool checkFusionStructuralLegality(LoopLikeOpInterface target,
- LoopLikeOpInterface source,
- Diagnostic &diag);
-
/// Given two scf.forall loops, `target` and `source`, fuses `target` into
/// `source`. Assumes that the given loops are siblings and are independent of
/// each other.
@@ -212,16 +202,6 @@ scf::ForallOp fuseIndependentSiblingForallLoops(scf::ForallOp target,
scf::ForOp fuseIndependentSiblingForLoops(scf::ForOp target, scf::ForOp source,
RewriterBase &rewriter);
-/// Given two scf.parallel loops, `target` and `source`, fuses `target` into
-/// `source`. Assumes that the given loops are siblings and are independent of
-/// each other.
-///
-/// This function does not perform any legality checks and simply fuses the
-/// loops. The caller is responsible for ensuring that the loops are legal to
-/// fuse.
-scf::ParallelOp fuseIndependentSiblingParallelLoops(scf::ParallelOp target,
- scf::ParallelOp source,
- RewriterBase &rewriter);
} // namespace mlir
#endif // MLIR_DIALECT_SCF_UTILS_UTILS_H_
diff --git a/mlir/include/mlir/Interfaces/LoopLikeInterface.h b/mlir/include/mlir/Interfaces/LoopLikeInterface.h
index d08e097a9b4af..9925fc6ce6ca9 100644
--- a/mlir/include/mlir/Interfaces/LoopLikeInterface.h
+++ b/mlir/include/mlir/Interfaces/LoopLikeInterface.h
@@ -90,24 +90,4 @@ struct JamBlockGatherer {
/// Include the generated interface declarations.
#include "mlir/Interfaces/LoopLikeInterface.h.inc"
-namespace mlir {
-/// A function that rewrites `target`'s terminator as a teminator obtained by
-/// fusing `source` into `target`.
-using FuseTerminatorFn =
- function_ref<void(RewriterBase &rewriter, LoopLikeOpInterface source,
- LoopLikeOpInterface &target, IRMapping mapping)>;
-
-/// Returns a fused `LoopLikeOpInterface` created by fusing `source` to
-/// `target`. The `NewYieldValuesFn` callback is used to pass to the
-/// `replaceWithAdditionalYields` interface method to replace the loop with a
-/// new loop with (possibly) additional yields, while the `FuseTerminatorFn`
-/// callback is repsonsible for updating the fused loop terminator.
-LoopLikeOpInterface createFused(LoopLikeOpInterface target,
- LoopLikeOpInterface source,
- RewriterBase &rewriter,
- NewYieldValuesFn newYieldValuesFn,
- FuseTerminatorFn fuseTerminatorFn);
-
-} // namespace mlir
-
#endif // MLIR_INTERFACES_LOOPLIKEINTERFACE_H_
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index cb15e0ecebf05..907d7f794593d 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -618,44 +618,6 @@ void ForOp::getSuccessorRegions(RegionBranchPoint point,
SmallVector<Region *> ForallOp::getLoopRegions() { return {&getRegion()}; }
-FailureOr<LoopLikeOpInterface> ForallOp::replaceWithAdditionalYields(
- RewriterBase &rewriter, ValueRange newInitOperands,
- bool replaceInitOperandUsesInLoop,
- const NewYieldValuesFn &newYieldValuesFn) {
- // Create a new loop before the existing one, with the extra operands.
- OpBuilder::InsertionGuard g(rewriter);
- rewriter.setInsertionPoint(getOperation());
- SmallVector<Value> inits(getOutputs());
- llvm::append_range(inits, newInitOperands);
- scf::ForallOp newLoop = rewriter.create<scf::ForallOp>(
- getLoc(), getMixedLowerBound(), getMixedUpperBound(), getMixedStep(),
- inits, getMapping(),
- /*bodyBuilderFn =*/[](OpBuilder &, Location, ValueRange) {});
-
- // Move the loop body to the new op.
- rewriter.mergeBlocks(getBody(), newLoop.getBody(),
- newLoop.getBody()->getArguments().take_front(
- getBody()->getNumArguments()));
-
- if (replaceInitOperandUsesInLoop) {
- // Replace all uses of `newInitOperands` with the corresponding basic block
- // arguments.
- for (auto &&[newOperand, oldOperand] :
- llvm::zip(newInitOperands, newLoop.getBody()->getArguments().take_back(
- newInitOperands.size()))) {
- rewriter.replaceUsesWithIf(newOperand, oldOperand, [&](OpOperand &use) {
- Operation *user = use.getOwner();
- return newLoop->isProperAncestor(user);
- });
- }
- }
-
- // Replace the old loop.
- rewriter.replaceOp(getOperation(),
- newLoop->getResults().take_front(getNumResults()));
- return cast<LoopLikeOpInterface>(newLoop.getOperation());
-}
-
/// Promotes the loop body of a forallOp to its containing block if it can be
/// determined that the loop has a single iteration.
LogicalResult scf::ForallOp::promoteIfSingleIteration(RewriterBase &rewriter) {
diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
index 41834fea3bb84..56ff2709a589e 100644
--- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
+++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
@@ -261,10 +261,8 @@ loopScheduling(scf::ForOp forOp,
return 1;
};
- std::optional<int64_t> ubConstant =
- getConstantIntValue(forOp.getUpperBound());
- std::optional<int64_t> lbConstant =
- getConstantIntValue(forOp.getLowerBound());
+ std::optional<int64_t> ubConstant = getConstantIntValue(forOp.getUpperBound());
+ std::optional<int64_t> lbConstant = getConstantIntValue(forOp.getLowerBound());
DenseMap<Operation *, unsigned> opCycles;
std::map<unsigned, std::vector<Operation *>> wrappedSchedule;
for (Operation &op : forOp.getBody()->getOperations()) {
@@ -449,6 +447,113 @@ void transform::TakeAssumedBranchOp::getEffects(
// LoopFuseSiblingOp
//===----------------------------------------------------------------------===//
+/// Check if `target` and `source` are siblings, in the context that `target`
+/// is being fused into `source`.
+///
+/// This is a simple check that just checks if both operations are in the same
+/// block and some checks to ensure that the fused IR does not violate
+/// dominance.
+static DiagnosedSilenceableFailure isOpSibling(Operation *target,
+ Operation *source) {
+ // Check if both operations are same.
+ if (target == source)
+ return emitSilenceableFailure(source)
+ << "target and source need to be
diff erent loops";
+
+ // Check if both operations are in the same block.
+ if (target->getBlock() != source->getBlock())
+ return emitSilenceableFailure(source)
+ << "target and source are not in the same block";
+
+ // Check if fusion will violate dominance.
+ DominanceInfo domInfo(source);
+ if (target->isBeforeInBlock(source)) {
+ // Since `target` is before `source`, all users of results of `target`
+ // need to be dominated by `source`.
+ for (Operation *user : target->getUsers()) {
+ if (!domInfo.properlyDominates(source, user, /*enclosingOpOk=*/false)) {
+ return emitSilenceableFailure(target)
+ << "user of results of target should be properly dominated by "
+ "source";
+ }
+ }
+ } else {
+ // Since `target` is after `source`, all values used by `target` need
+ // to dominate `source`.
+
+ // Check if operands of `target` are dominated by `source`.
+ for (Value operand : target->getOperands()) {
+ Operation *operandOp = operand.getDefiningOp();
+ // Operands without defining operations are block arguments. When `target`
+ // and `source` occur in the same block, these operands dominate `source`.
+ if (!operandOp)
+ continue;
+
+ // Operand's defining operation should properly dominate `source`.
+ if (!domInfo.properlyDominates(operandOp, source,
+ /*enclosingOpOk=*/false))
+ return emitSilenceableFailure(target)
+ << "operands of target should be properly dominated by source";
+ }
+
+ // Check if values used by `target` are dominated by `source`.
+ bool failed = false;
+ OpOperand *failedValue = nullptr;
+ visitUsedValuesDefinedAbove(target->getRegions(), [&](OpOperand *operand) {
+ Operation *operandOp = operand->get().getDefiningOp();
+ if (operandOp && !domInfo.properlyDominates(operandOp, source,
+ /*enclosingOpOk=*/false)) {
+ // `operand` is not an argument of an enclosing block and the defining
+ // op of `operand` is outside `target` but does not dominate `source`.
+ failed = true;
+ failedValue = operand;
+ }
+ });
+
+ if (failed)
+ return emitSilenceableFailure(failedValue->getOwner())
+ << "values used inside regions of target should be properly "
+ "dominated by source";
+ }
+
+ return DiagnosedSilenceableFailure::success();
+}
+
+/// Check if `target` scf.forall can be fused into `source` scf.forall.
+///
+/// This simply checks if both loops have the same bounds, steps and mapping.
+/// No attempt is made at checking that the side effects of `target` and
+/// `source` are independent of each other.
+static bool isForallWithIdenticalConfiguration(Operation *target,
+ Operation *source) {
+ auto targetOp = dyn_cast<scf::ForallOp>(target);
+ auto sourceOp = dyn_cast<scf::ForallOp>(source);
+ if (!targetOp || !sourceOp)
+ return false;
+
+ return targetOp.getMixedLowerBound() == sourceOp.getMixedLowerBound() &&
+ targetOp.getMixedUpperBound() == sourceOp.getMixedUpperBound() &&
+ targetOp.getMixedStep() == sourceOp.getMixedStep() &&
+ targetOp.getMapping() == sourceOp.getMapping();
+}
+
+/// Check if `target` scf.for can be fused into `source` scf.for.
+///
+/// This simply checks if both loops have the same bounds and steps. No attempt
+/// is made at checking that the side effects of `target` and `source` are
+/// independent of each other.
+static bool isForWithIdenticalConfiguration(Operation *target,
+ Operation *source) {
+ auto targetOp = dyn_cast<scf::ForOp>(target);
+ auto sourceOp = dyn_cast<scf::ForOp>(source);
+ if (!targetOp || !sourceOp)
+ return false;
+
+ return targetOp.getLowerBound() == sourceOp.getLowerBound() &&
+ targetOp.getUpperBound() == sourceOp.getUpperBound() &&
+ targetOp.getStep() == sourceOp.getStep();
+}
+
DiagnosedSilenceableFailure
transform::LoopFuseSiblingOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
@@ -464,32 +569,25 @@ transform::LoopFuseSiblingOp::apply(transform::TransformRewriter &rewriter,
<< "source handle (got " << llvm::range_size(sourceOps) << ")";
}
- auto target = dyn_cast<LoopLikeOpInterface>(*targetOps.begin());
- auto source = dyn_cast<LoopLikeOpInterface>(*sourceOps.begin());
- if (!target || !source)
- return emitSilenceableFailure(target->getLoc())
- << "target or source is not a loop op";
+ Operation *target = *targetOps.begin();
+ Operation *source = *sourceOps.begin();
- // Check if loops can be fused
- Diagnostic diag(target.getLoc(), DiagnosticSeverity::Error);
- if (!mlir::checkFusionStructuralLegality(target, source, diag))
- return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
+ // Check if the target and source are siblings.
+ DiagnosedSilenceableFailure diag = isOpSibling(target, source);
+ if (!diag.succeeded())
+ return diag;
Operation *fusedLoop;
- // TODO: Support fusion for loop-like ops besides scf.for, scf.forall
- // and scf.parallel.
- if (isa<scf::ForOp>(target) && isa<scf::ForOp>(source)) {
+ /// TODO: Support fusion for loop-like ops besides scf.for and scf.forall.
+ if (isForWithIdenticalConfiguration(target, source)) {
fusedLoop = fuseIndependentSiblingForLoops(
cast<scf::ForOp>(target), cast<scf::ForOp>(source), rewriter);
- } else if (isa<scf::ForallOp>(target) && isa<scf::ForallOp>(source)) {
+ } else if (isForallWithIdenticalConfiguration(target, source)) {
fusedLoop = fuseIndependentSiblingForallLoops(
cast<scf::ForallOp>(target), cast<scf::ForallOp>(source), rewriter);
- } else if (isa<scf::ParallelOp>(target) && isa<scf::ParallelOp>(source)) {
- fusedLoop = fuseIndependentSiblingParallelLoops(
- cast<scf::ParallelOp>(target), cast<scf::ParallelOp>(source), rewriter);
} else
return emitSilenceableFailure(target->getLoc())
- << "unsupported loop type for fusion";
+ << "operations cannot be fused";
assert(fusedLoop && "failed to fuse operations");
diff --git a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
index b775f988576e3..5934d85373b03 100644
--- a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
@@ -16,7 +16,6 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/Transforms/Transforms.h"
-#include "mlir/Dialect/SCF/Utils/Utils.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/OpDefinition.h"
@@ -38,6 +37,24 @@ static bool hasNestedParallelOp(ParallelOp ploop) {
return walkResult.wasInterrupted();
}
+/// Verify equal iteration spaces.
+static bool equalIterationSpaces(ParallelOp firstPloop,
+ ParallelOp secondPloop) {
+ if (firstPloop.getNumLoops() != secondPloop.getNumLoops())
+ return false;
+
+ auto matchOperands = [&](const OperandRange &lhs,
+ const OperandRange &rhs) -> bool {
+ // TODO: Extend this to support aliases and equal constants.
+ return std::equal(lhs.begin(), lhs.end(), rhs.begin());
+ };
+ return matchOperands(firstPloop.getLowerBound(),
+ secondPloop.getLowerBound()) &&
+ matchOperands(firstPloop.getUpperBound(),
+ secondPloop.getUpperBound()) &&
+ matchOperands(firstPloop.getStep(), secondPloop.getStep());
+}
+
/// Checks if the parallel loops have mixed access to the same buffers. Returns
/// `true` if the first parallel loop writes to the same indices that the second
/// loop reads.
@@ -136,10 +153,9 @@ verifyDependencies(ParallelOp firstPloop, ParallelOp secondPloop,
static bool isFusionLegal(ParallelOp firstPloop, ParallelOp secondPloop,
const IRMapping &firstToSecondPloopIndices,
llvm::function_ref<bool(Value, Value)> mayAlias) {
- Diagnostic diag(firstPloop.getLoc(), DiagnosticSeverity::Remark);
return !hasNestedParallelOp(firstPloop) &&
!hasNestedParallelOp(secondPloop) &&
- checkFusionStructuralLegality(firstPloop, secondPloop, diag) &&
+ equalIterationSpaces(firstPloop, secondPloop) &&
succeeded(verifyDependencies(firstPloop, secondPloop,
firstToSecondPloopIndices, mayAlias));
}
@@ -158,9 +174,61 @@ static void fuseIfLegal(ParallelOp firstPloop, ParallelOp &secondPloop,
mayAlias))
return;
- IRRewriter rewriter(builder);
- secondPloop = mlir::fuseIndependentSiblingParallelLoops(
- firstPloop, secondPloop, rewriter);
+ DominanceInfo dom;
+ // We are fusing first loop into second, make sure there are no users of the
+ // first loop results between loops.
+ for (Operation *user : firstPloop->getUsers())
+ if (!dom.properlyDominates(secondPloop, user, /*enclosingOpOk*/ false))
+ return;
+
+ ValueRange inits1 = firstPloop.getInitVals();
+ ValueRange inits2 = secondPloop.getInitVals();
+
+ SmallVector<Value> newInitVars(inits1.begin(), inits1.end());
+ newInitVars.append(inits2.begin(), inits2.end());
+
+ IRRewriter b(builder);
+ b.setInsertionPoint(secondPloop);
+ auto newSecondPloop = b.create<ParallelOp>(
+ secondPloop.getLoc(), secondPloop.getLowerBound(),
+ secondPloop.getUpperBound(), secondPloop.getStep(), newInitVars);
+
+ Block *newBlock = newSecondPloop.getBody();
+ auto term1 = cast<ReduceOp>(block1->getTerminator());
+ auto term2 = cast<ReduceOp>(block2->getTerminator());
+
+ b.inlineBlockBefore(block2, newBlock, newBlock->begin(),
+ newBlock->getArguments());
+ b.inlineBlockBefore(block1, newBlock, newBlock->begin(),
+ newBlock->getArguments());
+
+ ValueRange results = newSecondPloop.getResults();
+ if (!results.empty()) {
+ b.setInsertionPointToEnd(newBlock);
+
+ ValueRange reduceArgs1 = term1.getOperands();
+ ValueRange reduceArgs2 = term2.getOperands();
+ SmallVector<Value> newReduceArgs(reduceArgs1.begin(), reduceArgs1.end());
+ newReduceArgs.append(reduceArgs2.begin(), reduceArgs2.end());
+
+ auto newReduceOp = b.create<scf::ReduceOp>(term2.getLoc(), newReduceArgs);
+
+ for (auto &&[i, reg] : llvm::enumerate(llvm::concat<Region>(
+ term1.getReductions(), term2.getReductions()))) {
+ Block &oldRedBlock = reg.front();
+ Block &newRedBlock = newReduceOp.getReductions()[i].front();
+ b.inlineBlockBefore(&oldRedBlock, &newRedBlock, newRedBlock.begin(),
+ newRedBlock.getArguments());
+ }
+
+ firstPloop.replaceAllUsesWith(results.take_front(inits1.size()));
+ secondPloop.replaceAllUsesWith(results.take_back(inits2.size()));
+ }
+ term1->erase();
+ term2->erase();
+ firstPloop.erase();
+ secondPloop.erase();
+ secondPloop = newSecondPloop;
}
void mlir::scf::naivelyFuseParallelOps(
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index abfc9a1b4d444..c0ee9d2afe91c 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -17,7 +17,6 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/BuiltinOps.h"
-#include "mlir/IR/Dominance.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
@@ -1263,131 +1262,54 @@ TileLoops mlir::extractFixedOuterLoops(scf::ForOp rootForOp,
return tileLoops;
}
-//===----------------------------------------------------------------------===//
-// Fusion related helpers
-//===----------------------------------------------------------------------===//
-
-/// Check if `target` and `source` are siblings, in the context that `target`
-/// is being fused into `source`.
-///
-/// This is a simple check that just checks if both operations are in the same
-/// block and some checks to ensure that the fused IR does not violate
-/// dominance.
-static bool isOpSibling(Operation *target, Operation *source,
- Diagnostic &diag) {
- // Check if both operations are same.
- if (target == source) {
- diag << "target and source need to be
diff erent loops";
- return false;
- }
-
- // Check if both operations are in the same block.
- if (target->getBlock() != source->getBlock()) {
- diag << "target and source are not in the same block";
- return false;
- }
-
- // Check if fusion will violate dominance.
- DominanceInfo domInfo(source);
- if (target->isBeforeInBlock(source)) {
- // Since `target` is before `source`, all users of results of `target`
- // need to be dominated by `source`.
- for (Operation *user : target->getUsers()) {
- if (!domInfo.properlyDominates(source, user, /*enclosingOpOk=*/false)) {
- diag << "user of results of target should "
- "be properly dominated by source";
- return false;
- }
- }
- } else {
- // Since `target` is after `source`, all values used by `target` need
- // to dominate `source`.
-
- // Check if operands of `target` are dominated by `source`.
- for (Value operand : target->getOperands()) {
- Operation *operandOp = operand.getDefiningOp();
- // Operands without defining operations are block arguments. When `target`
- // and `source` occur in the same block, these operands dominate `source`.
- if (!operandOp)
- continue;
-
- // Operand's defining operation should properly dominate `source`.
- if (!domInfo.properlyDominates(operandOp, source,
- /*enclosingOpOk=*/false)) {
- diag << "operands of target should be properly dominated by source";
- return false;
- }
- }
-
- // Check if values used by `target` are dominated by `source`.
- bool failed = false;
- OpOperand *failedValue = nullptr;
- visitUsedValuesDefinedAbove(target->getRegions(), [&](OpOperand *operand) {
- Operation *operandOp = operand->get().getDefiningOp();
- if (operandOp && !domInfo.properlyDominates(operandOp, source,
- /*enclosingOpOk=*/false)) {
- // `operand` is not an argument of an enclosing block and the defining
- // op of `operand` is outside `target` but does not dominate `source`.
- failed = true;
- failedValue = operand;
- }
- });
-
- if (failed) {
- diag << "values used inside regions of target should be properly "
- "dominated by source";
- diag.attachNote(failedValue->getOwner()->getLoc()) << "see operation";
- return false;
- }
- }
-
- return true;
-}
-
-bool mlir::checkFusionStructuralLegality(LoopLikeOpInterface target,
- LoopLikeOpInterface source,
- Diagnostic &diag) {
- if (target->getName() != source->getName()) {
- diag << "target and source must be same loop type";
- return false;
- }
-
- bool iterSpaceEq =
- target.getLoopLowerBounds() == source.getLoopLowerBounds() &&
- target.getLoopUpperBounds() == source.getLoopUpperBounds() &&
- target.getLoopSteps() == source.getLoopSteps();
- // TODO: Decouple checks on concrete loop types and move this function
- // somewhere for general utility for `LoopLikeOpInterface`
- if (auto forAllTarget = dyn_cast<scf::ForallOp>(*target))
- iterSpaceEq = iterSpaceEq && forAllTarget.getMapping() ==
- cast<scf::ForallOp>(*source).getMapping();
- if (!iterSpaceEq) {
- diag << "target and source iteration spaces must be equal";
- return false;
- }
- return isOpSibling(target, source, diag);
-}
-
scf::ForallOp mlir::fuseIndependentSiblingForallLoops(scf::ForallOp target,
scf::ForallOp source,
RewriterBase &rewriter) {
- scf::ForallOp fusedLoop = cast<scf::ForallOp>(createFused(
- target, source, rewriter,
- [&](OpBuilder &b, Location loc, ArrayRef<BlockArgument> newBBArgs) {
- // `ForallOp` does not have yields, rather an `InParallelOp` terminator.
- return ValueRange{};
- },
- [&](RewriterBase &b, LoopLikeOpInterface source,
- LoopLikeOpInterface &target, IRMapping mapping) {
- auto sourceForall = cast<scf::ForallOp>(source);
- auto targetForall = cast<scf::ForallOp>(target);
- scf::InParallelOp fusedTerm = targetForall.getTerminator();
- b.setInsertionPointToEnd(fusedTerm.getBody());
- for (Operation &op : sourceForall.getTerminator().getYieldingOps())
- b.clone(op, mapping);
- }));
- rewriter.replaceOp(source,
- fusedLoop.getResults().take_back(source.getNumResults()));
+ unsigned numTargetOuts = target.getNumResults();
+ unsigned numSourceOuts = source.getNumResults();
+
+ // Create fused shared_outs.
+ SmallVector<Value> fusedOuts;
+ llvm::append_range(fusedOuts, target.getOutputs());
+ llvm::append_range(fusedOuts, source.getOutputs());
+
+ // Create a new scf.forall op after the source loop.
+ rewriter.setInsertionPointAfter(source);
+ scf::ForallOp fusedLoop = rewriter.create<scf::ForallOp>(
+ source.getLoc(), source.getMixedLowerBound(), source.getMixedUpperBound(),
+ source.getMixedStep(), fusedOuts, source.getMapping());
+
+ // Map control operands.
+ IRMapping mapping;
+ mapping.map(target.getInductionVars(), fusedLoop.getInductionVars());
+ mapping.map(source.getInductionVars(), fusedLoop.getInductionVars());
+
+ // Map shared outs.
+ mapping.map(target.getRegionIterArgs(),
+ fusedLoop.getRegionIterArgs().take_front(numTargetOuts));
+ mapping.map(source.getRegionIterArgs(),
+ fusedLoop.getRegionIterArgs().take_back(numSourceOuts));
+
+ // Append everything except the terminator into the fused operation.
+ rewriter.setInsertionPointToStart(fusedLoop.getBody());
+ for (Operation &op : target.getBody()->without_terminator())
+ rewriter.clone(op, mapping);
+ for (Operation &op : source.getBody()->without_terminator())
+ rewriter.clone(op, mapping);
+
+ // Fuse the old terminator in_parallel ops into the new one.
+ scf::InParallelOp targetTerm = target.getTerminator();
+ scf::InParallelOp sourceTerm = source.getTerminator();
+ scf::InParallelOp fusedTerm = fusedLoop.getTerminator();
+ rewriter.setInsertionPointToStart(fusedTerm.getBody());
+ for (Operation &op : targetTerm.getYieldingOps())
+ rewriter.clone(op, mapping);
+ for (Operation &op : sourceTerm.getYieldingOps())
+ rewriter.clone(op, mapping);
+
+ // Replace old loops by substituting their uses by results of the fused loop.
+ rewriter.replaceOp(target, fusedLoop.getResults().take_front(numTargetOuts));
+ rewriter.replaceOp(source, fusedLoop.getResults().take_back(numSourceOuts));
return fusedLoop;
}
@@ -1395,74 +1317,49 @@ scf::ForallOp mlir::fuseIndependentSiblingForallLoops(scf::ForallOp target,
scf::ForOp mlir::fuseIndependentSiblingForLoops(scf::ForOp target,
scf::ForOp source,
RewriterBase &rewriter) {
- scf::ForOp fusedLoop = cast<scf::ForOp>(createFused(
- target, source, rewriter,
- [&](OpBuilder &b, Location loc, ArrayRef<BlockArgument> newBBArgs) {
- return source.getYieldedValues();
- },
- [&](RewriterBase &b, LoopLikeOpInterface source,
- LoopLikeOpInterface &target, IRMapping mapping) {
- auto targetFor = cast<scf::ForOp>(target);
- auto newTerm = b.clone(*targetFor.getBody()->getTerminator(), mapping);
- b.replaceOp(targetFor.getBody()->getTerminator(), newTerm);
- }));
- rewriter.replaceOp(source,
- fusedLoop.getResults().take_back(source.getNumResults()));
- return fusedLoop;
-}
-
-// TODO: Finish refactoring this a la the above, but likely requires additional
-// interface methods.
-scf::ParallelOp mlir::fuseIndependentSiblingParallelLoops(
- scf::ParallelOp target, scf::ParallelOp source, RewriterBase &rewriter) {
- OpBuilder::InsertionGuard guard(rewriter);
- Block *block1 = target.getBody();
- Block *block2 = source.getBody();
- auto term1 = cast<scf::ReduceOp>(block1->getTerminator());
- auto term2 = cast<scf::ReduceOp>(block2->getTerminator());
-
- ValueRange inits1 = target.getInitVals();
- ValueRange inits2 = source.getInitVals();
-
- SmallVector<Value> newInitVars(inits1.begin(), inits1.end());
- newInitVars.append(inits2.begin(), inits2.end());
-
- rewriter.setInsertionPoint(source);
- auto fusedLoop = rewriter.create<scf::ParallelOp>(
- rewriter.getFusedLoc(target.getLoc(), source.getLoc()),
- source.getLowerBound(), source.getUpperBound(), source.getStep(),
- newInitVars);
- Block *newBlock = fusedLoop.getBody();
- rewriter.inlineBlockBefore(block2, newBlock, newBlock->begin(),
- newBlock->getArguments());
- rewriter.inlineBlockBefore(block1, newBlock, newBlock->begin(),
- newBlock->getArguments());
-
- ValueRange results = fusedLoop.getResults();
- if (!results.empty()) {
- rewriter.setInsertionPointToEnd(newBlock);
-
- ValueRange reduceArgs1 = term1.getOperands();
- ValueRange reduceArgs2 = term2.getOperands();
- SmallVector<Value> newReduceArgs(reduceArgs1.begin(), reduceArgs1.end());
- newReduceArgs.append(reduceArgs2.begin(), reduceArgs2.end());
-
- auto newReduceOp = rewriter.create<scf::ReduceOp>(
- rewriter.getFusedLoc(term1.getLoc(), term2.getLoc()), newReduceArgs);
-
- for (auto &&[i, reg] : llvm::enumerate(llvm::concat<Region>(
- term1.getReductions(), term2.getReductions()))) {
- Block &oldRedBlock = reg.front();
- Block &newRedBlock = newReduceOp.getReductions()[i].front();
- rewriter.inlineBlockBefore(&oldRedBlock, &newRedBlock,
- newRedBlock.begin(),
- newRedBlock.getArguments());
- }
- }
- rewriter.replaceOp(target, results.take_front(inits1.size()));
- rewriter.replaceOp(source, results.take_back(inits2.size()));
- rewriter.eraseOp(term1);
- rewriter.eraseOp(term2);
+ unsigned numTargetOuts = target.getNumResults();
+ unsigned numSourceOuts = source.getNumResults();
+
+ // Create fused init_args, with target's init_args before source's init_args.
+ SmallVector<Value> fusedInitArgs;
+ llvm::append_range(fusedInitArgs, target.getInitArgs());
+ llvm::append_range(fusedInitArgs, source.getInitArgs());
+
+ // Create a new scf.for op after the source loop (with scf.yield terminator
+ // (without arguments) only in case its init_args is empty).
+ rewriter.setInsertionPointAfter(source);
+ scf::ForOp fusedLoop = rewriter.create<scf::ForOp>(
+ source.getLoc(), source.getLowerBound(), source.getUpperBound(),
+ source.getStep(), fusedInitArgs);
+
+ // Map original induction variables and operands to those of the fused loop.
+ IRMapping mapping;
+ mapping.map(target.getInductionVar(), fusedLoop.getInductionVar());
+ mapping.map(target.getRegionIterArgs(),
+ fusedLoop.getRegionIterArgs().take_front(numTargetOuts));
+ mapping.map(source.getInductionVar(), fusedLoop.getInductionVar());
+ mapping.map(source.getRegionIterArgs(),
+ fusedLoop.getRegionIterArgs().take_back(numSourceOuts));
+
+ // Merge target's body into the new (fused) for loop and then source's body.
+ rewriter.setInsertionPointToStart(fusedLoop.getBody());
+ for (Operation &op : target.getBody()->without_terminator())
+ rewriter.clone(op, mapping);
+ for (Operation &op : source.getBody()->without_terminator())
+ rewriter.clone(op, mapping);
+
+ // Build fused yield results by appropriately mapping original yield operands.
+ SmallVector<Value> yieldResults;
+ for (Value operand : target.getBody()->getTerminator()->getOperands())
+ yieldResults.push_back(mapping.lookupOrDefault(operand));
+ for (Value operand : source.getBody()->getTerminator()->getOperands())
+ yieldResults.push_back(mapping.lookupOrDefault(operand));
+ if (!yieldResults.empty())
+ rewriter.create<scf::YieldOp>(source.getLoc(), yieldResults);
+
+ // Replace old loops by substituting their uses by results of the fused loop.
+ rewriter.replaceOp(target, fusedLoop.getResults().take_front(numTargetOuts));
+ rewriter.replaceOp(source, fusedLoop.getResults().take_back(numSourceOuts));
return fusedLoop;
}
diff --git a/mlir/lib/Interfaces/LoopLikeInterface.cpp b/mlir/lib/Interfaces/LoopLikeInterface.cpp
index 5a119a7cf2659..1e0e87b64e811 100644
--- a/mlir/lib/Interfaces/LoopLikeInterface.cpp
+++ b/mlir/lib/Interfaces/LoopLikeInterface.cpp
@@ -8,8 +8,6 @@
#include "mlir/Interfaces/LoopLikeInterface.h"
-#include "mlir/IR/IRMapping.h"
-#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "llvm/ADT/DenseSet.h"
@@ -115,60 +113,3 @@ LogicalResult detail::verifyLoopLikeOpInterface(Operation *op) {
return success();
}
-
-LoopLikeOpInterface mlir::createFused(LoopLikeOpInterface target,
- LoopLikeOpInterface source,
- RewriterBase &rewriter,
- NewYieldValuesFn newYieldValuesFn,
- FuseTerminatorFn fuseTerminatorFn) {
- auto targetIterArgs = target.getRegionIterArgs();
- std::optional<SmallVector<Value>> targetInductionVar =
- target.getLoopInductionVars();
- SmallVector<Value> targetYieldOperands(target.getYieldedValues());
- auto sourceIterArgs = source.getRegionIterArgs();
- std::optional<SmallVector<Value>> sourceInductionVar =
- *source.getLoopInductionVars();
- SmallVector<Value> sourceYieldOperands(source.getYieldedValues());
- auto sourceRegion = source.getLoopRegions().front();
-
- FailureOr<LoopLikeOpInterface> maybeFusedLoop =
- target.replaceWithAdditionalYields(rewriter, source.getInits(),
- /*replaceInitOperandUsesInLoop=*/false,
- newYieldValuesFn);
- if (failed(maybeFusedLoop))
- llvm_unreachable("failed to replace loop");
- LoopLikeOpInterface fusedLoop = *maybeFusedLoop;
- // Since the target op is rewritten at the original's location, we move it to
- // the soure op's location.
- rewriter.moveOpBefore(fusedLoop, source);
-
- // Map control operands.
- IRMapping mapping;
- std::optional<SmallVector<Value>> fusedInductionVar =
- fusedLoop.getLoopInductionVars();
- if (fusedInductionVar) {
- if (!targetInductionVar || !sourceInductionVar)
- llvm_unreachable(
- "expected target and source loops to have induction vars");
- mapping.map(*targetInductionVar, *fusedInductionVar);
- mapping.map(*sourceInductionVar, *fusedInductionVar);
- }
- mapping.map(targetIterArgs,
- fusedLoop.getRegionIterArgs().take_front(targetIterArgs.size()));
- mapping.map(targetYieldOperands,
- fusedLoop.getYieldedValues().take_front(targetIterArgs.size()));
- mapping.map(sourceIterArgs,
- fusedLoop.getRegionIterArgs().take_back(sourceIterArgs.size()));
- mapping.map(sourceYieldOperands,
- fusedLoop.getYieldedValues().take_back(sourceIterArgs.size()));
- // Append everything except the terminator into the fused operation.
- rewriter.setInsertionPoint(
- fusedLoop.getLoopRegions().front()->front().getTerminator());
- for (Operation &op : sourceRegion->front().without_terminator())
- rewriter.clone(op, mapping);
-
- // TODO: Replace with corresponding interface method if added
- fuseTerminatorFn(rewriter, source, fusedLoop, mapping);
-
- return fusedLoop;
-}
diff --git a/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir b/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir
index f8246b74a5744..54dd2bdf953ca 100644
--- a/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir
+++ b/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir
@@ -47,169 +47,6 @@ module attributes {transform.with_named_sequence} {
// -----
-// CHECK-LABEL: func @fuse_two_parallel
-// CHECK-SAME: ([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}) {
-func.func @fuse_two_parallel(%A: memref<2x2xf32>, %B: memref<2x2xf32>) {
-// CHECK-DAG: [[C2:%.*]] = arith.constant 2 : index
-// CHECK-DAG: [[C0:%.*]] = arith.constant 0 : index
-// CHECK-DAG: [[C1:%.*]] = arith.constant 1 : index
-// CHECK-DAG: [[C1FP:%.*]] = arith.constant 1.
- %c2 = arith.constant 2 : index
- %c0 = arith.constant 0 : index
- %c1 = arith.constant 1 : index
- %c1fp = arith.constant 1.0 : f32
-// CHECK: [[SUM:%.*]] = memref.alloc()
- %sum = memref.alloc() : memref<2x2xf32>
-// CHECK: scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]])
-// CHECK-SAME: to ([[C2]], [[C2]]) step ([[C1]], [[C1]]) {
-// CHECK: [[B_ELEM:%.*]] = memref.load [[B]]{{\[}}[[I]], [[J]]]
-// CHECK: [[SUM_ELEM:%.*]] = arith.addf [[B_ELEM]], [[C1FP]]
-// CHECK: memref.store [[SUM_ELEM]], [[SUM]]{{\[}}[[I]], [[J]]]
-// CHECK-NOT: scf.parallel
-// CHECK: [[SUM_ELEM_:%.*]] = memref.load [[SUM]]{{\[}}[[I]], [[J]]]
-// CHECK: [[A_ELEM:%.*]] = memref.load [[A]]{{\[}}[[I]], [[J]]]
-// CHECK: [[PRODUCT_ELEM:%.*]] = arith.mulf [[SUM_ELEM_]], [[A_ELEM]]
-// CHECK: memref.store [[PRODUCT_ELEM]], [[B]]{{\[}}[[I]], [[J]]]
-// CHECK: scf.reduce
-// CHECK: }
- scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
- %B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
- %sum_elem = arith.addf %B_elem, %c1fp : f32
- memref.store %sum_elem, %sum[%i, %j] : memref<2x2xf32>
- scf.reduce
- }
- scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
- %sum_elem = memref.load %sum[%i, %j] : memref<2x2xf32>
- %A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
- %product_elem = arith.mulf %sum_elem, %A_elem : f32
- memref.store %product_elem, %B[%i, %j] : memref<2x2xf32>
- scf.reduce
- }
-// CHECK: memref.dealloc [[SUM]]
- memref.dealloc %sum : memref<2x2xf32>
- return
-}
-module attributes {transform.with_named_sequence} {
- transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["scf.parallel"]} in %arg0 : (!transform.any_op) -> !transform.any_op
- %parallel:2 = transform.split_handle %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
- %fused = transform.loop.fuse_sibling %parallel#0 into %parallel#1 : (!transform.any_op,!transform.any_op) -> !transform.any_op
- transform.yield
- }
-}
-
-// -----
-
-// CHECK-LABEL: func @fuse_two_parallel_reverse
-// CHECK-SAME: ([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}) {
-func.func @fuse_two_parallel_reverse(%A: memref<2x2xf32>, %B: memref<2x2xf32>) {
-// CHECK-DAG: [[C2:%.*]] = arith.constant 2 : index
-// CHECK-DAG: [[C0:%.*]] = arith.constant 0 : index
-// CHECK-DAG: [[C1:%.*]] = arith.constant 1 : index
-// CHECK-DAG: [[C1FP:%.*]] = arith.constant 1.
- %c2 = arith.constant 2 : index
- %c0 = arith.constant 0 : index
- %c1 = arith.constant 1 : index
- %c1fp = arith.constant 1.0 : f32
-// CHECK: [[SUM:%.*]] = memref.alloc()
- %sum = memref.alloc() : memref<2x2xf32>
-// CHECK: scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]])
-// CHECK-SAME: to ([[C2]], [[C2]]) step ([[C1]], [[C1]]) {
-// CHECK: [[SUM_ELEM_:%.*]] = memref.load [[SUM]]{{\[}}[[I]], [[J]]]
-// CHECK: [[A_ELEM:%.*]] = memref.load [[A]]{{\[}}[[I]], [[J]]]
-// CHECK: [[PRODUCT_ELEM:%.*]] = arith.mulf [[SUM_ELEM_]], [[A_ELEM]]
-// CHECK: memref.store [[PRODUCT_ELEM]], [[B]]{{\[}}[[I]], [[J]]]
-// CHECK-NOT: scf.parallel
-// CHECK: [[B_ELEM:%.*]] = memref.load [[B]]{{\[}}[[I]], [[J]]]
-// CHECK: [[SUM_ELEM:%.*]] = arith.addf [[B_ELEM]], [[C1FP]]
-// CHECK: memref.store [[SUM_ELEM]], [[SUM]]{{\[}}[[I]], [[J]]]
-// CHECK: scf.reduce
-// CHECK: }
- scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
- %B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
- %sum_elem = arith.addf %B_elem, %c1fp : f32
- memref.store %sum_elem, %sum[%i, %j] : memref<2x2xf32>
- scf.reduce
- }
- scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
- %sum_elem = memref.load %sum[%i, %j] : memref<2x2xf32>
- %A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
- %product_elem = arith.mulf %sum_elem, %A_elem : f32
- memref.store %product_elem, %B[%i, %j] : memref<2x2xf32>
- scf.reduce
- }
-// CHECK: memref.dealloc [[SUM]]
- memref.dealloc %sum : memref<2x2xf32>
- return
-}
-module attributes {transform.with_named_sequence} {
- transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["scf.parallel"]} in %arg0 : (!transform.any_op) -> !transform.any_op
- %parallel:2 = transform.split_handle %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
- %fused = transform.loop.fuse_sibling %parallel#1 into %parallel#0 : (!transform.any_op,!transform.any_op) -> !transform.any_op
- transform.yield
- }
-}
-
-// -----
-
-// CHECK-LABEL: func @fuse_reductions_two
-// CHECK-SAME: (%[[A:.*]]: memref<2x2xf32>, %[[B:.*]]: memref<2x2xf32>) -> (f32, f32)
-func.func @fuse_reductions_two(%A: memref<2x2xf32>, %B: memref<2x2xf32>) -> (f32, f32) {
-// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
-// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
-// CHECK-DAG: %[[INIT1:.*]] = arith.constant 1.000000e+00 : f32
-// CHECK-DAG: %[[INIT2:.*]] = arith.constant 2.000000e+00 : f32
-// CHECK: %[[RES:.*]]:2 = scf.parallel (%[[I:.*]], %[[J:.*]]) = (%[[C0]], %[[C0]])
-// CHECK-SAME: to (%[[C2]], %[[C2]]) step (%[[C1]], %[[C1]])
-// CHECK-SAME: init (%[[INIT1]], %[[INIT2]]) -> (f32, f32)
-// CHECK: %[[VAL_A:.*]] = memref.load %[[A]][%[[I]], %[[J]]]
-// CHECK: %[[VAL_B:.*]] = memref.load %[[B]][%[[I]], %[[J]]]
-// CHECK: scf.reduce(%[[VAL_A]], %[[VAL_B]] : f32, f32) {
-// CHECK: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32):
-// CHECK: %[[R:.*]] = arith.addf %[[LHS]], %[[RHS]] : f32
-// CHECK: scf.reduce.return %[[R]] : f32
-// CHECK: }
-// CHECK: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32):
-// CHECK: %[[R:.*]] = arith.mulf %[[LHS]], %[[RHS]] : f32
-// CHECK: scf.reduce.return %[[R]] : f32
-// CHECK: }
-// CHECK: return %[[RES]]#0, %[[RES]]#1 : f32, f32
- %c2 = arith.constant 2 : index
- %c0 = arith.constant 0 : index
- %c1 = arith.constant 1 : index
- %init1 = arith.constant 1.0 : f32
- %init2 = arith.constant 2.0 : f32
- %res1 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init1) -> f32 {
- %A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
- scf.reduce(%A_elem : f32) {
- ^bb0(%lhs: f32, %rhs: f32):
- %1 = arith.addf %lhs, %rhs : f32
- scf.reduce.return %1 : f32
- }
- }
- %res2 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init2) -> f32 {
- %B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
- scf.reduce(%B_elem : f32) {
- ^bb0(%lhs: f32, %rhs: f32):
- %1 = arith.mulf %lhs, %rhs : f32
- scf.reduce.return %1 : f32
- }
- }
- return %res1, %res2 : f32, f32
-}
-module attributes {transform.with_named_sequence} {
- transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["scf.parallel"]} in %arg0 : (!transform.any_op) -> !transform.any_op
- %parallel:2 = transform.split_handle %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
- %fused = transform.loop.fuse_sibling %parallel#0 into %parallel#1 : (!transform.any_op,!transform.any_op) -> !transform.any_op
- transform.yield
- }
-}
-
-// -----
-
// CHECK: func.func @fuse_2nd_for_into_1st([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}
func.func @fuse_2nd_for_into_1st(%A: tensor<128xf32>, %B: tensor<128xf32>) -> (tensor<128xf32>, tensor<128xf32>) {
// CHECK-DAG: [[C0:%.*]] = arith.constant 0 : index
@@ -371,62 +208,6 @@ module attributes {transform.with_named_sequence} {
}
}
-
-// -----
-
-// CHECK: #[[$MAP:.+]] = affine_map<(d0) -> (d0 * 32)
-#map = affine_map<(d0) -> (d0 * 32)>
-#map1 = affine_map<(d0, d1) -> (d0, d1)>
-module {
- // CHECK: func.func @loop_sibling_fusion(%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}, %[[ARG3:.*]]: {{.*}}
- func.func @loop_sibling_fusion(%arg0: tensor<128xf32>, %arg1: tensor<128x128xf16>, %arg2: tensor<128x64xf32>, %arg3: tensor<128x128xf32>) -> (tensor<128xf32>, tensor<128x128xf16>) {
- // CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<128x128xf16>
- // CHECK-NEXT: %[[RESULTS:.*]]:2 = scf.forall (%[[I:.*]]) in (4) shared_outs(%[[S1:.*]] = %[[ARG0]], %[[S2:.*]] = %[[ARG1]]) -> (tensor<128xf32>, tensor<128x128xf16>) {
- // CHECK-NEXT: %[[IDX:.*]] = affine.apply #[[$MAP]](%[[I]])
- // CHECK-NEXT: %[[SLICE0:.*]] = tensor.extract_slice %[[ARG3]][%[[IDX]], 0] [32, 1] [1, 1] : tensor<128x128xf32> to tensor<32xf32>
- // CHECK-NEXT: %[[SLICE1:.*]] = tensor.extract_slice %[[ARG3]][%[[IDX]], 0] [32, 128] [1, 1] : tensor<128x128xf32> to tensor<32x128xf32>
- // CHECK-NEXT: %[[SLICE2:.*]] = tensor.extract_slice %[[EMPTY]][%[[IDX]], 0] [32, 128] [1, 1] : tensor<128x128xf16> to tensor<32x128xf16>
- // CHECK-NEXT: %[[GENERIC:.*]] = linalg.generic {{.*}} ins(%[[SLICE1]] : {{.*}}) outs(%[[SLICE2]] : {{.*}})
- // CHECK: scf.forall.in_parallel {
- // CHECK-NEXT: tensor.parallel_insert_slice %[[SLICE0]] into %[[S1]][%[[IDX]]] [32] [1] : tensor<32xf32> into tensor<128xf32>
- // CHECK-NEXT: tensor.parallel_insert_slice %[[GENERIC]] into %[[S2]][%[[IDX]], 0] [32, 128] [1, 1] : tensor<32x128xf16> into tensor<128x128xf16>
- // CHECK-NEXT: }
- // CHECK-NEXT: } {mapping = [#gpu.warp<linear_dim_0>]}
- // CHECK-NEXT: return %[[RESULTS]]#0, %[[RESULTS]]#1
- %0 = scf.forall (%arg4) in (4) shared_outs(%arg5 = %arg0) -> (tensor<128xf32>) {
- %3 = affine.apply #map(%arg4)
- %extracted_slice = tensor.extract_slice %arg3[%3, 0] [32, 1] [1, 1] : tensor<128x128xf32> to tensor<32xf32>
- scf.forall.in_parallel {
- tensor.parallel_insert_slice %extracted_slice into %arg5[%3] [32] [1] : tensor<32xf32> into tensor<128xf32>
- }
- } {mapping = [#gpu.warp<linear_dim_0>]}
- %1 = tensor.empty() : tensor<128x128xf16>
- %2 = scf.forall (%arg4) in (4) shared_outs(%arg5 = %arg1) -> (tensor<128x128xf16>) {
- %3 = affine.apply #map(%arg4)
- %extracted_slice = tensor.extract_slice %arg3[%3, 0] [32, 128] [1, 1] : tensor<128x128xf32> to tensor<32x128xf32>
- %extracted_slice_0 = tensor.extract_slice %1[%3, 0] [32, 128] [1, 1] : tensor<128x128xf16> to tensor<32x128xf16>
- %4 = linalg.generic {indexing_maps = [#map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%extracted_slice : tensor<32x128xf32>) outs(%extracted_slice_0 : tensor<32x128xf16>) {
- ^bb0(%in: f32, %out: f16):
- %5 = arith.truncf %in : f32 to f16
- linalg.yield %5 : f16
- } -> tensor<32x128xf16>
- scf.forall.in_parallel {
- tensor.parallel_insert_slice %4 into %arg5[%3, 0] [32, 128] [1, 1] : tensor<32x128xf16> into tensor<128x128xf16>
- }
- } {mapping = [#gpu.warp<linear_dim_0>]}
- return %0, %2 : tensor<128xf32>, tensor<128x128xf16>
- }
-}
-
-module attributes { transform.with_named_sequence } {
- transform.named_sequence @__transform_main(%root: !transform.any_op) {
- %loops = transform.structured.match ops{["scf.forall"]} in %root : (!transform.any_op) -> !transform.any_op
- %loop1, %loop2 = transform.split_handle %loops : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
- %loop3 = transform.loop.fuse_sibling %loop1 into %loop2 : (!transform.any_op, !transform.any_op) -> !transform.any_op
- transform.yield
- }
-}
-
// -----
func.func @source_for_uses_result_of_target_for_err(%A: tensor<128xf32>, %B: tensor<128xf32>) -> (tensor<128xf32>, tensor<128xf32>) {
@@ -501,9 +282,8 @@ func.func @target_for_region_uses_result_of_source_for_err(%A: tensor<128xf32>,
%6 = vector.transfer_write %5, %arg4[%arg3] {in_bounds = [true]} : vector<16xf32>, tensor<128xf32>
scf.yield %6 : tensor<128xf32>
}
- // expected-error @below {{values used inside regions of target should be properly dominated by source}}
%dup1 = scf.for %arg3 = %c0 to %c128 step %c16 iter_args(%arg4 = %B) -> (tensor<128xf32>) {
- // expected-note @below {{see operation}}
+ // expected-error @below {{values used inside regions of target should be properly dominated by source}}
%dup2 = vector.transfer_read %1[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32>
%dup3 = vector.transfer_read %arg4[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32>
%dup5 = arith.addf %dup3, %dup2 : vector<16xf32>
@@ -548,74 +328,6 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
-
-// -----
-
-func.func @non_matching_iteration_spaces_err(%A: memref<2x2xf32>, %B: memref<2x2xf32>) {
- %c2 = arith.constant 2 : index
- %c0 = arith.constant 0 : index
- %c1 = arith.constant 1 : index
- %c1fp = arith.constant 1.0 : f32
- %sum = memref.alloc() : memref<2x2xf32>
- // expected-error @below {{target and source iteration spaces must be equal}}
- scf.parallel (%i) = (%c0) to (%c2) step (%c1) {
- %B_elem = memref.load %B[%i, %c0] : memref<2x2xf32>
- %sum_elem = arith.addf %B_elem, %c1fp : f32
- memref.store %sum_elem, %sum[%i, %c0] : memref<2x2xf32>
- scf.reduce
- }
- scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
- %sum_elem = memref.load %sum[%i, %j] : memref<2x2xf32>
- %A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
- %product_elem = arith.mulf %sum_elem, %A_elem : f32
- memref.store %product_elem, %B[%i, %j] : memref<2x2xf32>
- scf.reduce
- }
- memref.dealloc %sum : memref<2x2xf32>
- return
-}
-module attributes {transform.with_named_sequence} {
- transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["scf.parallel"]} in %arg0 : (!transform.any_op) -> !transform.any_op
- %parallel:2 = transform.split_handle %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
- %fused = transform.loop.fuse_sibling %parallel#0 into %parallel#1 : (!transform.any_op,!transform.any_op) -> !transform.any_op
- transform.yield
- }
-}
-
-// -----
-
-func.func @non_matching_loop_types_err(%A: memref<2xf32>, %B: memref<2xf32>) {
- %c2 = arith.constant 2 : index
- %c0 = arith.constant 0 : index
- %c1 = arith.constant 1 : index
- %c1fp = arith.constant 1.0 : f32
- %sum = memref.alloc() : memref<2xf32>
- // expected-error @below {{target and source must be same loop type}}
- scf.for %i = %c0 to %c2 step %c1 {
- %B_elem = memref.load %B[%i] : memref<2xf32>
- %sum_elem = arith.addf %B_elem, %c1fp : f32
- memref.store %sum_elem, %sum[%i] : memref<2xf32>
- }
- scf.parallel (%i) = (%c0) to (%c2) step (%c1) {
- %sum_elem = memref.load %sum[%i] : memref<2xf32>
- %A_elem = memref.load %A[%i] : memref<2xf32>
- %product_elem = arith.mulf %sum_elem, %A_elem : f32
- memref.store %product_elem, %B[%i] : memref<2xf32>
- scf.reduce
- }
- memref.dealloc %sum : memref<2xf32>
- return
-}
-module attributes {transform.with_named_sequence} {
- transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["scf.for"]} in %arg0 : (!transform.any_op) -> !transform.any_op
- %1 = transform.structured.match ops{["scf.parallel"]} in %arg0 : (!transform.any_op) -> !transform.any_op
- %fused = transform.loop.fuse_sibling %0 into %1 : (!transform.any_op,!transform.any_op) -> !transform.any_op
- transform.yield
- }
-}
-
// -----
// CHECK: func.func @foreach_loop_pair_fuse([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}
More information about the Mlir-commits
mailing list