[Mlir-commits] [mlir] 9b078f8 - [MLIR][arith] Mark addf/mulf as commutative
Christian Sigg
llvmlistbot at llvm.org
Mon Jan 31 23:33:56 PST 2022
Author: Christian Sigg
Date: 2022-02-01T08:33:48+01:00
New Revision: 9b078f8fd26a392d0b51a3fe97f07b3c5ca30bc8
URL: https://github.com/llvm/llvm-project/commit/9b078f8fd26a392d0b51a3fe97f07b3c5ca30bc8
DIFF: https://github.com/llvm/llvm-project/commit/9b078f8fd26a392d0b51a3fe97f07b3c5ca30bc8.diff
LOG: [MLIR][arith] Mark addf/mulf as commutative
Following the discussion in D118318, mark `arith.addf/mulf` commutative.
Reviewed By: mehdi_amini
Differential Revision: https://reviews.llvm.org/D118600
Added:
Modified:
mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td
mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
mlir/test/Dialect/Arithmetic/canonicalize.mlir
mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
mlir/test/Dialect/Standard/expand-tanh.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td
index 496c92db50813..b0c08c4062504 100644
--- a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td
+++ b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td
@@ -594,7 +594,7 @@ def Arith_NegFOp : Arith_FloatUnaryOp<"negf"> {
// AddFOp
//===----------------------------------------------------------------------===//
-def Arith_AddFOp : Arith_FloatBinaryOp<"addf"> {
+def Arith_AddFOp : Arith_FloatBinaryOp<"addf", [Commutative]> {
let summary = "floating point addition operation";
let description = [{
The `addf` operation takes two operands and returns one result, each of
@@ -627,6 +627,28 @@ def Arith_AddFOp : Arith_FloatBinaryOp<"addf"> {
def Arith_SubFOp : Arith_FloatBinaryOp<"subf"> {
let summary = "floating point subtraction operation";
+ let description = [{
+ The `subf` operation takes two operands and returns one result, each of
+ these is required to be the same type. This type may be a floating point
+ scalar type, a vector whose element type is a floating point type, or a
+ floating point tensor.
+
+ Example:
+
+ ```mlir
+ // Scalar subtraction.
+ %a = arith.subf %b, %c : f64
+
+ // SIMD vector subtraction, e.g. for Intel SSE.
+ %f = arith.subf %g, %h : vector<4xf32>
+
+ // Tensor subtraction.
+ %x = arith.subf %y, %z : tensor<4x?xbf16>
+ ```
+
+ TODO: In the distant future, this will accept optional attributes for fast
+ math, contraction, rounding mode, and other controls.
+ }];
let hasFolder = 1;
}
@@ -723,7 +745,7 @@ def Arith_MinUIOp : Arith_IntBinaryOp<"minui"> {
// MulFOp
//===----------------------------------------------------------------------===//
-def Arith_MulFOp : Arith_FloatBinaryOp<"mulf"> {
+def Arith_MulFOp : Arith_FloatBinaryOp<"mulf", [Commutative]> {
let summary = "floating point multiplication operation";
let description = [{
The `mulf` operation takes two operands and returns one result, each of
diff --git a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
index 022a5674ef9ee..37395fa79911f 100644
--- a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
+++ b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
@@ -580,10 +580,6 @@ OpFoldResult arith::AddFOp::fold(ArrayRef<Attribute> operands) {
if (matchPattern(getRhs(), m_NegZeroFloat()))
return getLhs();
- // addf(-0, x) -> x
- if (matchPattern(getLhs(), m_NegZeroFloat()))
- return getRhs();
-
return constFoldBinaryOp<FloatAttr>(
operands, [](const APFloat &a, const APFloat &b) { return a + b; });
}
diff --git a/mlir/test/Dialect/Arithmetic/canonicalize.mlir b/mlir/test/Dialect/Arithmetic/canonicalize.mlir
index cf87af2e30e3c..bb40f62a5e01f 100644
--- a/mlir/test/Dialect/Arithmetic/canonicalize.mlir
+++ b/mlir/test/Dialect/Arithmetic/canonicalize.mlir
@@ -658,7 +658,7 @@ func @test_addf(%arg0 : f32) -> (f32, f32, f32, f32) {
%c0 = arith.constant 0.0 : f32
%c-0 = arith.constant -0.0 : f32
%c1 = arith.constant 1.0 : f32
- %0 = arith.addf %arg0, %c0 : f32
+ %0 = arith.addf %c0, %arg0 : f32
%1 = arith.addf %arg0, %c-0 : f32
%2 = arith.addf %c-0, %arg0 : f32
%3 = arith.addf %c1, %c1 : f32
@@ -685,15 +685,18 @@ func @test_subf(%arg0 : f16) -> (f16, f16, f16) {
// -----
// CHECK-LABEL: @test_mulf(
-func @test_mulf(%arg0 : f32) -> (f32, f32, f32) {
- // CHECK-NEXT: %[[C4:.+]] = arith.constant 4.0
- // CHECK-NEXT: return %arg0, %arg0, %[[C4]]
+func @test_mulf(%arg0 : f32) -> (f32, f32, f32, f32) {
+ // CHECK-DAG: %[[C2:.+]] = arith.constant 2.0
+ // CHECK-DAG: %[[C4:.+]] = arith.constant 4.0
+ // CHECK-NEXT: %[[X:.+]] = arith.mulf %arg0, %[[C2]]
+ // CHECK-NEXT: return %[[X]], %arg0, %arg0, %[[C4]]
%c1 = arith.constant 1.0 : f32
%c2 = arith.constant 2.0 : f32
- %0 = arith.mulf %arg0, %c1 : f32
- %1 = arith.mulf %c1, %arg0 : f32
- %2 = arith.mulf %c2, %c2 : f32
- return %0, %1, %2 : f32, f32, f32
+ %0 = arith.mulf %c2, %arg0 : f32
+ %1 = arith.mulf %arg0, %c1 : f32
+ %2 = arith.mulf %c1, %arg0 : f32
+ %3 = arith.mulf %c2, %c2 : f32
+ return %0, %1, %2, %3 : f32, f32, f32, f32
}
// -----
diff --git a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
index 3f68820b18cc7..d36186c396f87 100644
--- a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
@@ -243,7 +243,7 @@ func @generic_op_constant_fusion(%arg0 : tensor<5x?x?xf32>) -> tensor<5x?x?xf32>
// CHECK: %[[CST:.*]] = arith.constant {{.*}} : f32
// CHECK: linalg.generic
// CHECK: ^{{.+}}(%[[ARG1:[a-zA-Z0-9_]+]]: f32, %{{.+}}: f32):
-// CHECK: arith.mulf %[[CST]], %[[ARG1]]
+// CHECK: arith.mulf %[[ARG1]], %[[CST]]
// -----
@@ -275,7 +275,7 @@ func @generic_op_zero_dim_constant_fusion(%arg0 : tensor<5x?x?xf32>)
// CHECK: %[[CST:.*]] = arith.constant {{.*}} : f32
// CHECK: linalg.generic
// CHECK: ^{{.*}}(%[[ARG1:[a-zA-Z0-9_]*]]: f32, %{{.*}}: f32)
-// CHECK: arith.mulf %[[CST]], %[[ARG1]]
+// CHECK: arith.mulf %[[ARG1]], %[[CST]]
// -----
diff --git a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
index cccb38150a98d..b01191184b055 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
@@ -258,7 +258,7 @@ func @generalize_soft_plus_2d_f32(%input: tensor<16x32xf32>, %output: tensor<16x
// CHECK: %[[C1:.+]] = arith.constant 1.000000e+00 : f32
// CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32
// CHECK-NEXT: %[[EXP:.+]] = math.exp %[[IN]] : f32
-// CHECK-NEXT: %[[SUM:.+]] = arith.addf %[[C1]], %[[EXP]] : f32
+// CHECK-NEXT: %[[SUM:.+]] = arith.addf %[[EXP]], %[[C1]] : f32
// CHECK-NEXT: %[[LOG:.+]] = math.log %[[SUM]] : f32
// CHECK-NEXT: linalg.yield %[[LOG]] : f32
// CHECK-NEXT: -> tensor<16x32xf32>
diff --git a/mlir/test/Dialect/Standard/expand-tanh.mlir b/mlir/test/Dialect/Standard/expand-tanh.mlir
index 1aa70cb712c93..4f809b71bd54e 100644
--- a/mlir/test/Dialect/Standard/expand-tanh.mlir
+++ b/mlir/test/Dialect/Standard/expand-tanh.mlir
@@ -12,7 +12,7 @@ func @tanh(%arg: f32) -> f32 {
// CHECK: %[[NEGDOUBLEDX:.+]] = arith.negf %[[DOUBLEDX]] : f32
// CHECK: %[[EXP1:.+]] = math.exp %[[NEGDOUBLEDX]] : f32
// CHECK: %[[DIVIDEND1:.+]] = arith.subf %[[ONE]], %[[EXP1]] : f32
-// CHECK: %[[DIVISOR1:.+]] = arith.addf %[[ONE]], %[[EXP1]] : f32
+// CHECK: %[[DIVISOR1:.+]] = arith.addf %[[EXP1]], %[[ONE]] : f32
// CHECK: %[[RES1:.+]] = arith.divf %[[DIVIDEND1]], %[[DIVISOR1]] : f32
// CHECK: %[[EXP2:.+]] = math.exp %[[DOUBLEDX]] : f32
// CHECK: %[[DIVIDEND2:.+]] = arith.subf %[[EXP2]], %[[ONE]] : f32
More information about the Mlir-commits
mailing list