[Mlir-commits] [mlir] [mlir][Linalg] Bugfix in decompose generic by unfolding permutation (PR #126737)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Feb 11 06:30:14 PST 2025


https://github.com/gdehame created https://github.com/llvm/llvm-project/pull/126737

The pattern was returning success() by default which made the greedy pattern application act as if the IR was modified and even though nothing was changed and thus it can prevent it from converging for no legitimate reason.

The patch makes the rewrite pattern return failure() by default and success() if and only if the IR changed.

An example of unexpected behavior is by running `mlir-opt input.mlir --linalg-specialize-generic-ops` with `input.mlir` as follows:
```
#map = affine_map<(d0) -> (d0)>
func.func @f(%arg0: tensor<8xi32>, %arg1: tensor<8xi32>) -> tensor<8xi32> {
  %0 = tensor.empty() : tensor<8xi32>
  %1 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg0, %arg1: tensor<8xi32>, tensor<8xi32>) outs(%0: tensor<8xi32>) {
    ^bb0(%in: i32, %in_0: i32, %out: i32):
      %2 = arith.addi %in, %in_0: i32
      linalg.yield %2: i32
  } -> tensor<8xi32>
  return %1 : tensor<8xi32>
}
```

If we add the `--debug` option, we see that the `DecomposeProjectedPermutation` returns success() but the IR doesn't change and the pattern rewrite doesn't converge.

Output of `mlir-opt input.mlir --linalg-specialize-generic-ops --debug `:
```
Args: mlir-opt input.mlir --linalg-specialize-generic-ops --debug 
Load new dialect in Context builtin
ImplicitTypeIDRegistry::lookupOrInsert(mlir::FloatType)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::ShapedType)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::MemRefLayoutAttrInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::BlobAttr)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::TypedAttr)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::ElementsAttr)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::DistinctAttr)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::BytecodeOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::SymbolOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpAsmOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::RegionKindInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::ConditionallySpeculatable)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::MemoryEffectOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::ResourceBlobManagerDialectInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpAsmDialectInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::BytecodeDialectInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::detail::AffineBinaryOpExprStorage)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::detail::AffineConstantExprStorage)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::detail::AffineDimExprStorage)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::detail::AffineMapStorage)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::detail::IntegerSetStorage)
Load new dialect in Context builtin
ImplicitTypeIDRegistry::lookupOrInsert(mlir::CastOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::DestructurableTypeInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::LLVMTranslationDialectInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::ZeroOperands<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::OneRegion<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::ZeroResults<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::ZeroSuccessors<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::NoRegionArguments<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::NoTerminator<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::SingleBlock<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::OpInvariants<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::BytecodeOpInterface::Trait<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::AffineScope<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::IsIsolatedFromAbove<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::SymbolTable<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::SymbolOpInterface::Trait<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpAsmOpInterface::Trait<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::RegionKindInterface::Trait<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::HasOnlyGraphRegion<Empty>)
Load new dialect in Context func
ImplicitTypeIDRegistry::lookupOrInsert(mlir::CallOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::SymbolUserOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::CallableOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::FunctionOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::RegionBranchTerminatorOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::DialectInlinerInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::ConvertToLLVMPatternInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::bufferization::BufferizableOpInterface)
Load new dialect in Context cf
Load new dialect in Context arith
ImplicitTypeIDRegistry::lookupOrInsert(mlir::arith::ArithFastMathInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::VectorUnrollOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::InferTypeOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::InferIntRangeInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::arith::ArithIntegerOverflowFlagsInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::arith::ArithRoundingModeInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::SelectLikeOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::bufferization::BufferDeallocationOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::ValueBoundsOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::bufferization::BufferViewFlowOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::BranchOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::mesh::ShardingInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::AutomaticAllocationScope<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::CallableOpInterface::Trait<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::FunctionOpInterface::Trait<Empty>)
Load new dialect in Context tensor
Load new dialect in Context affine
Load new dialect in Context ub
ImplicitTypeIDRegistry::lookupOrInsert(mlir::ub::PoisonAttrInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::affine::AffineDmaStartOp)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::affine::AffineMapAccessInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::affine::AffineDmaWaitOp)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::LoopLikeOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::RegionBranchOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::affine::AffineReadOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::affine::AffineWriteOpInterface)
Load new dialect in Context complex
ImplicitTypeIDRegistry::lookupOrInsert(mlir::ReifyRankedShapedTypeOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::ShapedDimOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OffsetSizeAndStrideOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::DestinationStyleOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::tensor::RelayoutOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::transform::FindPayloadReplacementOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::SubsetOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::SubsetInsertionOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::SubsetExtractionOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::TilingInterface)
Load new dialect in Context linalg
Load new dialect in Context math
Load new dialect in Context memref
ImplicitTypeIDRegistry::lookupOrInsert(mlir::CopyOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::PromotableMemOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::DestructurableAccessorOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::PromotableAllocationOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::DestructurableAllocationOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::ViewLikeOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::bufferization::AllocationOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::RuntimeVerifiableOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::linalg::AggregatedOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::linalg::LinalgOp)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::linalg::ContractionOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::linalg::ConvolutionOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::linalg::FillOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::PartialReductionOpInterface)
Load new dialect in Context scf
ImplicitTypeIDRegistry::lookupOrInsert(mlir::ParallelCombiningOpInterface)
Ignoring repeated interface registration
Ignoring repeated interface registration
Ignoring repeated interface registration
Load new dialect in Context index
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::ZeroRegions<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::OneResult<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::OneTypedResult<mlir::RankedTensorType>::Impl<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::VariadicOperands<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::ConditionallySpeculatable::Trait<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::AlwaysSpeculatableImplTrait<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::MemoryEffectOpInterface::Trait<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::ReifyRankedShapedTypeOpInterface::Trait<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::VariadicResults<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::SingleBlockImplicitTerminator<mlir::linalg::YieldOp>::Impl<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::AttrSizedOperandSegments<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::HasRecursiveMemoryEffects<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::DestinationStyleOpInterface::Trait<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::linalg::LinalgOp::Trait<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::OneTypedResult<mlir::Type>::Impl<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::NOperands<2>::Impl<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::IsCommutative<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::InferIntRangeInterface::Trait<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::arith::ArithIntegerOverflowFlagsInterface::Trait<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::SameOperandsAndResultType<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::VectorUnrollOpInterface::Trait<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::Elementwise<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::Scalarizable<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::Vectorizable<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::Tensorizable<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::InferTypeOpInterface::Trait<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::RegionBranchTerminatorOpInterface::Trait<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::ReturnLike<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::IsTerminator<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::HasParent<mlir::func::FuncOp>::Impl<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::MemRefsNormalizable<Empty>)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::detail::OpToOpPassAdaptor)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::DialectFoldInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::ConstantLike<Empty>)

//===-------------------------------------------===//
Processing operation : 'func.return'(0x5ecbe5e16090) {
  "func.return"(%1) : (tensor<8xi32>) -> ()

} -> failure : pattern failed to match
//===-------------------------------------------===//

//===-------------------------------------------===//
Processing operation : 'linalg.yield'(0x5ecbe5e17630) {
  "linalg.yield"(%2) : (i32) -> ()

} -> failure : pattern failed to match
//===-------------------------------------------===//

//===-------------------------------------------===//
Processing operation : 'linalg.generic'(0x5ecbe5d9b560) {

  * Pattern mlir::linalg::LinalgSpecializationPattern : 'linalg.generic -> ()' {
Trying to match "mlir::linalg::LinalgSpecializationPattern"
"mlir::linalg::LinalgSpecializationPattern" result 0
  } -> failure : pattern failed to match

  * Pattern (anonymous namespace)::DecomposeProjectedPermutation : 'linalg.generic -> ()' {
Trying to match "(anonymous namespace)::DecomposeProjectedPermutation"
"(anonymous namespace)::DecomposeProjectedPermutation" result 1
  } -> success : pattern applied successfully
// *** IR Dump After Pattern Application ***
func.func @f(%arg0: tensor<8xi32>, %arg1: tensor<8xi32>) -> tensor<8xi32> {
  %0 = tensor.empty() : tensor<8xi32>
  %1 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%arg0, %arg1 : tensor<8xi32>, tensor<8xi32>) outs(%0 : tensor<8xi32>) {
  ^bb0(%in: i32, %in_0: i32, %out: i32):
    %2 = arith.addi %in, %in_0 : i32
    linalg.yield %2 : i32
  } -> tensor<8xi32>
  return %1 : tensor<8xi32>
}


} -> success : pattern matched
//===-------------------------------------------===//

//===-------------------------------------------===//
Processing operation : 'arith.addi'(0x5ecbe5e170b0) {
  %2 = "arith.addi"(%arg2, %arg3) <{overflowFlags = #arith.overflow<none>}> : (i32, i32) -> i32

} -> failure : pattern failed to match
//===-------------------------------------------===//

//===-------------------------------------------===//
Processing operation : 'func.func'(0x5ecbe5dc0ca0) {
} -> failure : pattern failed to match
//===-------------------------------------------===//

//===-------------------------------------------===//
Processing operation : 'tensor.empty'(0x5ecbe5dffa60) {
  %0 = "tensor.empty"() : () -> tensor<8xi32>

} -> failure : pattern failed to match
//===-------------------------------------------===//

//===-------------------------------------------===//
Processing operation : 'func.return'(0x5ecbe5e16090) {
  "func.return"(%1) : (tensor<8xi32>) -> ()

} -> failure : pattern failed to match
//===-------------------------------------------===//

//===-------------------------------------------===//
Processing operation : 'linalg.yield'(0x5ecbe5e17630) {
  "linalg.yield"(%2) : (i32) -> ()

} -> failure : pattern failed to match
//===-------------------------------------------===//

//===-------------------------------------------===//
Processing operation : 'linalg.generic'(0x5ecbe5d9b560) {

  * Pattern mlir::linalg::LinalgSpecializationPattern : 'linalg.generic -> ()' {
Trying to match "mlir::linalg::LinalgSpecializationPattern"
"mlir::linalg::LinalgSpecializationPattern" result 0
  } -> failure : pattern failed to match

  * Pattern (anonymous namespace)::DecomposeProjectedPermutation : 'linalg.generic -> ()' {
Trying to match "(anonymous namespace)::DecomposeProjectedPermutation"
"(anonymous namespace)::DecomposeProjectedPermutation" result 1
  } -> success : pattern applied successfully
// *** IR Dump After Pattern Application ***
func.func @f(%arg0: tensor<8xi32>, %arg1: tensor<8xi32>) -> tensor<8xi32> {
  %0 = tensor.empty() : tensor<8xi32>
  %1 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%arg0, %arg1 : tensor<8xi32>, tensor<8xi32>) outs(%0 : tensor<8xi32>) {
  ^bb0(%in: i32, %in_0: i32, %out: i32):
    %2 = arith.addi %in, %in_0 : i32
    linalg.yield %2 : i32
  } -> tensor<8xi32>
  return %1 : tensor<8xi32>
}


} -> success : pattern matched
//===-------------------------------------------===//

//===-------------------------------------------===//
Processing operation : 'arith.addi'(0x5ecbe5e170b0) {
  %2 = "arith.addi"(%arg2, %arg3) <{overflowFlags = #arith.overflow<none>}> : (i32, i32) -> i32

} -> failure : pattern failed to match
//===-------------------------------------------===//

//===-------------------------------------------===//
Processing operation : 'func.func'(0x5ecbe5dc0ca0) {
} -> failure : pattern failed to match
//===-------------------------------------------===//

//===-------------------------------------------===//
Processing operation : 'tensor.empty'(0x5ecbe5dffa60) {
  %0 = "tensor.empty"() : () -> tensor<8xi32>

} -> failure : pattern failed to match
//===-------------------------------------------===//

//===-------------------------------------------===//
Processing operation : 'func.return'(0x5ecbe5e16090) {
  "func.return"(%1) : (tensor<8xi32>) -> ()

} -> failure : pattern failed to match
//===-------------------------------------------===//

//===-------------------------------------------===//
Processing operation : 'linalg.yield'(0x5ecbe5e17630) {
  "linalg.yield"(%2) : (i32) -> ()

} -> failure : pattern failed to match
//===-------------------------------------------===//

//===-------------------------------------------===//
Processing operation : 'linalg.generic'(0x5ecbe5d9b560) {

  * Pattern mlir::linalg::LinalgSpecializationPattern : 'linalg.generic -> ()' {
Trying to match "mlir::linalg::LinalgSpecializationPattern"
"mlir::linalg::LinalgSpecializationPattern" result 0
  } -> failure : pattern failed to match

  * Pattern (anonymous namespace)::DecomposeProjectedPermutation : 'linalg.generic -> ()' {
Trying to match "(anonymous namespace)::DecomposeProjectedPermutation"
"(anonymous namespace)::DecomposeProjectedPermutation" result 1
  } -> success : pattern applied successfully
// *** IR Dump After Pattern Application ***
func.func @f(%arg0: tensor<8xi32>, %arg1: tensor<8xi32>) -> tensor<8xi32> {
  %0 = tensor.empty() : tensor<8xi32>
  %1 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%arg0, %arg1 : tensor<8xi32>, tensor<8xi32>) outs(%0 : tensor<8xi32>) {
  ^bb0(%in: i32, %in_0: i32, %out: i32):
    %2 = arith.addi %in, %in_0 : i32
    linalg.yield %2 : i32
  } -> tensor<8xi32>
  return %1 : tensor<8xi32>
}


} -> success : pattern matched
//===-------------------------------------------===//

//===-------------------------------------------===//
Processing operation : 'arith.addi'(0x5ecbe5e170b0) {
  %2 = "arith.addi"(%arg2, %arg3) <{overflowFlags = #arith.overflow<none>}> : (i32, i32) -> i32

} -> failure : pattern failed to match
//===-------------------------------------------===//

//===-------------------------------------------===//
Processing operation : 'func.func'(0x5ecbe5dc0ca0) {
} -> failure : pattern failed to match
//===-------------------------------------------===//

//===-------------------------------------------===//
Processing operation : 'tensor.empty'(0x5ecbe5dffa60) {
  %0 = "tensor.empty"() : () -> tensor<8xi32>

} -> failure : pattern failed to match
//===-------------------------------------------===//

//===-------------------------------------------===//
Processing operation : 'func.return'(0x5ecbe5e16090) {
  "func.return"(%1) : (tensor<8xi32>) -> ()

} -> failure : pattern failed to match
//===-------------------------------------------===//

//===-------------------------------------------===//
Processing operation : 'linalg.yield'(0x5ecbe5e17630) {
  "linalg.yield"(%2) : (i32) -> ()

} -> failure : pattern failed to match
//===-------------------------------------------===//

//===-------------------------------------------===//
Processing operation : 'linalg.generic'(0x5ecbe5d9b560) {

  * Pattern mlir::linalg::LinalgSpecializationPattern : 'linalg.generic -> ()' {
Trying to match "mlir::linalg::LinalgSpecializationPattern"
"mlir::linalg::LinalgSpecializationPattern" result 0
  } -> failure : pattern failed to match

  * Pattern (anonymous namespace)::DecomposeProjectedPermutation : 'linalg.generic -> ()' {
Trying to match "(anonymous namespace)::DecomposeProjectedPermutation"
"(anonymous namespace)::DecomposeProjectedPermutation" result 1
  } -> success : pattern applied successfully
// *** IR Dump After Pattern Application ***
func.func @f(%arg0: tensor<8xi32>, %arg1: tensor<8xi32>) -> tensor<8xi32> {
  %0 = tensor.empty() : tensor<8xi32>
  %1 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%arg0, %arg1 : tensor<8xi32>, tensor<8xi32>) outs(%0 : tensor<8xi32>) {
  ^bb0(%in: i32, %in_0: i32, %out: i32):
    %2 = arith.addi %in, %in_0 : i32
    linalg.yield %2 : i32
  } -> tensor<8xi32>
  return %1 : tensor<8xi32>
}


} -> success : pattern matched
//===-------------------------------------------===//

//===-------------------------------------------===//
Processing operation : 'arith.addi'(0x5ecbe5e170b0) {
  %2 = "arith.addi"(%arg2, %arg3) <{overflowFlags = #arith.overflow<none>}> : (i32, i32) -> i32

} -> failure : pattern failed to match
//===-------------------------------------------===//

//===-------------------------------------------===//
Processing operation : 'func.func'(0x5ecbe5dc0ca0) {
} -> failure : pattern failed to match
//===-------------------------------------------===//

//===-------------------------------------------===//
Processing operation : 'tensor.empty'(0x5ecbe5dffa60) {
  %0 = "tensor.empty"() : () -> tensor<8xi32>

} -> failure : pattern failed to match
//===-------------------------------------------===//

//===-------------------------------------------===//
Processing operation : 'func.return'(0x5ecbe5e16090) {
  "func.return"(%1) : (tensor<8xi32>) -> ()

} -> failure : pattern failed to match
//===-------------------------------------------===//

//===-------------------------------------------===//
Processing operation : 'linalg.yield'(0x5ecbe5e17630) {
  "linalg.yield"(%2) : (i32) -> ()

} -> failure : pattern failed to match
//===-------------------------------------------===//

//===-------------------------------------------===//
Processing operation : 'linalg.generic'(0x5ecbe5d9b560) {

  * Pattern mlir::linalg::LinalgSpecializationPattern : 'linalg.generic -> ()' {
Trying to match "mlir::linalg::LinalgSpecializationPattern"
"mlir::linalg::LinalgSpecializationPattern" result 0
  } -> failure : pattern failed to match

  * Pattern (anonymous namespace)::DecomposeProjectedPermutation : 'linalg.generic -> ()' {
Trying to match "(anonymous namespace)::DecomposeProjectedPermutation"
"(anonymous namespace)::DecomposeProjectedPermutation" result 1
  } -> success : pattern applied successfully
// *** IR Dump After Pattern Application ***
func.func @f(%arg0: tensor<8xi32>, %arg1: tensor<8xi32>) -> tensor<8xi32> {
  %0 = tensor.empty() : tensor<8xi32>
  %1 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%arg0, %arg1 : tensor<8xi32>, tensor<8xi32>) outs(%0 : tensor<8xi32>) {
  ^bb0(%in: i32, %in_0: i32, %out: i32):
    %2 = arith.addi %in, %in_0 : i32
    linalg.yield %2 : i32
  } -> tensor<8xi32>
  return %1 : tensor<8xi32>
}


} -> success : pattern matched
//===-------------------------------------------===//

//===-------------------------------------------===//
Processing operation : 'arith.addi'(0x5ecbe5e170b0) {
  %2 = "arith.addi"(%arg2, %arg3) <{overflowFlags = #arith.overflow<none>}> : (i32, i32) -> i32

} -> failure : pattern failed to match
//===-------------------------------------------===//

//===-------------------------------------------===//
Processing operation : 'func.func'(0x5ecbe5dc0ca0) {
} -> failure : pattern failed to match
//===-------------------------------------------===//

//===-------------------------------------------===//
Processing operation : 'tensor.empty'(0x5ecbe5dffa60) {
  %0 = "tensor.empty"() : () -> tensor<8xi32>

} -> failure : pattern failed to match
//===-------------------------------------------===//

//===-------------------------------------------===//
Processing operation : 'func.return'(0x5ecbe5e16090) {
  "func.return"(%1) : (tensor<8xi32>) -> ()

} -> failure : pattern failed to match
//===-------------------------------------------===//

//===-------------------------------------------===//
Processing operation : 'linalg.yield'(0x5ecbe5e17630) {
  "linalg.yield"(%2) : (i32) -> ()

} -> failure : pattern failed to match
//===-------------------------------------------===//

//===-------------------------------------------===//
Processing operation : 'linalg.generic'(0x5ecbe5d9b560) {

  * Pattern mlir::linalg::LinalgSpecializationPattern : 'linalg.generic -> ()' {
Trying to match "mlir::linalg::LinalgSpecializationPattern"
"mlir::linalg::LinalgSpecializationPattern" result 0
  } -> failure : pattern failed to match

  * Pattern (anonymous namespace)::DecomposeProjectedPermutation : 'linalg.generic -> ()' {
Trying to match "(anonymous namespace)::DecomposeProjectedPermutation"
"(anonymous namespace)::DecomposeProjectedPermutation" result 1
  } -> success : pattern applied successfully
// *** IR Dump After Pattern Application ***
func.func @f(%arg0: tensor<8xi32>, %arg1: tensor<8xi32>) -> tensor<8xi32> {
  %0 = tensor.empty() : tensor<8xi32>
  %1 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%arg0, %arg1 : tensor<8xi32>, tensor<8xi32>) outs(%0 : tensor<8xi32>) {
  ^bb0(%in: i32, %in_0: i32, %out: i32):
    %2 = arith.addi %in, %in_0 : i32
    linalg.yield %2 : i32
  } -> tensor<8xi32>
  return %1 : tensor<8xi32>
}


} -> success : pattern matched
//===-------------------------------------------===//

//===-------------------------------------------===//
Processing operation : 'arith.addi'(0x5ecbe5e170b0) {
  %2 = "arith.addi"(%arg2, %arg3) <{overflowFlags = #arith.overflow<none>}> : (i32, i32) -> i32

} -> failure : pattern failed to match
//===-------------------------------------------===//

//===-------------------------------------------===//
Processing operation : 'func.func'(0x5ecbe5dc0ca0) {
} -> failure : pattern failed to match
//===-------------------------------------------===//

//===-------------------------------------------===//
Processing operation : 'tensor.empty'(0x5ecbe5dffa60) {
  %0 = "tensor.empty"() : () -> tensor<8xi32>

} -> failure : pattern failed to match
//===-------------------------------------------===//

//===-------------------------------------------===//
Processing operation : 'func.return'(0x5ecbe5e16090) {
  "func.return"(%1) : (tensor<8xi32>) -> ()

} -> failure : pattern failed to match
//===-------------------------------------------===//

//===-------------------------------------------===//
Processing operation : 'linalg.yield'(0x5ecbe5e17630) {
  "linalg.yield"(%2) : (i32) -> ()

} -> failure : pattern failed to match
//===-------------------------------------------===//

//===-------------------------------------------===//
Processing operation : 'linalg.generic'(0x5ecbe5d9b560) {

  * Pattern mlir::linalg::LinalgSpecializationPattern : 'linalg.generic -> ()' {
Trying to match "mlir::linalg::LinalgSpecializationPattern"
"mlir::linalg::LinalgSpecializationPattern" result 0
  } -> failure : pattern failed to match

  * Pattern (anonymous namespace)::DecomposeProjectedPermutation : 'linalg.generic -> ()' {
Trying to match "(anonymous namespace)::DecomposeProjectedPermutation"
"(anonymous namespace)::DecomposeProjectedPermutation" result 1
  } -> success : pattern applied successfully
// *** IR Dump After Pattern Application ***
func.func @f(%arg0: tensor<8xi32>, %arg1: tensor<8xi32>) -> tensor<8xi32> {
  %0 = tensor.empty() : tensor<8xi32>
  %1 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%arg0, %arg1 : tensor<8xi32>, tensor<8xi32>) outs(%0 : tensor<8xi32>) {
  ^bb0(%in: i32, %in_0: i32, %out: i32):
    %2 = arith.addi %in, %in_0 : i32
    linalg.yield %2 : i32
  } -> tensor<8xi32>
  return %1 : tensor<8xi32>
}


} -> success : pattern matched
//===-------------------------------------------===//

//===-------------------------------------------===//
Processing operation : 'arith.addi'(0x5ecbe5e170b0) {
  %2 = "arith.addi"(%arg2, %arg3) <{overflowFlags = #arith.overflow<none>}> : (i32, i32) -> i32

} -> failure : pattern failed to match
//===-------------------------------------------===//

//===-------------------------------------------===//
Processing operation : 'func.func'(0x5ecbe5dc0ca0) {
} -> failure : pattern failed to match
//===-------------------------------------------===//

//===-------------------------------------------===//
Processing operation : 'tensor.empty'(0x5ecbe5dffa60) {
  %0 = "tensor.empty"() : () -> tensor<8xi32>

} -> failure : pattern failed to match
//===-------------------------------------------===//

//===-------------------------------------------===//
Processing operation : 'func.return'(0x5ecbe5e16090) {
  "func.return"(%1) : (tensor<8xi32>) -> ()

} -> failure : pattern failed to match
//===-------------------------------------------===//

//===-------------------------------------------===//
Processing operation : 'linalg.yield'(0x5ecbe5e17630) {
  "linalg.yield"(%2) : (i32) -> ()

} -> failure : pattern failed to match
//===-------------------------------------------===//

//===-------------------------------------------===//
Processing operation : 'linalg.generic'(0x5ecbe5d9b560) {

  * Pattern mlir::linalg::LinalgSpecializationPattern : 'linalg.generic -> ()' {
Trying to match "mlir::linalg::LinalgSpecializationPattern"
"mlir::linalg::LinalgSpecializationPattern" result 0
  } -> failure : pattern failed to match

  * Pattern (anonymous namespace)::DecomposeProjectedPermutation : 'linalg.generic -> ()' {
Trying to match "(anonymous namespace)::DecomposeProjectedPermutation"
"(anonymous namespace)::DecomposeProjectedPermutation" result 1
  } -> success : pattern applied successfully
// *** IR Dump After Pattern Application ***
func.func @f(%arg0: tensor<8xi32>, %arg1: tensor<8xi32>) -> tensor<8xi32> {
  %0 = tensor.empty() : tensor<8xi32>
  %1 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%arg0, %arg1 : tensor<8xi32>, tensor<8xi32>) outs(%0 : tensor<8xi32>) {
  ^bb0(%in: i32, %in_0: i32, %out: i32):
    %2 = arith.addi %in, %in_0 : i32
    linalg.yield %2 : i32
  } -> tensor<8xi32>
  return %1 : tensor<8xi32>
}


} -> success : pattern matched
//===-------------------------------------------===//

//===-------------------------------------------===//
Processing operation : 'arith.addi'(0x5ecbe5e170b0) {
  %2 = "arith.addi"(%arg2, %arg3) <{overflowFlags = #arith.overflow<none>}> : (i32, i32) -> i32

} -> failure : pattern failed to match
//===-------------------------------------------===//

//===-------------------------------------------===//
Processing operation : 'func.func'(0x5ecbe5dc0ca0) {
} -> failure : pattern failed to match
//===-------------------------------------------===//

//===-------------------------------------------===//
Processing operation : 'tensor.empty'(0x5ecbe5dffa60) {
  %0 = "tensor.empty"() : () -> tensor<8xi32>

} -> failure : pattern failed to match
//===-------------------------------------------===//

//===-------------------------------------------===//
Processing operation : 'func.return'(0x5ecbe5e16090) {
  "func.return"(%1) : (tensor<8xi32>) -> ()

} -> failure : pattern failed to match
//===-------------------------------------------===//

//===-------------------------------------------===//
Processing operation : 'linalg.yield'(0x5ecbe5e17630) {
  "linalg.yield"(%2) : (i32) -> ()

} -> failure : pattern failed to match
//===-------------------------------------------===//

//===-------------------------------------------===//
Processing operation : 'linalg.generic'(0x5ecbe5d9b560) {

  * Pattern mlir::linalg::LinalgSpecializationPattern : 'linalg.generic -> ()' {
Trying to match "mlir::linalg::LinalgSpecializationPattern"
"mlir::linalg::LinalgSpecializationPattern" result 0
  } -> failure : pattern failed to match

  * Pattern (anonymous namespace)::DecomposeProjectedPermutation : 'linalg.generic -> ()' {
Trying to match "(anonymous namespace)::DecomposeProjectedPermutation"
"(anonymous namespace)::DecomposeProjectedPermutation" result 1
  } -> success : pattern applied successfully
// *** IR Dump After Pattern Application ***
func.func @f(%arg0: tensor<8xi32>, %arg1: tensor<8xi32>) -> tensor<8xi32> {
  %0 = tensor.empty() : tensor<8xi32>
  %1 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%arg0, %arg1 : tensor<8xi32>, tensor<8xi32>) outs(%0 : tensor<8xi32>) {
  ^bb0(%in: i32, %in_0: i32, %out: i32):
    %2 = arith.addi %in, %in_0 : i32
    linalg.yield %2 : i32
  } -> tensor<8xi32>
  return %1 : tensor<8xi32>
}


} -> success : pattern matched
//===-------------------------------------------===//

//===-------------------------------------------===//
Processing operation : 'arith.addi'(0x5ecbe5e170b0) {
  %2 = "arith.addi"(%arg2, %arg3) <{overflowFlags = #arith.overflow<none>}> : (i32, i32) -> i32

} -> failure : pattern failed to match
//===-------------------------------------------===//

//===-------------------------------------------===//
Processing operation : 'func.func'(0x5ecbe5dc0ca0) {
} -> failure : pattern failed to match
//===-------------------------------------------===//

//===-------------------------------------------===//
Processing operation : 'tensor.empty'(0x5ecbe5dffa60) {
  %0 = "tensor.empty"() : () -> tensor<8xi32>

} -> failure : pattern failed to match
//===-------------------------------------------===//

//===-------------------------------------------===//
Processing operation : 'func.return'(0x5ecbe5e16090) {
  "func.return"(%1) : (tensor<8xi32>) -> ()

} -> failure : pattern failed to match
//===-------------------------------------------===//

//===-------------------------------------------===//
Processing operation : 'linalg.yield'(0x5ecbe5e17630) {
  "linalg.yield"(%2) : (i32) -> ()

} -> failure : pattern failed to match
//===-------------------------------------------===//

//===-------------------------------------------===//
Processing operation : 'linalg.generic'(0x5ecbe5d9b560) {

  * Pattern mlir::linalg::LinalgSpecializationPattern : 'linalg.generic -> ()' {
Trying to match "mlir::linalg::LinalgSpecializationPattern"
"mlir::linalg::LinalgSpecializationPattern" result 0
  } -> failure : pattern failed to match

  * Pattern (anonymous namespace)::DecomposeProjectedPermutation : 'linalg.generic -> ()' {
Trying to match "(anonymous namespace)::DecomposeProjectedPermutation"
"(anonymous namespace)::DecomposeProjectedPermutation" result 1
  } -> success : pattern applied successfully
// *** IR Dump After Pattern Application ***
func.func @f(%arg0: tensor<8xi32>, %arg1: tensor<8xi32>) -> tensor<8xi32> {
  %0 = tensor.empty() : tensor<8xi32>
  %1 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%arg0, %arg1 : tensor<8xi32>, tensor<8xi32>) outs(%0 : tensor<8xi32>) {
  ^bb0(%in: i32, %in_0: i32, %out: i32):
    %2 = arith.addi %in, %in_0 : i32
    linalg.yield %2 : i32
  } -> tensor<8xi32>
  return %1 : tensor<8xi32>
}


} -> success : pattern matched
//===-------------------------------------------===//

//===-------------------------------------------===//
Processing operation : 'arith.addi'(0x5ecbe5e170b0) {
  %2 = "arith.addi"(%arg2, %arg3) <{overflowFlags = #arith.overflow<none>}> : (i32, i32) -> i32

} -> failure : pattern failed to match
//===-------------------------------------------===//

//===-------------------------------------------===//
Processing operation : 'func.func'(0x5ecbe5dc0ca0) {
} -> failure : pattern failed to match
//===-------------------------------------------===//

//===-------------------------------------------===//
Processing operation : 'tensor.empty'(0x5ecbe5dffa60) {
  %0 = "tensor.empty"() : () -> tensor<8xi32>

} -> failure : pattern failed to match
//===-------------------------------------------===//
The pattern rewrite did not converge after scanning 10 times
ImplicitTypeIDRegistry::lookupOrInsert(mlir::detail::PreservedAnalyses::AllAnalysesType)

```

>From 6bb4f94ecc5c6ac3849197c8650fada7ad61c537 Mon Sep 17 00:00:00 2001
From: gdehame <gabrieldehame at gmail.com>
Date: Tue, 11 Feb 2025 15:24:47 +0100
Subject: [PATCH] [mlir][Linalg] Bugfix in decompose generic by unfolding
 permutation

The pattern was returning success() by default which made the greedy pattern application act as if the IR was modified and even though nothing was changed and thus it can prevent it from converging for no legitimate reason.

The patch makes the rewrite pattern return failure() by default and success() if and only if the IR changed
---
 .../Transforms/DecomposeGenericByUnfoldingPermutation.cpp      | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp
index 83c4b5bdf109765..281a248681792ab 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp
@@ -237,8 +237,9 @@ LogicalResult DecomposeProjectedPermutation::matchAndRewrite(
 
     newOp.getRegion().takeBody(op->getRegion(0));
     rewriter.replaceOp(op, newOp->getResults());
+    return success();
   }
-  return success();
+  return failure();
 }
 
 } // namespace



More information about the Mlir-commits mailing list