[Mlir-commits] [mlir] 9b078f8 - [MLIR][arith] Mark addf/mulf as commutative

Christian Sigg llvmlistbot at llvm.org
Mon Jan 31 23:33:56 PST 2022


Author: Christian Sigg
Date: 2022-02-01T08:33:48+01:00
New Revision: 9b078f8fd26a392d0b51a3fe97f07b3c5ca30bc8

URL: https://github.com/llvm/llvm-project/commit/9b078f8fd26a392d0b51a3fe97f07b3c5ca30bc8
DIFF: https://github.com/llvm/llvm-project/commit/9b078f8fd26a392d0b51a3fe97f07b3c5ca30bc8.diff

LOG: [MLIR][arith] Mark addf/mulf as commutative

Following the discussion in D118318, mark `arith.addf/mulf` commutative.

Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D118600

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td
    mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
    mlir/test/Dialect/Arithmetic/canonicalize.mlir
    mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
    mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
    mlir/test/Dialect/Standard/expand-tanh.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td
index 496c92db50813..b0c08c4062504 100644
--- a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td
+++ b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td
@@ -594,7 +594,7 @@ def Arith_NegFOp : Arith_FloatUnaryOp<"negf"> {
 // AddFOp
 //===----------------------------------------------------------------------===//
 
-def Arith_AddFOp : Arith_FloatBinaryOp<"addf"> {
+def Arith_AddFOp : Arith_FloatBinaryOp<"addf", [Commutative]> {
   let summary = "floating point addition operation";
   let description = [{
     The `addf` operation takes two operands and returns one result, each of
@@ -627,6 +627,28 @@ def Arith_AddFOp : Arith_FloatBinaryOp<"addf"> {
 
 def Arith_SubFOp : Arith_FloatBinaryOp<"subf"> {
   let summary = "floating point subtraction operation";
+  let description = [{
+    The `subf` operation takes two operands and returns one result, each of
+    these is required to be the same type. This type may be a floating point
+    scalar type, a vector whose element type is a floating point type, or a
+    floating point tensor.
+
+    Example:
+
+    ```mlir
+    // Scalar subtraction.
+    %a = arith.subf %b, %c : f64
+
+    // SIMD vector subtraction, e.g. for Intel SSE.
+    %f = arith.subf %g, %h : vector<4xf32>
+
+    // Tensor subtraction.
+    %x = arith.subf %y, %z : tensor<4x?xbf16>
+    ```
+
+    TODO: In the distant future, this will accept optional attributes for fast
+    math, contraction, rounding mode, and other controls.
+  }];
   let hasFolder = 1;
 }
 
@@ -723,7 +745,7 @@ def Arith_MinUIOp : Arith_IntBinaryOp<"minui"> {
 // MulFOp
 //===----------------------------------------------------------------------===//
 
-def Arith_MulFOp : Arith_FloatBinaryOp<"mulf"> {
+def Arith_MulFOp : Arith_FloatBinaryOp<"mulf", [Commutative]> {
   let summary = "floating point multiplication operation";
   let description = [{
     The `mulf` operation takes two operands and returns one result, each of

diff  --git a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
index 022a5674ef9ee..37395fa79911f 100644
--- a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
+++ b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
@@ -580,10 +580,6 @@ OpFoldResult arith::AddFOp::fold(ArrayRef<Attribute> operands) {
   if (matchPattern(getRhs(), m_NegZeroFloat()))
     return getLhs();
 
-  // addf(-0, x) -> x
-  if (matchPattern(getLhs(), m_NegZeroFloat()))
-    return getRhs();
-
   return constFoldBinaryOp<FloatAttr>(
       operands, [](const APFloat &a, const APFloat &b) { return a + b; });
 }

diff  --git a/mlir/test/Dialect/Arithmetic/canonicalize.mlir b/mlir/test/Dialect/Arithmetic/canonicalize.mlir
index cf87af2e30e3c..bb40f62a5e01f 100644
--- a/mlir/test/Dialect/Arithmetic/canonicalize.mlir
+++ b/mlir/test/Dialect/Arithmetic/canonicalize.mlir
@@ -658,7 +658,7 @@ func @test_addf(%arg0 : f32) -> (f32, f32, f32, f32) {
   %c0 = arith.constant 0.0 : f32
   %c-0 = arith.constant -0.0 : f32
   %c1 = arith.constant 1.0 : f32
-  %0 = arith.addf %arg0, %c0 : f32
+  %0 = arith.addf %c0, %arg0 : f32
   %1 = arith.addf %arg0, %c-0 : f32
   %2 = arith.addf %c-0, %arg0 : f32
   %3 = arith.addf %c1, %c1 : f32
@@ -685,15 +685,18 @@ func @test_subf(%arg0 : f16) -> (f16, f16, f16) {
 // -----
 
 // CHECK-LABEL: @test_mulf(
-func @test_mulf(%arg0 : f32) -> (f32, f32, f32) {
-  // CHECK-NEXT:  %[[C4:.+]] = arith.constant 4.0
-  // CHECK-NEXT:   return %arg0, %arg0, %[[C4]]
+func @test_mulf(%arg0 : f32) -> (f32, f32, f32, f32) {
+  // CHECK-DAG:   %[[C2:.+]] = arith.constant 2.0
+  // CHECK-DAG:   %[[C4:.+]] = arith.constant 4.0
+  // CHECK-NEXT:  %[[X:.+]] = arith.mulf %arg0, %[[C2]]
+  // CHECK-NEXT:  return %[[X]], %arg0, %arg0, %[[C4]]
   %c1 = arith.constant 1.0 : f32
   %c2 = arith.constant 2.0 : f32
-  %0 = arith.mulf %arg0, %c1 : f32
-  %1 = arith.mulf %c1, %arg0 : f32
-  %2 = arith.mulf %c2, %c2 : f32
-  return %0, %1, %2 : f32, f32, f32
+  %0 = arith.mulf %c2, %arg0 : f32
+  %1 = arith.mulf %arg0, %c1 : f32
+  %2 = arith.mulf %c1, %arg0 : f32
+  %3 = arith.mulf %c2, %c2 : f32
+  return %0, %1, %2, %3 : f32, f32, f32, f32
 }
 
 // -----

diff  --git a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
index 3f68820b18cc7..d36186c396f87 100644
--- a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
@@ -243,7 +243,7 @@ func @generic_op_constant_fusion(%arg0 : tensor<5x?x?xf32>) -> tensor<5x?x?xf32>
 //       CHECK:   %[[CST:.*]] = arith.constant {{.*}} : f32
 //       CHECK:   linalg.generic
 //       CHECK:   ^{{.+}}(%[[ARG1:[a-zA-Z0-9_]+]]: f32, %{{.+}}: f32):
-//       CHECK:     arith.mulf %[[CST]], %[[ARG1]]
+//       CHECK:     arith.mulf %[[ARG1]], %[[CST]]
 
 // -----
 
@@ -275,7 +275,7 @@ func @generic_op_zero_dim_constant_fusion(%arg0 : tensor<5x?x?xf32>)
 //       CHECK:   %[[CST:.*]] = arith.constant {{.*}} : f32
 //       CHECK:   linalg.generic
 //       CHECK:   ^{{.*}}(%[[ARG1:[a-zA-Z0-9_]*]]: f32, %{{.*}}: f32)
-//       CHECK:     arith.mulf %[[CST]], %[[ARG1]]
+//       CHECK:     arith.mulf %[[ARG1]], %[[CST]]
 
 // -----
 

diff  --git a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
index cccb38150a98d..b01191184b055 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
@@ -258,7 +258,7 @@ func @generalize_soft_plus_2d_f32(%input: tensor<16x32xf32>, %output: tensor<16x
 //      CHECK: %[[C1:.+]] = arith.constant 1.000000e+00 : f32
 //      CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32
 // CHECK-NEXT:   %[[EXP:.+]] = math.exp %[[IN]] : f32
-// CHECK-NEXT:   %[[SUM:.+]] = arith.addf %[[C1]], %[[EXP]] : f32
+// CHECK-NEXT:   %[[SUM:.+]] = arith.addf %[[EXP]], %[[C1]] : f32
 // CHECK-NEXT:   %[[LOG:.+]] = math.log %[[SUM]] : f32
 // CHECK-NEXT:   linalg.yield %[[LOG]] : f32
 // CHECK-NEXT: -> tensor<16x32xf32>

diff  --git a/mlir/test/Dialect/Standard/expand-tanh.mlir b/mlir/test/Dialect/Standard/expand-tanh.mlir
index 1aa70cb712c93..4f809b71bd54e 100644
--- a/mlir/test/Dialect/Standard/expand-tanh.mlir
+++ b/mlir/test/Dialect/Standard/expand-tanh.mlir
@@ -12,7 +12,7 @@ func @tanh(%arg: f32) -> f32 {
 // CHECK: %[[NEGDOUBLEDX:.+]] = arith.negf %[[DOUBLEDX]] : f32
 // CHECK: %[[EXP1:.+]] = math.exp %[[NEGDOUBLEDX]] : f32
 // CHECK: %[[DIVIDEND1:.+]] = arith.subf %[[ONE]], %[[EXP1]] : f32
-// CHECK: %[[DIVISOR1:.+]] = arith.addf %[[ONE]], %[[EXP1]] : f32
+// CHECK: %[[DIVISOR1:.+]] = arith.addf %[[EXP1]], %[[ONE]] : f32
 // CHECK: %[[RES1:.+]] = arith.divf %[[DIVIDEND1]], %[[DIVISOR1]] : f32
 // CHECK: %[[EXP2:.+]] = math.exp %[[DOUBLEDX]] : f32
 // CHECK: %[[DIVIDEND2:.+]] = arith.subf %[[EXP2]], %[[ONE]] : f32


        


More information about the Mlir-commits mailing list