[Mlir-commits] [mlir] [mlir][math] expand-math pass assumes the static shaped type (PR #128299)

Kai Sasaki llvmlistbot at llvm.org
Sat Feb 22 00:02:57 PST 2025


https://github.com/Lewuathe created https://github.com/llvm/llvm-project/pull/128299

In the process of `expand-math` pass, the conversion of ceil op assumes the static shaped type as input as it needs create 0 and 1 constant values whose type is aligned with the op type.

Fixes https://github.com/llvm/llvm-project/issues/128275

>From 3624e3afa1d2630d51e149fdd3f7e87afdbdc0e7 Mon Sep 17 00:00:00 2001
From: Kai Sasaki <lewuathe at gmail.com>
Date: Sat, 22 Feb 2025 17:00:07 +0900
Subject: [PATCH] [mlir][math] expand-math pass assumes the static shaped type

In the process of `expand-math` pass, the conversion of ceil op assumes
the static shaped type as input as it needs create 0 and 1 constant
values whose type is aligned with the op type.

Fixes https://github.com/llvm/llvm-project/issues/128275
---
 .../Math/Transforms/ExpandPatterns.cpp        |  5 +++++
 mlir/test/Dialect/Math/expand-math.mlir       | 22 +++++++++++++++++++
 2 files changed, 27 insertions(+)

diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
index 23356d752146d..67e8dbba989b7 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
@@ -222,6 +222,11 @@ static LogicalResult convertFmaFOp(math::FmaOp op, PatternRewriter &rewriter) {
 //      if (x > y) then incr = 1 else incr = 0
 //      y = y + incr   <= replace this op with the ceilf op.
 static LogicalResult convertCeilOp(math::CeilOp op, PatternRewriter &rewriter) {
+  // Creating constants assumes the statis shaped type.
+  auto shapedType = dyn_cast<ShapedType>(op.getType());
+  if (shapedType && !shapedType.hasStaticShape())
+    return failure();
+
   ImplicitLocOpBuilder b(op->getLoc(), rewriter);
   Value operand = op.getOperand();
   Type opType = operand.getType();
diff --git a/mlir/test/Dialect/Math/expand-math.mlir b/mlir/test/Dialect/Math/expand-math.mlir
index 1fdfb854325b4..4e249ec510afa 100644
--- a/mlir/test/Dialect/Math/expand-math.mlir
+++ b/mlir/test/Dialect/Math/expand-math.mlir
@@ -761,3 +761,25 @@ func.func @rsqrt_tns(%float: tensor<5x8xf32>) -> (tensor<5x8xf32>)  {
   %float_result = math.rsqrt %float : tensor<5x8xf32>
   return %float_result : tensor<5x8xf32>
 }
+
+// -----
+
+// CHECK-LABEL  func.func @non_static_shape_ceil_op
+// CHECK:     %[[IDX:.*]] = index.constant 0
+// CHECK:     %[[CST:.*]] = arith.constant dense<1.000000e+00> : tensor<2xf32>
+// CHECK:     %[[CAST:.*]] = tensor.cast %[[CST]] : tensor<2xf32> to tensor<?xf32>
+// CHECK:     %[[CEIL:.*]] = math.ceil %[[CAST]] : tensor<?xf32>
+// CHECK:     %[[DIM:.*]] = tensor.dim %[[CEIL]], %[[IDX]] : tensor<?xf32>
+// CHECK:     vector.print %[[DIM]] : index
+// CHECK:         return
+
+func.func @non_static_shape_ceil_op() {
+  %idx0 = index.constant 0
+  %cst_90 = arith.constant 1.000000e+00 : f32
+  %from_elements_92 = tensor.from_elements %cst_90, %cst_90 : tensor<2xf32>
+  %cast_93 = tensor.cast %from_elements_92 : tensor<2xf32> to tensor<?xf32>
+  %112 = math.ceil %cast_93 : tensor<?xf32>
+  %dim_233 = tensor.dim %112, %idx0 : tensor<?xf32>
+  vector.print %dim_233 : index
+  return
+}



More information about the Mlir-commits mailing list