[Mlir-commits] [mlir] [mlir][math] expand-math pass assumes the static shaped type (PR #128299)
Kai Sasaki
llvmlistbot at llvm.org
Sat Feb 22 00:02:57 PST 2025
https://github.com/Lewuathe created https://github.com/llvm/llvm-project/pull/128299
In the process of `expand-math` pass, the conversion of ceil op assumes the static shaped type as input as it needs create 0 and 1 constant values whose type is aligned with the op type.
Fixes https://github.com/llvm/llvm-project/issues/128275
>From 3624e3afa1d2630d51e149fdd3f7e87afdbdc0e7 Mon Sep 17 00:00:00 2001
From: Kai Sasaki <lewuathe at gmail.com>
Date: Sat, 22 Feb 2025 17:00:07 +0900
Subject: [PATCH] [mlir][math] expand-math pass assumes the static shaped type
In the process of `expand-math` pass, the conversion of ceil op assumes
the static shaped type as input as it needs create 0 and 1 constant
values whose type is aligned with the op type.
Fixes https://github.com/llvm/llvm-project/issues/128275
---
.../Math/Transforms/ExpandPatterns.cpp | 5 +++++
mlir/test/Dialect/Math/expand-math.mlir | 22 +++++++++++++++++++
2 files changed, 27 insertions(+)
diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
index 23356d752146d..67e8dbba989b7 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
@@ -222,6 +222,11 @@ static LogicalResult convertFmaFOp(math::FmaOp op, PatternRewriter &rewriter) {
// if (x > y) then incr = 1 else incr = 0
// y = y + incr <= replace this op with the ceilf op.
static LogicalResult convertCeilOp(math::CeilOp op, PatternRewriter &rewriter) {
+ // Creating constants assumes the statis shaped type.
+ auto shapedType = dyn_cast<ShapedType>(op.getType());
+ if (shapedType && !shapedType.hasStaticShape())
+ return failure();
+
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
Value operand = op.getOperand();
Type opType = operand.getType();
diff --git a/mlir/test/Dialect/Math/expand-math.mlir b/mlir/test/Dialect/Math/expand-math.mlir
index 1fdfb854325b4..4e249ec510afa 100644
--- a/mlir/test/Dialect/Math/expand-math.mlir
+++ b/mlir/test/Dialect/Math/expand-math.mlir
@@ -761,3 +761,25 @@ func.func @rsqrt_tns(%float: tensor<5x8xf32>) -> (tensor<5x8xf32>) {
%float_result = math.rsqrt %float : tensor<5x8xf32>
return %float_result : tensor<5x8xf32>
}
+
+// -----
+
+// CHECK-LABEL func.func @non_static_shape_ceil_op
+// CHECK: %[[IDX:.*]] = index.constant 0
+// CHECK: %[[CST:.*]] = arith.constant dense<1.000000e+00> : tensor<2xf32>
+// CHECK: %[[CAST:.*]] = tensor.cast %[[CST]] : tensor<2xf32> to tensor<?xf32>
+// CHECK: %[[CEIL:.*]] = math.ceil %[[CAST]] : tensor<?xf32>
+// CHECK: %[[DIM:.*]] = tensor.dim %[[CEIL]], %[[IDX]] : tensor<?xf32>
+// CHECK: vector.print %[[DIM]] : index
+// CHECK: return
+
+func.func @non_static_shape_ceil_op() {
+ %idx0 = index.constant 0
+ %cst_90 = arith.constant 1.000000e+00 : f32
+ %from_elements_92 = tensor.from_elements %cst_90, %cst_90 : tensor<2xf32>
+ %cast_93 = tensor.cast %from_elements_92 : tensor<2xf32> to tensor<?xf32>
+ %112 = math.ceil %cast_93 : tensor<?xf32>
+ %dim_233 = tensor.dim %112, %idx0 : tensor<?xf32>
+ vector.print %dim_233 : index
+ return
+}
More information about the Mlir-commits
mailing list