[Mlir-commits] [mlir] 43e1fc5 - [mlir][tosa] Fold tosa.reshape with splat values
Rob Suderman
llvmlistbot at llvm.org
Mon Aug 29 17:20:13 PDT 2022
Author: Rob Suderman
Date: 2022-08-29T17:18:03-07:00
New Revision: 43e1fc58dd1f4524ea518f920dbe91924927ce4d
URL: https://github.com/llvm/llvm-project/commit/43e1fc58dd1f4524ea518f920dbe91924927ce4d
DIFF: https://github.com/llvm/llvm-project/commit/43e1fc58dd1f4524ea518f920dbe91924927ce4d.diff
LOG: [mlir][tosa] Fold tosa.reshape with splat values
Folding reshapes of splats is trivial and should be canonicalized
away.
Reviewed By: NatashaKnk
Differential Revision: https://reviews.llvm.org/D132760
Added:
Modified:
mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
mlir/test/Dialect/Tosa/constant-op-fold.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 6af5b3d1017f..6d27d1e7f404 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -717,9 +717,18 @@ OpFoldResult ReshapeOp::fold(ArrayRef<Attribute> operands) {
auto inputTy = getInput1().getType().dyn_cast<RankedTensorType>();
auto outputTy = getType().dyn_cast<RankedTensorType>();
- if (!inputTy || !outputTy || inputTy != outputTy)
+ if (!inputTy || !outputTy)
return {};
- return getInput1();
+
+ if (inputTy == outputTy)
+ return getInput1();
+
+ auto operand = operands[0].dyn_cast_or_null<DenseElementsAttr>();
+ if (operand && outputTy.hasStaticShape() && operand.isSplat()) {
+ return SplatElementsAttr::get(outputTy, operand.getSplatValue<Attribute>());
+ }
+
+ return {};
}
OpFoldResult PadOp::fold(ArrayRef<Attribute> operands) {
diff --git a/mlir/test/Dialect/Tosa/constant-op-fold.mlir b/mlir/test/Dialect/Tosa/constant-op-fold.mlir
index f5dad4698dc7..af24d590bfaf 100644
--- a/mlir/test/Dialect/Tosa/constant-op-fold.mlir
+++ b/mlir/test/Dialect/Tosa/constant-op-fold.mlir
@@ -398,6 +398,16 @@ func.func @fold_greater_splat_i32_true() -> tensor<10xi1> {
// -----
+func.func @reshape_splat() -> tensor<6x5x4xi32> {
+ // CHECK: %[[SPLAT:.+]] = "tosa.const"() {value = dense<42> : tensor<6x5x4xi32>}
+ %splat = "tosa.const"() {value = dense<42> : tensor<4x5x6xi32>} : () -> tensor<4x5x6xi32>
+ %reshape = "tosa.reshape"(%splat) { new_shape = [6, 5, 4] } : (tensor<4x5x6xi32>) -> tensor<6x5x4xi32>
+ // CHECK: return %[[SPLAT]]
+ return %reshape : tensor<6x5x4xi32>
+}
+
+// -----
+
// CHECK-LABEL: @slice_splat
func.func @slice_splat() -> tensor<1x1x1xi32> {
// CHECK: %[[SLICE:.+]] = "tosa.const"() {value = dense<42> : tensor<1x1x1xi32>}
More information about the Mlir-commits
mailing list