[all-commits] [llvm/llvm-project] 96e9b6: Revert "[mlir] Rewrite canonicalization of collaps...

Han-Chung Wang via All-commits all-commits at lists.llvm.org
Tue Apr 5 15:06:12 PDT 2022


  Branch: refs/heads/main
  Home:   https://github.com/llvm/llvm-project
  Commit: 96e9b6c9dc60946f08399def879a19395bc98107
      https://github.com/llvm/llvm-project/commit/96e9b6c9dc60946f08399def879a19395bc98107
  Author: Hanhan Wang <hanchung at google.com>
  Date:   2022-04-05 (Tue, 05 Apr 2022)

  Changed paths:
    M mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
    M mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
    M mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
    M mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
    M mlir/test/Dialect/MemRef/canonicalize.mlir
    M mlir/test/Dialect/Tensor/canonicalize.mlir

  Log Message:
  -----------
  Revert "[mlir] Rewrite canonicalization of collapse(expand) and expand(collapse)."

This reverts commit 64f659bee67b5a024defeb3cd2ecf65e1ad8c0a7.

An invalid tensor.expand_shape op is generated with the commit. To repro:

$ mlir-opt -canonicalize a.mlir

```
func @foo(%0: tensor<1x1xf32>, %1: tensor<1x1xf32>, %2: tensor<1x1xf32>) -> tensor<1x1xf32> {
  %cst = arith.constant 0.000000e+00 : f32
  %3 = linalg.init_tensor [8, 1] : tensor<8x1xf32>
  %4 = linalg.fill ins(%cst : f32) outs(%3 : tensor<8x1xf32>) -> tensor<8x1xf32>
  %5 = tensor.collapse_shape %0 [] : tensor<1x1xf32> into tensor<f32>
  %6 = tensor.insert_slice %5 into %4[0, 0] [1, 1] [1, 1] : tensor<f32> into tensor<8x1xf32>
  %7 = linalg.init_tensor [8, 1] : tensor<8x1xf32>
  %8 = linalg.fill ins(%cst : f32) outs(%7 : tensor<8x1xf32>) -> tensor<8x1xf32>
  %9 = tensor.collapse_shape %2 [] : tensor<1x1xf32> into tensor<f32>
  %10 = tensor.insert_slice %9 into %8[0, 0] [1, 1] [1, 1] : tensor<f32> into tensor<8x1xf32>
  %11 = tensor.collapse_shape %6 [[0, 1]] : tensor<8x1xf32> into tensor<8xf32>
  %12 = linalg.init_tensor [8] : tensor<8xf32>
  %13 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%11 : tensor<8xf32>) outs(%12 : tensor<8xf32>) {
  ^bb0(%arg3: f32, %arg4: f32):
    linalg.yield %arg3 : f32
  } -> tensor<8xf32>
  %14 = tensor.expand_shape %13 [[0, 1, 2, 3]] : tensor<8xf32> into tensor<1x1x8x1xf32>
  %15 = tensor.collapse_shape %1 [] : tensor<1x1xf32> into tensor<f32>
  %16 = linalg.init_tensor [] : tensor<f32>
  %17 = linalg.generic {indexing_maps = [affine_map<() -> ()>, affine_map<() -> ()>], iterator_types = []} ins(%15 : tensor<f32>) outs(%16 : tensor<f32>) {
  ^bb0(%arg3: f32, %arg4: f32):
    linalg.yield %arg3 : f32
  } -> tensor<f32>
  %18 = tensor.expand_shape %17 [] : tensor<f32> into tensor<1x1x1x1xf32>
  %19 = tensor.collapse_shape %10 [[0, 1]] : tensor<8x1xf32> into tensor<8xf32>
  %20 = linalg.init_tensor [8] : tensor<8xf32>
  %21 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%19 : tensor<8xf32>) outs(%20 : tensor<8xf32>) {
  ^bb0(%arg3: f32, %arg4: f32):
    linalg.yield %arg3 : f32
  } -> tensor<8xf32>
  %22 = tensor.expand_shape %21 [[0, 1, 2, 3]] : tensor<8xf32> into tensor<1x1x8x1xf32>
  %23 = linalg.mmt4d {comment = "f32*f32->f32, aarch64, matrix*vector"} ins(%14, %18 : tensor<1x1x8x1xf32>, tensor<1x1x1x1xf32>) outs(%22 : tensor<1x1x8x1xf32>) -> tensor<1x1x8x1xf32>
  %24 = tensor.collapse_shape %23 [[0, 1, 2, 3]] : tensor<1x1x8x1xf32> into tensor<8xf32>
  %25 = linalg.init_tensor [8] : tensor<8xf32>
  %26 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%24 : tensor<8xf32>) outs(%25 : tensor<8xf32>) {
  ^bb0(%arg3: f32, %arg4: f32):
    linalg.yield %arg3 : f32
  } -> tensor<8xf32>
  %27 = tensor.expand_shape %26 [[0, 1]] : tensor<8xf32> into tensor<8x1xf32>
  %28 = tensor.extract_slice %27[0, 0] [1, 1] [1, 1] : tensor<8x1xf32> to tensor<f32>
  %29 = tensor.expand_shape %28 [] : tensor<f32> into tensor<1x1xf32>
  return %29 : tensor<1x1xf32>
}
```

Differential Revision: https://reviews.llvm.org/D123161




More information about the All-commits mailing list