[Mlir-commits] [mlir] d94426d - [mlir] Math: add algebraic simplification patterns to math transforms
Eugene Zhulenev
llvmlistbot at llvm.org
Tue Jul 27 09:22:40 PDT 2021
Author: Eugene Zhulenev
Date: 2021-07-27T09:22:33-07:00
New Revision: d94426d22a25559f25fd86276d7e9aefbd9d05ab
URL: https://github.com/llvm/llvm-project/commit/d94426d22a25559f25fd86276d7e9aefbd9d05ab
DIFF: https://github.com/llvm/llvm-project/commit/d94426d22a25559f25fd86276d7e9aefbd9d05ab.diff
LOG: [mlir] Math: add algebraic simplification patterns to math transforms
Reviewed By: bkramer
Differential Revision: https://reviews.llvm.org/D106822
Added:
mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
mlir/test/Dialect/Math/algebraic-simplification.mlir
mlir/test/lib/Dialect/Math/TestAlgebraicSimplification.cpp
Modified:
mlir/include/mlir/Dialect/Math/Transforms/Passes.h
mlir/lib/Dialect/Math/Transforms/CMakeLists.txt
mlir/test/lib/Dialect/Math/CMakeLists.txt
mlir/tools/mlir-opt/mlir-opt.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
index 10635667a5fcc..4378f177fa0a1 100644
--- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
@@ -15,6 +15,8 @@ class RewritePatternSet;
void populateExpandTanhPattern(RewritePatternSet &patterns);
+void populateMathAlgebraicSimplificationPatterns(RewritePatternSet &patterns);
+
void populateMathPolynomialApproximationPatterns(RewritePatternSet &patterns);
} // namespace mlir
diff --git a/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp b/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
new file mode 100644
index 0000000000000..2614fc7cf2f73
--- /dev/null
+++ b/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
@@ -0,0 +1,112 @@
+//===- AlgebraicSimplification.cpp - Simplify algebraic expressions -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements rewrites based on the basic rules of algebra
+// (Commutativity, associativity, etc...) and strength reductions for math
+// operations.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/Math/Transforms/Passes.h"
+#include "mlir/Dialect/Vector/VectorOps.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/TypeUtilities.h"
+#include <climits>
+
+using namespace mlir;
+
+//----------------------------------------------------------------------------//
+// PowFOp strength reduction.
+//----------------------------------------------------------------------------//
+
+namespace {
+struct PowFStrengthReduction : public OpRewritePattern<math::PowFOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(math::PowFOp op,
+ PatternRewriter &rewriter) const final;
+};
+} // namespace
+
+LogicalResult
+PowFStrengthReduction::matchAndRewrite(math::PowFOp op,
+ PatternRewriter &rewriter) const {
+ Location loc = op.getLoc();
+ Value x = op.lhs();
+
+ FloatAttr scalarExponent;
+ DenseFPElementsAttr vectorExponent;
+
+ bool isScalar = matchPattern(op.rhs(), m_Constant(&scalarExponent));
+ bool isVector = matchPattern(op.rhs(), m_Constant(&vectorExponent));
+
+ // Returns true if exponent is a constant equal to `value`.
+ auto isExponentValue = [&](double value) -> bool {
+ if (isScalar)
+ return scalarExponent.getValue().isExactlyValue(value);
+
+ if (isVector && vectorExponent.isSplat())
+ return vectorExponent.getSplatValue<FloatAttr>()
+ .getValue()
+ .isExactlyValue(value);
+
+ return false;
+ };
+
+ // Maybe broadcasts scalar value into vector type compatible with `op`.
+ auto bcast = [&](Value value) -> Value {
+ if (auto vec = op.getType().dyn_cast<VectorType>())
+ return rewriter.create<vector::BroadcastOp>(op.getLoc(), vec, value);
+ return value;
+ };
+
+ // Replace `pow(x, 1.0)` with `x`.
+ if (isExponentValue(1.0)) {
+ rewriter.replaceOp(op, x);
+ return success();
+ }
+
+ // Replace `pow(x, 2.0)` with `x * x`.
+ if (isExponentValue(2.0)) {
+ rewriter.replaceOpWithNewOp<MulFOp>(op, ValueRange({x, x}));
+ return success();
+ }
+
+ // Replace `pow(x, 2.0)` with `x * x * x`.
+ if (isExponentValue(3.0)) {
+ Value square = rewriter.create<MulFOp>(op.getLoc(), ValueRange({x, x}));
+ rewriter.replaceOpWithNewOp<MulFOp>(op, ValueRange({x, square}));
+ return success();
+ }
+
+ // Replace `pow(x, -1.0)` with `1.0 / x`.
+ if (isExponentValue(-1.0)) {
+ Value one = rewriter.create<ConstantOp>(
+ loc, rewriter.getFloatAttr(getElementTypeOrSelf(op.getType()), 1.0));
+ rewriter.replaceOpWithNewOp<DivFOp>(op, ValueRange({bcast(one), x}));
+ return success();
+ }
+
+ // Replace `pow(x, -2.0)` with `sqrt(x)`.
+ if (isExponentValue(-1.0)) {
+ rewriter.replaceOpWithNewOp<math::SqrtOp>(op, x);
+ return success();
+ }
+
+ return failure();
+}
+
+//----------------------------------------------------------------------------//
+
+void mlir::populateMathAlgebraicSimplificationPatterns(
+ RewritePatternSet &patterns) {
+ patterns.add<PowFStrengthReduction>(patterns.getContext());
+}
diff --git a/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt
index 1eea7f3b61562..6eece6e0f7b3b 100644
--- a/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt
@@ -1,4 +1,5 @@
add_mlir_dialect_library(MLIRMathTransforms
+ AlgebraicSimplification.cpp
ExpandTanh.cpp
PolynomialApproximation.cpp
diff --git a/mlir/test/Dialect/Math/algebraic-simplification.mlir b/mlir/test/Dialect/Math/algebraic-simplification.mlir
new file mode 100644
index 0000000000000..cb39bb7cd7f56
--- /dev/null
+++ b/mlir/test/Dialect/Math/algebraic-simplification.mlir
@@ -0,0 +1,51 @@
+// RUN: mlir-opt %s -test-math-algebraic-simplification | FileCheck %s --dump-input=always
+
+// CHECK-LABEL: @pow_noop
+func @pow_noop(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>) {
+ // CHECK: return %arg0, %arg1
+ %c = constant 1.0 : f32
+ %v = constant dense <1.0> : vector<4xf32>
+ %0 = math.powf %arg0, %c : f32
+ %1 = math.powf %arg1, %v : vector<4xf32>
+ return %0, %1 : f32, vector<4xf32>
+}
+
+// CHECK-LABEL: @pow_square
+func @pow_square(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>) {
+ // CHECK: %[[SCALAR:.*]] = mulf %arg0, %arg0
+ // CHECK: %[[VECTOR:.*]] = mulf %arg1, %arg1
+ // CHECK: return %[[SCALAR]], %[[VECTOR]]
+ %c = constant 2.0 : f32
+ %v = constant dense <2.0> : vector<4xf32>
+ %0 = math.powf %arg0, %c : f32
+ %1 = math.powf %arg1, %v : vector<4xf32>
+ return %0, %1 : f32, vector<4xf32>
+}
+
+// CHECK-LABEL: @pow_cube
+func @pow_cube(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>) {
+ // CHECK: %[[TMP_S:.*]] = mulf %arg0, %arg0
+ // CHECK: %[[SCALAR:.*]] = mulf %arg0, %[[TMP_S]]
+ // CHECK: %[[TMP_V:.*]] = mulf %arg1, %arg1
+ // CHECK: %[[VECTOR:.*]] = mulf %arg1, %[[TMP_V]]
+ // CHECK: return %[[SCALAR]], %[[VECTOR]]
+ %c = constant 3.0 : f32
+ %v = constant dense <3.0> : vector<4xf32>
+ %0 = math.powf %arg0, %c : f32
+ %1 = math.powf %arg1, %v : vector<4xf32>
+ return %0, %1 : f32, vector<4xf32>
+}
+
+// CHECK-LABEL: @pow_recip
+func @pow_recip(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>) {
+ // CHECK: %[[CST_S:.*]] = constant 1.0{{.*}} : f32
+ // CHECK: %[[CST_V:.*]] = constant dense<1.0{{.*}}> : vector<4xf32>
+ // CHECK: %[[SCALAR:.*]] = divf %[[CST_S]], %arg0
+ // CHECK: %[[VECTOR:.*]] = divf %[[CST_V]], %arg1
+ // CHECK: return %[[SCALAR]], %[[VECTOR]]
+ %c = constant -1.0 : f32
+ %v = constant dense <-1.0> : vector<4xf32>
+ %0 = math.powf %arg0, %c : f32
+ %1 = math.powf %arg1, %v : vector<4xf32>
+ return %0, %1 : f32, vector<4xf32>
+}
diff --git a/mlir/test/lib/Dialect/Math/CMakeLists.txt b/mlir/test/lib/Dialect/Math/CMakeLists.txt
index 600183145833d..64cae2f77c5e3 100644
--- a/mlir/test/lib/Dialect/Math/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Math/CMakeLists.txt
@@ -1,5 +1,6 @@
# Exclude tests from libMLIR.so
add_mlir_library(MLIRMathTestPasses
+ TestAlgebraicSimplification.cpp
TestExpandTanh.cpp
TestPolynomialApproximation.cpp
diff --git a/mlir/test/lib/Dialect/Math/TestAlgebraicSimplification.cpp b/mlir/test/lib/Dialect/Math/TestAlgebraicSimplification.cpp
new file mode 100644
index 0000000000000..b73f1b2b9483f
--- /dev/null
+++ b/mlir/test/lib/Dialect/Math/TestAlgebraicSimplification.cpp
@@ -0,0 +1,50 @@
+//===- TestAlgebraicSimplification.cpp - Test algebraic simplification ----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains test passes for algebraic simplification patterns.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/Math/Transforms/Passes.h"
+#include "mlir/Dialect/Vector/VectorOps.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+
+namespace {
+struct TestMathAlgebraicSimplificationPass
+ : public PassWrapper<TestMathAlgebraicSimplificationPass, FunctionPass> {
+ void runOnFunction() override;
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<vector::VectorDialect, math::MathDialect>();
+ }
+ StringRef getArgument() const final {
+ return "test-math-algebraic-simplification";
+ }
+ StringRef getDescription() const final {
+ return "Test math algebraic simplification";
+ }
+};
+} // end anonymous namespace
+
+void TestMathAlgebraicSimplificationPass::runOnFunction() {
+ RewritePatternSet patterns(&getContext());
+ populateMathAlgebraicSimplificationPatterns(patterns);
+ (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+}
+
+namespace mlir {
+namespace test {
+void registerTestMathAlgebraicSimplificationPass() {
+ PassRegistration<TestMathAlgebraicSimplificationPass>();
+}
+} // namespace test
+} // namespace mlir
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index c30575b39fbe8..c5be1daec8304 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -92,6 +92,7 @@ void registerTestLivenessPass();
void registerTestLoopFusion();
void registerTestLoopMappingPass();
void registerTestLoopUnrollingPass();
+void registerTestMathAlgebraicSimplificationPass();
void registerTestMathPolynomialApproximationPass();
void registerTestMemRefDependenceCheck();
void registerTestMemRefStrideCalculation();
@@ -173,6 +174,7 @@ void registerTestPasses() {
test::registerTestLoopFusion();
test::registerTestLoopMappingPass();
test::registerTestLoopUnrollingPass();
+ test::registerTestMathAlgebraicSimplificationPass();
test::registerTestMathPolynomialApproximationPass();
test::registerTestMemRefDependenceCheck();
test::registerTestMemRefStrideCalculation();
More information about the Mlir-commits
mailing list