[Mlir-commits] [mlir] dd115e5 - [mlir][IR] Implement proper folder for `IsCommutative` trait

Matthias Springer llvmlistbot at llvm.org
Thu Jul 20 01:19:59 PDT 2023


Author: Matthias Springer
Date: 2023-07-20T10:19:48+02:00
New Revision: dd115e5a9bc778b5a8c4e445b7cdafa27db54ddd

URL: https://github.com/llvm/llvm-project/commit/dd115e5a9bc778b5a8c4e445b7cdafa27db54ddd
DIFF: https://github.com/llvm/llvm-project/commit/dd115e5a9bc778b5a8c4e445b7cdafa27db54ddd.diff

LOG: [mlir][IR] Implement proper folder for `IsCommutative` trait

Commutative ops were previously folded with a special rule in `OperationFolder`. This change turns the folding into a proper `OpTrait` folder.

Differential Revision: https://reviews.llvm.org/D155687

Added: 
    

Modified: 
    mlir/include/mlir/IR/OpDefinition.h
    mlir/lib/IR/Operation.cpp
    mlir/lib/Transforms/Utils/FoldUtils.cpp
    mlir/test/Conversion/FuncToLLVM/calling-convention.mlir
    mlir/test/Conversion/TensorToLinalg/tensor-ops-to-linalg.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h
index d895679be3ca2c..221c607c15f4c9 100644
--- a/mlir/include/mlir/IR/OpDefinition.h
+++ b/mlir/include/mlir/IR/OpDefinition.h
@@ -314,6 +314,8 @@ namespace OpTrait {
 // corresponding trait classes.  This avoids them being template
 // instantiated/duplicated.
 namespace impl {
+LogicalResult foldCommutative(Operation *op, ArrayRef<Attribute> operands,
+                              SmallVectorImpl<OpFoldResult> &results);
 OpFoldResult foldIdempotent(Operation *op);
 OpFoldResult foldInvolution(Operation *op);
 LogicalResult verifyZeroOperands(Operation *op);
@@ -1148,7 +1150,13 @@ class ResultsAreSignlessIntegerLike
 
 /// This class adds property that the operation is commutative.
 template <typename ConcreteType>
-class IsCommutative : public TraitBase<ConcreteType, IsCommutative> {};
+class IsCommutative : public TraitBase<ConcreteType, IsCommutative> {
+public:
+  static LogicalResult foldTrait(Operation *op, ArrayRef<Attribute> operands,
+                                 SmallVectorImpl<OpFoldResult> &results) {
+    return impl::foldCommutative(op, operands, results);
+  }
+};
 
 /// This class adds property that the operation is an involution.
 /// This means a unary to unary operation "f" that satisfies f(f(x)) = x

diff  --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp
index 449c97d469bf62..efce8d92015087 100644
--- a/mlir/lib/IR/Operation.cpp
+++ b/mlir/lib/IR/Operation.cpp
@@ -12,6 +12,7 @@
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/IRMapping.h"
+#include "mlir/IR/Matchers.h"
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/IR/OperationSupport.h"
 #include "mlir/IR/PatternMatch.h"
@@ -790,6 +791,24 @@ InFlightDiagnostic OpState::emitRemark(const Twine &message) {
 // Op Trait implementations
 //===----------------------------------------------------------------------===//
 
+LogicalResult
+OpTrait::impl::foldCommutative(Operation *op, ArrayRef<Attribute> operands,
+                               SmallVectorImpl<OpFoldResult> &results) {
+  // Nothing to fold if there are not at least 2 operands.
+  if (op->getNumOperands() < 2)
+    return failure();
+  // Move all constant operands to the end.
+  OpOperand *operandsBegin = op->getOpOperands().begin();
+  auto isNonConstant = [&](OpOperand &o) {
+    return !static_cast<bool>(operands[std::distance(operandsBegin, &o)]);
+  };
+  auto *firstConstantIt = llvm::find_if_not(op->getOpOperands(), isNonConstant);
+  auto *newConstantIt = std::stable_partition(
+      firstConstantIt, op->getOpOperands().end(), isNonConstant);
+  // Return success if the op was modified.
+  return success(firstConstantIt != newConstantIt);
+}
+
 OpFoldResult OpTrait::impl::foldIdempotent(Operation *op) {
   if (op->getNumOperands() == 1) {
     auto *argumentOp = op->getOperand(0).getDefiningOp();

diff  --git a/mlir/lib/Transforms/Utils/FoldUtils.cpp b/mlir/lib/Transforms/Utils/FoldUtils.cpp
index e9e59cfeed79ea..ad1e0436c64d61 100644
--- a/mlir/lib/Transforms/Utils/FoldUtils.cpp
+++ b/mlir/lib/Transforms/Utils/FoldUtils.cpp
@@ -217,21 +217,6 @@ LogicalResult OperationFolder::tryToFold(Operation *op,
                                          SmallVectorImpl<Value> &results) {
   SmallVector<Attribute, 8> operandConstants;
 
-  // If this is a commutative operation, move constants to be trailing operands.
-  bool updatedOpOperands = false;
-  if (op->getNumOperands() >= 2 && op->hasTrait<OpTrait::IsCommutative>()) {
-    auto isNonConstant = [&](OpOperand &o) {
-      return !matchPattern(o.get(), m_Constant());
-    };
-    auto *firstConstantIt =
-        llvm::find_if_not(op->getOpOperands(), isNonConstant);
-    auto *newConstantIt = std::stable_partition(
-        firstConstantIt, op->getOpOperands().end(), isNonConstant);
-
-    // Remember if we actually moved anything.
-    updatedOpOperands = firstConstantIt != newConstantIt;
-  }
-
   // Check to see if any operands to the operation is constant and whether
   // the operation knows how to constant fold itself.
   operandConstants.assign(op->getNumOperands(), Attribute());
@@ -244,7 +229,7 @@ LogicalResult OperationFolder::tryToFold(Operation *op,
   SmallVector<OpFoldResult, 8> foldResults;
   if (failed(op->fold(operandConstants, foldResults)) ||
       failed(processFoldResults(op, results, foldResults)))
-    return success(updatedOpOperands);
+    return failure();
   return success();
 }
 

diff  --git a/mlir/test/Conversion/FuncToLLVM/calling-convention.mlir b/mlir/test/Conversion/FuncToLLVM/calling-convention.mlir
index 710f70401589d0..1ed67708875604 100644
--- a/mlir/test/Conversion/FuncToLLVM/calling-convention.mlir
+++ b/mlir/test/Conversion/FuncToLLVM/calling-convention.mlir
@@ -127,7 +127,7 @@ func.func @return_var_memref_caller(%arg0: memref<4x3xf32>) {
   // CHECK: %[[PTR_SIZE:.*]] = llvm.mlir.constant
   // CHECK: %[[DOUBLE_PTR_SIZE:.*]] = llvm.mul %[[TWO]], %[[PTR_SIZE]]
   // CHECK: %[[RANK:.*]] = llvm.extractvalue %[[CALL_RES]][0] : !llvm.struct<(i64, ptr)>
-  // CHECK: %[[DOUBLE_RANK:.*]] = llvm.mul %[[TWO]], %[[RANK]]
+  // CHECK: %[[DOUBLE_RANK:.*]] = llvm.mul %[[RANK]], %[[TWO]]
   // CHECK: %[[DOUBLE_RANK_INC:.*]] = llvm.add %[[DOUBLE_RANK]], %[[ONE]]
   // CHECK: %[[TABLES_SIZE:.*]] = llvm.mul %[[DOUBLE_RANK_INC]], %[[IDX_SIZE]]
   // CHECK: %[[ALLOC_SIZE:.*]] = llvm.add %[[DOUBLE_PTR_SIZE]], %[[TABLES_SIZE]]
@@ -159,7 +159,7 @@ func.func @return_var_memref(%arg0: memref<4x3xf32>) -> memref<*xf32> attributes
 
   // CHECK: %[[PTR_SIZE:.*]] = llvm.mlir.constant
   // CHECK: %[[DOUBLE_PTR_SIZE:.*]] = llvm.mul %[[TWO]], %[[PTR_SIZE]]
-  // CHECK: %[[DOUBLE_RANK:.*]] = llvm.mul %[[TWO]], %[[RANK]]
+  // CHECK: %[[DOUBLE_RANK:.*]] = llvm.mul %[[RANK]], %[[TWO]]
   // CHECK: %[[DOUBLE_RANK_INC:.*]] = llvm.add %[[DOUBLE_RANK]], %[[ONE]]
   // CHECK: %[[TABLES_SIZE:.*]] = llvm.mul %[[DOUBLE_RANK_INC]], %[[IDX_SIZE]]
   // CHECK: %[[ALLOC_SIZE:.*]] = llvm.add %[[DOUBLE_PTR_SIZE]], %[[TABLES_SIZE]]

diff  --git a/mlir/test/Conversion/TensorToLinalg/tensor-ops-to-linalg.mlir b/mlir/test/Conversion/TensorToLinalg/tensor-ops-to-linalg.mlir
index f11abe4348a93f..238c0c51312a6b 100644
--- a/mlir/test/Conversion/TensorToLinalg/tensor-ops-to-linalg.mlir
+++ b/mlir/test/Conversion/TensorToLinalg/tensor-ops-to-linalg.mlir
@@ -27,7 +27,7 @@ func.func @generalize_pad_tensor_static_shape(%arg0: tensor<1x28x28x1xf32>) -> t
 // CHECK-DAG:       %[[C1:.*]] = arith.constant 1 : index
 // CHECK:           %[[DIM1:.*]] = tensor.dim %[[IN]], %[[C1]] : tensor<4x?x2x?xf32>
 // CHECK-DAG:       %[[C2:.*]] = arith.constant 2 : index
-// CHECK:           %[[OUT_DIM2:.*]] = arith.addi %[[C2]], %[[OFFSET]] : index
+// CHECK:           %[[OUT_DIM2:.*]] = arith.addi %[[OFFSET]], %[[C2]] : index
 // CHECK-DAG:       %[[C3:.*]] = arith.constant 3 : index
 // CHECK:           %[[DIM3:.*]] = tensor.dim %[[IN]], %[[C3]] : tensor<4x?x2x?xf32>
 // CHECK:           %[[OUT_DIM3:.*]] = arith.addi %[[DIM3]], %[[OFFSET]] : index


        


More information about the Mlir-commits mailing list