[Mlir-commits] [mlir] b80a9ca - [MLIR] Allow non-binary operations to be commutative

Stephen Neuendorffer llvmlistbot at llvm.org
Mon Feb 10 10:25:04 PST 2020


Author: Stephen Neuendorffer
Date: 2020-02-10T10:23:55-08:00
New Revision: b80a9ca8cbc19beb6117d7a05e3403adc598a059

URL: https://github.com/llvm/llvm-project/commit/b80a9ca8cbc19beb6117d7a05e3403adc598a059
DIFF: https://github.com/llvm/llvm-project/commit/b80a9ca8cbc19beb6117d7a05e3403adc598a059.diff

LOG: [MLIR] Allow non-binary operations to be commutative

NFC for binary operations.

Differential Revision: https://reviews.llvm.org/D73670

Added: 
    

Modified: 
    mlir/include/mlir/IR/OperationSupport.h
    mlir/lib/Transforms/Utils/FoldUtils.cpp
    mlir/test/Transforms/test-canonicalize.mlir
    mlir/test/lib/TestDialect/TestOps.td

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h
index bf19a5af14ff..544cfe829b3c 100644
--- a/mlir/include/mlir/IR/OperationSupport.h
+++ b/mlir/include/mlir/IR/OperationSupport.h
@@ -56,9 +56,11 @@ class OwningRewritePatternList;
 //===----------------------------------------------------------------------===//
 
 enum class OperationProperty {
-  /// This bit is set for an operation if it is a commutative operation: that
-  /// is a binary operator (two inputs) where "a op b" and "b op a" produce the
-  /// same results.
+  /// This bit is set for an operation if it is a commutative
+  /// operation: that is an operator where order of operands does not
+  /// change the result of the operation.  For example, in a binary
+  /// commutative operation, "a op b" and "b op a" produce the same
+  /// results.
   Commutative = 0x1,
 
   /// This bit is set for operations that have no side effects: that means that

diff  --git a/mlir/lib/Transforms/Utils/FoldUtils.cpp b/mlir/lib/Transforms/Utils/FoldUtils.cpp
index f05206a0814c..1caee370ff91 100644
--- a/mlir/lib/Transforms/Utils/FoldUtils.cpp
+++ b/mlir/lib/Transforms/Utils/FoldUtils.cpp
@@ -134,20 +134,19 @@ LogicalResult OperationFolder::tryToFold(
   SmallVector<Attribute, 8> operandConstants;
   SmallVector<OpFoldResult, 8> foldResults;
 
+  // If this is a commutative operation, move constants to be trailing operands.
+  if (op->getNumOperands() >= 2 && op->isCommutative()) {
+    std::stable_partition(
+        op->getOpOperands().begin(), op->getOpOperands().end(),
+        [&](OpOperand &O) { return !matchPattern(O.get(), m_Constant()); });
+  }
+
   // Check to see if any operands to the operation is constant and whether
   // the operation knows how to constant fold itself.
   operandConstants.assign(op->getNumOperands(), Attribute());
   for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i)
     matchPattern(op->getOperand(i), m_Constant(&operandConstants[i]));
 
-  // If this is a commutative binary operation with a constant on the left
-  // side move it to the right side.
-  if (operandConstants.size() == 2 && operandConstants[0] &&
-      !operandConstants[1] && op->isCommutative()) {
-    std::swap(op->getOpOperand(0), op->getOpOperand(1));
-    std::swap(operandConstants[0], operandConstants[1]);
-  }
-
   // Attempt to constant fold the operation.
   if (failed(op->fold(operandConstants, foldResults)))
     return failure();

diff  --git a/mlir/test/Transforms/test-canonicalize.mlir b/mlir/test/Transforms/test-canonicalize.mlir
index dfcc156912b6..920fd8c5f989 100644
--- a/mlir/test/Transforms/test-canonicalize.mlir
+++ b/mlir/test/Transforms/test-canonicalize.mlir
@@ -35,3 +35,19 @@ func @remove_op_with_variadic_results_and_folder(%arg0 : i32, %arg1 : i32) -> (i
   %0, %1 = "test.op_with_variadic_results_and_folder"(%arg0, %arg1) : (i32, i32) -> (i32, i32)
   return %0, %1 : i32, i32
 }
+
+// CHECK-LABEL: func @test_commutative_multi
+// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: i32, %[[ARG_1:[a-z0-9]*]]: i32)
+func @test_commutative_multi(%arg0: i32, %arg1: i32) -> (i32, i32) {
+  // CHECK: %[[C42:.*]] = constant 42 : i32
+  %c42_i32 = constant 42 : i32
+  // CHECK: %[[C43:.*]] = constant 43 : i32
+  %c43_i32 = constant 43 : i32
+  // CHECK-NEXT: %[[O0:.*]] = "test.op_commutative"(%[[ARG_0]], %[[ARG_1]], %[[C42]], %[[C43]]) : (i32, i32, i32, i32) -> i32
+  %y = "test.op_commutative"(%c42_i32, %arg0, %arg1, %c43_i32) : (i32, i32, i32, i32) -> i32
+
+  // CHECK-NEXT: %[[O1:.*]] = "test.op_commutative"(%[[ARG_0]], %[[ARG_1]], %[[C42]], %[[C43]]) : (i32, i32, i32, i32) -> i32
+  %z = "test.op_commutative"(%arg0, %c42_i32, %c43_i32, %arg1): (i32, i32, i32, i32) -> i32
+  // CHECK-NEXT: return %[[O0]], %[[O1]]
+  return %y, %z: i32, i32
+}

diff  --git a/mlir/test/lib/TestDialect/TestOps.td b/mlir/test/lib/TestDialect/TestOps.td
index 4b5f1b5850d8..de7f1875ef05 100644
--- a/mlir/test/lib/TestDialect/TestOps.td
+++ b/mlir/test/lib/TestDialect/TestOps.td
@@ -639,6 +639,11 @@ def TestOpWithVariadicResultsAndFolder: TEST_Op<"op_with_variadic_results_and_fo
   let hasFolder = 1;
 }
 
+def TestCommutativeOp : TEST_Op<"op_commutative", [Commutative]> {
+  let arguments = (ins I32:$op1, I32:$op2, I32:$op3, I32:$op4);
+  let results = (outs I32);
+}
+
 //===----------------------------------------------------------------------===//
 // Test Patterns (Symbol Binding)
 


        


More information about the Mlir-commits mailing list