[Mlir-commits] [mlir] 43e1fc5 - [mlir][tosa] Fold tosa.reshape with splat values

Rob Suderman llvmlistbot at llvm.org
Mon Aug 29 17:20:13 PDT 2022


Author: Rob Suderman
Date: 2022-08-29T17:18:03-07:00
New Revision: 43e1fc58dd1f4524ea518f920dbe91924927ce4d

URL: https://github.com/llvm/llvm-project/commit/43e1fc58dd1f4524ea518f920dbe91924927ce4d
DIFF: https://github.com/llvm/llvm-project/commit/43e1fc58dd1f4524ea518f920dbe91924927ce4d.diff

LOG: [mlir][tosa] Fold tosa.reshape with splat values

Folding reshapes of splats is trivial and should be canonicalized
away.

Reviewed By: NatashaKnk

Differential Revision: https://reviews.llvm.org/D132760

Added: 
    

Modified: 
    mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
    mlir/test/Dialect/Tosa/constant-op-fold.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 6af5b3d1017f..6d27d1e7f404 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -717,9 +717,18 @@ OpFoldResult ReshapeOp::fold(ArrayRef<Attribute> operands) {
   auto inputTy = getInput1().getType().dyn_cast<RankedTensorType>();
   auto outputTy = getType().dyn_cast<RankedTensorType>();
 
-  if (!inputTy || !outputTy || inputTy != outputTy)
+  if (!inputTy || !outputTy)
     return {};
-  return getInput1();
+
+  if (inputTy == outputTy)
+    return getInput1();
+
+  auto operand = operands[0].dyn_cast_or_null<DenseElementsAttr>();
+  if (operand && outputTy.hasStaticShape() && operand.isSplat()) {
+    return SplatElementsAttr::get(outputTy, operand.getSplatValue<Attribute>());
+  }
+
+  return {};
 }
 
 OpFoldResult PadOp::fold(ArrayRef<Attribute> operands) {

diff  --git a/mlir/test/Dialect/Tosa/constant-op-fold.mlir b/mlir/test/Dialect/Tosa/constant-op-fold.mlir
index f5dad4698dc7..af24d590bfaf 100644
--- a/mlir/test/Dialect/Tosa/constant-op-fold.mlir
+++ b/mlir/test/Dialect/Tosa/constant-op-fold.mlir
@@ -398,6 +398,16 @@ func.func @fold_greater_splat_i32_true() -> tensor<10xi1> {
 
 // -----
 
+func.func @reshape_splat() -> tensor<6x5x4xi32> {
+  // CHECK: %[[SPLAT:.+]] = "tosa.const"() {value = dense<42> : tensor<6x5x4xi32>}
+  %splat = "tosa.const"() {value = dense<42> : tensor<4x5x6xi32>} : () -> tensor<4x5x6xi32>
+  %reshape = "tosa.reshape"(%splat) { new_shape = [6, 5, 4] } : (tensor<4x5x6xi32>) -> tensor<6x5x4xi32>
+  // CHECK: return %[[SPLAT]]
+  return %reshape : tensor<6x5x4xi32>
+}
+
+// -----
+
 // CHECK-LABEL: @slice_splat
 func.func @slice_splat() -> tensor<1x1x1xi32> {
   // CHECK: %[[SLICE:.+]] = "tosa.const"() {value = dense<42> : tensor<1x1x1xi32>}


        


More information about the Mlir-commits mailing list