[Mlir-commits] [mlir] [mlir][math] expand-math pass assumes the static shaped type (PR #128299)
Kai Sasaki
llvmlistbot at llvm.org
Sat Feb 22 21:38:29 PST 2025
https://github.com/Lewuathe updated https://github.com/llvm/llvm-project/pull/128299
>From 3624e3afa1d2630d51e149fdd3f7e87afdbdc0e7 Mon Sep 17 00:00:00 2001
From: Kai Sasaki <lewuathe at gmail.com>
Date: Sat, 22 Feb 2025 17:00:07 +0900
Subject: [PATCH 1/2] [mlir][math] expand-math pass assumes the static shaped
type
In the process of `expand-math` pass, the conversion of ceil op assumes
the static shaped type as input as it needs create 0 and 1 constant
values whose type is aligned with the op type.
Fixes https://github.com/llvm/llvm-project/issues/128275
---
.../Math/Transforms/ExpandPatterns.cpp | 5 +++++
mlir/test/Dialect/Math/expand-math.mlir | 22 +++++++++++++++++++
2 files changed, 27 insertions(+)
diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
index 23356d752146d..67e8dbba989b7 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
@@ -222,6 +222,11 @@ static LogicalResult convertFmaFOp(math::FmaOp op, PatternRewriter &rewriter) {
// if (x > y) then incr = 1 else incr = 0
// y = y + incr <= replace this op with the ceilf op.
static LogicalResult convertCeilOp(math::CeilOp op, PatternRewriter &rewriter) {
+ // Creating constants assumes the statis shaped type.
+ auto shapedType = dyn_cast<ShapedType>(op.getType());
+ if (shapedType && !shapedType.hasStaticShape())
+ return failure();
+
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
Value operand = op.getOperand();
Type opType = operand.getType();
diff --git a/mlir/test/Dialect/Math/expand-math.mlir b/mlir/test/Dialect/Math/expand-math.mlir
index 1fdfb854325b4..4e249ec510afa 100644
--- a/mlir/test/Dialect/Math/expand-math.mlir
+++ b/mlir/test/Dialect/Math/expand-math.mlir
@@ -761,3 +761,25 @@ func.func @rsqrt_tns(%float: tensor<5x8xf32>) -> (tensor<5x8xf32>) {
%float_result = math.rsqrt %float : tensor<5x8xf32>
return %float_result : tensor<5x8xf32>
}
+
+// -----
+
+// CHECK-LABEL func.func @non_static_shape_ceil_op
+// CHECK: %[[IDX:.*]] = index.constant 0
+// CHECK: %[[CST:.*]] = arith.constant dense<1.000000e+00> : tensor<2xf32>
+// CHECK: %[[CAST:.*]] = tensor.cast %[[CST]] : tensor<2xf32> to tensor<?xf32>
+// CHECK: %[[CEIL:.*]] = math.ceil %[[CAST]] : tensor<?xf32>
+// CHECK: %[[DIM:.*]] = tensor.dim %[[CEIL]], %[[IDX]] : tensor<?xf32>
+// CHECK: vector.print %[[DIM]] : index
+// CHECK: return
+
+func.func @non_static_shape_ceil_op() {
+ %idx0 = index.constant 0
+ %cst_90 = arith.constant 1.000000e+00 : f32
+ %from_elements_92 = tensor.from_elements %cst_90, %cst_90 : tensor<2xf32>
+ %cast_93 = tensor.cast %from_elements_92 : tensor<2xf32> to tensor<?xf32>
+ %112 = math.ceil %cast_93 : tensor<?xf32>
+ %dim_233 = tensor.dim %112, %idx0 : tensor<?xf32>
+ vector.print %dim_233 : index
+ return
+}
>From e01b8836c435c9fbe894d8815dc3b04aac81d7cc Mon Sep 17 00:00:00 2001
From: Kai Sasaki <lewuathe at gmail.com>
Date: Sun, 23 Feb 2025 14:38:15 +0900
Subject: [PATCH 2/2] Post review follow-up
---
.../Dialect/ControlFlow/IR/ControlFlowOps.cpp | 8 ++++++
.../Math/Transforms/ExpandPatterns.cpp | 2 +-
mlir/test/Dialect/Math/expand-math.mlir | 27 +++++++------------
3 files changed, 18 insertions(+), 19 deletions(-)
diff --git a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
index edd7f607f24f4..3bda649c0a255 100644
--- a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
+++ b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
@@ -159,9 +159,17 @@ simplifyBrToBlockWithSinglePred(BranchOp op, PatternRewriter &rewriter) {
if (succ == opParent || !llvm::hasSingleElement(succ->getPredecessors()))
return failure();
+ op->emitWarning() << "111";
+ op->emitWarning() << op->getUses().empty();
+ if (!op->getUses().empty()) {
+ return failure();
+ }
+
+ op->emitWarning() << "222";
// Merge the successor into the current block and erase the branch.
SmallVector<Value> brOperands(op.getOperands());
rewriter.eraseOp(op);
+ llvm::errs() << "333\n";
rewriter.mergeBlocks(succ, opParent, brOperands);
return success();
}
diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
index 67e8dbba989b7..bb592c667549c 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
@@ -222,7 +222,7 @@ static LogicalResult convertFmaFOp(math::FmaOp op, PatternRewriter &rewriter) {
// if (x > y) then incr = 1 else incr = 0
// y = y + incr <= replace this op with the ceilf op.
static LogicalResult convertCeilOp(math::CeilOp op, PatternRewriter &rewriter) {
- // Creating constants assumes the statis shaped type.
+ // Creating constants assumes the static shaped type.
auto shapedType = dyn_cast<ShapedType>(op.getType());
if (shapedType && !shapedType.hasStaticShape())
return failure();
diff --git a/mlir/test/Dialect/Math/expand-math.mlir b/mlir/test/Dialect/Math/expand-math.mlir
index 4e249ec510afa..56d562ad0b3fe 100644
--- a/mlir/test/Dialect/Math/expand-math.mlir
+++ b/mlir/test/Dialect/Math/expand-math.mlir
@@ -764,22 +764,13 @@ func.func @rsqrt_tns(%float: tensor<5x8xf32>) -> (tensor<5x8xf32>) {
// -----
-// CHECK-LABEL func.func @non_static_shape_ceil_op
-// CHECK: %[[IDX:.*]] = index.constant 0
-// CHECK: %[[CST:.*]] = arith.constant dense<1.000000e+00> : tensor<2xf32>
-// CHECK: %[[CAST:.*]] = tensor.cast %[[CST]] : tensor<2xf32> to tensor<?xf32>
-// CHECK: %[[CEIL:.*]] = math.ceil %[[CAST]] : tensor<?xf32>
-// CHECK: %[[DIM:.*]] = tensor.dim %[[CEIL]], %[[IDX]] : tensor<?xf32>
-// CHECK: vector.print %[[DIM]] : index
-// CHECK: return
-
-func.func @non_static_shape_ceil_op() {
- %idx0 = index.constant 0
- %cst_90 = arith.constant 1.000000e+00 : f32
- %from_elements_92 = tensor.from_elements %cst_90, %cst_90 : tensor<2xf32>
- %cast_93 = tensor.cast %from_elements_92 : tensor<2xf32> to tensor<?xf32>
- %112 = math.ceil %cast_93 : tensor<?xf32>
- %dim_233 = tensor.dim %112, %idx0 : tensor<?xf32>
- vector.print %dim_233 : index
- return
+// CHECK-LABEL: func.func @non_static_shape_ceil_op
+// CHECK-SAME: (%[[ARG:.*]]: tensor<?xf32>)
+// CHECK-SAME: -> tensor<?xf32>
+// CHECK: %[[CEIL:.*]] = math.ceil %[[ARG]] : tensor<?xf32>
+// CHECK: return %[[CEIL]] : tensor<?xf32>
+
+func.func @non_static_shape_ceil_op(%arg: tensor<?xf32>) -> tensor<?xf32>{
+ %a = math.ceil %arg : tensor<?xf32>
+ return %a: tensor<?xf32>
}
More information about the Mlir-commits
mailing list