[Mlir-commits] [mlir] dd115e5 - [mlir][IR] Implement proper folder for `IsCommutative` trait
Matthias Springer
llvmlistbot at llvm.org
Thu Jul 20 01:19:59 PDT 2023
Author: Matthias Springer
Date: 2023-07-20T10:19:48+02:00
New Revision: dd115e5a9bc778b5a8c4e445b7cdafa27db54ddd
URL: https://github.com/llvm/llvm-project/commit/dd115e5a9bc778b5a8c4e445b7cdafa27db54ddd
DIFF: https://github.com/llvm/llvm-project/commit/dd115e5a9bc778b5a8c4e445b7cdafa27db54ddd.diff
LOG: [mlir][IR] Implement proper folder for `IsCommutative` trait
Commutative ops were previously folded with a special rule in `OperationFolder`. This change turns the folding into a proper `OpTrait` folder.
Differential Revision: https://reviews.llvm.org/D155687
Added:
Modified:
mlir/include/mlir/IR/OpDefinition.h
mlir/lib/IR/Operation.cpp
mlir/lib/Transforms/Utils/FoldUtils.cpp
mlir/test/Conversion/FuncToLLVM/calling-convention.mlir
mlir/test/Conversion/TensorToLinalg/tensor-ops-to-linalg.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h
index d895679be3ca2c..221c607c15f4c9 100644
--- a/mlir/include/mlir/IR/OpDefinition.h
+++ b/mlir/include/mlir/IR/OpDefinition.h
@@ -314,6 +314,8 @@ namespace OpTrait {
// corresponding trait classes. This avoids them being template
// instantiated/duplicated.
namespace impl {
+LogicalResult foldCommutative(Operation *op, ArrayRef<Attribute> operands,
+ SmallVectorImpl<OpFoldResult> &results);
OpFoldResult foldIdempotent(Operation *op);
OpFoldResult foldInvolution(Operation *op);
LogicalResult verifyZeroOperands(Operation *op);
@@ -1148,7 +1150,13 @@ class ResultsAreSignlessIntegerLike
/// This class adds property that the operation is commutative.
template <typename ConcreteType>
-class IsCommutative : public TraitBase<ConcreteType, IsCommutative> {};
+class IsCommutative : public TraitBase<ConcreteType, IsCommutative> {
+public:
+ static LogicalResult foldTrait(Operation *op, ArrayRef<Attribute> operands,
+ SmallVectorImpl<OpFoldResult> &results) {
+ return impl::foldCommutative(op, operands, results);
+ }
+};
/// This class adds property that the operation is an involution.
/// This means a unary to unary operation "f" that satisfies f(f(x)) = x
diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp
index 449c97d469bf62..efce8d92015087 100644
--- a/mlir/lib/IR/Operation.cpp
+++ b/mlir/lib/IR/Operation.cpp
@@ -12,6 +12,7 @@
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/IRMapping.h"
+#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/PatternMatch.h"
@@ -790,6 +791,24 @@ InFlightDiagnostic OpState::emitRemark(const Twine &message) {
// Op Trait implementations
//===----------------------------------------------------------------------===//
+LogicalResult
+OpTrait::impl::foldCommutative(Operation *op, ArrayRef<Attribute> operands,
+ SmallVectorImpl<OpFoldResult> &results) {
+ // Nothing to fold if there are not at least 2 operands.
+ if (op->getNumOperands() < 2)
+ return failure();
+ // Move all constant operands to the end.
+ OpOperand *operandsBegin = op->getOpOperands().begin();
+ auto isNonConstant = [&](OpOperand &o) {
+ return !static_cast<bool>(operands[std::distance(operandsBegin, &o)]);
+ };
+ auto *firstConstantIt = llvm::find_if_not(op->getOpOperands(), isNonConstant);
+ auto *newConstantIt = std::stable_partition(
+ firstConstantIt, op->getOpOperands().end(), isNonConstant);
+ // Return success if the op was modified.
+ return success(firstConstantIt != newConstantIt);
+}
+
OpFoldResult OpTrait::impl::foldIdempotent(Operation *op) {
if (op->getNumOperands() == 1) {
auto *argumentOp = op->getOperand(0).getDefiningOp();
diff --git a/mlir/lib/Transforms/Utils/FoldUtils.cpp b/mlir/lib/Transforms/Utils/FoldUtils.cpp
index e9e59cfeed79ea..ad1e0436c64d61 100644
--- a/mlir/lib/Transforms/Utils/FoldUtils.cpp
+++ b/mlir/lib/Transforms/Utils/FoldUtils.cpp
@@ -217,21 +217,6 @@ LogicalResult OperationFolder::tryToFold(Operation *op,
SmallVectorImpl<Value> &results) {
SmallVector<Attribute, 8> operandConstants;
- // If this is a commutative operation, move constants to be trailing operands.
- bool updatedOpOperands = false;
- if (op->getNumOperands() >= 2 && op->hasTrait<OpTrait::IsCommutative>()) {
- auto isNonConstant = [&](OpOperand &o) {
- return !matchPattern(o.get(), m_Constant());
- };
- auto *firstConstantIt =
- llvm::find_if_not(op->getOpOperands(), isNonConstant);
- auto *newConstantIt = std::stable_partition(
- firstConstantIt, op->getOpOperands().end(), isNonConstant);
-
- // Remember if we actually moved anything.
- updatedOpOperands = firstConstantIt != newConstantIt;
- }
-
// Check to see if any operands to the operation is constant and whether
// the operation knows how to constant fold itself.
operandConstants.assign(op->getNumOperands(), Attribute());
@@ -244,7 +229,7 @@ LogicalResult OperationFolder::tryToFold(Operation *op,
SmallVector<OpFoldResult, 8> foldResults;
if (failed(op->fold(operandConstants, foldResults)) ||
failed(processFoldResults(op, results, foldResults)))
- return success(updatedOpOperands);
+ return failure();
return success();
}
diff --git a/mlir/test/Conversion/FuncToLLVM/calling-convention.mlir b/mlir/test/Conversion/FuncToLLVM/calling-convention.mlir
index 710f70401589d0..1ed67708875604 100644
--- a/mlir/test/Conversion/FuncToLLVM/calling-convention.mlir
+++ b/mlir/test/Conversion/FuncToLLVM/calling-convention.mlir
@@ -127,7 +127,7 @@ func.func @return_var_memref_caller(%arg0: memref<4x3xf32>) {
// CHECK: %[[PTR_SIZE:.*]] = llvm.mlir.constant
// CHECK: %[[DOUBLE_PTR_SIZE:.*]] = llvm.mul %[[TWO]], %[[PTR_SIZE]]
// CHECK: %[[RANK:.*]] = llvm.extractvalue %[[CALL_RES]][0] : !llvm.struct<(i64, ptr)>
- // CHECK: %[[DOUBLE_RANK:.*]] = llvm.mul %[[TWO]], %[[RANK]]
+ // CHECK: %[[DOUBLE_RANK:.*]] = llvm.mul %[[RANK]], %[[TWO]]
// CHECK: %[[DOUBLE_RANK_INC:.*]] = llvm.add %[[DOUBLE_RANK]], %[[ONE]]
// CHECK: %[[TABLES_SIZE:.*]] = llvm.mul %[[DOUBLE_RANK_INC]], %[[IDX_SIZE]]
// CHECK: %[[ALLOC_SIZE:.*]] = llvm.add %[[DOUBLE_PTR_SIZE]], %[[TABLES_SIZE]]
@@ -159,7 +159,7 @@ func.func @return_var_memref(%arg0: memref<4x3xf32>) -> memref<*xf32> attributes
// CHECK: %[[PTR_SIZE:.*]] = llvm.mlir.constant
// CHECK: %[[DOUBLE_PTR_SIZE:.*]] = llvm.mul %[[TWO]], %[[PTR_SIZE]]
- // CHECK: %[[DOUBLE_RANK:.*]] = llvm.mul %[[TWO]], %[[RANK]]
+ // CHECK: %[[DOUBLE_RANK:.*]] = llvm.mul %[[RANK]], %[[TWO]]
// CHECK: %[[DOUBLE_RANK_INC:.*]] = llvm.add %[[DOUBLE_RANK]], %[[ONE]]
// CHECK: %[[TABLES_SIZE:.*]] = llvm.mul %[[DOUBLE_RANK_INC]], %[[IDX_SIZE]]
// CHECK: %[[ALLOC_SIZE:.*]] = llvm.add %[[DOUBLE_PTR_SIZE]], %[[TABLES_SIZE]]
diff --git a/mlir/test/Conversion/TensorToLinalg/tensor-ops-to-linalg.mlir b/mlir/test/Conversion/TensorToLinalg/tensor-ops-to-linalg.mlir
index f11abe4348a93f..238c0c51312a6b 100644
--- a/mlir/test/Conversion/TensorToLinalg/tensor-ops-to-linalg.mlir
+++ b/mlir/test/Conversion/TensorToLinalg/tensor-ops-to-linalg.mlir
@@ -27,7 +27,7 @@ func.func @generalize_pad_tensor_static_shape(%arg0: tensor<1x28x28x1xf32>) -> t
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[DIM1:.*]] = tensor.dim %[[IN]], %[[C1]] : tensor<4x?x2x?xf32>
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
-// CHECK: %[[OUT_DIM2:.*]] = arith.addi %[[C2]], %[[OFFSET]] : index
+// CHECK: %[[OUT_DIM2:.*]] = arith.addi %[[OFFSET]], %[[C2]] : index
// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
// CHECK: %[[DIM3:.*]] = tensor.dim %[[IN]], %[[C3]] : tensor<4x?x2x?xf32>
// CHECK: %[[OUT_DIM3:.*]] = arith.addi %[[DIM3]], %[[OFFSET]] : index
More information about the Mlir-commits
mailing list