[Mlir-commits] [mlir] 6e7bbdd - [mlir] Add offset/stride helper functions to OffsetSizeAndStrideOpInterface
Matthias Springer
llvmlistbot at llvm.org
Mon Jun 7 04:11:58 PDT 2021
Author: Matthias Springer
Date: 2021-06-07T20:11:41+09:00
New Revision: 6e7bbdd6e7f7649bccc4f981520ed916e21d7058
URL: https://github.com/llvm/llvm-project/commit/6e7bbdd6e7f7649bccc4f981520ed916e21d7058
DIFF: https://github.com/llvm/llvm-project/commit/6e7bbdd6e7f7649bccc4f981520ed916e21d7058.diff
LOG: [mlir] Add offset/stride helper functions to OffsetSizeAndStrideOpInterface
* Add hasUnitStride and hasZeroOffset to OffsetSizeAndStrideOpInterface. These functions are useful for various patterns. E.g., some vectorization patterns apply only for tensor ops with zero offsets and/or unit stride.
* Add getConstantIntValue and isEqualConstantInt helper functions, which are useful for implementing the two above functions, as well as various patterns.
Differential Revision: https://reviews.llvm.org/D103763
Added:
Modified:
mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
mlir/include/mlir/Interfaces/ViewLikeInterface.h
mlir/include/mlir/Interfaces/ViewLikeInterface.td
mlir/lib/Dialect/StandardOps/IR/Ops.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
index 65cac17098c86..ee1cc67dee457 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
@@ -122,6 +122,15 @@ bool applyCmpPredicate(CmpIPredicate predicate, const APInt &lhs,
bool applyCmpPredicate(CmpFPredicate predicate, const APFloat &lhs,
const APFloat &rhs);
+/// If ofr is a constant integer, i.e., an IntegerAttr or a ConstantOp with an
+/// IntegerAttr, return the integer.
+llvm::Optional<int64_t> getConstantIntValue(OpFoldResult ofr);
+
+/// Return true if ofr and value are the same integer.
+/// Ignore integer bitwidth and type mismatch that come from the fact there is
+/// no IndexAttr and that IndexType has no bitwidth.
+bool isEqualConstantInt(OpFoldResult ofr, int64_t value);
+
/// Return true if ofr1 and ofr2 are the same integer constant attribute values
/// or the same SSA value.
/// Ignore integer bitwitdh and type mismatch that come from the fact there is
diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.h b/mlir/include/mlir/Interfaces/ViewLikeInterface.h
index 0094fffeea966..8d58570148910 100644
--- a/mlir/include/mlir/Interfaces/ViewLikeInterface.h
+++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.h
@@ -30,6 +30,8 @@ struct Range {
class OffsetSizeAndStrideOpInterface;
+bool isEqualConstantInt(OpFoldResult ofr, int64_t value);
+
namespace detail {
LogicalResult verifyOffsetSizeAndStrideOp(OffsetSizeAndStrideOpInterface op);
diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.td b/mlir/include/mlir/Interfaces/ViewLikeInterface.td
index e26a02f61966a..62f24f2b97362 100644
--- a/mlir/include/mlir/Interfaces/ViewLikeInterface.td
+++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.td
@@ -436,6 +436,30 @@ def OffsetSizeAndStrideOpInterface : OpInterface<"OffsetSizeAndStrideOpInterface
$_op.getOperation()), other, cmp);
}]
>,
+ InterfaceMethod<
+ /*desc=*/[{ Return true if all strides are guaranteed to be 1. }],
+ /*retTy=*/"bool",
+ /*methodName=*/"hasUnitStride",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return ::llvm::all_of(getMixedStrides(), [](OpFoldResult ofr) {
+ return ::mlir::isEqualConstantInt(ofr, 1);
+ });
+ }]
+ >,
+ InterfaceMethod<
+ /*desc=*/[{ Return true if all offsets are guaranteed to be 0. }],
+ /*retTy=*/"bool",
+ /*methodName=*/"hasZeroOffset",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return ::llvm::all_of(getMixedOffsets(), [](OpFoldResult ofr) {
+ return ::mlir::isEqualConstantInt(ofr, 0);
+ });
+ }]
+ >,
];
let extraClassDeclaration = [{
diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index f115fd09213e8..a3c2513e28bcc 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -60,24 +60,35 @@ static void dispatchIndexOpFoldResults(ArrayRef<OpFoldResult> ofrs,
dispatchIndexOpFoldResult(ofr, dynamicVec, staticVec, sentinel);
}
+/// If ofr is a constant integer, i.e., an IntegerAttr or a ConstantOp with an
+/// IntegerAttr, return the integer.
+llvm::Optional<int64_t> mlir::getConstantIntValue(OpFoldResult ofr) {
+ Attribute attr = ofr.dyn_cast<Attribute>();
+ // Note: isa+cast-like pattern allows writing the condition below as 1 line.
+ if (!attr && ofr.get<Value>().getDefiningOp<ConstantOp>())
+ attr = ofr.get<Value>().getDefiningOp<ConstantOp>().getValue();
+ if (auto intAttr = attr.dyn_cast_or_null<IntegerAttr>())
+ return intAttr.getValue().getSExtValue();
+ return llvm::None;
+}
+
+/// Return true if ofr and value are the same integer.
+/// Ignore integer bitwidth and type mismatch that come from the fact there is
+/// no IndexAttr and that IndexType has no bitwidth.
+bool mlir::isEqualConstantInt(OpFoldResult ofr, int64_t value) {
+ auto ofrValue = getConstantIntValue(ofr);
+ return ofrValue && *ofrValue == value;
+}
+
/// Return true if ofr1 and ofr2 are the same integer constant attribute values
/// or the same SSA value.
-/// Ignore integer bitwitdh and type mismatch that come from the fact there is
-/// no IndexAttr and that IndexType have no bitwidth.
-bool mlir::isEqualConstantIntOrValue(OpFoldResult op1, OpFoldResult op2) {
- auto getConstantIntValue = [](OpFoldResult ofr) -> llvm::Optional<int64_t> {
- Attribute attr = ofr.dyn_cast<Attribute>();
- // Note: isa+cast-like pattern allows writing the condition below as 1 line.
- if (!attr && ofr.get<Value>().getDefiningOp<ConstantOp>())
- attr = ofr.get<Value>().getDefiningOp<ConstantOp>().getValue();
- if (auto intAttr = attr.dyn_cast_or_null<IntegerAttr>())
- return intAttr.getValue().getSExtValue();
- return llvm::None;
- };
- auto cst1 = getConstantIntValue(op1), cst2 = getConstantIntValue(op2);
+/// Ignore integer bitwidth and type mismatch that come from the fact there is
+/// no IndexAttr and that IndexType has no bitwidth.
+bool mlir::isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2) {
+ auto cst1 = getConstantIntValue(ofr1), cst2 = getConstantIntValue(ofr2);
if (cst1 && cst2 && *cst1 == *cst2)
return true;
- auto v1 = op1.dyn_cast<Value>(), v2 = op2.dyn_cast<Value>();
+ auto v1 = ofr1.dyn_cast<Value>(), v2 = ofr2.dyn_cast<Value>();
return v1 && v2 && v1 == v2;
}
More information about the Mlir-commits
mailing list