[Mlir-commits] [mlir] [MLIR][MathDialect] fix fp32 promotion crash when encounters scf.if (PR #104451)

Ivy Zhang llvmlistbot at llvm.org
Mon Aug 19 05:39:17 PDT 2024


https://github.com/crazydemo updated https://github.com/llvm/llvm-project/pull/104451

>From 5d1bd57149bd7a0f0c2686ac48afdb8f86802fe3 Mon Sep 17 00:00:00 2001
From: Zhang Yan <yan3.zhang at intel.com>
Date: Thu, 15 Aug 2024 22:24:20 +0800
Subject: [PATCH 1/4] make IfOp legal in math pass

---
 mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp |  2 ++
 mlir/test/Dialect/Math/legalize-to-f32.mlir        | 13 +++++++++++++
 2 files changed, 15 insertions(+)

diff --git a/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp b/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
index 3d99f3033cf560..27f364f5cbb055 100644
--- a/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
@@ -19,6 +19,7 @@
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include "llvm/ADT/STLExtras.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
 
 namespace mlir::math {
 #define GEN_PASS_DEF_MATHLEGALIZETOF32
@@ -69,6 +70,7 @@ void mlir::math::populateLegalizeToF32ConversionTarget(
       [&typeConverter](Operation *op) -> bool {
         return typeConverter.isLegal(op);
       });
+  target.addLegalOp<scf::IfOp>();
   target.addLegalOp<FmaOp>();
   target.addLegalOp<arith::ExtFOp, arith::TruncFOp>();
 }
diff --git a/mlir/test/Dialect/Math/legalize-to-f32.mlir b/mlir/test/Dialect/Math/legalize-to-f32.mlir
index ae6ae7c5bc4b44..26a226d336ab43 100644
--- a/mlir/test/Dialect/Math/legalize-to-f32.mlir
+++ b/mlir/test/Dialect/Math/legalize-to-f32.mlir
@@ -83,3 +83,16 @@ func.func @sequences(%arg0: f16) -> f16 {
   %1 = math.sin %0 : f16
   return %1 : f16
 }
+
+// CHECK-LABEL: @promote_in_if_block
+func.func @promote_in_if_block(%arg0: bf16, %arg1: bf16, %arg2: i1) -> bf16 {
+  // CHECK: [[EXTF0:%.+]] = arith.extf [[ARG0]]
+  %0 = scf.if %arg2 -> bf16 {
+    %1 = math.absf %arg0 : bf16
+    // CHECK: [[TRUNCF0:%.+]] = arith.truncf
+    scf.yield %1 : bf16
+  } else {
+    scf.yield %arg1 : bf16
+  }
+  return %0 : bf16
+}

>From 5a30187277c3cca5c3f1082e87909bfa19441941 Mon Sep 17 00:00:00 2001
From: Zhang Yan <yan3.zhang at intel.com>
Date: Thu, 15 Aug 2024 23:00:03 +0800
Subject: [PATCH 2/4] fix format

---
 mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp b/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
index 27f364f5cbb055..108670a58d2fc9 100644
--- a/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
@@ -14,12 +14,12 @@
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Math/IR/Math.h"
 #include "mlir/Dialect/Math/Transforms/Passes.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/IR/Diagnostics.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include "llvm/ADT/STLExtras.h"
-#include "mlir/Dialect/SCF/IR/SCF.h"
 
 namespace mlir::math {
 #define GEN_PASS_DEF_MATHLEGALIZETOF32

>From a39597cad26af5591244271463040703216eb668 Mon Sep 17 00:00:00 2001
From: Zhang Yan <yan3.zhang at intel.com>
Date: Fri, 16 Aug 2024 10:13:24 +0800
Subject: [PATCH 3/4] fix ci

---
 mlir/test/Dialect/Math/legalize-to-f32.mlir | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/test/Dialect/Math/legalize-to-f32.mlir b/mlir/test/Dialect/Math/legalize-to-f32.mlir
index 26a226d336ab43..cd3862599d5e26 100644
--- a/mlir/test/Dialect/Math/legalize-to-f32.mlir
+++ b/mlir/test/Dialect/Math/legalize-to-f32.mlir
@@ -86,7 +86,7 @@ func.func @sequences(%arg0: f16) -> f16 {
 
 // CHECK-LABEL: @promote_in_if_block
 func.func @promote_in_if_block(%arg0: bf16, %arg1: bf16, %arg2: i1) -> bf16 {
-  // CHECK: [[EXTF0:%.+]] = arith.extf [[ARG0]]
+  // CHECK: [[EXTF0:%.+]] = arith.extf
   %0 = scf.if %arg2 -> bf16 {
     %1 = math.absf %arg0 : bf16
     // CHECK: [[TRUNCF0:%.+]] = arith.truncf

>From 05925e31529c53a0bf253aa035ecbb917a22ee8b Mon Sep 17 00:00:00 2001
From: Zhang Yan <yan3.zhang at intel.com>
Date: Mon, 19 Aug 2024 20:36:23 +0800
Subject: [PATCH 4/4] fix comments

---
 mlir/lib/Conversion/MathToLibm/MathToLibm.cpp |  1 +
 .../Dialect/Math/Transforms/LegalizeToF32.cpp | 10 ++---
 .../MathToLibm/convert-to-libm.mlir           | 40 +++++++++++++++++++
 mlir/test/Dialect/Math/legalize-to-f32.mlir   |  1 +
 4 files changed, 46 insertions(+), 6 deletions(-)

diff --git a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
index 5b1c59d0c95e92..a2488dc600f51a 100644
--- a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
+++ b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
@@ -191,6 +191,7 @@ void mlir::populateMathToLibmConversionPatterns(RewritePatternSet &patterns) {
   populatePatternsForOp<math::SinOp>(patterns, ctx, "sinf", "sin");
   populatePatternsForOp<math::SinhOp>(patterns, ctx, "sinhf", "sinh");
   populatePatternsForOp<math::SqrtOp>(patterns, ctx, "sqrtf", "sqrt");
+  populatePatternsForOp<math::RsqrtOp>(patterns, ctx, "rsqrtf", "rsqrt");
   populatePatternsForOp<math::TanOp>(patterns, ctx, "tanf", "tan");
   populatePatternsForOp<math::TanhOp>(patterns, ctx, "tanhf", "tanh");
   populatePatternsForOp<math::TruncOp>(patterns, ctx, "truncf", "trunc");
diff --git a/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp b/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
index 108670a58d2fc9..5aab150f53922b 100644
--- a/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
@@ -14,7 +14,6 @@
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Math/IR/Math.h"
 #include "mlir/Dialect/Math/Transforms/Passes.h"
-#include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/IR/Diagnostics.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
@@ -66,11 +65,10 @@ void mlir::math::populateLegalizeToF32TypeConverter(
 
 void mlir::math::populateLegalizeToF32ConversionTarget(
     ConversionTarget &target, TypeConverter &typeConverter) {
-  target.addDynamicallyLegalDialect<MathDialect>(
-      [&typeConverter](Operation *op) -> bool {
-        return typeConverter.isLegal(op);
-      });
-  target.addLegalOp<scf::IfOp>();
+  target.markUnknownOpDynamicallyLegal([&typeConverter](Operation *op) -> bool {
+    if (isa<MathDialect>(op->getDialect())) return typeConverter.isLegal(op);
+    return true;
+  });
   target.addLegalOp<FmaOp>();
   target.addLegalOp<arith::ExtFOp, arith::TruncFOp>();
 }
diff --git a/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir b/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir
index ffef12250595f0..08354dbf280c16 100644
--- a/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir
+++ b/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir
@@ -1,3 +1,4 @@
+
 // RUN: mlir-opt %s -convert-math-to-libm -canonicalize | FileCheck %s
 
 // CHECK-DAG: @acos(f64) -> f64 attributes {llvm.readnone}
@@ -58,6 +59,8 @@
 // CHECK-DAG: @ceilf(f32) -> f32 attributes {llvm.readnone}
 // CHECK-DAG: @sqrt(f64) -> f64 attributes {llvm.readnone}
 // CHECK-DAG: @sqrtf(f32) -> f32 attributes {llvm.readnone}
+// CHECK-DAG: @rsqrt(f64) -> f64 attributes {llvm.readnone}
+// CHECK-DAG: @rsqrtf(f32) -> f32 attributes {llvm.readnone}
 // CHECK-DAG: @pow(f64, f64) -> f64 attributes {llvm.readnone}
 // CHECK-DAG: @powf(f32, f32) -> f32 attributes {llvm.readnone}
 
@@ -999,6 +1002,43 @@ func.func @sqrt_vec_caller(%float: vector<2xf32>, %double: vector<2xf64>) -> (ve
 // CHECK:           return %[[VAL_11]], %[[VAL_17]] : vector<2xf32>, vector<2xf64>
 // CHECK:         }
 
+// CHECK-LABEL: func @rsqrt_caller
+// CHECK-SAME: %[[FLOAT:.*]]: f32
+// CHECK-SAME: %[[DOUBLE:.*]]: f64
+func.func @rsqrt_caller(%float: f32, %double: f64) -> (f32, f64) {
+  // CHECK-DAG: %[[FLOAT_RESULT:.*]] = call @rsqrtf(%[[FLOAT]]) : (f32) -> f32
+  %float_result = math.rsqrt %float : f32
+  // CHECK-DAG: %[[DOUBLE_RESULT:.*]] = call @rsqrt(%[[DOUBLE]]) : (f64) -> f64
+  %double_result = math.rsqrt %double : f64
+  // CHECK: return %[[FLOAT_RESULT]], %[[DOUBLE_RESULT]]
+  return %float_result, %double_result : f32, f64
+}
+
+func.func @rsqrt_vec_caller(%float: vector<2xf32>, %double: vector<2xf64>) -> (vector<2xf32>, vector<2xf64>) {
+  %float_result = math.rsqrt %float : vector<2xf32>
+  %double_result = math.rsqrt %double : vector<2xf64>
+  return %float_result, %double_result : vector<2xf32>, vector<2xf64>
+}
+// CHECK-LABEL:   func @rsqrt_vec_caller(
+// CHECK-SAME:                           %[[VAL_0:.*]]: vector<2xf32>,
+// CHECK-SAME:                           %[[VAL_1:.*]]: vector<2xf64>) -> (vector<2xf32>, vector<2xf64>) {
+// CHECK-DAG:       %[[CVF:.*]] = arith.constant dense<0.000000e+00> : vector<2xf32>
+// CHECK-DAG:       %[[CVD:.*]] = arith.constant dense<0.000000e+00> : vector<2xf64>
+// CHECK:           %[[IN0_F32:.*]] = vector.extract %[[VAL_0]][0] : f32 from vector<2xf32>
+// CHECK:           %[[OUT0_F32:.*]] = call @rsqrtf(%[[IN0_F32]]) : (f32) -> f32
+// CHECK:           %[[VAL_8:.*]] = vector.insert %[[OUT0_F32]], %[[CVF]] [0] : f32 into vector<2xf32>
+// CHECK:           %[[IN1_F32:.*]] = vector.extract %[[VAL_0]][1] : f32 from vector<2xf32>
+// CHECK:           %[[OUT1_F32:.*]] = call @rsqrtf(%[[IN1_F32]]) : (f32) -> f32
+// CHECK:           %[[VAL_11:.*]] = vector.insert %[[OUT1_F32]], %[[VAL_8]] [1] : f32 into vector<2xf32>
+// CHECK:           %[[IN0_F64:.*]] = vector.extract %[[VAL_1]][0] : f64 from vector<2xf64>
+// CHECK:           %[[OUT0_F64:.*]] = call @rsqrt(%[[IN0_F64]]) : (f64) -> f64
+// CHECK:           %[[VAL_14:.*]] = vector.insert %[[OUT0_F64]], %[[CVD]] [0] : f64 into vector<2xf64>
+// CHECK:           %[[IN1_F64:.*]] = vector.extract %[[VAL_1]][1] : f64 from vector<2xf64>
+// CHECK:           %[[OUT1_F64:.*]] = call @rsqrt(%[[IN1_F64]]) : (f64) -> f64
+// CHECK:           %[[VAL_17:.*]] = vector.insert %[[OUT1_F64]], %[[VAL_14]] [1] : f64 into vector<2xf64>
+// CHECK:           return %[[VAL_11]], %[[VAL_17]] : vector<2xf32>, vector<2xf64>
+// CHECK:         }
+
 // CHECK-LABEL: func @powf_caller(
 // CHECK-SAME: %[[FLOATA:.*]]: f32, %[[FLOATB:.*]]: f32
 // CHECK-SAME: %[[DOUBLEA:.*]]: f64, %[[DOUBLEB:.*]]: f64
diff --git a/mlir/test/Dialect/Math/legalize-to-f32.mlir b/mlir/test/Dialect/Math/legalize-to-f32.mlir
index cd3862599d5e26..ebb0de9d2653e2 100644
--- a/mlir/test/Dialect/Math/legalize-to-f32.mlir
+++ b/mlir/test/Dialect/Math/legalize-to-f32.mlir
@@ -87,6 +87,7 @@ func.func @sequences(%arg0: f16) -> f16 {
 // CHECK-LABEL: @promote_in_if_block
 func.func @promote_in_if_block(%arg0: bf16, %arg1: bf16, %arg2: i1) -> bf16 {
   // CHECK: [[EXTF0:%.+]] = arith.extf
+  // CHECK-NEXT: %[[RES:.*]] = scf.if
   %0 = scf.if %arg2 -> bf16 {
     %1 = math.absf %arg0 : bf16
     // CHECK: [[TRUNCF0:%.+]] = arith.truncf



More information about the Mlir-commits mailing list