[Mlir-commits] [mlir] d94426d - [mlir] Math: add algebraic simplification patterns to math transforms

Eugene Zhulenev llvmlistbot at llvm.org
Tue Jul 27 09:22:40 PDT 2021


Author: Eugene Zhulenev
Date: 2021-07-27T09:22:33-07:00
New Revision: d94426d22a25559f25fd86276d7e9aefbd9d05ab

URL: https://github.com/llvm/llvm-project/commit/d94426d22a25559f25fd86276d7e9aefbd9d05ab
DIFF: https://github.com/llvm/llvm-project/commit/d94426d22a25559f25fd86276d7e9aefbd9d05ab.diff

LOG: [mlir] Math: add algebraic simplification patterns to math transforms

Reviewed By: bkramer

Differential Revision: https://reviews.llvm.org/D106822

Added: 
    mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
    mlir/test/Dialect/Math/algebraic-simplification.mlir
    mlir/test/lib/Dialect/Math/TestAlgebraicSimplification.cpp

Modified: 
    mlir/include/mlir/Dialect/Math/Transforms/Passes.h
    mlir/lib/Dialect/Math/Transforms/CMakeLists.txt
    mlir/test/lib/Dialect/Math/CMakeLists.txt
    mlir/tools/mlir-opt/mlir-opt.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
index 10635667a5fcc..4378f177fa0a1 100644
--- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
@@ -15,6 +15,8 @@ class RewritePatternSet;
 
 void populateExpandTanhPattern(RewritePatternSet &patterns);
 
+void populateMathAlgebraicSimplificationPatterns(RewritePatternSet &patterns);
+
 void populateMathPolynomialApproximationPatterns(RewritePatternSet &patterns);
 
 } // namespace mlir

diff  --git a/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp b/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
new file mode 100644
index 0000000000000..2614fc7cf2f73
--- /dev/null
+++ b/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
@@ -0,0 +1,112 @@
+//===- AlgebraicSimplification.cpp - Simplify algebraic expressions -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements rewrites based on the basic rules of algebra
+// (Commutativity, associativity, etc...) and strength reductions for math
+// operations.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/Math/Transforms/Passes.h"
+#include "mlir/Dialect/Vector/VectorOps.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/TypeUtilities.h"
+#include <climits>
+
+using namespace mlir;
+
+//----------------------------------------------------------------------------//
+// PowFOp strength reduction.
+//----------------------------------------------------------------------------//
+
+namespace {
+struct PowFStrengthReduction : public OpRewritePattern<math::PowFOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(math::PowFOp op,
+                                PatternRewriter &rewriter) const final;
+};
+} // namespace
+
+LogicalResult
+PowFStrengthReduction::matchAndRewrite(math::PowFOp op,
+                                       PatternRewriter &rewriter) const {
+  Location loc = op.getLoc();
+  Value x = op.lhs();
+
+  FloatAttr scalarExponent;
+  DenseFPElementsAttr vectorExponent;
+
+  bool isScalar = matchPattern(op.rhs(), m_Constant(&scalarExponent));
+  bool isVector = matchPattern(op.rhs(), m_Constant(&vectorExponent));
+
+  // Returns true if exponent is a constant equal to `value`.
+  auto isExponentValue = [&](double value) -> bool {
+    if (isScalar)
+      return scalarExponent.getValue().isExactlyValue(value);
+
+    if (isVector && vectorExponent.isSplat())
+      return vectorExponent.getSplatValue<FloatAttr>()
+          .getValue()
+          .isExactlyValue(value);
+
+    return false;
+  };
+
+  // Maybe broadcasts scalar value into vector type compatible with `op`.
+  auto bcast = [&](Value value) -> Value {
+    if (auto vec = op.getType().dyn_cast<VectorType>())
+      return rewriter.create<vector::BroadcastOp>(op.getLoc(), vec, value);
+    return value;
+  };
+
+  // Replace `pow(x, 1.0)` with `x`.
+  if (isExponentValue(1.0)) {
+    rewriter.replaceOp(op, x);
+    return success();
+  }
+
+  // Replace `pow(x, 2.0)` with `x * x`.
+  if (isExponentValue(2.0)) {
+    rewriter.replaceOpWithNewOp<MulFOp>(op, ValueRange({x, x}));
+    return success();
+  }
+
+  // Replace `pow(x, 2.0)` with `x * x * x`.
+  if (isExponentValue(3.0)) {
+    Value square = rewriter.create<MulFOp>(op.getLoc(), ValueRange({x, x}));
+    rewriter.replaceOpWithNewOp<MulFOp>(op, ValueRange({x, square}));
+    return success();
+  }
+
+  // Replace `pow(x, -1.0)` with `1.0 / x`.
+  if (isExponentValue(-1.0)) {
+    Value one = rewriter.create<ConstantOp>(
+        loc, rewriter.getFloatAttr(getElementTypeOrSelf(op.getType()), 1.0));
+    rewriter.replaceOpWithNewOp<DivFOp>(op, ValueRange({bcast(one), x}));
+    return success();
+  }
+
+  // Replace `pow(x, -2.0)` with `sqrt(x)`.
+  if (isExponentValue(-1.0)) {
+    rewriter.replaceOpWithNewOp<math::SqrtOp>(op, x);
+    return success();
+  }
+
+  return failure();
+}
+
+//----------------------------------------------------------------------------//
+
+void mlir::populateMathAlgebraicSimplificationPatterns(
+    RewritePatternSet &patterns) {
+  patterns.add<PowFStrengthReduction>(patterns.getContext());
+}

diff  --git a/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt
index 1eea7f3b61562..6eece6e0f7b3b 100644
--- a/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt
@@ -1,4 +1,5 @@
 add_mlir_dialect_library(MLIRMathTransforms
+  AlgebraicSimplification.cpp
   ExpandTanh.cpp
   PolynomialApproximation.cpp
 

diff  --git a/mlir/test/Dialect/Math/algebraic-simplification.mlir b/mlir/test/Dialect/Math/algebraic-simplification.mlir
new file mode 100644
index 0000000000000..cb39bb7cd7f56
--- /dev/null
+++ b/mlir/test/Dialect/Math/algebraic-simplification.mlir
@@ -0,0 +1,51 @@
+// RUN: mlir-opt %s -test-math-algebraic-simplification | FileCheck %s --dump-input=always
+
+// CHECK-LABEL: @pow_noop
+func @pow_noop(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>) {
+  // CHECK: return %arg0, %arg1
+  %c = constant 1.0 : f32
+  %v = constant dense <1.0> : vector<4xf32>
+  %0 = math.powf %arg0, %c : f32
+  %1 = math.powf %arg1, %v : vector<4xf32>
+  return %0, %1 : f32, vector<4xf32>
+}
+
+// CHECK-LABEL: @pow_square
+func @pow_square(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>) {
+  // CHECK: %[[SCALAR:.*]] = mulf %arg0, %arg0
+  // CHECK: %[[VECTOR:.*]] = mulf %arg1, %arg1
+  // CHECK: return %[[SCALAR]], %[[VECTOR]]
+  %c = constant 2.0 : f32
+  %v = constant dense <2.0> : vector<4xf32>
+  %0 = math.powf %arg0, %c : f32
+  %1 = math.powf %arg1, %v : vector<4xf32>
+  return %0, %1 : f32, vector<4xf32>
+}
+
+// CHECK-LABEL: @pow_cube
+func @pow_cube(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>) {
+  // CHECK: %[[TMP_S:.*]] = mulf %arg0, %arg0
+  // CHECK: %[[SCALAR:.*]] = mulf %arg0, %[[TMP_S]]
+  // CHECK: %[[TMP_V:.*]] = mulf %arg1, %arg1
+  // CHECK: %[[VECTOR:.*]] = mulf %arg1, %[[TMP_V]]
+  // CHECK: return %[[SCALAR]], %[[VECTOR]]
+  %c = constant 3.0 : f32
+  %v = constant dense <3.0> : vector<4xf32>
+  %0 = math.powf %arg0, %c : f32
+  %1 = math.powf %arg1, %v : vector<4xf32>
+  return %0, %1 : f32, vector<4xf32>
+}
+
+// CHECK-LABEL: @pow_recip
+func @pow_recip(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>) {
+  // CHECK: %[[CST_S:.*]] = constant 1.0{{.*}} : f32
+  // CHECK: %[[CST_V:.*]] = constant dense<1.0{{.*}}> : vector<4xf32>
+  // CHECK: %[[SCALAR:.*]] = divf %[[CST_S]], %arg0
+  // CHECK: %[[VECTOR:.*]] = divf %[[CST_V]], %arg1
+  // CHECK: return %[[SCALAR]], %[[VECTOR]]
+  %c = constant -1.0 : f32
+  %v = constant dense <-1.0> : vector<4xf32>
+  %0 = math.powf %arg0, %c : f32
+  %1 = math.powf %arg1, %v : vector<4xf32>
+  return %0, %1 : f32, vector<4xf32>
+}

diff  --git a/mlir/test/lib/Dialect/Math/CMakeLists.txt b/mlir/test/lib/Dialect/Math/CMakeLists.txt
index 600183145833d..64cae2f77c5e3 100644
--- a/mlir/test/lib/Dialect/Math/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Math/CMakeLists.txt
@@ -1,5 +1,6 @@
 # Exclude tests from libMLIR.so
 add_mlir_library(MLIRMathTestPasses
+  TestAlgebraicSimplification.cpp
   TestExpandTanh.cpp
   TestPolynomialApproximation.cpp
 

diff  --git a/mlir/test/lib/Dialect/Math/TestAlgebraicSimplification.cpp b/mlir/test/lib/Dialect/Math/TestAlgebraicSimplification.cpp
new file mode 100644
index 0000000000000..b73f1b2b9483f
--- /dev/null
+++ b/mlir/test/lib/Dialect/Math/TestAlgebraicSimplification.cpp
@@ -0,0 +1,50 @@
+//===- TestAlgebraicSimplification.cpp - Test algebraic simplification ----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains test passes for algebraic simplification patterns.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/Math/Transforms/Passes.h"
+#include "mlir/Dialect/Vector/VectorOps.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+
+namespace {
+struct TestMathAlgebraicSimplificationPass
+    : public PassWrapper<TestMathAlgebraicSimplificationPass, FunctionPass> {
+  void runOnFunction() override;
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<vector::VectorDialect, math::MathDialect>();
+  }
+  StringRef getArgument() const final {
+    return "test-math-algebraic-simplification";
+  }
+  StringRef getDescription() const final {
+    return "Test math algebraic simplification";
+  }
+};
+} // end anonymous namespace
+
+void TestMathAlgebraicSimplificationPass::runOnFunction() {
+  RewritePatternSet patterns(&getContext());
+  populateMathAlgebraicSimplificationPatterns(patterns);
+  (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+}
+
+namespace mlir {
+namespace test {
+void registerTestMathAlgebraicSimplificationPass() {
+  PassRegistration<TestMathAlgebraicSimplificationPass>();
+}
+} // namespace test
+} // namespace mlir

diff  --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index c30575b39fbe8..c5be1daec8304 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -92,6 +92,7 @@ void registerTestLivenessPass();
 void registerTestLoopFusion();
 void registerTestLoopMappingPass();
 void registerTestLoopUnrollingPass();
+void registerTestMathAlgebraicSimplificationPass();
 void registerTestMathPolynomialApproximationPass();
 void registerTestMemRefDependenceCheck();
 void registerTestMemRefStrideCalculation();
@@ -173,6 +174,7 @@ void registerTestPasses() {
   test::registerTestLoopFusion();
   test::registerTestLoopMappingPass();
   test::registerTestLoopUnrollingPass();
+  test::registerTestMathAlgebraicSimplificationPass();
   test::registerTestMathPolynomialApproximationPass();
   test::registerTestMemRefDependenceCheck();
   test::registerTestMemRefStrideCalculation();


        


More information about the Mlir-commits mailing list