[Mlir-commits] [mlir] [mlir][math] expand-math pass assumes the static shaped type (PR #128299)

Kai Sasaki llvmlistbot at llvm.org
Sat Feb 22 21:38:29 PST 2025


https://github.com/Lewuathe updated https://github.com/llvm/llvm-project/pull/128299

>From 3624e3afa1d2630d51e149fdd3f7e87afdbdc0e7 Mon Sep 17 00:00:00 2001
From: Kai Sasaki <lewuathe at gmail.com>
Date: Sat, 22 Feb 2025 17:00:07 +0900
Subject: [PATCH 1/2] [mlir][math] expand-math pass assumes the static shaped
 type

In the process of `expand-math` pass, the conversion of ceil op assumes
the static shaped type as input as it needs create 0 and 1 constant
values whose type is aligned with the op type.

Fixes https://github.com/llvm/llvm-project/issues/128275
---
 .../Math/Transforms/ExpandPatterns.cpp        |  5 +++++
 mlir/test/Dialect/Math/expand-math.mlir       | 22 +++++++++++++++++++
 2 files changed, 27 insertions(+)

diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
index 23356d752146d..67e8dbba989b7 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
@@ -222,6 +222,11 @@ static LogicalResult convertFmaFOp(math::FmaOp op, PatternRewriter &rewriter) {
 //      if (x > y) then incr = 1 else incr = 0
 //      y = y + incr   <= replace this op with the ceilf op.
 static LogicalResult convertCeilOp(math::CeilOp op, PatternRewriter &rewriter) {
+  // Creating constants assumes the statis shaped type.
+  auto shapedType = dyn_cast<ShapedType>(op.getType());
+  if (shapedType && !shapedType.hasStaticShape())
+    return failure();
+
   ImplicitLocOpBuilder b(op->getLoc(), rewriter);
   Value operand = op.getOperand();
   Type opType = operand.getType();
diff --git a/mlir/test/Dialect/Math/expand-math.mlir b/mlir/test/Dialect/Math/expand-math.mlir
index 1fdfb854325b4..4e249ec510afa 100644
--- a/mlir/test/Dialect/Math/expand-math.mlir
+++ b/mlir/test/Dialect/Math/expand-math.mlir
@@ -761,3 +761,25 @@ func.func @rsqrt_tns(%float: tensor<5x8xf32>) -> (tensor<5x8xf32>)  {
   %float_result = math.rsqrt %float : tensor<5x8xf32>
   return %float_result : tensor<5x8xf32>
 }
+
+// -----
+
+// CHECK-LABEL  func.func @non_static_shape_ceil_op
+// CHECK:     %[[IDX:.*]] = index.constant 0
+// CHECK:     %[[CST:.*]] = arith.constant dense<1.000000e+00> : tensor<2xf32>
+// CHECK:     %[[CAST:.*]] = tensor.cast %[[CST]] : tensor<2xf32> to tensor<?xf32>
+// CHECK:     %[[CEIL:.*]] = math.ceil %[[CAST]] : tensor<?xf32>
+// CHECK:     %[[DIM:.*]] = tensor.dim %[[CEIL]], %[[IDX]] : tensor<?xf32>
+// CHECK:     vector.print %[[DIM]] : index
+// CHECK:         return
+
+func.func @non_static_shape_ceil_op() {
+  %idx0 = index.constant 0
+  %cst_90 = arith.constant 1.000000e+00 : f32
+  %from_elements_92 = tensor.from_elements %cst_90, %cst_90 : tensor<2xf32>
+  %cast_93 = tensor.cast %from_elements_92 : tensor<2xf32> to tensor<?xf32>
+  %112 = math.ceil %cast_93 : tensor<?xf32>
+  %dim_233 = tensor.dim %112, %idx0 : tensor<?xf32>
+  vector.print %dim_233 : index
+  return
+}

>From e01b8836c435c9fbe894d8815dc3b04aac81d7cc Mon Sep 17 00:00:00 2001
From: Kai Sasaki <lewuathe at gmail.com>
Date: Sun, 23 Feb 2025 14:38:15 +0900
Subject: [PATCH 2/2] Post review follow-up

---
 .../Dialect/ControlFlow/IR/ControlFlowOps.cpp |  8 ++++++
 .../Math/Transforms/ExpandPatterns.cpp        |  2 +-
 mlir/test/Dialect/Math/expand-math.mlir       | 27 +++++++------------
 3 files changed, 18 insertions(+), 19 deletions(-)

diff --git a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
index edd7f607f24f4..3bda649c0a255 100644
--- a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
+++ b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
@@ -159,9 +159,17 @@ simplifyBrToBlockWithSinglePred(BranchOp op, PatternRewriter &rewriter) {
   if (succ == opParent || !llvm::hasSingleElement(succ->getPredecessors()))
     return failure();
 
+  op->emitWarning() << "111";
+  op->emitWarning() << op->getUses().empty();
+  if (!op->getUses().empty()) {
+    return failure();
+  }
+
+  op->emitWarning() << "222";
   // Merge the successor into the current block and erase the branch.
   SmallVector<Value> brOperands(op.getOperands());
   rewriter.eraseOp(op);
+  llvm::errs() << "333\n";
   rewriter.mergeBlocks(succ, opParent, brOperands);
   return success();
 }
diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
index 67e8dbba989b7..bb592c667549c 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
@@ -222,7 +222,7 @@ static LogicalResult convertFmaFOp(math::FmaOp op, PatternRewriter &rewriter) {
 //      if (x > y) then incr = 1 else incr = 0
 //      y = y + incr   <= replace this op with the ceilf op.
 static LogicalResult convertCeilOp(math::CeilOp op, PatternRewriter &rewriter) {
-  // Creating constants assumes the statis shaped type.
+  // Creating constants assumes the static shaped type.
   auto shapedType = dyn_cast<ShapedType>(op.getType());
   if (shapedType && !shapedType.hasStaticShape())
     return failure();
diff --git a/mlir/test/Dialect/Math/expand-math.mlir b/mlir/test/Dialect/Math/expand-math.mlir
index 4e249ec510afa..56d562ad0b3fe 100644
--- a/mlir/test/Dialect/Math/expand-math.mlir
+++ b/mlir/test/Dialect/Math/expand-math.mlir
@@ -764,22 +764,13 @@ func.func @rsqrt_tns(%float: tensor<5x8xf32>) -> (tensor<5x8xf32>)  {
 
 // -----
 
-// CHECK-LABEL  func.func @non_static_shape_ceil_op
-// CHECK:     %[[IDX:.*]] = index.constant 0
-// CHECK:     %[[CST:.*]] = arith.constant dense<1.000000e+00> : tensor<2xf32>
-// CHECK:     %[[CAST:.*]] = tensor.cast %[[CST]] : tensor<2xf32> to tensor<?xf32>
-// CHECK:     %[[CEIL:.*]] = math.ceil %[[CAST]] : tensor<?xf32>
-// CHECK:     %[[DIM:.*]] = tensor.dim %[[CEIL]], %[[IDX]] : tensor<?xf32>
-// CHECK:     vector.print %[[DIM]] : index
-// CHECK:         return
-
-func.func @non_static_shape_ceil_op() {
-  %idx0 = index.constant 0
-  %cst_90 = arith.constant 1.000000e+00 : f32
-  %from_elements_92 = tensor.from_elements %cst_90, %cst_90 : tensor<2xf32>
-  %cast_93 = tensor.cast %from_elements_92 : tensor<2xf32> to tensor<?xf32>
-  %112 = math.ceil %cast_93 : tensor<?xf32>
-  %dim_233 = tensor.dim %112, %idx0 : tensor<?xf32>
-  vector.print %dim_233 : index
-  return
+// CHECK-LABEL:    func.func @non_static_shape_ceil_op
+// CHECK-SAME:     (%[[ARG:.*]]: tensor<?xf32>)
+// CHECK-SAME:     -> tensor<?xf32>
+// CHECK:          %[[CEIL:.*]] = math.ceil %[[ARG]] : tensor<?xf32>
+// CHECK:          return %[[CEIL]] : tensor<?xf32>
+
+func.func @non_static_shape_ceil_op(%arg: tensor<?xf32>) -> tensor<?xf32>{
+  %a = math.ceil %arg : tensor<?xf32>
+  return %a: tensor<?xf32>
 }



More information about the Mlir-commits mailing list