[Mlir-commits] [mlir] e8c2877 - [mlir] Reuse the code between `getMixed*s()` funcs in ViewLikeInterface.cpp.
Alexander Belyaev
llvmlistbot at llvm.org
Sun Jul 31 12:10:32 PDT 2022
Author: Alexander Belyaev
Date: 2022-07-31T21:09:30+02:00
New Revision: e8c2877565149587fd66fbee591b7d44eecd667d
URL: https://github.com/llvm/llvm-project/commit/e8c2877565149587fd66fbee591b7d44eecd667d
DIFF: https://github.com/llvm/llvm-project/commit/e8c2877565149587fd66fbee591b7d44eecd667d.diff
LOG: [mlir] Reuse the code between `getMixed*s()` funcs in ViewLikeInterface.cpp.
Differential Revision: https://reviews.llvm.org/D130706
Added:
Modified:
mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
mlir/include/mlir/Interfaces/ViewLikeInterface.td
mlir/lib/Dialect/Utils/StaticValueUtils.cpp
mlir/lib/Interfaces/ViewLikeInterface.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
index bde7f476bde50..fc152f31cf0b3 100644
--- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
@@ -39,6 +39,19 @@ void dispatchIndexOpFoldResults(ArrayRef<OpFoldResult> ofrs,
SmallVectorImpl<int64_t> &staticVec,
int64_t sentinel);
+/// Return a vector of OpFoldResults given the special value
+/// that indicates whether of the value is dynamic or not.
+SmallVector<OpFoldResult, 4> getMixedValues(ArrayAttr staticValues,
+ ValueRange dynamicValues,
+ int64_t dynamicValueIndicator);
+
+/// Decompose a vector of mixed static or dynamic values into the corresponding
+/// pair of arrays. This is the inverse function of `getMixedValues`.
+std::pair<ArrayAttr, SmallVector<Value>>
+decomposeMixedValues(Builder &b,
+ const SmallVectorImpl<OpFoldResult> &mixedValues,
+ const int64_t dynamicValueIndicator);
+
/// Extract int64_t values from the assumed ArrayAttr of IntegerAttr.
SmallVector<int64_t, 4> extractFromI64ArrayAttr(Attribute attr);
diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.td b/mlir/include/mlir/Interfaces/ViewLikeInterface.td
index 8be74c928c488..ed4ba5505c581 100644
--- a/mlir/include/mlir/Interfaces/ViewLikeInterface.td
+++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.td
@@ -237,7 +237,30 @@ def OffsetSizeAndStrideOpInterface : OpInterface<"OffsetSizeAndStrideOpInterface
return ::mlir::ShapedType::isDynamicStrideOrOffset(v.getSExtValue());
}]
>,
-
+ StaticInterfaceMethod<
+ /*desc=*/"Return constant that indicates the offset is dynamic",
+ /*retTy=*/"int64_t",
+ /*methodName=*/"getDynamicOffsetIndicator",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImpl=*/[{ return ::mlir::ShapedType::kDynamicStrideOrOffset; }]
+ >,
+ StaticInterfaceMethod<
+ /*desc=*/"Return constant that indicates the size is dynamic",
+ /*retTy=*/"int64_t",
+ /*methodName=*/"getDynamicSizeIndicator",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImpl=*/[{ return ::mlir::ShapedType::kDynamicSize; }]
+ >,
+ StaticInterfaceMethod<
+ /*desc=*/"Return constant that indicates the stride is dynamic",
+ /*retTy=*/"int64_t",
+ /*methodName=*/"getDynamicStrideIndicator",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImpl=*/[{ return ::mlir::ShapedType::kDynamicStrideOrOffset; }]
+ >,
InterfaceMethod<
/*desc=*/[{
Assert the offset `idx` is a static constant and return its value.
diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
index 80e93553858fd..7771245e3fe2e 100644
--- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
+++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/IR/Builders.h"
#include "mlir/IR/Matchers.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/APSInt.h"
@@ -109,4 +110,40 @@ bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2) {
auto v1 = ofr1.dyn_cast<Value>(), v2 = ofr2.dyn_cast<Value>();
return v1 && v1 == v2;
}
+
+/// Return a vector of OpFoldResults given the special value
+/// that indicates whether of the value is dynamic or not.
+SmallVector<OpFoldResult, 4> getMixedValues(ArrayAttr staticValues,
+ ValueRange dynamicValues,
+ int64_t dynamicValueIndicator) {
+ SmallVector<OpFoldResult, 4> res;
+ res.reserve(staticValues.size());
+ unsigned numDynamic = 0;
+ unsigned count = static_cast<unsigned>(staticValues.size());
+ for (unsigned idx = 0; idx < count; ++idx) {
+ APInt value = staticValues[idx].cast<IntegerAttr>().getValue();
+ res.push_back(value.getSExtValue() == dynamicValueIndicator
+ ? OpFoldResult{dynamicValues[numDynamic++]}
+ : OpFoldResult{staticValues[idx]});
+ }
+ return res;
+}
+
+std::pair<ArrayAttr, SmallVector<Value>>
+decomposeMixedValues(Builder &b,
+ const SmallVectorImpl<OpFoldResult> &mixedValues,
+ const int64_t dynamicValueIndicator) {
+ SmallVector<int64_t> staticValues;
+ SmallVector<Value> dynamicValues;
+ for (const auto &it : mixedValues) {
+ if (it.is<Attribute>()) {
+ staticValues.push_back(it.get<Attribute>().cast<IntegerAttr>().getInt());
+ } else {
+ staticValues.push_back(dynamicValueIndicator);
+ dynamicValues.push_back(it.get<Value>());
+ }
+ }
+ return {b.getI64ArrayAttr(staticValues), dynamicValues};
+}
+
} // namespace mlir
diff --git a/mlir/lib/Interfaces/ViewLikeInterface.cpp b/mlir/lib/Interfaces/ViewLikeInterface.cpp
index dfeda72b38119..49331b5164682 100644
--- a/mlir/lib/Interfaces/ViewLikeInterface.cpp
+++ b/mlir/lib/Interfaces/ViewLikeInterface.cpp
@@ -182,72 +182,29 @@ bool mlir::detail::sameOffsetsSizesAndStrides(
SmallVector<OpFoldResult, 4>
mlir::getMixedOffsets(OffsetSizeAndStrideOpInterface op,
ArrayAttr staticOffsets, ValueRange offsets) {
- SmallVector<OpFoldResult, 4> res;
- unsigned numDynamic = 0;
- unsigned count = static_cast<unsigned>(staticOffsets.size());
- for (unsigned idx = 0; idx < count; ++idx) {
- if (op.isDynamicOffset(idx))
- res.push_back(offsets[numDynamic++]);
- else
- res.push_back(staticOffsets[idx]);
- }
- return res;
+ return getMixedValues(staticOffsets, offsets, op.getDynamicOffsetIndicator());
}
SmallVector<OpFoldResult, 4>
mlir::getMixedSizes(OffsetSizeAndStrideOpInterface op, ArrayAttr staticSizes,
ValueRange sizes) {
- SmallVector<OpFoldResult, 4> res;
- unsigned numDynamic = 0;
- unsigned count = static_cast<unsigned>(staticSizes.size());
- for (unsigned idx = 0; idx < count; ++idx) {
- if (op.isDynamicSize(idx))
- res.push_back(sizes[numDynamic++]);
- else
- res.push_back(staticSizes[idx]);
- }
- return res;
+ return getMixedValues(staticSizes, sizes, op.getDynamicSizeIndicator());
}
SmallVector<OpFoldResult, 4>
mlir::getMixedStrides(OffsetSizeAndStrideOpInterface op,
ArrayAttr staticStrides, ValueRange strides) {
- SmallVector<OpFoldResult, 4> res;
- unsigned numDynamic = 0;
- unsigned count = static_cast<unsigned>(staticStrides.size());
- for (unsigned idx = 0; idx < count; ++idx) {
- if (op.isDynamicStride(idx))
- res.push_back(strides[numDynamic++]);
- else
- res.push_back(staticStrides[idx]);
- }
- return res;
-}
-
-static std::pair<ArrayAttr, SmallVector<Value>>
-decomposeMixedImpl(OpBuilder &b,
- const SmallVectorImpl<OpFoldResult> &mixedValues,
- const int64_t dynamicValuePlaceholder) {
- SmallVector<int64_t> staticValues;
- SmallVector<Value> dynamicValues;
- for (const auto &it : mixedValues) {
- if (it.is<Attribute>()) {
- staticValues.push_back(it.get<Attribute>().cast<IntegerAttr>().getInt());
- } else {
- staticValues.push_back(ShapedType::kDynamicStrideOrOffset);
- dynamicValues.push_back(it.get<Value>());
- }
- }
- return {b.getI64ArrayAttr(staticValues), dynamicValues};
+ return getMixedValues(staticStrides, strides, op.getDynamicStrideIndicator());
}
std::pair<ArrayAttr, SmallVector<Value>> mlir::decomposeMixedStridesOrOffsets(
OpBuilder &b, const SmallVectorImpl<OpFoldResult> &mixedValues) {
- return decomposeMixedImpl(b, mixedValues, ShapedType::kDynamicStrideOrOffset);
+ return decomposeMixedValues(b, mixedValues,
+ ShapedType::kDynamicStrideOrOffset);
}
std::pair<ArrayAttr, SmallVector<Value>>
mlir::decomposeMixedSizes(OpBuilder &b,
const SmallVectorImpl<OpFoldResult> &mixedValues) {
- return decomposeMixedImpl(b, mixedValues, ShapedType::kDynamicSize);
+ return decomposeMixedValues(b, mixedValues, ShapedType::kDynamicSize);
}
More information about the Mlir-commits
mailing list