[Mlir-commits] [mlir] [mlir] [linalg] Check for dim shape to decide unit dim for each operand in dropUnitDims pass. (PR #91673)

Sayan Saha llvmlistbot at llvm.org
Fri May 10 05:20:26 PDT 2024


================
@@ -1087,3 +1087,47 @@ func.func @drop_known_unit_constant_low_high(%arg0: tensor<1x383x128xf32>) -> te
 //       CHECK:   } : tensor<383x128xf32> to tensor<384x128xf32>
 //       CHECK:   tensor.expand_shape %[[PADDED]]
 //  CHECK-SAME:     {{\[}}[0, 1], [2]] output_shape [1, 384, 128] : tensor<384x128xf32> into tensor<1x384x128xf32>
+
+
+// -----
+
+// CHECK: #[[$MAP0:.+]] = affine_map<()[s0, s1] -> (s0 * s1)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0) -> (0, d0)>
+// CHECK: #[[$MAP2:.+]] = affine_map<(d0) -> ()>
+
+// CHECK-LABEL: func @drop_unit_dim_corresponding_to_dynamic_dim
+// CHECK-SAME:                    %[[ARG0:.*]]: tensor<1x?x?x1xf32>,
+// CHECK-SAME:                    %[[ARG1:.*]]: index) -> tensor<?x1x61x1xf32> {
+// CHECK:           %[[VAL_0:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_1:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_2:.*]] = arith.constant dense<1.000000e+00> : tensor<f32>
+// CHECK:           %[[VAL_3:.*]] = tensor.collapse_shape %[[ARG0]] {{\[\[}}0, 1], [2, 3]] : tensor<1x?x?x1xf32> into tensor<?x?xf32>
+// CHECK:           %[[VAL_4:.*]] = tensor.empty(%[[ARG1]]) : tensor<?x61xf32>
+// CHECK:           %[[VAL_5:.*]] = affine.apply #[[$MAP0]](){{\[}}%[[ARG1]], %[[VAL_1]]]
+// CHECK:           %[[VAL_6:.*]] = tensor.empty(%[[VAL_5]]) : tensor<?x61xf32>
+// CHECK:           %[[VAL_7:.*]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel"]} ins(%[[VAL_3]], %[[VAL_2]], %[[VAL_4]] : tensor<?x?xf32>, tensor<f32>, tensor<?x61xf32>) outs(%[[VAL_6]] : tensor<?x61xf32>) {
----------------
sahas3 wrote:

Thanks for the review. If I understand correctly you are suggesting that the verifier should error for the `linalg.generic` op since `d0` is inferred to be different for different operands? 

The op is created by the following flow (I am just providing small snippets with relevant ops for brevity):

```
%inserted_slice_37 = tensor.insert_slice %expanded_18 into %15[%c0_25, %c0_28, %c0_30, %c0_31] [1, 1, 61, 1] [1, 1, 1, 1] : tensor<1x1x61x1xf32> into tensor<?x?x?x?xf32>
%18 = linalg.conv_2d_nhwc_hwcf ins(%inserted_slice_37, %cst_2 : tensor<?x?x?x?xf32>, tensor<1x1x1x1xf32>) outs(%17 : tensor<?x1x61x1xf32>) -> tensor<?x1x61x1xf32>
    %reduced = linalg.reduce ins(%cst_7 : tensor<3xi32>) outs(%cst_6 : tensor<i32>) dimensions = [0] 
      (%in: i32, %init: i32) {
        %21 = arith.muli %in, %init : i32
        linalg.yield %21 : i32
      }
```

converts to following after `--canonicalize`:

```
%inserted_slice_13 = tensor.insert_slice %expanded_10 into %7[0, 0, 0, 0] [1, 1, 61, 1] [1, 1, 1, 1] : tensor<1x1x61x1xf32> into tensor<1x?x?x1xf32>
%10 = linalg.conv_2d_nhwc_hwcf ins(%inserted_slice_13, %cst_2 : tensor<1x?x?x1xf32>, tensor<1x1x1x1xf32>) outs(%9 : tensor<?x1x61x1xf32>) -> tensor<?x1x61x1xf32>
  %reduced = linalg.reduce ins(%cst_6 : tensor<3xi32>) outs(%cst_5 : tensor<i32>) dimensions = [0] 
    (%in: i32, %init: i32) {
      %13 = arith.muli %in, %init : i32
      linalg.yield %13 : i32
    }
```

which converts after `--linalg-generalize-named-ops` to : 

```
%8 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)>, affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d3)>, affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%inserted_slice_13, %cst_2 : tensor<1x?x?x1xf32>, tensor<1x1x1x1xf32>) outs(%7 : tensor<?x1x61x1xf32>) {
  ^bb0(%in: f32, %in_16: f32, %out: f32):
    %12 = arith.mulf %in, %in_16 : f32
    %13 = arith.addf %out, %12 : f32
    linalg.yield %13 : f32
  } -> tensor<?x1x61x1xf32>
```

producing the `linalg.genericOp` mentioned in the repro. So, I think in addition to the verifier the canonicalizer should be enhanced to also update the type of the `outs` operand as it is updating the `ins` operand.

https://github.com/llvm/llvm-project/pull/91673


More information about the Mlir-commits mailing list