[Mlir-commits] [mlir] [Mlir] decompose generic by unfolding projected permutation crash fix (PR #122449)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sat Jan 11 04:39:54 PST 2025


================
@@ -159,6 +160,16 @@ LogicalResult DecomposeProjectedPermutation::matchAndRewrite(
     auto map = op.getMatchingIndexingMap(&opOperand);
     if (!map.isProjectedPermutation(false))
       return failure();
+
+    // If we have any inputs that aren't memref or ranked tensor types, reject
+    // the pattern.
+    if (!dyn_cast<ShapedType>(opOperand.get().getType()))
----------------
GrumpyPigSkin wrote:

Yes it matches 0D shapes :)

```mlir
//===-------------------------------------------===//
Processing operation : 'linalg.generic'(0x55556870a890) {

  * Pattern mlir::linalg::LinalgSpecializationPattern : 'linalg.generic -> ()' {
Trying to match "mlir::linalg::LinalgSpecializationPattern"
"mlir::linalg::LinalgSpecializationPattern" result 0
  } -> failure : pattern failed to match

  * Pattern (anonymous namespace)::DecomposeProjectedPermutation : 'linalg.generic -> ()' {
Trying to match "(anonymous namespace)::DecomposeProjectedPermutation"
"(anonymous namespace)::DecomposeProjectedPermutation" result 1
  } -> success : pattern applied successfully
// *** IR Dump After Pattern Application ***
func.func @test_broadcast_single_tensor() -> tensor<2x2xi32> {
  %cst = arith.constant dense<2> : tensor<2x2xi32>
  %cst_0 = arith.constant dense<42> : tensor<i32>
  %0 = tensor.empty() : tensor<2x2xi32>
  %1 = tensor.empty() : tensor<2x2xi32>
  %broadcasted = linalg.broadcast ins(%cst_0 : tensor<i32>) outs(%1 : tensor<2x2xi32>) dimensions = [0, 1] 
  %2 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%cst, %broadcasted : tensor<2x2xi32>, tensor<2x2xi32>) outs(%0 : tensor<2x2xi32>) {
  ^bb0(%in: i32, %in_1: i32, %out: i32):
    %3 = arith.addi %in, %in_1 : i32
    linalg.yield %3 : i32
  } -> tensor<2x2xi32>
  return %2 : tensor<2x2xi32>
}


} -> success : pattern matched
//===-------------------------------------------===//

//===-------------------------------------------===//
Processing operation : 'arith.addi'(0x555568783370) {
  %6 = "arith.addi"(%arg0, %arg1) <{overflowFlags = #arith.overflow<none>}> : (i32, i32) -> i32

} -> failure : pattern failed to match
```

On a side note, after matching `DecomposeProjectedPermutation`, it then fails to match any other pattern. Is this dependant on the pass being applied or could it indicate that something has gone wrong?

https://github.com/llvm/llvm-project/pull/122449


More information about the Mlir-commits mailing list