[Mlir-commits] [mlir] f941908 - Revert "[MLIR][Arith] add fastMathAttr on arith::extf and arith::truncf" (#95344)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Jun 12 20:23:24 PDT 2024


Author: Ivy Zhang
Date: 2024-06-13T11:23:20+08:00
New Revision: f941908d77e0a009351e5d5d3f01c704b5ff2ff7

URL: https://github.com/llvm/llvm-project/commit/f941908d77e0a009351e5d5d3f01c704b5ff2ff7
DIFF: https://github.com/llvm/llvm-project/commit/f941908d77e0a009351e5d5d3f01c704b5ff2ff7.diff

LOG: Revert "[MLIR][Arith] add fastMathAttr on arith::extf and arith::truncf" (#95344)

Reverts llvm/llvm-project#93443

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
    mlir/lib/Dialect/Arith/IR/ArithOps.cpp
    mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
    mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
    mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
    mlir/test/Dialect/Arith/canonicalize.mlir
    mlir/test/Dialect/Arith/emulate-unsupported-floats.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index c4471f9bc5af2..06fbdb7f2c4cb 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -1199,7 +1199,7 @@ def Arith_ExtSIOp : Arith_IToICastOp<"extsi"> {
 // ExtFOp
 //===----------------------------------------------------------------------===//
 
-def Arith_ExtFOp : Arith_FToFCastOp<"extf", [DeclareOpInterfaceMethods<ArithFastMathInterface>]> {
+def Arith_ExtFOp : Arith_FToFCastOp<"extf"> {
   let summary = "cast from floating-point to wider floating-point";
   let description = [{
     Cast a floating-point value to a larger floating-point-typed value.
@@ -1208,13 +1208,6 @@ def Arith_ExtFOp : Arith_FToFCastOp<"extf", [DeclareOpInterfaceMethods<ArithFast
   }];
   let hasVerifier = 1;
   let hasFolder = 1;
-
-  let arguments = (ins FloatLike:$in, DefaultValuedAttr<
-                         Arith_FastMathAttr, "::mlir::arith::FastMathFlags::none">:$fastmath);
-  let results = (outs FloatLike:$out);
-
-  let assemblyFormat = [{ $in (`fastmath` `` $fastmath^)?
-                          attr-dict `:` type($in) `to` type($out) }];
 }
 
 //===----------------------------------------------------------------------===//
@@ -1253,11 +1246,8 @@ def Arith_TruncFOp :
     Arith_Op<"truncf",
       [Pure, SameOperandsAndResultShape, SameInputOutputTensorDims,
        DeclareOpInterfaceMethods<ArithRoundingModeInterface>,
-       DeclareOpInterfaceMethods<ArithFastMathInterface>,
        DeclareOpInterfaceMethods<CastOpInterface>]>,
     Arguments<(ins FloatLike:$in,
-                   DefaultValuedAttr<
-                      Arith_FastMathAttr, "::mlir::arith::FastMathFlags::none">:$fastmath,
                    OptionalAttr<Arith_RoundingModeAttr>:$roundingmode)>,
     Results<(outs FloatLike:$out)> {
   let summary = "cast from floating-point to narrower floating-point";
@@ -1277,9 +1267,7 @@ def Arith_TruncFOp :
 
   let hasFolder = 1;
   let hasVerifier = 1;
-  let assemblyFormat = [{ $in ($roundingmode^)?
-                          (`fastmath` `` $fastmath^)?
-                          attr-dict `:` type($in) `to` type($out) }];
+  let assemblyFormat = "$in ($roundingmode^)? attr-dict `:` type($in) `to` type($out)";
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 291f6e5424ba5..2f6647a2a27b1 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -1390,20 +1390,6 @@ LogicalResult arith::ExtSIOp::verify() {
 /// Fold extension of float constants when there is no information loss due the
 /// 
diff erence in fp semantics.
 OpFoldResult arith::ExtFOp::fold(FoldAdaptor adaptor) {
-  if (auto truncFOp = getOperand().getDefiningOp<TruncFOp>()) {
-    if (truncFOp.getOperand().getType() == getType()) {
-      arith::FastMathFlags truncFMF = truncFOp.getFastmath();
-      bool isTruncContract =
-          bitEnumContainsAll(truncFMF, arith::FastMathFlags::contract);
-      arith::FastMathFlags extFMF = getFastmath();
-      bool isExtContract =
-          bitEnumContainsAll(extFMF, arith::FastMathFlags::contract);
-      if (isTruncContract && isExtContract) {
-        return truncFOp.getOperand();
-      }
-    }
-  }
-
   auto resElemType = cast<FloatType>(getElementTypeOrSelf(getType()));
   const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
   return constFoldCastOp<FloatAttr, FloatAttr>(

diff  --git a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
index 8e1cb474feee7..4a50da3513f99 100644
--- a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
@@ -94,11 +94,8 @@ void EmulateFloatPattern::rewrite(Operation *op, ArrayRef<Value> operands,
   SmallVector<Value> newResults(expandedOp->getResults());
   for (auto [res, oldType, newType] : llvm::zip_equal(
            MutableArrayRef{newResults}, op->getResultTypes(), resultTypes)) {
-    if (oldType != newType) {
-      auto truncFOp = rewriter.create<arith::TruncFOp>(loc, oldType, res);
-      truncFOp.setFastmath(arith::FastMathFlags::contract);
-      res = truncFOp.getResult();
-    }
+    if (oldType != newType)
+      res = rewriter.create<arith::TruncFOp>(loc, oldType, res);
   }
   rewriter.replaceOp(op, newResults);
 }
@@ -117,9 +114,7 @@ void mlir::arith::populateEmulateUnsupportedFloatsConversions(
   });
   converter.addTargetMaterialization(
       [](OpBuilder &b, Type target, ValueRange input, Location loc) {
-        auto extFOp = b.create<arith::ExtFOp>(loc, target, input);
-        extFOp.setFastmath(arith::FastMathFlags::contract);
-        return extFOp;
+        return b.create<arith::ExtFOp>(loc, target, input);
       });
 }
 

diff  --git a/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp b/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
index 3d99f3033cf56..5998133b7eab8 100644
--- a/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
@@ -57,9 +57,7 @@ void mlir::math::populateLegalizeToF32TypeConverter(
   });
   typeConverter.addTargetMaterialization(
       [](OpBuilder &b, Type target, ValueRange input, Location loc) {
-        auto extFOp = b.create<arith::ExtFOp>(loc, target, input);
-        extFOp.setFastmath(arith::FastMathFlags::contract);
-        return extFOp;
+        return b.create<arith::ExtFOp>(loc, target, input);
       });
 }
 
@@ -86,11 +84,8 @@ LogicalResult LegalizeToF32RewritePattern::matchAndRewrite(
   SmallVector<Value> results = (*legalized)->getResults();
   for (auto [result, newType, origType] : llvm::zip_equal(
            results, (*legalized)->getResultTypes(), op->getResultTypes())) {
-    if (newType != origType) {
-      auto truncFOp = rewriter.create<arith::TruncFOp>(loc, origType, result);
-      truncFOp.setFastmath(arith::FastMathFlags::contract);
-      result = truncFOp.getResult();
-    }
+    if (newType != origType)
+      result = rewriter.create<arith::TruncFOp>(loc, origType, result);
   }
   rewriter.replaceOp(op, results);
   return success();

diff  --git a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
index cacdd801871fb..56ae930e6d627 100644
--- a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
+++ b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
@@ -162,11 +162,11 @@ func.func @uitofp(%arg0 : i32, %arg1 : i64) {
 // Checking conversion of integer types to floating point.
 // CHECK-LABEL: @fpext
 func.func @fpext(%arg0 : f16, %arg1 : f32) {
-// CHECK-NEXT: = llvm.fpext {{.*}} {fastmath = #arith.fastmath<none>} : f16 to f32
+// CHECK-NEXT: = llvm.fpext {{.*}} : f16 to f32
   %0 = arith.extf %arg0: f16 to f32
-// CHECK-NEXT: = llvm.fpext {{.*}} {fastmath = #arith.fastmath<none>} : f16 to f64
+// CHECK-NEXT: = llvm.fpext {{.*}} : f16 to f64
   %1 = arith.extf %arg0: f16 to f64
-// CHECK-NEXT: = llvm.fpext {{.*}} {fastmath = #arith.fastmath<none>} : f32 to f64
+// CHECK-NEXT: = llvm.fpext {{.*}} : f32 to f64
   %2 = arith.extf %arg1: f32 to f64
   return
 }
@@ -174,11 +174,11 @@ func.func @fpext(%arg0 : f16, %arg1 : f32) {
 // Checking conversion of integer types to floating point.
 // CHECK-LABEL: @fpext
 func.func @fpext_vector(%arg0 : vector<2xf16>, %arg1 : vector<2xf32>) {
-// CHECK-NEXT: = llvm.fpext {{.*}} {fastmath = #arith.fastmath<none>} : vector<2xf16> to vector<2xf32>
+// CHECK-NEXT: = llvm.fpext {{.*}} : vector<2xf16> to vector<2xf32>
   %0 = arith.extf %arg0: vector<2xf16> to vector<2xf32>
-// CHECK-NEXT: = llvm.fpext {{.*}} {fastmath = #arith.fastmath<none>} : vector<2xf16> to vector<2xf64>
+// CHECK-NEXT: = llvm.fpext {{.*}} : vector<2xf16> to vector<2xf64>
   %1 = arith.extf %arg0: vector<2xf16> to vector<2xf64>
-// CHECK-NEXT: = llvm.fpext {{.*}} {fastmath = #arith.fastmath<none>} : vector<2xf32> to vector<2xf64>
+// CHECK-NEXT: = llvm.fpext {{.*}} : vector<2xf32> to vector<2xf64>
   %2 = arith.extf %arg1: vector<2xf32> to vector<2xf64>
   return
 }
@@ -268,11 +268,11 @@ func.func @uitofp_vector(%arg0 : vector<2xi16>, %arg1 : vector<2xi32>, %arg2 : v
 // Checking conversion of integer types to floating point.
 // CHECK-LABEL: @fptrunc
 func.func @fptrunc(%arg0 : f32, %arg1 : f64) {
-// CHECK-NEXT: = llvm.fptrunc {{.*}} {fastmath = #arith.fastmath<none>} : f32 to f16
+// CHECK-NEXT: = llvm.fptrunc {{.*}} : f32 to f16
   %0 = arith.truncf %arg0: f32 to f16
-// CHECK-NEXT: = llvm.fptrunc {{.*}} {fastmath = #arith.fastmath<none>} : f64 to f16
+// CHECK-NEXT: = llvm.fptrunc {{.*}} : f64 to f16
   %1 = arith.truncf %arg1: f64 to f16
-// CHECK-NEXT: = llvm.fptrunc {{.*}} {fastmath = #arith.fastmath<none>} : f64 to f32
+// CHECK-NEXT: = llvm.fptrunc {{.*}} : f64 to f32
   %2 = arith.truncf %arg1: f64 to f32
   return
 }
@@ -280,26 +280,26 @@ func.func @fptrunc(%arg0 : f32, %arg1 : f64) {
 // Checking conversion of integer types to floating point.
 // CHECK-LABEL: @fptrunc
 func.func @fptrunc_vector(%arg0 : vector<2xf32>, %arg1 : vector<2xf64>) {
-// CHECK-NEXT: = llvm.fptrunc {{.*}} {fastmath = #arith.fastmath<none>} : vector<2xf32> to vector<2xf16>
+// CHECK-NEXT: = llvm.fptrunc {{.*}} : vector<2xf32> to vector<2xf16>
   %0 = arith.truncf %arg0: vector<2xf32> to vector<2xf16>
-// CHECK-NEXT: = llvm.fptrunc {{.*}} {fastmath = #arith.fastmath<none>} : vector<2xf64> to vector<2xf16>
+// CHECK-NEXT: = llvm.fptrunc {{.*}} : vector<2xf64> to vector<2xf16>
   %1 = arith.truncf %arg1: vector<2xf64> to vector<2xf16>
-// CHECK-NEXT: = llvm.fptrunc {{.*}} {fastmath = #arith.fastmath<none>} : vector<2xf64> to vector<2xf32>
+// CHECK-NEXT: = llvm.fptrunc {{.*}} : vector<2xf64> to vector<2xf32>
   %2 = arith.truncf %arg1: vector<2xf64> to vector<2xf32>
   return
 }
 
 // CHECK-LABEL: experimental_constrained_fptrunc
 func.func @experimental_constrained_fptrunc(%arg0 : f64) {
-// CHECK-NEXT: = llvm.intr.experimental.constrained.fptrunc {{.*}} tonearest ignore {fastmath = #arith.fastmath<none>} : f64 to f32
+// CHECK-NEXT: = llvm.intr.experimental.constrained.fptrunc {{.*}} tonearest ignore : f64 to f32
   %0 = arith.truncf %arg0 to_nearest_even : f64 to f32
-// CHECK-NEXT: = llvm.intr.experimental.constrained.fptrunc {{.*}} downward ignore {fastmath = #arith.fastmath<none>} : f64 to f32
+// CHECK-NEXT: = llvm.intr.experimental.constrained.fptrunc {{.*}} downward ignore : f64 to f32
   %1 = arith.truncf %arg0 downward : f64 to f32
-// CHECK-NEXT: = llvm.intr.experimental.constrained.fptrunc {{.*}} upward ignore {fastmath = #arith.fastmath<none>} : f64 to f32
+// CHECK-NEXT: = llvm.intr.experimental.constrained.fptrunc {{.*}} upward ignore : f64 to f32
   %2 = arith.truncf %arg0 upward : f64 to f32
-// CHECK-NEXT: = llvm.intr.experimental.constrained.fptrunc {{.*}} towardzero ignore {fastmath = #arith.fastmath<none>} : f64 to f32
+// CHECK-NEXT: = llvm.intr.experimental.constrained.fptrunc {{.*}} towardzero ignore : f64 to f32
   %3 = arith.truncf %arg0 toward_zero : f64 to f32
-// CHECK-NEXT: = llvm.intr.experimental.constrained.fptrunc {{.*}} tonearestaway ignore {fastmath = #arith.fastmath<none>} : f64 to f32
+// CHECK-NEXT: = llvm.intr.experimental.constrained.fptrunc {{.*}} tonearestaway ignore : f64 to f32
   %4 = arith.truncf %arg0 to_nearest_away : f64 to f32
   return
 }

diff  --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index 4fe7cfb689be8..e4f95bb0545a2 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -3031,143 +3031,6 @@ func.func @mulsi_extended_i0() -> (i0, i0) {
   return %mulsi_extended#0, %mulsi_extended#1 : i0, i0
 }
 
-// CHECK-LABEL: @sequences_fastmath_contract
-// CHECK-SAME: ([[ARG0:%.+]]: bf16)
-// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
-// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]]
-// CHECK: [[SIN:%.+]] = math.sin [[ABSF]]
-// CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]]
-// CHECK: return [[TRUNCF]] : bf16
-func.func @sequences_fastmath_contract(%arg0: bf16) -> bf16 {
-  %0 = arith.extf %arg0 fastmath<contract> : bf16 to f32
-  %1 = math.absf %0 : f32
-  %2 = arith.truncf %1 fastmath<contract> : f32 to bf16
-  %3 = arith.extf %2 fastmath<contract> : bf16 to f32
-  %4 = math.sin %3 : f32
-  %5 = arith.truncf %4 fastmath<contract> : f32 to bf16
-  return %5 : bf16
-}
-
-// CHECK-LABEL: @sequences_no_fastmath
-// CHECK-SAME: ([[ARG0:%.+]]: bf16)
-// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
-// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]]
-// CHECK: [[TRUNCF1:%.+]] = arith.truncf [[ABSF]]
-// CHECK: [[EXTF1:%.+]] = arith.extf [[TRUNCF1]]
-// CHECK: [[SIN:%.+]] = math.sin [[EXTF1]]
-// CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]]
-// CHECK: return [[TRUNCF]] : bf16
-func.func @sequences_no_fastmath(%arg0: bf16) -> bf16 {
-  %0 = arith.extf %arg0 : bf16 to f32
-  %1 = math.absf %0 : f32
-  %2 = arith.truncf %1 : f32 to bf16
-  %3 = arith.extf %2 : bf16 to f32
-  %4 = math.sin %3 : f32
-  %5 = arith.truncf %4 : f32 to bf16
-  return %5 : bf16
-}
-
-// CHECK-LABEL: @eliminate_cast_to_f16
-// CHECK: return [[arg0:%.+]] : f32
-func.func @eliminate_cast_to_f16(%arg0: f32) -> f32 {
-  %0 = arith.truncf %arg0 fastmath<contract> : f32 to f16
-  %1 = arith.extf %0 fastmath<contract> : f16 to f32
-  return %1 : f32
-}
-
-// CHECK-LABEL: @eliminate_cast_to_bf16
-// CHECK: return [[arg0:%.+]] : f32
-func.func @eliminate_cast_to_bf16(%arg0: f32) -> f32 {
-  %0 = arith.truncf %arg0 fastmath<contract> : f32 to bf16
-  %1 = arith.extf %0 fastmath<contract> : bf16 to f32
-  return %1 : f32
-}
-
-// CHECK-LABEL: @bf16_sin_vector
-// CHECK-SAME: ([[ARG0:%.+]]: vector<32x32x32xbf16>)
-// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
-// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]]
-// CHECK: [[SIN:%.+]] = math.sin [[ABSF]]
-// CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]]
-// CHECK: return [[TRUNCF]] : vector<32x32x32xbf16>
-func.func @bf16_sin_vector(%arg0: vector<32x32x32xbf16>) -> vector<32x32x32xbf16> {
-  %0 = arith.extf %arg0 fastmath<contract> : vector<32x32x32xbf16> to vector<32x32x32xf32>
-  %1 = math.absf %0 : vector<32x32x32xf32>
-  %2 = arith.truncf %1 fastmath<contract> : vector<32x32x32xf32> to vector<32x32x32xbf16>
-  %3 = arith.extf %2 fastmath<contract> : vector<32x32x32xbf16> to vector<32x32x32xf32>
-  %4 = math.sin %3 : vector<32x32x32xf32>
-  %5 = arith.truncf %4 fastmath<contract> : vector<32x32x32xf32> to vector<32x32x32xbf16>
-  return %5 : vector<32x32x32xbf16>
-}
-
-// CHECK-LABEL: @f16_sin_vector
-// CHECK-SAME: ([[ARG0:%.+]]: vector<32x32x32xf16>)
-// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
-// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]]
-// CHECK: [[SIN:%.+]] = math.sin [[ABSF]]
-// CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]]
-// CHECK: return [[TRUNCF]] : vector<32x32x32xf16>
-func.func @f16_sin_vector(%arg0: vector<32x32x32xf16>) -> vector<32x32x32xf16> {
-  %0 = arith.extf %arg0 fastmath<contract> : vector<32x32x32xf16> to vector<32x32x32xf32>
-  %1 = math.absf %0 : vector<32x32x32xf32>
-  %2 = arith.truncf %1 fastmath<contract> : vector<32x32x32xf32> to vector<32x32x32xf16>
-  %3 = arith.extf %2 fastmath<contract> : vector<32x32x32xf16> to vector<32x32x32xf32>
-  %4 = math.sin %3 : vector<32x32x32xf32>
-  %5 = arith.truncf %4 fastmath<contract> : vector<32x32x32xf32> to vector<32x32x32xf16>
-  return %5 : vector<32x32x32xf16>
-}
-
-// CHECK-LABEL: @bf16_branch_vector
-// CHECK-SAME: ([[ARG0:%.+]]: vector<32x32x32xbf16>)
-// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
-// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]]
-// CHECK-DAG: [[SIN:%.+]] = math.sin [[ABSF]]
-// CHECK-DAG: [[COS:%.+]] = math.cos [[ABSF]]
-// CHECK: [[ADDF:%.+]] = arith.addf [[SIN]], [[COS]]
-// CHECK: [[TRUNCF:%.+]] = arith.truncf [[ADDF]]
-// CHECK: return [[TRUNCF]] : vector<32x32x32xbf16>
-func.func @bf16_branch_vector(%arg0: vector<32x32x32xbf16>) -> vector<32x32x32xbf16> {
-  %0 = arith.extf %arg0 fastmath<contract> : vector<32x32x32xbf16> to vector<32x32x32xf32>
-  %1 = math.absf %0 : vector<32x32x32xf32>
-  %2 = arith.truncf %1 fastmath<contract> : vector<32x32x32xf32> to vector<32x32x32xbf16>
-  %3 = arith.extf %2 fastmath<contract> : vector<32x32x32xbf16> to vector<32x32x32xf32>
-  %4 = math.sin %3 : vector<32x32x32xf32>
-  %5 = arith.truncf %4 fastmath<contract> : vector<32x32x32xf32> to vector<32x32x32xbf16>
-  %6 = arith.extf %5 fastmath<contract> : vector<32x32x32xbf16> to vector<32x32x32xf32>
-  %7 = math.cos %3 : vector<32x32x32xf32>
-  %8 = arith.truncf %7 fastmath<contract> : vector<32x32x32xf32> to vector<32x32x32xbf16>
-  %9 = arith.extf %8 fastmath<contract> : vector<32x32x32xbf16> to vector<32x32x32xf32>
-  %10 = arith.addf %6, %9 : vector<32x32x32xf32>
-  %11 = arith.truncf %10 fastmath<contract> : vector<32x32x32xf32> to vector<32x32x32xbf16>
-  return %11 : vector<32x32x32xbf16>
-}
-
-// CHECK-LABEL: @bf16_fma
-// CHECK-SAME: ([[ARG0:%.+]]: vector<32x32x32xbf16>, [[ARG1:%.+]]: vector<32x32x32xbf16>, [[ARG2:%.+]]: vector<32x32x32xbf16>)
-// CHECK: [[EXTF0:%.+]] = arith.extf [[ARG0]]
-// CHECK: [[ABSF:%.+]] = math.absf [[EXTF0]]
-// CHECK-DAG: [[SIN:%.+]] = math.sin [[ABSF]]
-// CHECK: [[TRUNCF0:%.+]] = arith.truncf [[SIN]]
-// CHECK-DAG: [[FMA:%.+]] = math.fma [[TRUNCF0]], [[ARG1]], [[ARG2]]
-// CHECK: [[EXTF1:%.+]] = arith.extf [[FMA]]
-// CHECK: [[ADDF:%.+]] = arith.addf [[EXTF1]], [[SIN]]
-// CHECK: [[TRUNCF1:%.+]] = arith.truncf [[ADDF]]
-// CHECK: return [[TRUNCF1]] : vector<32x32x32xbf16>
-func.func @bf16_fma(%arg0: vector<32x32x32xbf16>, %arg1: vector<32x32x32xbf16>, %arg2: vector<32x32x32xbf16>) -> vector<32x32x32xbf16> {
-  %0 = arith.extf %arg0 fastmath<contract> : vector<32x32x32xbf16> to vector<32x32x32xf32>
-  %1 = math.absf %0 : vector<32x32x32xf32>
-  %2 = arith.truncf %1 fastmath<contract> : vector<32x32x32xf32> to vector<32x32x32xbf16>
-  %3 = arith.extf %2 fastmath<contract> : vector<32x32x32xbf16> to vector<32x32x32xf32>
-  %4 = math.sin %3 : vector<32x32x32xf32>
-  %5 = arith.truncf %4 fastmath<contract> : vector<32x32x32xf32> to vector<32x32x32xbf16>
-  %6 = arith.extf %5 fastmath<contract> : vector<32x32x32xbf16> to vector<32x32x32xf32>
-  %7 = math.fma %5, %arg1, %arg2 : vector<32x32x32xbf16>
-  %8 = arith.extf %7 fastmath<contract> : vector<32x32x32xbf16> to vector<32x32x32xf32>
-  %9 = arith.addf %8, %6 : vector<32x32x32xf32>
-  %10 = arith.truncf %9 fastmath<contract> : vector<32x32x32xf32> to vector<32x32x32xbf16>
-  return %10 : vector<32x32x32xbf16>
-}
-
 {-#
   dialect_resources: {
     builtin: {

diff  --git a/mlir/test/Dialect/Arith/emulate-unsupported-floats.mlir b/mlir/test/Dialect/Arith/emulate-unsupported-floats.mlir
index 99790cc45d490..a69ef131d8d47 100644
--- a/mlir/test/Dialect/Arith/emulate-unsupported-floats.mlir
+++ b/mlir/test/Dialect/Arith/emulate-unsupported-floats.mlir
@@ -4,10 +4,10 @@ func.func @basic_expansion(%x: bf16) -> bf16 {
 // CHECK-LABEL: @basic_expansion
 // CHECK-SAME: [[X:%.+]]: bf16
 // CHECK-DAG: [[C:%.+]] = arith.constant {{.*}} : bf16
-// CHECK-DAG: [[X_EXP:%.+]] = arith.extf [[X]] fastmath<contract> : bf16 to f32
-// CHECK-DAG: [[C_EXP:%.+]] = arith.extf [[C]] fastmath<contract> : bf16 to f32
+// CHECK-DAG: [[X_EXP:%.+]] = arith.extf [[X]] : bf16 to f32
+// CHECK-DAG: [[C_EXP:%.+]] = arith.extf [[C]] : bf16 to f32
 // CHECK: [[Y_EXP:%.+]] = arith.addf [[X_EXP]], [[C_EXP]] : f32
-// CHECK: [[Y:%.+]] = arith.truncf [[Y_EXP]] fastmath<contract> : f32 to bf16
+// CHECK: [[Y:%.+]] = arith.truncf [[Y_EXP]] : f32 to bf16
 // CHECK: return [[Y]]
   %c = arith.constant 1.0 : bf16
   %y = arith.addf %x, %c : bf16
@@ -19,15 +19,15 @@ func.func @basic_expansion(%x: bf16) -> bf16 {
 func.func @chained(%x: bf16, %y: bf16, %z: bf16) -> i1 {
 // CHECK-LABEL: @chained
 // CHECK-SAME: [[X:%.+]]: bf16, [[Y:%.+]]: bf16, [[Z:%.+]]: bf16
-// CHECK-DAG: [[X_EXP:%.+]] = arith.extf [[X]] fastmath<contract> : bf16 to f32
-// CHECK-DAG: [[Y_EXP:%.+]] = arith.extf [[Y]] fastmath<contract> : bf16 to f32
-// CHECK-DAG: [[Z_EXP:%.+]] = arith.extf [[Z]] fastmath<contract> : bf16 to f32
+// CHECK-DAG: [[X_EXP:%.+]] = arith.extf [[X]] : bf16 to f32
+// CHECK-DAG: [[Y_EXP:%.+]] = arith.extf [[Y]] : bf16 to f32
+// CHECK-DAG: [[Z_EXP:%.+]] = arith.extf [[Z]] : bf16 to f32
 // CHECK: [[P_EXP:%.+]] = arith.addf [[X_EXP]], [[Y_EXP]] : f32
-// CHECK: [[P:%.+]] = arith.truncf [[P_EXP]] fastmath<contract> : f32 to bf16
-// CHECK: [[P_EXP2:%.+]] = arith.extf [[P]] fastmath<contract> : bf16 to f32
+// CHECK: [[P:%.+]] = arith.truncf [[P_EXP]] : f32 to bf16
+// CHECK: [[P_EXP2:%.+]] = arith.extf [[P]] : bf16 to f32
 // CHECK: [[Q_EXP:%.+]] = arith.mulf [[P_EXP2]], [[Z_EXP]]
-// CHECK: [[Q:%.+]] = arith.truncf [[Q_EXP]] fastmath<contract> : f32 to bf16
-// CHECK: [[Q_EXP2:%.+]] = arith.extf [[Q]] fastmath<contract> : bf16 to f32
+// CHECK: [[Q:%.+]] = arith.truncf [[Q_EXP]] : f32 to bf16
+// CHECK: [[Q_EXP2:%.+]] = arith.extf [[Q]] : bf16 to f32
 // CHECK: [[RES:%.+]] = arith.cmpf ole, [[P_EXP2]], [[Q_EXP2]] : f32
 // CHECK: return [[RES]]
   %p = arith.addf %x, %y : bf16
@@ -41,12 +41,12 @@ func.func @chained(%x: bf16, %y: bf16, %z: bf16) -> i1 {
 func.func @memops(%a: memref<4xf8E4M3FNUZ>, %b: memref<4xf8E4M3FNUZ>) {
 // CHECK-LABEL: @memops
 // CHECK: [[V:%.+]] = memref.load {{.*}} : memref<4xf8E4M3FNUZ>
-// CHECK: [[V_EXP:%.+]] = arith.extf [[V]] fastmath<contract> : f8E4M3FNUZ to f32
+// CHECK: [[V_EXP:%.+]] = arith.extf [[V]] : f8E4M3FNUZ to f32
 // CHECK: memref.store [[V]]
 // CHECK: [[W:%.+]] = memref.load
-// CHECK: [[W_EXP:%.+]] = arith.extf [[W]] fastmath<contract> : f8E4M3FNUZ to f32
+// CHECK: [[W_EXP:%.+]] = arith.extf [[W]] : f8E4M3FNUZ to f32
 // CHECK: [[X_EXP:%.+]] = arith.addf [[V_EXP]], [[W_EXP]] : f32
-// CHECK: [[X:%.+]] = arith.truncf [[X_EXP]] fastmath<contract> : f32 to f8E4M3FNUZ
+// CHECK: [[X:%.+]] = arith.truncf [[X_EXP]] : f32 to f8E4M3FNUZ
 // CHECK: memref.store [[X]]
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index
@@ -63,9 +63,9 @@ func.func @memops(%a: memref<4xf8E4M3FNUZ>, %b: memref<4xf8E4M3FNUZ>) {
 func.func @vectors(%a: vector<4xf8E4M3FNUZ>) -> vector<4xf32> {
 // CHECK-LABEL: @vectors
 // CHECK-SAME: [[A:%.+]]: vector<4xf8E4M3FNUZ>
-// CHECK: [[A_EXP:%.+]] = arith.extf [[A]] fastmath<contract> : vector<4xf8E4M3FNUZ> to vector<4xf32>
+// CHECK: [[A_EXP:%.+]] = arith.extf [[A]] : vector<4xf8E4M3FNUZ> to vector<4xf32>
 // CHECK: [[B_EXP:%.+]] = arith.mulf [[A_EXP]], [[A_EXP]] : vector<4xf32>
-// CHECK: [[B:%.+]] = arith.truncf [[B_EXP]] fastmath<contract> : vector<4xf32> to vector<4xf8E4M3FNUZ>
+// CHECK: [[B:%.+]] = arith.truncf [[B_EXP]] : vector<4xf32> to vector<4xf8E4M3FNUZ>
 // CHECK: [[RET:%.+]] = arith.extf [[B]] : vector<4xf8E4M3FNUZ> to vector<4xf32>
 // CHECK: return [[RET]]
   %b = arith.mulf %a, %a : vector<4xf8E4M3FNUZ>


        


More information about the Mlir-commits mailing list