[Mlir-commits] [mlir] b1a735b - [mlir][math] expand-math pass assumes the static shaped type (#128299)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Feb 25 17:30:18 PST 2025
Author: Kai Sasaki
Date: 2025-02-26T10:30:14+09:00
New Revision: b1a735b45dcc194ad9be08d057bc853ad1c1467b
URL: https://github.com/llvm/llvm-project/commit/b1a735b45dcc194ad9be08d057bc853ad1c1467b
DIFF: https://github.com/llvm/llvm-project/commit/b1a735b45dcc194ad9be08d057bc853ad1c1467b.diff
LOG: [mlir][math] expand-math pass assumes the static shaped type (#128299)
In the process of `expand-math` pass, the conversion of ceil op assumes
the static shaped type as input as it needs create 0 and 1 constant
values whose type is aligned with the op type.
Fixes https://github.com/llvm/llvm-project/issues/128275
Added:
Modified:
mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
mlir/test/Dialect/Math/expand-math.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
index 23356d752146d..bb592c667549c 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
@@ -222,6 +222,11 @@ static LogicalResult convertFmaFOp(math::FmaOp op, PatternRewriter &rewriter) {
// if (x > y) then incr = 1 else incr = 0
// y = y + incr <= replace this op with the ceilf op.
static LogicalResult convertCeilOp(math::CeilOp op, PatternRewriter &rewriter) {
+ // Creating constants assumes the static shaped type.
+ auto shapedType = dyn_cast<ShapedType>(op.getType());
+ if (shapedType && !shapedType.hasStaticShape())
+ return failure();
+
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
Value operand = op.getOperand();
Type opType = operand.getType();
diff --git a/mlir/test/Dialect/Math/expand-math.mlir b/mlir/test/Dialect/Math/expand-math.mlir
index 1fdfb854325b4..946a411e4cc4b 100644
--- a/mlir/test/Dialect/Math/expand-math.mlir
+++ b/mlir/test/Dialect/Math/expand-math.mlir
@@ -761,3 +761,29 @@ func.func @rsqrt_tns(%float: tensor<5x8xf32>) -> (tensor<5x8xf32>) {
%float_result = math.rsqrt %float : tensor<5x8xf32>
return %float_result : tensor<5x8xf32>
}
+
+// -----
+
+// CHECK-LABEL: func.func @non_static_shape_ceil_op
+// CHECK-SAME: (%[[ARG:.*]]: tensor<?xf32>)
+// CHECK-SAME: -> tensor<?xf32>
+// CHECK: %[[CEIL:.*]] = math.ceil %[[ARG]] : tensor<?xf32>
+// CHECK: return %[[CEIL]] : tensor<?xf32>
+
+func.func @non_static_shape_ceil_op(%arg: tensor<?xf32>) -> tensor<?xf32>{
+ %a = math.ceil %arg : tensor<?xf32>
+ return %a: tensor<?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @unranked_ceil_op
+// CHECK-SAME: (%[[ARG:.*]]: tensor<*xf32>)
+// CHECK-SAME: -> tensor<*xf32>
+// CHECK: %[[CEIL:.*]] = math.ceil %[[ARG]] : tensor<*xf32>
+// CHECK: return %[[CEIL]] : tensor<*xf32>
+
+func.func @unranked_ceil_op(%arg: tensor<*xf32>) -> tensor<*xf32>{
+ %a = math.ceil %arg : tensor<*xf32>
+ return %a: tensor<*xf32>
+}
More information about the Mlir-commits
mailing list