[Mlir-commits] [mlir] [MLIR][MathDialect] fix fp32 promotion crash when encounters scf.if (PR #104451)

Ivy Zhang llvmlistbot at llvm.org
Thu Aug 15 19:16:02 PDT 2024


https://github.com/crazydemo updated https://github.com/llvm/llvm-project/pull/104451

>From 5d1bd57149bd7a0f0c2686ac48afdb8f86802fe3 Mon Sep 17 00:00:00 2001
From: Zhang Yan <yan3.zhang at intel.com>
Date: Thu, 15 Aug 2024 22:24:20 +0800
Subject: [PATCH 1/3] make IfOp legal in math pass

---
 mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp |  2 ++
 mlir/test/Dialect/Math/legalize-to-f32.mlir        | 13 +++++++++++++
 2 files changed, 15 insertions(+)

diff --git a/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp b/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
index 3d99f3033cf560..27f364f5cbb055 100644
--- a/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
@@ -19,6 +19,7 @@
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include "llvm/ADT/STLExtras.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
 
 namespace mlir::math {
 #define GEN_PASS_DEF_MATHLEGALIZETOF32
@@ -69,6 +70,7 @@ void mlir::math::populateLegalizeToF32ConversionTarget(
       [&typeConverter](Operation *op) -> bool {
         return typeConverter.isLegal(op);
       });
+  target.addLegalOp<scf::IfOp>();
   target.addLegalOp<FmaOp>();
   target.addLegalOp<arith::ExtFOp, arith::TruncFOp>();
 }
diff --git a/mlir/test/Dialect/Math/legalize-to-f32.mlir b/mlir/test/Dialect/Math/legalize-to-f32.mlir
index ae6ae7c5bc4b44..26a226d336ab43 100644
--- a/mlir/test/Dialect/Math/legalize-to-f32.mlir
+++ b/mlir/test/Dialect/Math/legalize-to-f32.mlir
@@ -83,3 +83,16 @@ func.func @sequences(%arg0: f16) -> f16 {
   %1 = math.sin %0 : f16
   return %1 : f16
 }
+
+// CHECK-LABEL: @promote_in_if_block
+func.func @promote_in_if_block(%arg0: bf16, %arg1: bf16, %arg2: i1) -> bf16 {
+  // CHECK: [[EXTF0:%.+]] = arith.extf [[ARG0]]
+  %0 = scf.if %arg2 -> bf16 {
+    %1 = math.absf %arg0 : bf16
+    // CHECK: [[TRUNCF0:%.+]] = arith.truncf
+    scf.yield %1 : bf16
+  } else {
+    scf.yield %arg1 : bf16
+  }
+  return %0 : bf16
+}

>From 5a30187277c3cca5c3f1082e87909bfa19441941 Mon Sep 17 00:00:00 2001
From: Zhang Yan <yan3.zhang at intel.com>
Date: Thu, 15 Aug 2024 23:00:03 +0800
Subject: [PATCH 2/3] fix format

---
 mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp b/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
index 27f364f5cbb055..108670a58d2fc9 100644
--- a/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
@@ -14,12 +14,12 @@
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Math/IR/Math.h"
 #include "mlir/Dialect/Math/Transforms/Passes.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/IR/Diagnostics.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include "llvm/ADT/STLExtras.h"
-#include "mlir/Dialect/SCF/IR/SCF.h"
 
 namespace mlir::math {
 #define GEN_PASS_DEF_MATHLEGALIZETOF32

>From a39597cad26af5591244271463040703216eb668 Mon Sep 17 00:00:00 2001
From: Zhang Yan <yan3.zhang at intel.com>
Date: Fri, 16 Aug 2024 10:13:24 +0800
Subject: [PATCH 3/3] fix ci

---
 mlir/test/Dialect/Math/legalize-to-f32.mlir | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/test/Dialect/Math/legalize-to-f32.mlir b/mlir/test/Dialect/Math/legalize-to-f32.mlir
index 26a226d336ab43..cd3862599d5e26 100644
--- a/mlir/test/Dialect/Math/legalize-to-f32.mlir
+++ b/mlir/test/Dialect/Math/legalize-to-f32.mlir
@@ -86,7 +86,7 @@ func.func @sequences(%arg0: f16) -> f16 {
 
 // CHECK-LABEL: @promote_in_if_block
 func.func @promote_in_if_block(%arg0: bf16, %arg1: bf16, %arg2: i1) -> bf16 {
-  // CHECK: [[EXTF0:%.+]] = arith.extf [[ARG0]]
+  // CHECK: [[EXTF0:%.+]] = arith.extf
   %0 = scf.if %arg2 -> bf16 {
     %1 = math.absf %arg0 : bf16
     // CHECK: [[TRUNCF0:%.+]] = arith.truncf



More information about the Mlir-commits mailing list