[Mlir-commits] [mlir] [MLIR][MathDialect] fix fp32 promotion crash when encounters scf.if (PR #104451)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Aug 15 07:56:57 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-math
@llvm/pr-subscribers-mlir
Author: Ivy Zhang (crazydemo)
<details>
<summary>Changes</summary>
This PR fixes a corner case, which needs to do promotion within a `scf.if` block.
The current main branch crashes in such case, due to `scf.if` will also be promoted.
---
Full diff: https://github.com/llvm/llvm-project/pull/104451.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp (+2)
- (modified) mlir/test/Dialect/Math/legalize-to-f32.mlir (+13)
``````````diff
diff --git a/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp b/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
index 3d99f3033cf560..27f364f5cbb055 100644
--- a/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
@@ -19,6 +19,7 @@
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/STLExtras.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
namespace mlir::math {
#define GEN_PASS_DEF_MATHLEGALIZETOF32
@@ -69,6 +70,7 @@ void mlir::math::populateLegalizeToF32ConversionTarget(
[&typeConverter](Operation *op) -> bool {
return typeConverter.isLegal(op);
});
+ target.addLegalOp<scf::IfOp>();
target.addLegalOp<FmaOp>();
target.addLegalOp<arith::ExtFOp, arith::TruncFOp>();
}
diff --git a/mlir/test/Dialect/Math/legalize-to-f32.mlir b/mlir/test/Dialect/Math/legalize-to-f32.mlir
index ae6ae7c5bc4b44..26a226d336ab43 100644
--- a/mlir/test/Dialect/Math/legalize-to-f32.mlir
+++ b/mlir/test/Dialect/Math/legalize-to-f32.mlir
@@ -83,3 +83,16 @@ func.func @sequences(%arg0: f16) -> f16 {
%1 = math.sin %0 : f16
return %1 : f16
}
+
+// CHECK-LABEL: @promote_in_if_block
+func.func @promote_in_if_block(%arg0: bf16, %arg1: bf16, %arg2: i1) -> bf16 {
+ // CHECK: [[EXTF0:%.+]] = arith.extf [[ARG0]]
+ %0 = scf.if %arg2 -> bf16 {
+ %1 = math.absf %arg0 : bf16
+ // CHECK: [[TRUNCF0:%.+]] = arith.truncf
+ scf.yield %1 : bf16
+ } else {
+ scf.yield %arg1 : bf16
+ }
+ return %0 : bf16
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/104451
More information about the Mlir-commits
mailing list