[Mlir-commits] [mlir] [MLIR][MathDialect] fix fp32 promotion crash when encounters scf.if (PR #104451)
Ivy Zhang
llvmlistbot at llvm.org
Thu Aug 15 07:56:25 PDT 2024
https://github.com/crazydemo created https://github.com/llvm/llvm-project/pull/104451
This PR fixes a corner case, which needs to do promotion within a `scf.if` block.
The current main branch crashes in such case, due to `scf.if` will also be promoted.
>From 5d1bd57149bd7a0f0c2686ac48afdb8f86802fe3 Mon Sep 17 00:00:00 2001
From: Zhang Yan <yan3.zhang at intel.com>
Date: Thu, 15 Aug 2024 22:24:20 +0800
Subject: [PATCH] make IfOp legal in math pass
---
mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp | 2 ++
mlir/test/Dialect/Math/legalize-to-f32.mlir | 13 +++++++++++++
2 files changed, 15 insertions(+)
diff --git a/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp b/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
index 3d99f3033cf560..27f364f5cbb055 100644
--- a/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
@@ -19,6 +19,7 @@
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/STLExtras.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
namespace mlir::math {
#define GEN_PASS_DEF_MATHLEGALIZETOF32
@@ -69,6 +70,7 @@ void mlir::math::populateLegalizeToF32ConversionTarget(
[&typeConverter](Operation *op) -> bool {
return typeConverter.isLegal(op);
});
+ target.addLegalOp<scf::IfOp>();
target.addLegalOp<FmaOp>();
target.addLegalOp<arith::ExtFOp, arith::TruncFOp>();
}
diff --git a/mlir/test/Dialect/Math/legalize-to-f32.mlir b/mlir/test/Dialect/Math/legalize-to-f32.mlir
index ae6ae7c5bc4b44..26a226d336ab43 100644
--- a/mlir/test/Dialect/Math/legalize-to-f32.mlir
+++ b/mlir/test/Dialect/Math/legalize-to-f32.mlir
@@ -83,3 +83,16 @@ func.func @sequences(%arg0: f16) -> f16 {
%1 = math.sin %0 : f16
return %1 : f16
}
+
+// CHECK-LABEL: @promote_in_if_block
+func.func @promote_in_if_block(%arg0: bf16, %arg1: bf16, %arg2: i1) -> bf16 {
+ // CHECK: [[EXTF0:%.+]] = arith.extf [[ARG0]]
+ %0 = scf.if %arg2 -> bf16 {
+ %1 = math.absf %arg0 : bf16
+ // CHECK: [[TRUNCF0:%.+]] = arith.truncf
+ scf.yield %1 : bf16
+ } else {
+ scf.yield %arg1 : bf16
+ }
+ return %0 : bf16
+}
More information about the Mlir-commits
mailing list