[Mlir-commits] [mlir] 7042fcc - [MLIR][Arith][Resubmit] add fastMathAttr on arith::extf and arith::truncf (#95346)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jun 14 16:42:33 PDT 2024


Author: Ivy Zhang
Date: 2024-06-15T07:42:29+08:00
New Revision: 7042fcc6389c6c103d501b6f39988eafed0d9b5b

URL: https://github.com/llvm/llvm-project/commit/7042fcc6389c6c103d501b6f39988eafed0d9b5b
DIFF: https://github.com/llvm/llvm-project/commit/7042fcc6389c6c103d501b6f39988eafed0d9b5b.diff

LOG: [MLIR][Arith][Resubmit] add fastMathAttr on arith::extf and arith::truncf (#95346)

Add an `fastMathAttr` on `arith::extf` and `arith::truncf`. If these two
ops are inserted by some promotion passes (like legalize-to-f32 /
emulate-unsupported-floats), they will be labeled as
`FastMathFlags::contract`, denoting that they can be then `eliminated by
canonicalizer`.

The `elimination` can help improve performance, while may introduce some
numerical differences.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
    mlir/lib/Dialect/Arith/IR/ArithOps.cpp
    mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
    mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
    mlir/test/Dialect/Arith/canonicalize.mlir
    mlir/test/Dialect/Arith/emulate-unsupported-floats.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index 06fbdb7f2c4cb..6e24b607cebcf 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -1199,7 +1199,7 @@ def Arith_ExtSIOp : Arith_IToICastOp<"extsi"> {
 // ExtFOp
 //===----------------------------------------------------------------------===//
 
-def Arith_ExtFOp : Arith_FToFCastOp<"extf"> {
+def Arith_ExtFOp : Arith_FToFCastOp<"extf", [DeclareOpInterfaceMethods<ArithFastMathInterface>]> {
   let summary = "cast from floating-point to wider floating-point";
   let description = [{
     Cast a floating-point value to a larger floating-point-typed value.
@@ -1208,6 +1208,13 @@ def Arith_ExtFOp : Arith_FToFCastOp<"extf"> {
   }];
   let hasVerifier = 1;
   let hasFolder = 1;
+
+  let arguments = (ins FloatLike:$in,
+                       OptionalAttr<Arith_FastMathAttr>:$fastmath);
+  let results = (outs FloatLike:$out);
+
+  let assemblyFormat = [{ $in (`fastmath` `` $fastmath^)?
+                          attr-dict `:` type($in) `to` type($out) }];
 }
 
 //===----------------------------------------------------------------------===//
@@ -1246,9 +1253,11 @@ def Arith_TruncFOp :
     Arith_Op<"truncf",
       [Pure, SameOperandsAndResultShape, SameInputOutputTensorDims,
        DeclareOpInterfaceMethods<ArithRoundingModeInterface>,
+       DeclareOpInterfaceMethods<ArithFastMathInterface>,
        DeclareOpInterfaceMethods<CastOpInterface>]>,
     Arguments<(ins FloatLike:$in,
-                   OptionalAttr<Arith_RoundingModeAttr>:$roundingmode)>,
+                   OptionalAttr<Arith_RoundingModeAttr>:$roundingmode,
+                   OptionalAttr<Arith_FastMathAttr>:$fastmath)>,
     Results<(outs FloatLike:$out)> {
   let summary = "cast from floating-point to narrower floating-point";
   let description = [{
@@ -1267,7 +1276,9 @@ def Arith_TruncFOp :
 
   let hasFolder = 1;
   let hasVerifier = 1;
-  let assemblyFormat = "$in ($roundingmode^)? attr-dict `:` type($in) `to` type($out)";
+  let assemblyFormat = [{ $in ($roundingmode^)?
+                          (`fastmath` `` $fastmath^)?
+                          attr-dict `:` type($in) `to` type($out) }];
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 2f6647a2a27b1..f4781fcb546a3 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -1390,6 +1390,22 @@ LogicalResult arith::ExtSIOp::verify() {
 /// Fold extension of float constants when there is no information loss due the
 /// 
diff erence in fp semantics.
 OpFoldResult arith::ExtFOp::fold(FoldAdaptor adaptor) {
+  if (auto truncFOp = getOperand().getDefiningOp<TruncFOp>()) {
+    if (truncFOp.getOperand().getType() == getType()) {
+      arith::FastMathFlags truncFMF =
+          truncFOp.getFastmath().value_or(arith::FastMathFlags::none);
+      bool isTruncContract =
+          bitEnumContainsAll(truncFMF, arith::FastMathFlags::contract);
+      arith::FastMathFlags extFMF =
+          getFastmath().value_or(arith::FastMathFlags::none);
+      bool isExtContract =
+          bitEnumContainsAll(extFMF, arith::FastMathFlags::contract);
+      if (isTruncContract && isExtContract) {
+        return truncFOp.getOperand();
+      }
+    }
+  }
+
   auto resElemType = cast<FloatType>(getElementTypeOrSelf(getType()));
   const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
   return constFoldCastOp<FloatAttr, FloatAttr>(

diff  --git a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
index 4a50da3513f99..8e1cb474feee7 100644
--- a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
@@ -94,8 +94,11 @@ void EmulateFloatPattern::rewrite(Operation *op, ArrayRef<Value> operands,
   SmallVector<Value> newResults(expandedOp->getResults());
   for (auto [res, oldType, newType] : llvm::zip_equal(
            MutableArrayRef{newResults}, op->getResultTypes(), resultTypes)) {
-    if (oldType != newType)
-      res = rewriter.create<arith::TruncFOp>(loc, oldType, res);
+    if (oldType != newType) {
+      auto truncFOp = rewriter.create<arith::TruncFOp>(loc, oldType, res);
+      truncFOp.setFastmath(arith::FastMathFlags::contract);
+      res = truncFOp.getResult();
+    }
   }
   rewriter.replaceOp(op, newResults);
 }
@@ -114,7 +117,9 @@ void mlir::arith::populateEmulateUnsupportedFloatsConversions(
   });
   converter.addTargetMaterialization(
       [](OpBuilder &b, Type target, ValueRange input, Location loc) {
-        return b.create<arith::ExtFOp>(loc, target, input);
+        auto extFOp = b.create<arith::ExtFOp>(loc, target, input);
+        extFOp.setFastmath(arith::FastMathFlags::contract);
+        return extFOp;
       });
 }
 

diff  --git a/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp b/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
index 5998133b7eab8..3d99f3033cf56 100644
--- a/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
@@ -57,7 +57,9 @@ void mlir::math::populateLegalizeToF32TypeConverter(
   });
   typeConverter.addTargetMaterialization(
       [](OpBuilder &b, Type target, ValueRange input, Location loc) {
-        return b.create<arith::ExtFOp>(loc, target, input);
+        auto extFOp = b.create<arith::ExtFOp>(loc, target, input);
+        extFOp.setFastmath(arith::FastMathFlags::contract);
+        return extFOp;
       });
 }
 
@@ -84,8 +86,11 @@ LogicalResult LegalizeToF32RewritePattern::matchAndRewrite(
   SmallVector<Value> results = (*legalized)->getResults();
   for (auto [result, newType, origType] : llvm::zip_equal(
            results, (*legalized)->getResultTypes(), op->getResultTypes())) {
-    if (newType != origType)
-      result = rewriter.create<arith::TruncFOp>(loc, origType, result);
+    if (newType != origType) {
+      auto truncFOp = rewriter.create<arith::TruncFOp>(loc, origType, result);
+      truncFOp.setFastmath(arith::FastMathFlags::contract);
+      result = truncFOp.getResult();
+    }
   }
   rewriter.replaceOp(op, results);
   return success();

diff  --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index e4f95bb0545a2..4fe7cfb689be8 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -3031,6 +3031,143 @@ func.func @mulsi_extended_i0() -> (i0, i0) {
   return %mulsi_extended#0, %mulsi_extended#1 : i0, i0
 }
 
+// CHECK-LABEL: @sequences_fastmath_contract
+// CHECK-SAME: ([[ARG0:%.+]]: bf16)
+// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
+// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]]
+// CHECK: [[SIN:%.+]] = math.sin [[ABSF]]
+// CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]]
+// CHECK: return [[TRUNCF]] : bf16
+func.func @sequences_fastmath_contract(%arg0: bf16) -> bf16 {
+  %0 = arith.extf %arg0 fastmath<contract> : bf16 to f32
+  %1 = math.absf %0 : f32
+  %2 = arith.truncf %1 fastmath<contract> : f32 to bf16
+  %3 = arith.extf %2 fastmath<contract> : bf16 to f32
+  %4 = math.sin %3 : f32
+  %5 = arith.truncf %4 fastmath<contract> : f32 to bf16
+  return %5 : bf16
+}
+
+// CHECK-LABEL: @sequences_no_fastmath
+// CHECK-SAME: ([[ARG0:%.+]]: bf16)
+// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
+// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]]
+// CHECK: [[TRUNCF1:%.+]] = arith.truncf [[ABSF]]
+// CHECK: [[EXTF1:%.+]] = arith.extf [[TRUNCF1]]
+// CHECK: [[SIN:%.+]] = math.sin [[EXTF1]]
+// CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]]
+// CHECK: return [[TRUNCF]] : bf16
+func.func @sequences_no_fastmath(%arg0: bf16) -> bf16 {
+  %0 = arith.extf %arg0 : bf16 to f32
+  %1 = math.absf %0 : f32
+  %2 = arith.truncf %1 : f32 to bf16
+  %3 = arith.extf %2 : bf16 to f32
+  %4 = math.sin %3 : f32
+  %5 = arith.truncf %4 : f32 to bf16
+  return %5 : bf16
+}
+
+// CHECK-LABEL: @eliminate_cast_to_f16
+// CHECK: return [[arg0:%.+]] : f32
+func.func @eliminate_cast_to_f16(%arg0: f32) -> f32 {
+  %0 = arith.truncf %arg0 fastmath<contract> : f32 to f16
+  %1 = arith.extf %0 fastmath<contract> : f16 to f32
+  return %1 : f32
+}
+
+// CHECK-LABEL: @eliminate_cast_to_bf16
+// CHECK: return [[arg0:%.+]] : f32
+func.func @eliminate_cast_to_bf16(%arg0: f32) -> f32 {
+  %0 = arith.truncf %arg0 fastmath<contract> : f32 to bf16
+  %1 = arith.extf %0 fastmath<contract> : bf16 to f32
+  return %1 : f32
+}
+
+// CHECK-LABEL: @bf16_sin_vector
+// CHECK-SAME: ([[ARG0:%.+]]: vector<32x32x32xbf16>)
+// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
+// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]]
+// CHECK: [[SIN:%.+]] = math.sin [[ABSF]]
+// CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]]
+// CHECK: return [[TRUNCF]] : vector<32x32x32xbf16>
+func.func @bf16_sin_vector(%arg0: vector<32x32x32xbf16>) -> vector<32x32x32xbf16> {
+  %0 = arith.extf %arg0 fastmath<contract> : vector<32x32x32xbf16> to vector<32x32x32xf32>
+  %1 = math.absf %0 : vector<32x32x32xf32>
+  %2 = arith.truncf %1 fastmath<contract> : vector<32x32x32xf32> to vector<32x32x32xbf16>
+  %3 = arith.extf %2 fastmath<contract> : vector<32x32x32xbf16> to vector<32x32x32xf32>
+  %4 = math.sin %3 : vector<32x32x32xf32>
+  %5 = arith.truncf %4 fastmath<contract> : vector<32x32x32xf32> to vector<32x32x32xbf16>
+  return %5 : vector<32x32x32xbf16>
+}
+
+// CHECK-LABEL: @f16_sin_vector
+// CHECK-SAME: ([[ARG0:%.+]]: vector<32x32x32xf16>)
+// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
+// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]]
+// CHECK: [[SIN:%.+]] = math.sin [[ABSF]]
+// CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]]
+// CHECK: return [[TRUNCF]] : vector<32x32x32xf16>
+func.func @f16_sin_vector(%arg0: vector<32x32x32xf16>) -> vector<32x32x32xf16> {
+  %0 = arith.extf %arg0 fastmath<contract> : vector<32x32x32xf16> to vector<32x32x32xf32>
+  %1 = math.absf %0 : vector<32x32x32xf32>
+  %2 = arith.truncf %1 fastmath<contract> : vector<32x32x32xf32> to vector<32x32x32xf16>
+  %3 = arith.extf %2 fastmath<contract> : vector<32x32x32xf16> to vector<32x32x32xf32>
+  %4 = math.sin %3 : vector<32x32x32xf32>
+  %5 = arith.truncf %4 fastmath<contract> : vector<32x32x32xf32> to vector<32x32x32xf16>
+  return %5 : vector<32x32x32xf16>
+}
+
+// CHECK-LABEL: @bf16_branch_vector
+// CHECK-SAME: ([[ARG0:%.+]]: vector<32x32x32xbf16>)
+// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
+// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]]
+// CHECK-DAG: [[SIN:%.+]] = math.sin [[ABSF]]
+// CHECK-DAG: [[COS:%.+]] = math.cos [[ABSF]]
+// CHECK: [[ADDF:%.+]] = arith.addf [[SIN]], [[COS]]
+// CHECK: [[TRUNCF:%.+]] = arith.truncf [[ADDF]]
+// CHECK: return [[TRUNCF]] : vector<32x32x32xbf16>
+func.func @bf16_branch_vector(%arg0: vector<32x32x32xbf16>) -> vector<32x32x32xbf16> {
+  %0 = arith.extf %arg0 fastmath<contract> : vector<32x32x32xbf16> to vector<32x32x32xf32>
+  %1 = math.absf %0 : vector<32x32x32xf32>
+  %2 = arith.truncf %1 fastmath<contract> : vector<32x32x32xf32> to vector<32x32x32xbf16>
+  %3 = arith.extf %2 fastmath<contract> : vector<32x32x32xbf16> to vector<32x32x32xf32>
+  %4 = math.sin %3 : vector<32x32x32xf32>
+  %5 = arith.truncf %4 fastmath<contract> : vector<32x32x32xf32> to vector<32x32x32xbf16>
+  %6 = arith.extf %5 fastmath<contract> : vector<32x32x32xbf16> to vector<32x32x32xf32>
+  %7 = math.cos %3 : vector<32x32x32xf32>
+  %8 = arith.truncf %7 fastmath<contract> : vector<32x32x32xf32> to vector<32x32x32xbf16>
+  %9 = arith.extf %8 fastmath<contract> : vector<32x32x32xbf16> to vector<32x32x32xf32>
+  %10 = arith.addf %6, %9 : vector<32x32x32xf32>
+  %11 = arith.truncf %10 fastmath<contract> : vector<32x32x32xf32> to vector<32x32x32xbf16>
+  return %11 : vector<32x32x32xbf16>
+}
+
+// CHECK-LABEL: @bf16_fma
+// CHECK-SAME: ([[ARG0:%.+]]: vector<32x32x32xbf16>, [[ARG1:%.+]]: vector<32x32x32xbf16>, [[ARG2:%.+]]: vector<32x32x32xbf16>)
+// CHECK: [[EXTF0:%.+]] = arith.extf [[ARG0]]
+// CHECK: [[ABSF:%.+]] = math.absf [[EXTF0]]
+// CHECK-DAG: [[SIN:%.+]] = math.sin [[ABSF]]
+// CHECK: [[TRUNCF0:%.+]] = arith.truncf [[SIN]]
+// CHECK-DAG: [[FMA:%.+]] = math.fma [[TRUNCF0]], [[ARG1]], [[ARG2]]
+// CHECK: [[EXTF1:%.+]] = arith.extf [[FMA]]
+// CHECK: [[ADDF:%.+]] = arith.addf [[EXTF1]], [[SIN]]
+// CHECK: [[TRUNCF1:%.+]] = arith.truncf [[ADDF]]
+// CHECK: return [[TRUNCF1]] : vector<32x32x32xbf16>
+func.func @bf16_fma(%arg0: vector<32x32x32xbf16>, %arg1: vector<32x32x32xbf16>, %arg2: vector<32x32x32xbf16>) -> vector<32x32x32xbf16> {
+  %0 = arith.extf %arg0 fastmath<contract> : vector<32x32x32xbf16> to vector<32x32x32xf32>
+  %1 = math.absf %0 : vector<32x32x32xf32>
+  %2 = arith.truncf %1 fastmath<contract> : vector<32x32x32xf32> to vector<32x32x32xbf16>
+  %3 = arith.extf %2 fastmath<contract> : vector<32x32x32xbf16> to vector<32x32x32xf32>
+  %4 = math.sin %3 : vector<32x32x32xf32>
+  %5 = arith.truncf %4 fastmath<contract> : vector<32x32x32xf32> to vector<32x32x32xbf16>
+  %6 = arith.extf %5 fastmath<contract> : vector<32x32x32xbf16> to vector<32x32x32xf32>
+  %7 = math.fma %5, %arg1, %arg2 : vector<32x32x32xbf16>
+  %8 = arith.extf %7 fastmath<contract> : vector<32x32x32xbf16> to vector<32x32x32xf32>
+  %9 = arith.addf %8, %6 : vector<32x32x32xf32>
+  %10 = arith.truncf %9 fastmath<contract> : vector<32x32x32xf32> to vector<32x32x32xbf16>
+  return %10 : vector<32x32x32xbf16>
+}
+
 {-#
   dialect_resources: {
     builtin: {

diff  --git a/mlir/test/Dialect/Arith/emulate-unsupported-floats.mlir b/mlir/test/Dialect/Arith/emulate-unsupported-floats.mlir
index a69ef131d8d47..99790cc45d490 100644
--- a/mlir/test/Dialect/Arith/emulate-unsupported-floats.mlir
+++ b/mlir/test/Dialect/Arith/emulate-unsupported-floats.mlir
@@ -4,10 +4,10 @@ func.func @basic_expansion(%x: bf16) -> bf16 {
 // CHECK-LABEL: @basic_expansion
 // CHECK-SAME: [[X:%.+]]: bf16
 // CHECK-DAG: [[C:%.+]] = arith.constant {{.*}} : bf16
-// CHECK-DAG: [[X_EXP:%.+]] = arith.extf [[X]] : bf16 to f32
-// CHECK-DAG: [[C_EXP:%.+]] = arith.extf [[C]] : bf16 to f32
+// CHECK-DAG: [[X_EXP:%.+]] = arith.extf [[X]] fastmath<contract> : bf16 to f32
+// CHECK-DAG: [[C_EXP:%.+]] = arith.extf [[C]] fastmath<contract> : bf16 to f32
 // CHECK: [[Y_EXP:%.+]] = arith.addf [[X_EXP]], [[C_EXP]] : f32
-// CHECK: [[Y:%.+]] = arith.truncf [[Y_EXP]] : f32 to bf16
+// CHECK: [[Y:%.+]] = arith.truncf [[Y_EXP]] fastmath<contract> : f32 to bf16
 // CHECK: return [[Y]]
   %c = arith.constant 1.0 : bf16
   %y = arith.addf %x, %c : bf16
@@ -19,15 +19,15 @@ func.func @basic_expansion(%x: bf16) -> bf16 {
 func.func @chained(%x: bf16, %y: bf16, %z: bf16) -> i1 {
 // CHECK-LABEL: @chained
 // CHECK-SAME: [[X:%.+]]: bf16, [[Y:%.+]]: bf16, [[Z:%.+]]: bf16
-// CHECK-DAG: [[X_EXP:%.+]] = arith.extf [[X]] : bf16 to f32
-// CHECK-DAG: [[Y_EXP:%.+]] = arith.extf [[Y]] : bf16 to f32
-// CHECK-DAG: [[Z_EXP:%.+]] = arith.extf [[Z]] : bf16 to f32
+// CHECK-DAG: [[X_EXP:%.+]] = arith.extf [[X]] fastmath<contract> : bf16 to f32
+// CHECK-DAG: [[Y_EXP:%.+]] = arith.extf [[Y]] fastmath<contract> : bf16 to f32
+// CHECK-DAG: [[Z_EXP:%.+]] = arith.extf [[Z]] fastmath<contract> : bf16 to f32
 // CHECK: [[P_EXP:%.+]] = arith.addf [[X_EXP]], [[Y_EXP]] : f32
-// CHECK: [[P:%.+]] = arith.truncf [[P_EXP]] : f32 to bf16
-// CHECK: [[P_EXP2:%.+]] = arith.extf [[P]] : bf16 to f32
+// CHECK: [[P:%.+]] = arith.truncf [[P_EXP]] fastmath<contract> : f32 to bf16
+// CHECK: [[P_EXP2:%.+]] = arith.extf [[P]] fastmath<contract> : bf16 to f32
 // CHECK: [[Q_EXP:%.+]] = arith.mulf [[P_EXP2]], [[Z_EXP]]
-// CHECK: [[Q:%.+]] = arith.truncf [[Q_EXP]] : f32 to bf16
-// CHECK: [[Q_EXP2:%.+]] = arith.extf [[Q]] : bf16 to f32
+// CHECK: [[Q:%.+]] = arith.truncf [[Q_EXP]] fastmath<contract> : f32 to bf16
+// CHECK: [[Q_EXP2:%.+]] = arith.extf [[Q]] fastmath<contract> : bf16 to f32
 // CHECK: [[RES:%.+]] = arith.cmpf ole, [[P_EXP2]], [[Q_EXP2]] : f32
 // CHECK: return [[RES]]
   %p = arith.addf %x, %y : bf16
@@ -41,12 +41,12 @@ func.func @chained(%x: bf16, %y: bf16, %z: bf16) -> i1 {
 func.func @memops(%a: memref<4xf8E4M3FNUZ>, %b: memref<4xf8E4M3FNUZ>) {
 // CHECK-LABEL: @memops
 // CHECK: [[V:%.+]] = memref.load {{.*}} : memref<4xf8E4M3FNUZ>
-// CHECK: [[V_EXP:%.+]] = arith.extf [[V]] : f8E4M3FNUZ to f32
+// CHECK: [[V_EXP:%.+]] = arith.extf [[V]] fastmath<contract> : f8E4M3FNUZ to f32
 // CHECK: memref.store [[V]]
 // CHECK: [[W:%.+]] = memref.load
-// CHECK: [[W_EXP:%.+]] = arith.extf [[W]] : f8E4M3FNUZ to f32
+// CHECK: [[W_EXP:%.+]] = arith.extf [[W]] fastmath<contract> : f8E4M3FNUZ to f32
 // CHECK: [[X_EXP:%.+]] = arith.addf [[V_EXP]], [[W_EXP]] : f32
-// CHECK: [[X:%.+]] = arith.truncf [[X_EXP]] : f32 to f8E4M3FNUZ
+// CHECK: [[X:%.+]] = arith.truncf [[X_EXP]] fastmath<contract> : f32 to f8E4M3FNUZ
 // CHECK: memref.store [[X]]
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index
@@ -63,9 +63,9 @@ func.func @memops(%a: memref<4xf8E4M3FNUZ>, %b: memref<4xf8E4M3FNUZ>) {
 func.func @vectors(%a: vector<4xf8E4M3FNUZ>) -> vector<4xf32> {
 // CHECK-LABEL: @vectors
 // CHECK-SAME: [[A:%.+]]: vector<4xf8E4M3FNUZ>
-// CHECK: [[A_EXP:%.+]] = arith.extf [[A]] : vector<4xf8E4M3FNUZ> to vector<4xf32>
+// CHECK: [[A_EXP:%.+]] = arith.extf [[A]] fastmath<contract> : vector<4xf8E4M3FNUZ> to vector<4xf32>
 // CHECK: [[B_EXP:%.+]] = arith.mulf [[A_EXP]], [[A_EXP]] : vector<4xf32>
-// CHECK: [[B:%.+]] = arith.truncf [[B_EXP]] : vector<4xf32> to vector<4xf8E4M3FNUZ>
+// CHECK: [[B:%.+]] = arith.truncf [[B_EXP]] fastmath<contract> : vector<4xf32> to vector<4xf8E4M3FNUZ>
 // CHECK: [[RET:%.+]] = arith.extf [[B]] : vector<4xf8E4M3FNUZ> to vector<4xf32>
 // CHECK: return [[RET]]
   %b = arith.mulf %a, %a : vector<4xf8E4M3FNUZ>


        


More information about the Mlir-commits mailing list