[Mlir-commits] [mlir] [mlir][Math] Add pass to legalize math functions to f32-or-higher (PR #78361)
Krzysztof Drewniak
llvmlistbot at llvm.org
Wed Jan 17 11:56:47 PST 2024
https://github.com/krzysz00 updated https://github.com/llvm/llvm-project/pull/78361
>From 52703fb5d38dcf3fa6ef7a16e43d9d8520f6d699 Mon Sep 17 00:00:00 2001
From: Krzysztof Drewniak <Krzysztof.Drewniak at amd.com>
Date: Tue, 16 Jan 2024 22:48:50 +0000
Subject: [PATCH 1/3] [mlir][Math] Add pass to legalize math functions to
f32-or-higher
Since most of the operations in the `math` dialect don't have
low-precision implementations, add the -math-legalize-to-f32 pass that
goes through and brackets low-precision math funcitons (like
`math.sin %0 : f16`) with `arith.extf` and `arith.truncf`. This
preserves the original semantics of the math operation but allows
lowering to proceed.
Versions of this lowering are already implicitly present in some
passes, like ConvertGPUToROCDL. However, because those are implicit
rewrites, they hide the floating-point extension and truncation,
preventing anyone from writing passes that operate on those implitic
extf/truncf pairs.
Exposing this legalization explicitly is needed to allow lowening
8-bit floats on AMD GPUs, as the implementation of extf and truncf on
that platform requires the complex logic found in ArithToAMDGPU, which
runs before the GPU to ROCDL lowering.
---
.../mlir/Dialect/Math/Transforms/Passes.h | 10 ++
.../mlir/Dialect/Math/Transforms/Passes.td | 17 +++
.../Dialect/Math/Transforms/CMakeLists.txt | 1 +
.../Dialect/Math/Transforms/LegalizeToF32.cpp | 118 ++++++++++++++++++
mlir/test/Dialect/Math/legalize-to-f32.mlir | 83 ++++++++++++
5 files changed, 229 insertions(+)
create mode 100644 mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
create mode 100644 mlir/test/Dialect/Math/legalize-to-f32.mlir
diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
index 9e6759ef229d6f4..010dde5ea73847d 100644
--- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
@@ -16,12 +16,15 @@ namespace math {
#define GEN_PASS_DECL
#include "mlir/Dialect/Math/Transforms/Passes.h.inc"
#define GEN_PASS_DECL_MATHUPLIFTTOFMA
+#define GEN_PASS_DECL_MATHLEGALIZETOF32
#include "mlir/Dialect/Math/Transforms/Passes.h.inc"
#define GEN_PASS_REGISTRATION
#include "mlir/Dialect/Math/Transforms/Passes.h.inc"
} // namespace math
+class ConversionTarget;
class RewritePatternSet;
+class TypeConverter;
void populateExpandCtlzPattern(RewritePatternSet &patterns);
void populateExpandTanPattern(RewritePatternSet &patterns);
@@ -48,6 +51,13 @@ void populateMathPolynomialApproximationPatterns(
void populateUpliftToFMAPatterns(RewritePatternSet &patterns);
+namespace math {
+void populateLegalizeToF32TypeConverter(TypeConverter &typeConverter);
+void populateLegalizeToF32ConversionTarget(ConversionTarget &target,
+ TypeConverter &typeConverter);
+void populateLegalizeToF32Patterns(RewritePatternSet &patterns,
+ TypeConverter &typeConverter);
+} // namespace math
} // namespace mlir
#endif // MLIR_DIALECT_MATH_TRANSFORMS_PASSES_H_
diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.td b/mlir/include/mlir/Dialect/Math/Transforms/Passes.td
index d81a92b0371e319..e870e714bfda588 100644
--- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.td
@@ -19,4 +19,21 @@ def MathUpliftToFMA : Pass<"math-uplift-to-fma"> {
let dependentDialects = ["math::MathDialect"];
}
+def MathLegalizeToF32 : Pass<"math-legalize-to-f32"> {
+ let summary = "Legalize floating-point math ops on low-precision floats";
+ let description = [{
+ On many targets, the math functions are not implemented for floating-point
+ types less precise than IEEE single-precision (aka f32), such as half-floats,
+ bfloat16, or 8-bit floats.
+
+ This pass explicitly legalizes these math functions by inserting
+ `arith.extf` and `arith.truncf` pairs around said op, which preserves
+ the original semantics while enabling lowering.
+
+ As an exception, this pass does not legalize `math.fma`, because
+ that is an operation frequently implemented at low precisions.
+ }];
+ let dependentDialects = ["math::MathDialect", "arith::ArithDialect"];
+}
+
#endif // MLIR_DIALECT_MATH_TRANSFORMS_PASSES
diff --git a/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt
index 2d446b453edc914..2a5b4fbcb52712e 100644
--- a/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt
@@ -1,6 +1,7 @@
add_mlir_dialect_library(MLIRMathTransforms
AlgebraicSimplification.cpp
ExpandPatterns.cpp
+ LegalizeToF32.cpp
PolynomialApproximation.cpp
UpliftToFMA.cpp
diff --git a/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp b/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
new file mode 100644
index 000000000000000..d281790e877152b
--- /dev/null
+++ b/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
@@ -0,0 +1,118 @@
+//===- LegalizeToF32.cpp - Legalize functions on small floats ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements legalizing math operations on small floating-point
+// types through arith.extf and arith.truncf.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/Math/Transforms/Passes.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeUtilities.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/ADT/STLExtras.h"
+
+namespace mlir::math {
+#define GEN_PASS_DEF_MATHLEGALIZETOF32
+#include "mlir/Dialect/Math/Transforms/Passes.h.inc"
+} // namespace mlir::math
+
+using namespace mlir;
+namespace {
+struct LegalizeToF32RewritePattern final : ConversionPattern {
+ LegalizeToF32RewritePattern(TypeConverter &converter, MLIRContext *context)
+ : ConversionPattern(converter, MatchAnyOpTypeTag{}, 1, context) {}
+ LogicalResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override;
+};
+
+struct LegalizeToF32Pass final
+ : mlir::math::impl::MathLegalizeToF32Base<LegalizeToF32Pass> {
+ void runOnOperation() override;
+};
+} // namespace
+
+void mlir::math::populateLegalizeToF32TypeConverter(
+ TypeConverter &typeConverter) {
+ typeConverter.addConversion(
+ [](Type type) -> std::optional<Type> { return type; });
+ typeConverter.addConversion([](FloatType type) -> std::optional<Type> {
+ if (type.getWidth() < 32)
+ return Float32Type::get(type.getContext());
+ return std::nullopt;
+ });
+ typeConverter.addConversion([](ShapedType type) -> std::optional<Type> {
+ if (auto elemTy = dyn_cast<FloatType>(type.getElementType()))
+ return type.clone(Float32Type::get(type.getContext()));
+ return std::nullopt;
+ });
+ typeConverter.addTargetMaterialization(
+ [](OpBuilder &b, Type target, ValueRange input, Location loc) {
+ return b.create<arith::ExtFOp>(loc, target, input);
+ });
+}
+
+void mlir::math::populateLegalizeToF32ConversionTarget(
+ ConversionTarget &target, TypeConverter &typeConverter) {
+ target.addDynamicallyLegalDialect<MathDialect>(
+ [&typeConverter](Operation *op) -> bool {
+ return typeConverter.isLegal(op);
+ });
+ target.addLegalOp<FmaOp>();
+ target.addLegalOp<arith::ExtFOp, arith::TruncFOp>();
+}
+
+LogicalResult LegalizeToF32RewritePattern::matchAndRewrite(
+ Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const {
+ Location loc = op->getLoc();
+ const TypeConverter *converter = getTypeConverter();
+ if (converter->isLegal(op))
+ return rewriter.notifyMatchFailure(loc, "op already legal");
+ OperationState newOp(loc, op->getName());
+ newOp.addOperands(operands);
+
+ SmallVector<Type> newResultTypes;
+ if (failed(converter->convertTypes(op->getResultTypes(), newResultTypes)))
+ return rewriter.notifyMatchFailure(loc, "couldn't convert return types");
+ newOp.addTypes(newResultTypes);
+ newOp.addAttributes(op->getAttrs());
+ Operation *legalized = rewriter.create(newOp);
+ SmallVector<Value> results = legalized->getResults();
+ for (auto [result, newType, origType] :
+ llvm::zip_equal(results, newResultTypes, op->getResultTypes())) {
+ if (newType != origType)
+ result = rewriter.create<arith::TruncFOp>(loc, origType, result);
+ }
+ rewriter.replaceOp(op, results);
+ return success();
+}
+
+void mlir::math::populateLegalizeToF32Patterns(RewritePatternSet &patterns,
+ TypeConverter &typeConverter) {
+ patterns.add<LegalizeToF32RewritePattern>(typeConverter,
+ patterns.getContext());
+}
+
+void LegalizeToF32Pass::runOnOperation() {
+ Operation *op = getOperation();
+ MLIRContext &ctx = getContext();
+
+ TypeConverter typeConverter;
+ math::populateLegalizeToF32TypeConverter(typeConverter);
+ ConversionTarget target(ctx);
+ math::populateLegalizeToF32ConversionTarget(target, typeConverter);
+ RewritePatternSet patterns(&ctx);
+ math::populateLegalizeToF32Patterns(patterns, typeConverter);
+ if (failed(applyPartialConversion(op, target, std::move(patterns))))
+ return signalPassFailure();
+}
diff --git a/mlir/test/Dialect/Math/legalize-to-f32.mlir b/mlir/test/Dialect/Math/legalize-to-f32.mlir
new file mode 100644
index 000000000000000..3f648c9379955b9
--- /dev/null
+++ b/mlir/test/Dialect/Math/legalize-to-f32.mlir
@@ -0,0 +1,83 @@
+// RUN: mlir-opt %s --split-input-file -math-legalize-to-f32 | FileCheck %s
+
+// CHECK-LABEL: @sin
+// CHECK-SAME: ([[ARG0:%.+]]: f16)
+func.func @sin(%arg0: f16) -> f16 {
+ // CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
+ // CHECK: [[SIN:%.+]] = math.sin [[EXTF]]
+ // CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]]
+ // CHECK: return [[TRUNCF]] : f16
+ %0 = math.sin %arg0 : f16
+ return %0 : f16
+}
+
+// CHECK-LABEL: @fpowi
+// CHECK-SAME: ([[ARG0:%.+]]: f16, [[ARG1:%.+]]: i32)
+func.func @fpowi(%arg0: f16, %arg1: i32) -> f16 {
+ // CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
+ // CHECK: [[FPOWI:%.+]] = math.fpowi [[EXTF]], [[ARG1]]
+ // CHECK: [[TRUNCF:%.+]] = arith.truncf [[FPOWI]]
+ // CHECK: return [[TRUNCF]] : f16
+ %0 = math.fpowi %arg0, %arg1 : f16, i32
+ return %0 : f16
+}
+
+// CHECK-LABEL: @fma
+// CHECK-SAME: ([[ARG0:%.+]]: f16, [[ARG1:%.+]]: f16, [[ARG2:%.+]]: f16)
+// CHECK: [[FMA:%.+]] = math.fma [[ARG0]], [[ARG1]], [[ARG2]]
+// CHECK: return [[FMA]] : f16
+func.func @fma(%arg0: f16, %arg1: f16, %arg2: f16) -> f16 {
+ %0 = math.fma %arg0, %arg1, %arg2 : f16
+ return %0 : f16
+}
+
+// CHECK-LABEL: @absf_f32
+// CHECK-SAME: ([[ARG0:%.+]]: f32)
+// CHECK: [[ABSF:%.+]] = math.absf [[ARG0]]
+// CHECK: return [[ABSF]] : f32
+func.func @absf_f32(%arg0: f32) -> f32 {
+ %0 = math.absf %arg0 : f32
+ return %0 : f32
+}
+
+// CHECK-LABEL: @absf_f64
+// CHECK-SAME: ([[ARG0:%.+]]: f64)
+// CHECK: [[ABSF:%.+]] = math.absf [[ARG0]]
+// CHECK: return [[ABSF]] : f64
+func.func @absf_f64(%arg0: f64) -> f64 {
+ %0 = math.absf %arg0 : f64
+ return %0 : f64
+}
+
+// CHECK-LABEL: @sin_vector
+// CHECK-SAME: ([[ARG0:%.+]]: vector<2xbf16>)
+// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
+// CHECK: [[SIN:%.+]] = math.sin [[EXTF]]
+// CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]]
+// CHECK: return [[TRUNCF]] : vector<2xbf16>
+func.func @sin_vector(%arg0: vector<2xbf16>) -> vector<2xbf16> {
+ %0 = math.sin %arg0 : vector<2xbf16>
+ return %0 : vector<2xbf16>
+}
+
+// CHECK-LABEL: @fastmath
+// CHECK: math.sin %{{.+}} fastmath<nsz>
+func.func @fastmath(%arg0: f16) -> f16 {
+ %0 = math.sin %arg0 fastmath<nsz> : f16
+ return %0 : f16
+}
+
+// CHECK-LABEL: @sequences
+// CHECK-SAME: ([[ARG0:%.+]]: f16)
+// CHECK: [[EXTF0:%.+]] = arith.extf [[ARG0]]
+// CHECK: [[ABSF:%.+]] = math.absf [[EXTF0]]
+// CHECK: [[TRUNCF0:%.+]] = arith.truncf [[ABSF]]
+// CHECK: [[EXTF1:%.+]] = arith.extf [[TRUNCF0]]
+// CHECK: [[SIN:%.+]] = math.sin [[EXTF1]]
+// CHECK: [[TRUNCF1:%.+]] = arith.truncf [[SIN]]
+// CHECK: return [[TRUNCF1]] : f16
+func.func @sequences(%arg0: f16) -> f16 {
+ %0 = math.absf %arg0 : f16
+ %1 = math.sin %0 : f16
+ return %1 : f16
+}
>From c807c262cce2ef5ceb15d9d9d16d26b8894c7906 Mon Sep 17 00:00:00 2001
From: Krzysztof Drewniak <Krzysztof.Drewniak at amd.com>
Date: Wed, 17 Jan 2024 18:09:48 +0000
Subject: [PATCH 2/3] Mutate the operation in place instead of futzing with
OperationState
---
.../Dialect/Math/Transforms/LegalizeToF32.cpp | 33 +++++++++++--------
1 file changed, 20 insertions(+), 13 deletions(-)
diff --git a/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp b/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
index d281790e877152b..ba9759cc4a53ae9 100644
--- a/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
@@ -78,22 +78,29 @@ LogicalResult LegalizeToF32RewritePattern::matchAndRewrite(
const TypeConverter *converter = getTypeConverter();
if (converter->isLegal(op))
return rewriter.notifyMatchFailure(loc, "op already legal");
- OperationState newOp(loc, op->getName());
- newOp.addOperands(operands);
-
SmallVector<Type> newResultTypes;
if (failed(converter->convertTypes(op->getResultTypes(), newResultTypes)))
return rewriter.notifyMatchFailure(loc, "couldn't convert return types");
- newOp.addTypes(newResultTypes);
- newOp.addAttributes(op->getAttrs());
- Operation *legalized = rewriter.create(newOp);
- SmallVector<Value> results = legalized->getResults();
- for (auto [result, newType, origType] :
- llvm::zip_equal(results, newResultTypes, op->getResultTypes())) {
- if (newType != origType)
- result = rewriter.create<arith::TruncFOp>(loc, origType, result);
- }
- rewriter.replaceOp(op, results);
+
+ ConversionPatternRewriter::InsertionGuard guard(rewriter);
+ // Truncations will be created that need to come after the math op, not
+ // before.
+ rewriter.setInsertionPointAfter(op);
+ rewriter.updateRootInPlace(op, [&]() {
+ op->setOperands(operands);
+ for (auto [result, newType] :
+ llvm::zip_equal(op->getResults(), newResultTypes)) {
+ Type oldType = result.getType();
+ if (oldType == newType)
+ continue;
+ result.setType(newType);
+ auto truncOp = rewriter.create<arith::TruncFOp>(loc, oldType, result);
+ // Intintionally don't tell the rewriter we're doing this to prevent
+ // spurious attempts to legalize the consumer, which can lead to things
+ // like running `func.return` through the pattern.
+ result.replaceAllUsesExcept(truncOp.getResult(), truncOp);
+ }
+ });
return success();
}
>From b5ab5063228de1051c4205afd9236a71c6b98f1e Mon Sep 17 00:00:00 2001
From: Krzysztof Drewniak <Krzysztof.Drewniak at amd.com>
Date: Wed, 17 Jan 2024 19:56:09 +0000
Subject: [PATCH 3/3] Revert "Mutate the operation in place instead of futzing
with OperationState"
This reverts commit c807c262cce2ef5ceb15d9d9d16d26b8894c7906.
Turns our we're doing the old thing
---
.../Dialect/Math/Transforms/LegalizeToF32.cpp | 33 ++++++++-----------
1 file changed, 13 insertions(+), 20 deletions(-)
diff --git a/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp b/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
index ba9759cc4a53ae9..d281790e877152b 100644
--- a/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
@@ -78,29 +78,22 @@ LogicalResult LegalizeToF32RewritePattern::matchAndRewrite(
const TypeConverter *converter = getTypeConverter();
if (converter->isLegal(op))
return rewriter.notifyMatchFailure(loc, "op already legal");
+ OperationState newOp(loc, op->getName());
+ newOp.addOperands(operands);
+
SmallVector<Type> newResultTypes;
if (failed(converter->convertTypes(op->getResultTypes(), newResultTypes)))
return rewriter.notifyMatchFailure(loc, "couldn't convert return types");
-
- ConversionPatternRewriter::InsertionGuard guard(rewriter);
- // Truncations will be created that need to come after the math op, not
- // before.
- rewriter.setInsertionPointAfter(op);
- rewriter.updateRootInPlace(op, [&]() {
- op->setOperands(operands);
- for (auto [result, newType] :
- llvm::zip_equal(op->getResults(), newResultTypes)) {
- Type oldType = result.getType();
- if (oldType == newType)
- continue;
- result.setType(newType);
- auto truncOp = rewriter.create<arith::TruncFOp>(loc, oldType, result);
- // Intintionally don't tell the rewriter we're doing this to prevent
- // spurious attempts to legalize the consumer, which can lead to things
- // like running `func.return` through the pattern.
- result.replaceAllUsesExcept(truncOp.getResult(), truncOp);
- }
- });
+ newOp.addTypes(newResultTypes);
+ newOp.addAttributes(op->getAttrs());
+ Operation *legalized = rewriter.create(newOp);
+ SmallVector<Value> results = legalized->getResults();
+ for (auto [result, newType, origType] :
+ llvm::zip_equal(results, newResultTypes, op->getResultTypes())) {
+ if (newType != origType)
+ result = rewriter.create<arith::TruncFOp>(loc, origType, result);
+ }
+ rewriter.replaceOp(op, results);
return success();
}
More information about the Mlir-commits
mailing list