[Mlir-commits] [mlir] [mlir][tensor] Fix PadOp::getConstantPaddingValue (PR #121205)

Longsheng Mou llvmlistbot at llvm.org
Fri Dec 27 03:43:25 PST 2024


https://github.com/CoTinker created https://github.com/llvm/llvm-project/pull/121205

In this context, 'constant' refers to a value defined outside the PadOp block. Therefore, we should perform the 'inside block check' before the 'constant check' to avoid a crash. Fixes #120947.

>From 524bc2f479f849b5ead9d12dbaa8828558571fd4 Mon Sep 17 00:00:00 2001
From: Longsheng Mou <longshengmou at gmail.com>
Date: Thu, 26 Dec 2024 11:32:05 +0800
Subject: [PATCH] [mlir][tensor] Fix PadOp::getConstantPaddingValue

In this context, 'constant' refers to a value defined outside the PadOp block. Therefore, we should perform the 'inside block check' before the 'constant check' to avoid a crash.
---
 mlir/lib/Dialect/Tensor/IR/TensorOps.cpp      |  6 ++---
 .../TensorToLinalg/tensor-ops-to-linalg.mlir  | 25 +++++++++++++++++++
 2 files changed, 28 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index f79c774ceb3e9a..be846a11dcef18 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -3638,12 +3638,12 @@ Value PadOp::getConstantPaddingValue() {
   if (!yieldOp)
     return {};
   Value padValue = yieldOp.getValue();
-  // Check if yield value is a constant.
-  if (matchPattern(padValue, m_Constant()))
-    return padValue;
   // Check if yield value is defined inside the PadOp block.
   if (padValue.getParentBlock() == &getRegion().front())
     return {};
+  // Check if yield value is a constant.
+  if (matchPattern(padValue, m_Constant()))
+    return padValue;
   // Else: Yield value defined outside of the PadOp block.
   return padValue;
 }
diff --git a/mlir/test/Conversion/TensorToLinalg/tensor-ops-to-linalg.mlir b/mlir/test/Conversion/TensorToLinalg/tensor-ops-to-linalg.mlir
index a0a676edceb745..c4a5916aaf9a61 100644
--- a/mlir/test/Conversion/TensorToLinalg/tensor-ops-to-linalg.mlir
+++ b/mlir/test/Conversion/TensorToLinalg/tensor-ops-to-linalg.mlir
@@ -19,6 +19,8 @@ func.func @generalize_pad_tensor_static_shape(%arg0: tensor<1x28x28x1xf32>) -> t
   return %0 : tensor<1x32x32x1xf32>
 }
 
+// -----
+
 // CHECK-LABEL:   func @generalize_pad_tensor_dynamic_shape(
 // CHECK-SAME:                                              %[[IN:.*]]: tensor<4x?x2x?xf32>,
 // CHECK-SAME:                                              %[[OFFSET:.*]]: index) -> tensor<4x?x?x?xf32> {
@@ -44,3 +46,26 @@ func.func @generalize_pad_tensor_dynamic_shape(%arg0: tensor<4x?x2x?xf32>, %arg1
   } : tensor<4x?x2x?xf32> to tensor<4x?x?x?xf32>
   return %out : tensor<4x?x?x?xf32>
 }
+
+// -----
+
+// Ensure that the constant value inside the PadOp block does not cause a crash.
+
+// CHECK-LABEL:   func.func @generalize_pad_tensor_constant_inside(
+// CHECK-SAME:                                                     %[[VAL_0:.*]]: tensor<1x28x28x1xf32>) -> tensor<1x32x32x1xf32> {
+// CHECK:           %[[VAL_1:.*]] = tensor.generate  {
+// CHECK:           ^bb0(%[[VAL_2:.*]]: index, %[[VAL_3:.*]]: index, %[[VAL_4:.*]]: index, %[[VAL_5:.*]]: index):
+// CHECK:             %[[VAL_6:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:             tensor.yield %[[VAL_6]] : f32
+// CHECK:           } : tensor<1x32x32x1xf32>
+// CHECK:           %[[VAL_7:.*]] = tensor.insert_slice %[[VAL_0]] into %[[VAL_8:.*]][0, 2, 2, 0] [1, 28, 28, 1] [1, 1, 1, 1] : tensor<1x28x28x1xf32> into tensor<1x32x32x1xf32>
+// CHECK:           return %[[VAL_7]] : tensor<1x32x32x1xf32>
+// CHECK:         }
+func.func @generalize_pad_tensor_constant_inside(%arg0: tensor<1x28x28x1xf32>) -> tensor<1x32x32x1xf32> {
+  %0 = tensor.pad %arg0 low[0, 2, 2, 0] high[0, 2, 2, 0]  {
+  ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index):
+    %cst = arith.constant 0.000000e+00 : f32
+    tensor.yield %cst : f32
+  } : tensor<1x28x28x1xf32> to tensor<1x32x32x1xf32>
+  return %0 : tensor<1x32x32x1xf32>
+}



More information about the Mlir-commits mailing list