[Mlir-commits] [mlir] b1a735b - [mlir][math] expand-math pass assumes the static shaped type (#128299)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Feb 25 17:30:18 PST 2025


Author: Kai Sasaki
Date: 2025-02-26T10:30:14+09:00
New Revision: b1a735b45dcc194ad9be08d057bc853ad1c1467b

URL: https://github.com/llvm/llvm-project/commit/b1a735b45dcc194ad9be08d057bc853ad1c1467b
DIFF: https://github.com/llvm/llvm-project/commit/b1a735b45dcc194ad9be08d057bc853ad1c1467b.diff

LOG: [mlir][math] expand-math pass assumes the static shaped type (#128299)

In the process of `expand-math` pass, the conversion of ceil op assumes
the static shaped type as input as it needs create 0 and 1 constant
values whose type is aligned with the op type.

Fixes https://github.com/llvm/llvm-project/issues/128275

Added: 
    

Modified: 
    mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
    mlir/test/Dialect/Math/expand-math.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
index 23356d752146d..bb592c667549c 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
@@ -222,6 +222,11 @@ static LogicalResult convertFmaFOp(math::FmaOp op, PatternRewriter &rewriter) {
 //      if (x > y) then incr = 1 else incr = 0
 //      y = y + incr   <= replace this op with the ceilf op.
 static LogicalResult convertCeilOp(math::CeilOp op, PatternRewriter &rewriter) {
+  // Creating constants assumes the static shaped type.
+  auto shapedType = dyn_cast<ShapedType>(op.getType());
+  if (shapedType && !shapedType.hasStaticShape())
+    return failure();
+
   ImplicitLocOpBuilder b(op->getLoc(), rewriter);
   Value operand = op.getOperand();
   Type opType = operand.getType();

diff  --git a/mlir/test/Dialect/Math/expand-math.mlir b/mlir/test/Dialect/Math/expand-math.mlir
index 1fdfb854325b4..946a411e4cc4b 100644
--- a/mlir/test/Dialect/Math/expand-math.mlir
+++ b/mlir/test/Dialect/Math/expand-math.mlir
@@ -761,3 +761,29 @@ func.func @rsqrt_tns(%float: tensor<5x8xf32>) -> (tensor<5x8xf32>)  {
   %float_result = math.rsqrt %float : tensor<5x8xf32>
   return %float_result : tensor<5x8xf32>
 }
+
+// -----
+
+// CHECK-LABEL:    func.func @non_static_shape_ceil_op
+// CHECK-SAME:     (%[[ARG:.*]]: tensor<?xf32>)
+// CHECK-SAME:     -> tensor<?xf32>
+// CHECK:          %[[CEIL:.*]] = math.ceil %[[ARG]] : tensor<?xf32>
+// CHECK:          return %[[CEIL]] : tensor<?xf32>
+
+func.func @non_static_shape_ceil_op(%arg: tensor<?xf32>) -> tensor<?xf32>{
+  %a = math.ceil %arg : tensor<?xf32>
+  return %a: tensor<?xf32>
+}
+
+// -----
+
+// CHECK-LABEL:    func.func @unranked_ceil_op
+// CHECK-SAME:     (%[[ARG:.*]]: tensor<*xf32>)
+// CHECK-SAME:     -> tensor<*xf32>
+// CHECK:          %[[CEIL:.*]] = math.ceil %[[ARG]] : tensor<*xf32>
+// CHECK:          return %[[CEIL]] : tensor<*xf32>
+
+func.func @unranked_ceil_op(%arg: tensor<*xf32>) -> tensor<*xf32>{
+  %a = math.ceil %arg : tensor<*xf32>
+  return %a: tensor<*xf32>
+}


        


More information about the Mlir-commits mailing list