[all-commits] [llvm/llvm-project] 38d0b2: [mlir] New canonicalization patterns for shape.sha...
Rafael Ubal via All-commits
all-commits at lists.llvm.org
Fri Jul 19 07:09:53 PDT 2024
Branch: refs/heads/main
Home: https://github.com/llvm/llvm-project
Commit: 38d0b2d174efe05504a18988299b4d78d37999b7
https://github.com/llvm/llvm-project/commit/38d0b2d174efe05504a18988299b4d78d37999b7
Author: Rafael Ubal <rubal at mathworks.com>
Date: 2024-07-19 (Fri, 19 Jul 2024)
Changed paths:
M mlir/lib/Dialect/Shape/IR/Shape.cpp
M mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
M mlir/test/Dialect/Shape/canonicalize.mlir
A mlir/test/Dialect/Shape/unranked-tensor-lowering.mlir
M mlir/test/Dialect/Tensor/canonicalize.mlir
Log Message:
-----------
[mlir] New canonicalization patterns for shape.shape_of and tensor.reshape (#98531)
This PR includes 3 new canonicalization patterns:
- Operation `shape.shape_of`: shape of reshape
```
// Before
func.func @f(%arg0: tensor<*xf32>, %arg1: tensor<?xindex>) -> tensor<?xindex> {
%reshape = tensor.reshape %arg0(%arg1) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
%0 = shape.shape_of %reshape : tensor<*xf32> -> tensor<?xindex>
return %0 : tensor<?xindex>
}
// After
func.func @f(%arg0: tensor<*xf32>, %arg1: tensor<?xindex>) -> tensor<?xindex> {
return %arg1 : tensor<?xindex>
}
```
- Operation `tensor.reshape`: reshape of reshape
```
// Before
func.func @fold_tensor_reshape(%arg0: tensor<*xf32>, %arg1: tensor<?xindex>, %arg2: tensor<?xindex>) -> tensor<*xf32> {
%0 = tensor.reshape %arg0(%arg1) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
%1 = tensor.reshape %0(%arg2) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
return %1 : tensor<*xf32>
}
// After
func.func @fold_tensor_reshape(%arg0: tensor<*xf32>, %arg1: tensor<?xindex>, %arg2: tensor<?xindex>) -> tensor<*xf32> {
%reshape = tensor.reshape %arg0(%arg2) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
return %reshape : tensor<*xf32>
}
```
- Operation `tensor.reshape`: reshape 1D to 1D
```
// Before
func.func @fold_reshape_1d(%input: tensor<?xf32>, %shape: tensor<1xindex>) -> tensor<?xf32> {
%0 = tensor.reshape %input(%shape) : (tensor<?xf32>, tensor<1xindex>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
// After
func.func @fold_reshape_1d(%arg0: tensor<?xf32>, %arg1: tensor<1xindex>) -> tensor<?xf32> {
return %arg0 : tensor<?xf32>
}
```
These three canonicalization patterns cooperate to simplify the IR
structure emerging from the lowering of certain element-wise ops with
unranked tensor inputs. See file `unranked-tensor-lowering.mlir` in the
proposed change list for a detailed example and description.
For context, this PR is meant to enable code optimizations for the code
generated while lowering ops `quant.qcast` and `quant.dcast` with
unranked tensors, as proposed in
https://discourse.llvm.org/t/rfc-improvements-in-the-quant-dialect/79942
(implementation currently in progress).
To unsubscribe from these emails, change your notification settings at https://github.com/llvm/llvm-project/settings/notifications
More information about the All-commits
mailing list