[Mlir-commits] [mlir] [MLIR][MathDialect] fix fp32 promotion crash when encounters scf.if (PR #104451)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Aug 15 07:56:57 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-math

@llvm/pr-subscribers-mlir

Author: Ivy Zhang (crazydemo)

<details>
<summary>Changes</summary>

This PR fixes a corner case, which needs to do promotion within a `scf.if` block.
The current main branch crashes in such case, due to `scf.if` will also be promoted.

---
Full diff: https://github.com/llvm/llvm-project/pull/104451.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp (+2) 
- (modified) mlir/test/Dialect/Math/legalize-to-f32.mlir (+13) 


``````````diff
diff --git a/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp b/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
index 3d99f3033cf560..27f364f5cbb055 100644
--- a/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
@@ -19,6 +19,7 @@
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include "llvm/ADT/STLExtras.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
 
 namespace mlir::math {
 #define GEN_PASS_DEF_MATHLEGALIZETOF32
@@ -69,6 +70,7 @@ void mlir::math::populateLegalizeToF32ConversionTarget(
       [&typeConverter](Operation *op) -> bool {
         return typeConverter.isLegal(op);
       });
+  target.addLegalOp<scf::IfOp>();
   target.addLegalOp<FmaOp>();
   target.addLegalOp<arith::ExtFOp, arith::TruncFOp>();
 }
diff --git a/mlir/test/Dialect/Math/legalize-to-f32.mlir b/mlir/test/Dialect/Math/legalize-to-f32.mlir
index ae6ae7c5bc4b44..26a226d336ab43 100644
--- a/mlir/test/Dialect/Math/legalize-to-f32.mlir
+++ b/mlir/test/Dialect/Math/legalize-to-f32.mlir
@@ -83,3 +83,16 @@ func.func @sequences(%arg0: f16) -> f16 {
   %1 = math.sin %0 : f16
   return %1 : f16
 }
+
+// CHECK-LABEL: @promote_in_if_block
+func.func @promote_in_if_block(%arg0: bf16, %arg1: bf16, %arg2: i1) -> bf16 {
+  // CHECK: [[EXTF0:%.+]] = arith.extf [[ARG0]]
+  %0 = scf.if %arg2 -> bf16 {
+    %1 = math.absf %arg0 : bf16
+    // CHECK: [[TRUNCF0:%.+]] = arith.truncf
+    scf.yield %1 : bf16
+  } else {
+    scf.yield %arg1 : bf16
+  }
+  return %0 : bf16
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/104451


More information about the Mlir-commits mailing list