[Mlir-commits] [mlir] [mlir][math] Rsqrt math expand pass expects static shaped operand (PR #129006)
Kai Sasaki
llvmlistbot at llvm.org
Wed Feb 26 22:02:47 PST 2025
https://github.com/Lewuathe updated https://github.com/llvm/llvm-project/pull/129006
>From d01daf2396a5affe8e7e2dc6be421958f4e714cf Mon Sep 17 00:00:00 2001
From: Kai Sasaki <lewuathe at gmail.com>
Date: Thu, 27 Feb 2025 14:54:25 +0900
Subject: [PATCH] [mlir][math] Rsqrt math expand pass expects static shaped
operand
Similar to the issue reported in
https://github.com/llvm/llvm-project/pull/128299/files, ExpandMath
pattern for rsqrt expects the static shaped operands. Otherwise, it
crashes due to the assertion violation.
---
.../Math/Transforms/ExpandPatterns.cpp | 5 ++++
mlir/test/Dialect/Math/expand-math.mlir | 26 +++++++++++++++++++
2 files changed, 31 insertions(+)
diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
index bb592c667549c..7b5350ca26b60 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
@@ -646,6 +646,11 @@ static LogicalResult convertRsqrtOp(math::RsqrtOp op,
auto operand = op.getOperand();
auto operandTy = operand.getType();
+ // Operand type must be shatic shaped type to create const float.
+ auto shapedOperandType = dyn_cast<ShapedType>(operandTy);
+ if (shapedOperandType && !shapedOperandType.hasStaticShape())
+ return failure();
+
auto eTy = getElementTypeOrSelf(operandTy);
if (!isa<FloatType>(eTy))
return failure();
diff --git a/mlir/test/Dialect/Math/expand-math.mlir b/mlir/test/Dialect/Math/expand-math.mlir
index 946a411e4cc4b..1420acaa40d35 100644
--- a/mlir/test/Dialect/Math/expand-math.mlir
+++ b/mlir/test/Dialect/Math/expand-math.mlir
@@ -787,3 +787,29 @@ func.func @unranked_ceil_op(%arg: tensor<*xf32>) -> tensor<*xf32>{
%a = math.ceil %arg : tensor<*xf32>
return %a: tensor<*xf32>
}
+
+// -----
+
+// CHECK-LABEL: func.func @non_static_shape_rsqrt_op
+// CHECK-SAME: (%[[ARG:.*]]: tensor<?xf32>)
+// CHECK-SAME: -> tensor<?xf32>
+// CHECK: %[[RSQRT:.*]] = math.rsqrt %[[ARG]] : tensor<?xf32>
+// CHECK: return %[[RSQRT]] : tensor<?xf32>
+
+func.func @non_static_shape_rsqrt_op(%arg: tensor<?xf32>) -> tensor<?xf32>{
+ %a = math.rsqrt %arg : tensor<?xf32>
+ return %a: tensor<?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @unranked_rsqrt_op
+// CHECK-SAME: (%[[ARG:.*]]: tensor<*xf32>)
+// CHECK-SAME: -> tensor<*xf32>
+// CHECK: %[[RSQRT:.*]] = math.rsqrt %[[ARG]] : tensor<*xf32>
+// CHECK: return %[[RSQRT]] : tensor<*xf32>
+
+func.func @unranked_rsqrt_op(%arg: tensor<*xf32>) -> tensor<*xf32>{
+ %a = math.rsqrt %arg : tensor<*xf32>
+ return %a: tensor<*xf32>
+}
More information about the Mlir-commits
mailing list