[Mlir-commits] [mlir] [mlir][tosa] Fix for incorrect cannonicalization of tosa.pad (PR #98356)

Spenser Bauman llvmlistbot at llvm.org
Wed Jul 10 10:51:55 PDT 2024


https://github.com/sabauma created https://github.com/llvm/llvm-project/pull/98356

The current fold method for tosa.pad can produce invalid IR by replacing the padded value with the tosa.pad is a noop. When the type of the input value does not match the type of the tosa.pad, the canonicalizer detects the change in types and asserts.

This change addresses the issue by avoiding folding when the input and result types do not match.

>From 96193efdd1957358f700fde15f4dbc47cde58eaa Mon Sep 17 00:00:00 2001
From: Spenser Bauman <sbauman at mathworks.com>
Date: Wed, 10 Jul 2024 13:21:45 -0400
Subject: [PATCH] [mlir][tosa] Fix for incorrect cannonicalization of tosa.pad

The current fold method for tosa.pad can produce invalid IR by replacing
the padded value with the tosa.pad is a noop. When the type of the input
value does not match the type of the tosa.pad, the canonicalizer detects
the change in types and asserts.
---
 mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp |  2 +-
 mlir/test/Dialect/Tosa/canonicalize.mlir           | 14 ++++++++++++++
 2 files changed, 15 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 8687be075ea67..866ab0d2228f7 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -859,7 +859,7 @@ OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
 
 OpFoldResult PadOp::fold(FoldAdaptor adaptor) {
   // If the pad is all zeros we can fold this operation away.
-  if (adaptor.getPadding()) {
+  if (adaptor.getPadding() && getInput1().getType() == getType()) {
     auto densePad = llvm::cast<DenseElementsAttr>(adaptor.getPadding());
     if (densePad.isSplat() && densePad.getSplatValue<APInt>().isZero()) {
       return getInput1();
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index accc792c8f2ac..3bcf58015831b 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -217,6 +217,20 @@ func.func @pad_noop(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
 
 // -----
 
+// CHECK-LABEL: @pad_noop_type_mismatch_nofold
+func.func @pad_noop_type_mismatch_nofold(%arg0: tensor<10xf32>) -> tensor<?xf32> {
+  // CHECK: %[[PAD:.+]] = tosa.pad
+  // CHECK: return %[[PAD]]
+
+  %c0_i32 = arith.constant 0 : i32
+  %shape = tensor.from_elements %c0_i32, %c0_i32 : tensor<1x2xi32>
+
+  %0 = tosa.pad %arg0, %shape : (tensor<10xf32>, tensor<1x2xi32>) -> tensor<?xf32>
+  return %0 : tensor<?xf32>
+}
+
+// -----
+
 // CHECK-LABEL: @pad_determine_val_i32
 func.func @pad_determine_val_i32(%arg0: tensor<?x?xi32>, %arg1 : tensor<2x2xi32>) -> tensor<?x?xi32> {
   // CHECK: %[[ZERO:.+]] = "tosa.const"() <{value = dense<0> : tensor<i32>}



More information about the Mlir-commits mailing list