[Mlir-commits] [mlir] 10b56e0 - [mlir][Arith] Add pass for emulating unsupported float ops (#1079)

Krzysztof Drewniak llvmlistbot at llvm.org
Tue Jul 11 13:32:39 PDT 2023


Author: Krzysztof Drewniak
Date: 2023-07-11T20:32:35Z
New Revision: 10b56e0210bf615519570f561acbeb916db032f4

URL: https://github.com/llvm/llvm-project/commit/10b56e0210bf615519570f561acbeb916db032f4
DIFF: https://github.com/llvm/llvm-project/commit/10b56e0210bf615519570f561acbeb916db032f4.diff

LOG: [mlir][Arith] Add pass for emulating unsupported float ops (#1079)

To complement the bf16 expansion and truncation patterns added to
ExpandOps, define a pass that replaces, for any arithmetic operation
op,
%y = arith.op %v0, %v1, ... : T
with
%e0 = arith.expf %v0 : T to U
%e1 = arith.expf %v1 : T to U
...
%y.exp = arith.op %e0, %e1, ... : U
%y = arith.truncf %y.exp : U to T

This allows for "emulating" floating-point operations not supported on
a given target (such as bfloat operations or most arithmetic on 8-bit
floats) by extending those types to supported ones, performing the
arithmetic operation, and then truncating back to the original
type (which ensures appropriate rounding behavior).

The lowering of the extf and truncf ops introduced by this
transformation should be handled by subsequent passes.

Reviewed By: rsuderman

Differential Revision: https://reviews.llvm.org/D154539

Added: 
    mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
    mlir/test/Dialect/Arith/emulate-unsupported-floats.mlir

Modified: 
    mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
    mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
    mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
index de36cb48e6d024..2c2353bdc81027 100644
--- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
@@ -13,6 +13,8 @@
 
 namespace mlir {
 class DataFlowSolver;
+class ConversionTarget;
+class TypeConverter;
 
 namespace arith {
 
@@ -42,6 +44,21 @@ void populateArithWideIntEmulationPatterns(
 void populateArithNarrowTypeEmulationPatterns(
     NarrowTypeEmulationConverter &typeConverter, RewritePatternSet &patterns);
 
+/// Populate the type conversions needed to emulate the unsupported
+/// `sourceTypes` with `destType`
+void populateEmulateUnsupportedFloatsConversions(TypeConverter &converter,
+                                                 ArrayRef<Type> sourceTypes,
+                                                 Type targetType);
+
+/// Add rewrite patterns for converting operations that use illegal float types
+/// to ones that use legal ones.
+void populateEmulateUnsupportedFloatsPatterns(RewritePatternSet &patterns,
+                                              TypeConverter &converter);
+
+/// Set up a dialect conversion to reject arithmetic operations on unsupported
+/// float types.
+void populateEmulateUnsupportedFloatsLegality(ConversionTarget &target,
+                                              TypeConverter &converter);
 /// Add patterns to expand Arith ceil/floor division ops.
 void populateCeilFloorDivExpandOpsPatterns(RewritePatternSet &patterns);
 

diff  --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
index 50b748435afa9b..77575b0b5a57df 100644
--- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
@@ -63,6 +63,28 @@ def ArithIntRangeOpts : Pass<"int-range-optimizations"> {
   }];
 }
 
+def ArithEmulateUnsupportedFloats : Pass<"arith-emulate-unsupported-floats"> {
+  let summary = "Emulate operations on unsupported floats with extf/truncf";
+  let description = [{
+    Emulate arith and vector floating point operations that use float types
+    which are unspported on a target by inserting extf/truncf pairs around all
+    such operations in order to produce arithmetic that can be performed while
+    preserving the original rounding behavior.
+
+    This pass does not attempt to reason about the operations being performed
+    to determine when type conversions can be elided.
+  }];
+
+  let options = [
+    ListOption<"sourceTypeStrs", "source-types", "std::string",
+      "MLIR types without arithmetic support on a given target">,
+    Option<"targetTypeStr", "target-type", "std::string", "\"f32\"",
+      "MLIR type to convert the unsupported source types to">,
+  ];
+
+  let dependentDialects = ["vector::VectorDialect"];
+}
+
 def ArithEmulateWideInt : Pass<"arith-emulate-wide-int"> {
   let summary = "Emulate 2*N-bit integer operations using N-bit operations";
   let description = [{

diff  --git a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
index b969389f223995..a9b86b4d99256c 100644
--- a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
@@ -1,6 +1,7 @@
 add_mlir_dialect_library(MLIRArithTransforms
   BufferizableOpInterfaceImpl.cpp
   Bufferize.cpp
+  EmulateUnsupportedFloats.cpp
   EmulateWideInt.cpp
   EmulateNarrowType.cpp
   ExpandOps.cpp

diff  --git a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
new file mode 100644
index 00000000000000..e3cfe9813171bc
--- /dev/null
+++ b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
@@ -0,0 +1,184 @@
+//===- EmulateUnsupportedFloats.cpp - Promote small floats --*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+// This pass promotes small floats (of some unsupported types T) to a supported
+// type U by wrapping all float operations on Ts with expansion to and
+// truncation from U, then operating on U.
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/Transforms/Passes.h"
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/ADT/STLExtras.h"
+#include <optional>
+
+namespace mlir::arith {
+#define GEN_PASS_DEF_ARITHEMULATEUNSUPPORTEDFLOATS
+#include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
+} // namespace mlir::arith
+
+using namespace mlir;
+
+namespace {
+struct EmulateUnsupportedFloatsPass
+    : arith::impl::ArithEmulateUnsupportedFloatsBase<
+          EmulateUnsupportedFloatsPass> {
+  using arith::impl::ArithEmulateUnsupportedFloatsBase<
+      EmulateUnsupportedFloatsPass>::ArithEmulateUnsupportedFloatsBase;
+
+  void runOnOperation() override;
+};
+
+struct EmulateFloatPattern final : ConversionPattern {
+  EmulateFloatPattern(TypeConverter &converter, MLIRContext *ctx)
+      : ConversionPattern(converter, Pattern::MatchAnyOpTypeTag(), 1, ctx) {}
+
+  LogicalResult match(Operation *op) const override;
+  void rewrite(Operation *op, ArrayRef<Value> operands,
+               ConversionPatternRewriter &rewriter) const override;
+};
+} // end namespace
+
+/// Map strings to float types. This function is here because no one else needs
+/// it yet, feel free to abstract it out.
+static std::optional<FloatType> parseFloatType(MLIRContext *ctx,
+                                               StringRef name) {
+  Builder b(ctx);
+  return llvm::StringSwitch<std::optional<FloatType>>(name)
+      .Case("f8E5M2", b.getFloat8E5M2Type())
+      .Case("f8E4M3FN", b.getFloat8E4M3FNType())
+      .Case("f8E5M2FNUZ", b.getFloat8E5M2FNUZType())
+      .Case("f8E4M3FNUZ", b.getFloat8E4M3FNUZType())
+      .Case("bf16", b.getBF16Type())
+      .Case("f16", b.getF16Type())
+      .Case("f32", b.getF32Type())
+      .Case("f64", b.getF64Type())
+      .Case("f80", b.getF80Type())
+      .Case("f128", b.getF128Type())
+      .Default(std::nullopt);
+}
+
+LogicalResult EmulateFloatPattern::match(Operation *op) const {
+  if (getTypeConverter()->isLegal(op))
+    return failure();
+  // The rewrite doesn't handle cloning regions.
+  if (op->getNumRegions() != 0)
+    return failure();
+  return success();
+}
+
+void EmulateFloatPattern::rewrite(Operation *op, ArrayRef<Value> operands,
+                                  ConversionPatternRewriter &rewriter) const {
+  Location loc = op->getLoc();
+  TypeConverter *converter = getTypeConverter();
+  SmallVector<Type> resultTypes;
+  assert(
+      succeeded(converter->convertTypes(op->getResultTypes(), resultTypes)) &&
+      "type conversions shouldn't fail in this pass");
+  Operation *expandedOp =
+      rewriter.create(loc, op->getName().getIdentifier(), operands, resultTypes,
+                      op->getAttrs(), op->getSuccessors(), /*regions=*/{});
+  SmallVector<Value> newResults(expandedOp->getResults());
+  for (auto [res, oldType, newType] : llvm::zip_equal(
+           MutableArrayRef{newResults}, op->getResultTypes(), resultTypes)) {
+    if (oldType != newType)
+      res = rewriter.create<arith::TruncFOp>(loc, oldType, res);
+  }
+  rewriter.replaceOp(op, newResults);
+}
+
+void mlir::arith::populateEmulateUnsupportedFloatsConversions(
+    TypeConverter &converter, ArrayRef<Type> sourceTypes, Type targetType) {
+  converter.addConversion([sourceTypes = SmallVector<Type>(sourceTypes),
+                           targetType](Type type) -> std::optional<Type> {
+    if (llvm::is_contained(sourceTypes, type))
+      return targetType;
+    if (auto shaped = type.dyn_cast<ShapedType>())
+      if (llvm::is_contained(sourceTypes, shaped.getElementType()))
+        return shaped.clone(targetType);
+    // All other types legal
+    return type;
+  });
+  converter.addTargetMaterialization(
+      [](OpBuilder &b, Type target, ValueRange input, Location loc) {
+        return b.create<arith::ExtFOp>(loc, target, input);
+      });
+}
+
+void mlir::arith::populateEmulateUnsupportedFloatsPatterns(
+    RewritePatternSet &patterns, TypeConverter &converter) {
+  patterns.add<EmulateFloatPattern>(converter, patterns.getContext());
+}
+
+void mlir::arith::populateEmulateUnsupportedFloatsLegality(
+    ConversionTarget &target, TypeConverter &converter) {
+  // Don't try to legalize functions and other ops that don't need expansion.
+  target.markUnknownOpDynamicallyLegal([](Operation *op) { return true; });
+  target.addDynamicallyLegalDialect<arith::ArithDialect>(
+      [&](Operation *op) -> std::optional<bool> {
+        return converter.isLegal(op);
+      });
+  // Manually mark arithmetic-performing vector instructions.
+  target.addDynamicallyLegalOp<
+      vector::ContractionOp, vector::ReductionOp, vector::MultiDimReductionOp,
+      vector::FMAOp, vector::OuterProductOp, vector::MatmulOp, vector::ScanOp>(
+      [&](Operation *op) { return converter.isLegal(op); });
+  target.addLegalOp<arith::ExtFOp, arith::TruncFOp, arith::ConstantOp,
+                    vector::SplatOp>();
+}
+
+void EmulateUnsupportedFloatsPass::runOnOperation() {
+  MLIRContext *ctx = &getContext();
+  Operation *op = getOperation();
+  SmallVector<Type> sourceTypes;
+  Type targetType;
+
+  std::optional<FloatType> maybeTargetType = parseFloatType(ctx, targetTypeStr);
+  if (!maybeTargetType) {
+    emitError(UnknownLoc::get(ctx), "could not map target type '" +
+                                        targetTypeStr +
+                                        "' to a known floating-point type");
+    return signalPassFailure();
+  }
+  targetType = *maybeTargetType;
+  for (StringRef sourceTypeStr : sourceTypeStrs) {
+    std::optional<FloatType> maybeSourceType =
+        parseFloatType(ctx, sourceTypeStr);
+    if (!maybeSourceType) {
+      emitError(UnknownLoc::get(ctx), "could not map source type '" +
+                                          sourceTypeStr +
+                                          "' to a known floating-point type");
+      return signalPassFailure();
+    }
+    sourceTypes.push_back(*maybeSourceType);
+  }
+  if (sourceTypes.empty())
+    (void)emitOptionalWarning(
+        std::nullopt,
+        "no source types specified, float emulation will do nothing");
+
+  if (llvm::is_contained(sourceTypes, targetType)) {
+    emitError(UnknownLoc::get(ctx),
+              "target type cannot be an unsupported source type");
+    return signalPassFailure();
+  }
+  TypeConverter converter;
+  arith::populateEmulateUnsupportedFloatsConversions(converter, sourceTypes,
+                                                     targetType);
+  RewritePatternSet patterns(ctx);
+  arith::populateEmulateUnsupportedFloatsPatterns(patterns, converter);
+  ConversionTarget target(getContext());
+  arith::populateEmulateUnsupportedFloatsLegality(target, converter);
+
+  if (failed(applyPartialConversion(op, target, std::move(patterns))))
+    signalPassFailure();
+}

diff  --git a/mlir/test/Dialect/Arith/emulate-unsupported-floats.mlir b/mlir/test/Dialect/Arith/emulate-unsupported-floats.mlir
new file mode 100644
index 00000000000000..a69ef131d8d47f
--- /dev/null
+++ b/mlir/test/Dialect/Arith/emulate-unsupported-floats.mlir
@@ -0,0 +1,87 @@
+// RUN: mlir-opt --split-input-file --arith-emulate-unsupported-floats="source-types=bf16,f8E4M3FNUZ target-type=f32" %s | FileCheck %s
+
+func.func @basic_expansion(%x: bf16) -> bf16 {
+// CHECK-LABEL: @basic_expansion
+// CHECK-SAME: [[X:%.+]]: bf16
+// CHECK-DAG: [[C:%.+]] = arith.constant {{.*}} : bf16
+// CHECK-DAG: [[X_EXP:%.+]] = arith.extf [[X]] : bf16 to f32
+// CHECK-DAG: [[C_EXP:%.+]] = arith.extf [[C]] : bf16 to f32
+// CHECK: [[Y_EXP:%.+]] = arith.addf [[X_EXP]], [[C_EXP]] : f32
+// CHECK: [[Y:%.+]] = arith.truncf [[Y_EXP]] : f32 to bf16
+// CHECK: return [[Y]]
+  %c = arith.constant 1.0 : bf16
+  %y = arith.addf %x, %c : bf16
+  func.return %y : bf16
+}
+
+// -----
+
+func.func @chained(%x: bf16, %y: bf16, %z: bf16) -> i1 {
+// CHECK-LABEL: @chained
+// CHECK-SAME: [[X:%.+]]: bf16, [[Y:%.+]]: bf16, [[Z:%.+]]: bf16
+// CHECK-DAG: [[X_EXP:%.+]] = arith.extf [[X]] : bf16 to f32
+// CHECK-DAG: [[Y_EXP:%.+]] = arith.extf [[Y]] : bf16 to f32
+// CHECK-DAG: [[Z_EXP:%.+]] = arith.extf [[Z]] : bf16 to f32
+// CHECK: [[P_EXP:%.+]] = arith.addf [[X_EXP]], [[Y_EXP]] : f32
+// CHECK: [[P:%.+]] = arith.truncf [[P_EXP]] : f32 to bf16
+// CHECK: [[P_EXP2:%.+]] = arith.extf [[P]] : bf16 to f32
+// CHECK: [[Q_EXP:%.+]] = arith.mulf [[P_EXP2]], [[Z_EXP]]
+// CHECK: [[Q:%.+]] = arith.truncf [[Q_EXP]] : f32 to bf16
+// CHECK: [[Q_EXP2:%.+]] = arith.extf [[Q]] : bf16 to f32
+// CHECK: [[RES:%.+]] = arith.cmpf ole, [[P_EXP2]], [[Q_EXP2]] : f32
+// CHECK: return [[RES]]
+  %p = arith.addf %x, %y : bf16
+  %q = arith.mulf %p, %z : bf16
+  %res = arith.cmpf ole, %p, %q : bf16
+  func.return %res : i1
+}
+
+// -----
+
+func.func @memops(%a: memref<4xf8E4M3FNUZ>, %b: memref<4xf8E4M3FNUZ>) {
+// CHECK-LABEL: @memops
+// CHECK: [[V:%.+]] = memref.load {{.*}} : memref<4xf8E4M3FNUZ>
+// CHECK: [[V_EXP:%.+]] = arith.extf [[V]] : f8E4M3FNUZ to f32
+// CHECK: memref.store [[V]]
+// CHECK: [[W:%.+]] = memref.load
+// CHECK: [[W_EXP:%.+]] = arith.extf [[W]] : f8E4M3FNUZ to f32
+// CHECK: [[X_EXP:%.+]] = arith.addf [[V_EXP]], [[W_EXP]] : f32
+// CHECK: [[X:%.+]] = arith.truncf [[X_EXP]] : f32 to f8E4M3FNUZ
+// CHECK: memref.store [[X]]
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %v = memref.load %a[%c0] : memref<4xf8E4M3FNUZ>
+  memref.store %v, %b[%c0] : memref<4xf8E4M3FNUZ>
+  %w = memref.load %a[%c1] : memref<4xf8E4M3FNUZ>
+  %x = arith.addf %v, %w : f8E4M3FNUZ
+  memref.store %x, %b[%c1] : memref<4xf8E4M3FNUZ>
+  func.return
+}
+
+// -----
+
+func.func @vectors(%a: vector<4xf8E4M3FNUZ>) -> vector<4xf32> {
+// CHECK-LABEL: @vectors
+// CHECK-SAME: [[A:%.+]]: vector<4xf8E4M3FNUZ>
+// CHECK: [[A_EXP:%.+]] = arith.extf [[A]] : vector<4xf8E4M3FNUZ> to vector<4xf32>
+// CHECK: [[B_EXP:%.+]] = arith.mulf [[A_EXP]], [[A_EXP]] : vector<4xf32>
+// CHECK: [[B:%.+]] = arith.truncf [[B_EXP]] : vector<4xf32> to vector<4xf8E4M3FNUZ>
+// CHECK: [[RET:%.+]] = arith.extf [[B]] : vector<4xf8E4M3FNUZ> to vector<4xf32>
+// CHECK: return [[RET]]
+  %b = arith.mulf %a, %a : vector<4xf8E4M3FNUZ>
+  %ret = arith.extf %b : vector<4xf8E4M3FNUZ> to vector<4xf32>
+  func.return %ret : vector<4xf32>
+}
+
+// -----
+
+func.func @no_expansion(%x: f32) -> f32 {
+// CHECK-LABEL: @no_expansion
+// CHECK-SAME: [[X:%.+]]: f32
+// CHECK-DAG: [[C:%.+]] = arith.constant {{.*}} : f32
+// CHECK: [[Y:%.+]] = arith.addf [[X]], [[C]] : f32
+// CHECK: return [[Y]]
+  %c = arith.constant 1.0 : f32
+  %y = arith.addf %x, %c : f32
+  func.return %y : f32
+}


        


More information about the Mlir-commits mailing list