[Mlir-commits] [mlir] 55f2547 - [mlir][math] Rsqrt math expand pass expects static shaped operand (#129006)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Feb 27 20:37:09 PST 2025


Author: Kai Sasaki
Date: 2025-02-28T13:37:06+09:00
New Revision: 55f254726ee1a83a40c14cfc39306071044cc68c

URL: https://github.com/llvm/llvm-project/commit/55f254726ee1a83a40c14cfc39306071044cc68c
DIFF: https://github.com/llvm/llvm-project/commit/55f254726ee1a83a40c14cfc39306071044cc68c.diff

LOG: [mlir][math] Rsqrt math expand pass expects static shaped operand (#129006)

Similar to the issue reported in

https://github.com/llvm/llvm-project/pull/128299#pullrequestreview-2636142506,
ExpandMath pattern for rsqrt expects the static shaped operands.
Otherwise, it crashes due to the assertion violation.

See: https://github.com/llvm/llvm-project/pull/128299

Added: 
    

Modified: 
    mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
    mlir/test/Dialect/Math/expand-math.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
index bb592c667549c..7b5350ca26b60 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
@@ -646,6 +646,11 @@ static LogicalResult convertRsqrtOp(math::RsqrtOp op,
 
   auto operand = op.getOperand();
   auto operandTy = operand.getType();
+  // Operand type must be shatic shaped type to create const float.
+  auto shapedOperandType = dyn_cast<ShapedType>(operandTy);
+  if (shapedOperandType && !shapedOperandType.hasStaticShape())
+    return failure();
+
   auto eTy = getElementTypeOrSelf(operandTy);
   if (!isa<FloatType>(eTy))
     return failure();

diff  --git a/mlir/test/Dialect/Math/expand-math.mlir b/mlir/test/Dialect/Math/expand-math.mlir
index 946a411e4cc4b..1420acaa40d35 100644
--- a/mlir/test/Dialect/Math/expand-math.mlir
+++ b/mlir/test/Dialect/Math/expand-math.mlir
@@ -787,3 +787,29 @@ func.func @unranked_ceil_op(%arg: tensor<*xf32>) -> tensor<*xf32>{
   %a = math.ceil %arg : tensor<*xf32>
   return %a: tensor<*xf32>
 }
+
+// -----
+
+// CHECK-LABEL:    func.func @non_static_shape_rsqrt_op
+// CHECK-SAME:     (%[[ARG:.*]]: tensor<?xf32>)
+// CHECK-SAME:     -> tensor<?xf32>
+// CHECK:          %[[RSQRT:.*]] = math.rsqrt %[[ARG]] : tensor<?xf32>
+// CHECK:          return %[[RSQRT]] : tensor<?xf32>
+
+func.func @non_static_shape_rsqrt_op(%arg: tensor<?xf32>) -> tensor<?xf32>{
+  %a = math.rsqrt %arg : tensor<?xf32>
+  return %a: tensor<?xf32>
+}
+
+// -----
+
+// CHECK-LABEL:    func.func @unranked_rsqrt_op
+// CHECK-SAME:     (%[[ARG:.*]]: tensor<*xf32>)
+// CHECK-SAME:     -> tensor<*xf32>
+// CHECK:          %[[RSQRT:.*]] = math.rsqrt %[[ARG]] : tensor<*xf32>
+// CHECK:          return %[[RSQRT]] : tensor<*xf32>
+
+func.func @unranked_rsqrt_op(%arg: tensor<*xf32>) -> tensor<*xf32>{
+  %a = math.rsqrt %arg : tensor<*xf32>
+  return %a: tensor<*xf32>
+}


        


More information about the Mlir-commits mailing list