[Mlir-commits] [mlir] 892bf09 - [mlir][arith] Allow to specify `constFoldBinaryOp` result type
Jakub Kuderski
llvmlistbot at llvm.org
Mon Feb 13 11:20:03 PST 2023
Author: Jakub Kuderski
Date: 2023-02-13T14:18:14-05:00
New Revision: 892bf09606b654321c475b503210f30600c3ff7f
URL: https://github.com/llvm/llvm-project/commit/892bf09606b654321c475b503210f30600c3ff7f
DIFF: https://github.com/llvm/llvm-project/commit/892bf09606b654321c475b503210f30600c3ff7f.diff
LOG: [mlir][arith] Allow to specify `constFoldBinaryOp` result type
This enables us to use the common fold helpers on elementwise ops that
produce different result type than operand types, e.g., `arith.cmpi` or
`arith.addui_extended`.
Use the updated helper to teach `arith.cmpi` to fold constant vectors.
Reviewed By: Mogball
Differential Revision: https://reviews.llvm.org/D143779
Added:
Modified:
mlir/include/mlir/Dialect/CommonFolders.h
mlir/lib/Dialect/Arith/IR/ArithOps.cpp
mlir/test/Dialect/Arith/canonicalize.mlir
mlir/test/Dialect/SCF/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/CommonFolders.h b/mlir/include/mlir/Dialect/CommonFolders.h
index 425a8c2686adb..4f24580d02a70 100644
--- a/mlir/include/mlir/Dialect/CommonFolders.h
+++ b/mlir/include/mlir/Dialect/CommonFolders.h
@@ -24,14 +24,16 @@
namespace mlir {
/// Performs constant folding `calculate` with element-wise behavior on the two
/// attributes in `operands` and returns the result if possible.
+/// Uses `resultType` for the type of the returned attribute.
template <class AttrElementT,
class ElementValueT = typename AttrElementT::ValueType,
class CalculationT = function_ref<
std::optional<ElementValueT>(ElementValueT, ElementValueT)>>
Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
+ Type resultType,
const CalculationT &calculate) {
assert(operands.size() == 2 && "binary op takes two operands");
- if (!operands[0] || !operands[1])
+ if (!resultType || !operands[0] || !operands[1])
return {};
if (operands[0].isa<AttrElementT>() && operands[1].isa<AttrElementT>()) {
@@ -45,7 +47,7 @@ Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
if (!calRes)
return {};
- return AttrElementT::get(lhs.getType(), *calRes);
+ return AttrElementT::get(resultType, *calRes);
}
if (operands[0].isa<SplatElementsAttr>() &&
@@ -62,9 +64,10 @@ Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
if (!elementResult)
return {};
- return DenseElementsAttr::get(lhs.getType(), *elementResult);
- } else if (operands[0].isa<ElementsAttr>() &&
- operands[1].isa<ElementsAttr>()) {
+ return DenseElementsAttr::get(resultType, *elementResult);
+ }
+
+ if (operands[0].isa<ElementsAttr>() && operands[1].isa<ElementsAttr>()) {
// Operands are ElementsAttr-derived; perform an element-wise fold by
// expanding the values.
auto lhs = operands[0].cast<ElementsAttr>();
@@ -83,11 +86,53 @@ Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
elementResults.push_back(*elementResult);
}
- return DenseElementsAttr::get(lhs.getType(), elementResults);
+ return DenseElementsAttr::get(resultType, elementResults);
}
return {};
}
+/// Performs constant folding `calculate` with element-wise behavior on the two
+/// attributes in `operands` and returns the result if possible.
+/// Uses the operand element type for the element type of the returned
+/// attribute.
+template <class AttrElementT,
+ class ElementValueT = typename AttrElementT::ValueType,
+ class CalculationT = function_ref<
+ std::optional<ElementValueT>(ElementValueT, ElementValueT)>>
+Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
+ const CalculationT &calculate) {
+ assert(operands.size() == 2 && "binary op takes two operands");
+ auto getResultType = [](Attribute attr) -> Type {
+ if (auto typed = attr.dyn_cast_or_null<TypedAttr>())
+ return typed.getType();
+ return {};
+ };
+
+ Type lhsType = getResultType(operands[0]);
+ Type rhsType = getResultType(operands[1]);
+ if (!lhsType || !rhsType)
+ return {};
+ if (lhsType != rhsType)
+ return {};
+
+ return constFoldBinaryOpConditional<AttrElementT, ElementValueT,
+ CalculationT>(operands, lhsType,
+ calculate);
+}
+
+template <class AttrElementT,
+ class ElementValueT = typename AttrElementT::ValueType,
+ class CalculationT =
+ function_ref<ElementValueT(ElementValueT, ElementValueT)>>
+Attribute constFoldBinaryOp(ArrayRef<Attribute> operands, Type resultType,
+ const CalculationT &calculate) {
+ return constFoldBinaryOpConditional<AttrElementT>(
+ operands, resultType,
+ [&](ElementValueT a, ElementValueT b) -> std::optional<ElementValueT> {
+ return calculate(a, b);
+ });
+}
+
template <class AttrElementT,
class ElementValueT = typename AttrElementT::ValueType,
class CalculationT =
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index d3739f8dbae61..775ee8466beef 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -13,11 +13,14 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/CommonFolders.h"
#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributeInterfaces.h"
+#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
+#include "llvm/ADT/APInt.h"
#include "llvm/ADT/APSInt.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallString.h"
@@ -107,6 +110,23 @@ namespace {
#include "ArithCanonicalization.inc"
} // namespace
+//===----------------------------------------------------------------------===//
+// Common helpers
+//===----------------------------------------------------------------------===//
+
+/// Return the type of the same shape (scalar, vector or tensor) containing i1.
+static Type getI1SameShape(Type type) {
+ auto i1Type = IntegerType::get(type.getContext(), 1);
+ if (auto tensorType = type.dyn_cast<RankedTensorType>())
+ return RankedTensorType::get(tensorType.getShape(), i1Type);
+ if (type.isa<UnrankedTensorType>())
+ return UnrankedTensorType::get(i1Type);
+ if (auto vectorType = type.dyn_cast<VectorType>())
+ return VectorType::get(vectorType.getShape(), i1Type,
+ vectorType.getNumScalableDims());
+ return i1Type;
+}
+
//===----------------------------------------------------------------------===//
// ConstantOp
//===----------------------------------------------------------------------===//
@@ -276,41 +296,16 @@ arith::AddUIExtendedOp::fold(FoldAdaptor adaptor,
// addui_extended(constant_a, constant_b) -> constant_sum, constant_carry
// Let the `constFoldBinaryOp` utility attempt to fold the sum of both
// operands. If that succeeds, calculate the overflow bit based on the sum
- // and the first (constant) operand, `lhs`. Note that we cannot simply call
- // `constFoldBinaryOp` again to calculate the overflow bit because the
- // constructed attribute is of the same element type as both operands.
+ // and the first (constant) operand, `lhs`.
if (Attribute sumAttr = constFoldBinaryOp<IntegerAttr>(
adaptor.getOperands(),
[](APInt a, const APInt &b) { return std::move(a) + b; })) {
- Attribute overflowAttr;
- if (auto lhs = adaptor.getLhs().dyn_cast<IntegerAttr>()) {
- // Both arguments are scalars, calculate the scalar overflow value.
- auto sum = sumAttr.cast<IntegerAttr>();
- overflowAttr = IntegerAttr::get(
- overflowTy,
- calculateUnsignedOverflow(sum.getValue(), lhs.getValue()));
- } else if (auto lhs = adaptor.getLhs().dyn_cast<SplatElementsAttr>()) {
- // Both arguments are splats, calculate the splat overflow value.
- auto sum = sumAttr.cast<SplatElementsAttr>();
- APInt overflow = calculateUnsignedOverflow(sum.getSplatValue<APInt>(),
- lhs.getSplatValue<APInt>());
- overflowAttr = SplatElementsAttr::get(overflowTy, overflow);
- } else if (auto lhs = adaptor.getLhs().dyn_cast<ElementsAttr>()) {
- // Othwerwise calculate element-wise overflow values.
- auto sum = sumAttr.cast<ElementsAttr>();
- const auto numElems = static_cast<size_t>(sum.getNumElements());
- SmallVector<APInt> overflowValues;
- overflowValues.reserve(numElems);
-
- auto sumIt = sum.value_begin<APInt>();
- auto lhsIt = lhs.value_begin<APInt>();
- for (size_t i = 0, e = numElems; i != e; ++i, ++sumIt, ++lhsIt)
- overflowValues.push_back(calculateUnsignedOverflow(*sumIt, *lhsIt));
-
- overflowAttr = DenseElementsAttr::get(overflowTy, overflowValues);
- } else {
+ Attribute overflowAttr = constFoldBinaryOp<IntegerAttr>(
+ ArrayRef({sumAttr, adaptor.getLhs()}),
+ getI1SameShape(sumAttr.cast<TypedAttr>().getType()),
+ calculateUnsignedOverflow);
+ if (!overflowAttr)
return failure();
- }
results.push_back(sumAttr);
results.push_back(overflowAttr);
@@ -1534,23 +1529,6 @@ void arith::BitcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
patterns.add<BitcastOfBitcast>(context);
}
-//===----------------------------------------------------------------------===//
-// Helpers for compare ops
-//===----------------------------------------------------------------------===//
-
-/// Return the type of the same shape (scalar, vector or tensor) containing i1.
-static Type getI1SameShape(Type type) {
- auto i1Type = IntegerType::get(type.getContext(), 1);
- if (auto tensorType = type.dyn_cast<RankedTensorType>())
- return RankedTensorType::get(tensorType.getShape(), i1Type);
- if (type.isa<UnrankedTensorType>())
- return UnrankedTensorType::get(i1Type);
- if (auto vectorType = type.dyn_cast<VectorType>())
- return VectorType::get(vectorType.getShape(), i1Type,
- vectorType.getNumScalableDims());
- return i1Type;
-}
-
//===----------------------------------------------------------------------===//
// CmpIOp
//===----------------------------------------------------------------------===//
@@ -1671,16 +1649,18 @@ OpFoldResult arith::CmpIOp::fold(FoldAdaptor adaptor) {
llvm_unreachable("unknown cmpi predicate kind");
}
- auto lhs = adaptor.getLhs().dyn_cast_or_null<IntegerAttr>();
- if (!lhs)
- return {};
-
// We are moving constants to the right side; So if lhs is constant rhs is
// guaranteed to be a constant.
- auto rhs = adaptor.getRhs().cast<IntegerAttr>();
+ if (auto lhs = adaptor.getLhs().dyn_cast_or_null<TypedAttr>()) {
+ return constFoldBinaryOp<IntegerAttr>(
+ adaptor.getOperands(), getI1SameShape(lhs.getType()),
+ [pred = getPredicate()](const APInt &lhs, const APInt &rhs) {
+ return APInt(1,
+ static_cast<int64_t>(applyCmpPredicate(pred, lhs, rhs)));
+ });
+ }
- auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
- return BoolAttr::get(getContext(), val);
+ return {};
}
void arith::CmpIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index 0ee1b0ba73333..355e7a8753ff6 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -322,6 +322,46 @@ func.func @cmpIExtUIEQ(%arg0: i8, %arg1: i8) -> i1 {
return %res : i1
}
+// CHECK-LABEL: @cmpIFoldEQ
+// CHECK: %[[res:.+]] = arith.constant dense<[true, true, false]> : vector<3xi1>
+// CHECK: return %[[res]]
+func.func @cmpIFoldEQ() -> vector<3xi1> {
+ %lhs = arith.constant dense<[1, 2, 3]> : vector<3xi32>
+ %rhs = arith.constant dense<[1, 2, 4]> : vector<3xi32>
+ %res = arith.cmpi eq, %lhs, %rhs : vector<3xi32>
+ return %res : vector<3xi1>
+}
+
+// CHECK-LABEL: @cmpIFoldNE
+// CHECK: %[[res:.+]] = arith.constant dense<[false, false, true]> : vector<3xi1>
+// CHECK: return %[[res]]
+func.func @cmpIFoldNE() -> vector<3xi1> {
+ %lhs = arith.constant dense<[1, 2, 3]> : vector<3xi32>
+ %rhs = arith.constant dense<[1, 2, 4]> : vector<3xi32>
+ %res = arith.cmpi ne, %lhs, %rhs : vector<3xi32>
+ return %res : vector<3xi1>
+}
+
+// CHECK-LABEL: @cmpIFoldSGE
+// CHECK: %[[res:.+]] = arith.constant dense<[true, true, false]> : vector<3xi1>
+// CHECK: return %[[res]]
+func.func @cmpIFoldSGE() -> vector<3xi1> {
+ %lhs = arith.constant dense<2> : vector<3xi32>
+ %rhs = arith.constant dense<[1, 2, 4]> : vector<3xi32>
+ %res = arith.cmpi sge, %lhs, %rhs : vector<3xi32>
+ return %res : vector<3xi1>
+}
+
+// CHECK-LABEL: @cmpIFoldULT
+// CHECK: %[[res:.+]] = arith.constant dense<false> : vector<3xi1>
+// CHECK: return %[[res]]
+func.func @cmpIFoldULT() -> vector<3xi1> {
+ %lhs = arith.constant dense<2> : vector<3xi32>
+ %rhs = arith.constant dense<1> : vector<3xi32>
+ %res = arith.cmpi ult, %lhs, %rhs : vector<3xi32>
+ return %res : vector<3xi1>
+}
+
// -----
// CHECK-LABEL: @andOfExtSI
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index 220adc57fbcd8..7ee88b64ea8ac 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -1070,13 +1070,13 @@ func.func @invariant_loop_args_in_same_order(%f_arg0: tensor<i32>) -> (tensor<i3
// CHECK: return %[[WHILE]]#0, %[[FUNC_ARG0]], %[[WHILE]]#1, %[[WHILE]]#2, %[[ZERO]]
// CHECK-LABEL: @while_loop_invariant_argument_
diff erent_order
-func.func @while_loop_invariant_argument_
diff erent_order() -> (tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>) {
+func.func @while_loop_invariant_argument_
diff erent_order(%arg : tensor<i32>) -> (tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>) {
%cst_0 = arith.constant dense<0> : tensor<i32>
%cst_1 = arith.constant dense<1> : tensor<i32>
%cst_42 = arith.constant dense<42> : tensor<i32>
%0:6 = scf.while (%arg0 = %cst_0, %arg1 = %cst_1, %arg2 = %cst_1, %arg3 = %cst_1, %arg4 = %cst_0) : (tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>) -> (tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>) {
- %1 = arith.cmpi slt, %arg0, %cst_42 : tensor<i32>
+ %1 = arith.cmpi slt, %arg0, %arg : tensor<i32>
%2 = tensor.extract %1[] : tensor<i1>
scf.condition(%2) %arg1, %arg0, %arg2, %arg0, %arg3, %arg4 : tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>
} do {
@@ -1087,11 +1087,11 @@ func.func @while_loop_invariant_argument_
diff erent_order() -> (tensor<i32>, tens
}
return %0#0, %0#1, %0#2, %0#3, %0#4, %0#5 : tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>
}
+// CHECK-SAME: (%[[ARG:.+]]: tensor<i32>)
// CHECK: %[[ZERO:.*]] = arith.constant dense<0>
// CHECK: %[[ONE:.*]] = arith.constant dense<1>
-// CHECK: %[[CST42:.*]] = arith.constant dense<42>
// CHECK: %[[WHILE:.*]]:2 = scf.while (%[[ARG1:.*]] = %[[ONE]], %[[ARG4:.*]] = %[[ZERO]])
-// CHECK: arith.cmpi slt, %[[ZERO]], %[[CST42]]
+// CHECK: arith.cmpi sgt, %[[ARG]], %[[ZERO]]
// CHECK: tensor.extract %{{.*}}[]
// CHECK: scf.condition(%{{.*}}) %[[ARG1]], %[[ARG4]]
// CHECK: } do {
More information about the Mlir-commits
mailing list