[Mlir-commits] [mlir] [Mlir] decompose generic by unfolding projected permutation crash fix (PR #122449)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sat Jan 11 04:39:54 PST 2025
================
@@ -159,6 +160,16 @@ LogicalResult DecomposeProjectedPermutation::matchAndRewrite(
auto map = op.getMatchingIndexingMap(&opOperand);
if (!map.isProjectedPermutation(false))
return failure();
+
+ // If we have any inputs that aren't memref or ranked tensor types, reject
+ // the pattern.
+ if (!dyn_cast<ShapedType>(opOperand.get().getType()))
----------------
GrumpyPigSkin wrote:
Yes it matches 0D shapes :)
```mlir
//===-------------------------------------------===//
Processing operation : 'linalg.generic'(0x55556870a890) {
* Pattern mlir::linalg::LinalgSpecializationPattern : 'linalg.generic -> ()' {
Trying to match "mlir::linalg::LinalgSpecializationPattern"
"mlir::linalg::LinalgSpecializationPattern" result 0
} -> failure : pattern failed to match
* Pattern (anonymous namespace)::DecomposeProjectedPermutation : 'linalg.generic -> ()' {
Trying to match "(anonymous namespace)::DecomposeProjectedPermutation"
"(anonymous namespace)::DecomposeProjectedPermutation" result 1
} -> success : pattern applied successfully
// *** IR Dump After Pattern Application ***
func.func @test_broadcast_single_tensor() -> tensor<2x2xi32> {
%cst = arith.constant dense<2> : tensor<2x2xi32>
%cst_0 = arith.constant dense<42> : tensor<i32>
%0 = tensor.empty() : tensor<2x2xi32>
%1 = tensor.empty() : tensor<2x2xi32>
%broadcasted = linalg.broadcast ins(%cst_0 : tensor<i32>) outs(%1 : tensor<2x2xi32>) dimensions = [0, 1]
%2 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%cst, %broadcasted : tensor<2x2xi32>, tensor<2x2xi32>) outs(%0 : tensor<2x2xi32>) {
^bb0(%in: i32, %in_1: i32, %out: i32):
%3 = arith.addi %in, %in_1 : i32
linalg.yield %3 : i32
} -> tensor<2x2xi32>
return %2 : tensor<2x2xi32>
}
} -> success : pattern matched
//===-------------------------------------------===//
//===-------------------------------------------===//
Processing operation : 'arith.addi'(0x555568783370) {
%6 = "arith.addi"(%arg0, %arg1) <{overflowFlags = #arith.overflow<none>}> : (i32, i32) -> i32
} -> failure : pattern failed to match
```
On a side note, after matching `DecomposeProjectedPermutation`, it then fails to match any other pattern. Is this dependant on the pass being applied or could it indicate that something has gone wrong?
https://github.com/llvm/llvm-project/pull/122449
More information about the Mlir-commits
mailing list