[Mlir-commits] [mlir] [MLIR][MathDialect] fix fp32 promotion crash when encounters scf.if (PR #104451)

Ivy Zhang llvmlistbot at llvm.org
Thu Aug 15 07:56:25 PDT 2024


https://github.com/crazydemo created https://github.com/llvm/llvm-project/pull/104451

This PR fixes a corner case, which needs to do promotion within a `scf.if` block.
The current main branch crashes in such case, due to `scf.if` will also be promoted.

>From 5d1bd57149bd7a0f0c2686ac48afdb8f86802fe3 Mon Sep 17 00:00:00 2001
From: Zhang Yan <yan3.zhang at intel.com>
Date: Thu, 15 Aug 2024 22:24:20 +0800
Subject: [PATCH] make IfOp legal in math pass

---
 mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp |  2 ++
 mlir/test/Dialect/Math/legalize-to-f32.mlir        | 13 +++++++++++++
 2 files changed, 15 insertions(+)

diff --git a/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp b/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
index 3d99f3033cf560..27f364f5cbb055 100644
--- a/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
@@ -19,6 +19,7 @@
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include "llvm/ADT/STLExtras.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
 
 namespace mlir::math {
 #define GEN_PASS_DEF_MATHLEGALIZETOF32
@@ -69,6 +70,7 @@ void mlir::math::populateLegalizeToF32ConversionTarget(
       [&typeConverter](Operation *op) -> bool {
         return typeConverter.isLegal(op);
       });
+  target.addLegalOp<scf::IfOp>();
   target.addLegalOp<FmaOp>();
   target.addLegalOp<arith::ExtFOp, arith::TruncFOp>();
 }
diff --git a/mlir/test/Dialect/Math/legalize-to-f32.mlir b/mlir/test/Dialect/Math/legalize-to-f32.mlir
index ae6ae7c5bc4b44..26a226d336ab43 100644
--- a/mlir/test/Dialect/Math/legalize-to-f32.mlir
+++ b/mlir/test/Dialect/Math/legalize-to-f32.mlir
@@ -83,3 +83,16 @@ func.func @sequences(%arg0: f16) -> f16 {
   %1 = math.sin %0 : f16
   return %1 : f16
 }
+
+// CHECK-LABEL: @promote_in_if_block
+func.func @promote_in_if_block(%arg0: bf16, %arg1: bf16, %arg2: i1) -> bf16 {
+  // CHECK: [[EXTF0:%.+]] = arith.extf [[ARG0]]
+  %0 = scf.if %arg2 -> bf16 {
+    %1 = math.absf %arg0 : bf16
+    // CHECK: [[TRUNCF0:%.+]] = arith.truncf
+    scf.yield %1 : bf16
+  } else {
+    scf.yield %arg1 : bf16
+  }
+  return %0 : bf16
+}



More information about the Mlir-commits mailing list