[Mlir-commits] [mlir] 42263fb - [mlir][MemRef] Make reinterpret_cast(extract_strided_metadata) more robust
Quentin Colombet
llvmlistbot at llvm.org
Mon Nov 14 10:03:21 PST 2022
Author: Quentin Colombet
Date: 2022-11-14T18:02:15Z
New Revision: 42263fb52d627e19a5d00fe57dcd2fe607732ff9
URL: https://github.com/llvm/llvm-project/commit/42263fb52d627e19a5d00fe57dcd2fe607732ff9
DIFF: https://github.com/llvm/llvm-project/commit/42263fb52d627e19a5d00fe57dcd2fe607732ff9.diff
LOG: [mlir][MemRef] Make reinterpret_cast(extract_strided_metadata) more robust
Prior to this patch the canonicalization pattern that turns
`reinterpret_cast(extract_strided_metadata)` into cast was only applied
when all the input operands of the `reinterpret_cast` are exactly all the
output results of the `extract_strided_metadata`.
This missed simplification opportunities when the values would have hold
the same constant values, but yet, come from different actual values.
E.g., prior to this patch, a pattern of the form:
```
%base, %offset = extract_strided_metadata %source : memref<i16>
reinterpret_cast %base to offset:[0]
```
Wouldn't have been simplified into a simple cast, because %offset is not
directly the same value object as 0.
This patch teaches this pattern how to check if the constant values
match what the results of the `extract_strided_metadata` operation would
have hold.
Differential Revision: https://reviews.llvm.org/D135736
Added:
Modified:
mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
mlir/test/Dialect/MemRef/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index cfc9d2a773087..b364f68ca2541 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -913,6 +913,23 @@ def MemRef_ExtractStridedMetadataOp : MemRef_Op<"extract_strided_metadata", [
$source `:` type($source) `->` type(results) attr-dict
}];
+ let extraClassDeclaration = [{
+ /// Return a vector of all the static or dynamic sizes of the op, while
+ /// statically inferring the sizes of the dynamic sizes, when possible.
+ /// This is best effort.
+ /// E.g., if `getSizes` returns `[%dyn_size0, %dyn_size1]`, but the
+ /// source memref type is `memref<2x8xi16>`, this method will
+ /// return `[2, 8]`.
+ /// Similarly if the resulting memref type is `memref<2x?xi16>`, but
+ /// `%dyn_size1` can statically be pinned to a constant value, this
+ /// constant value is returned instead of `%dyn_size`.
+ SmallVector<OpFoldResult> getConstifiedMixedSizes();
+ /// Similar to `getConstifiedMixedSizes` but for strides.
+ SmallVector<OpFoldResult> getConstifiedMixedStrides();
+ /// Similar to `getConstifiedMixedSizes` but for the offset.
+ OpFoldResult getConstifiedMixedOffset();
+ }];
+
let hasFolder = 1;
}
@@ -1301,6 +1318,20 @@ def MemRef_ReinterpretCastOp
/// Return the number of leading operands before the `offsets`, `sizes` and
/// and `strides` operands.
static unsigned getOffsetSizeAndStrideStartOperandIndex() { return 1; }
+
+ /// Return a vector of all the static or dynamic sizes of the op, while
+ /// statically inferring the sizes of the dynamic sizes, when possible.
+ /// This is best effort.
+ /// E.g., if `getMixedSizes` returns `[2, %dyn_size]`, but the resulting
+ /// memref type is `memref<2x8xi16>`, this method will return `[2, 8]`.
+ /// Similarly if the resulting memref type is `memref<2x?xi16>`, but
+ /// `%dyn_size` can statically be pinned to a constant value, this
+ /// constant value is returned instead of `%dyn_size`.
+ SmallVector<OpFoldResult> getConstifiedMixedSizes();
+ /// Similar to `getConstifiedMixedSizes` but for strides.
+ SmallVector<OpFoldResult> getConstifiedMixedStrides();
+ /// Similar to `getConstifiedMixedSizes` but for the offset.
+ OpFoldResult getConstifiedMixedOffset();
}];
let hasFolder = 1;
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index fb317374b4246..11aabd59af3bb 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -109,6 +109,105 @@ Type mlir::memref::getTensorTypeFromMemRefType(Type type) {
return NoneType::get(type.getContext());
}
+//===----------------------------------------------------------------------===//
+// Utility functions for propagating static information
+//===----------------------------------------------------------------------===//
+
+/// Helper function that infers the constant values from a list of \p values,
+/// a \p memRefTy, and another helper function \p getAttributes.
+/// The inferred constant values replace the related `OpFoldResult` in
+/// \p values.
+///
+/// \note This function shouldn't be used directly, instead, use the
+/// `getConstifiedMixedXXX` methods from the related operations.
+///
+/// \p getAttributes retuns a list of potentially constant values, as determined
+/// by \p isDynamic, from the given \p memRefTy. The returned list must have as
+/// many elements as \p values or be empty.
+///
+/// E.g., consider the following example:
+/// ```
+/// memref.reinterpret_cast %base to <...> strides: [2, %dyn_stride] :
+/// memref<f32> to memref<?x?xf32, strided<[?, 1], offset: ?>>
+/// ```
+/// `ReinterpretCastOp::getMixedStrides()` will return `[2, %dyn_stride]`.
+/// Now using this helper function with:
+/// - `values == [2, %dyn_stride]`,
+/// - `memRefTy == memref<?x?xf32, strided<[?, 1], offset: ?>>`
+/// - `getAttributes == getConstantStrides` (i.e., a wrapper around
+/// `getStridesAndOffset`), and
+/// - `isDynamic == isDynamicStrideOrOffset`
+/// Will yield: `values == [2, 1]`
+static void constifyIndexValues(
+ SmallVectorImpl<OpFoldResult> &values, MemRefType memRefTy,
+ MLIRContext *ctxt,
+ llvm::function_ref<SmallVector<int64_t>(MemRefType)> getAttributes,
+ llvm::function_ref<bool(int64_t)> isDynamic) {
+ SmallVector<int64_t> constValues = getAttributes(memRefTy);
+ Builder builder(ctxt);
+ for (const auto &it : llvm::enumerate(constValues)) {
+ int64_t constValue = it.value();
+ if (!isDynamic(constValue))
+ values[it.index()] = builder.getIndexAttr(constValue);
+ }
+ for (OpFoldResult &ofr : values) {
+ if (ofr.is<Attribute>()) {
+ // FIXME: We shouldn't need to do that, but right now, the static indices
+ // are created with the wrong type: `i64` instead of `index`.
+ // As a result, if we were to keep the attribute as is, we may fail to see
+ // that two attributes are equal because one would have the i64 type and
+ // the other the index type.
+ // The alternative would be to create constant indices with getI64Attr in
+ // this and the previous loop, but it doesn't logically make sense (we are
+ // dealing with indices here) and would only strenghten the inconsistency
+ // around how static indices are created (some places use getI64Attr,
+ // others use getIndexAttr).
+ // The workaround here is to stick to the IndexAttr type for all the
+ // values, hence we recreate the attribute even when it is already static
+ // to make sure the type is consistent.
+ ofr = builder.getIndexAttr(
+ ofr.get<Attribute>().cast<IntegerAttr>().getInt());
+ continue;
+ }
+ Optional<int64_t> maybeConstant = getConstantIntValue(ofr.get<Value>());
+ if (maybeConstant)
+ ofr = builder.getIndexAttr(*maybeConstant);
+ }
+}
+
+/// Wrapper around `getShape` that conforms to the function signature
+/// expected for `getAttributes` in `constifyIndexValues`.
+static SmallVector<int64_t> getConstantSizes(MemRefType memRefTy) {
+ ArrayRef<int64_t> sizes = memRefTy.getShape();
+ return SmallVector<int64_t>(sizes.begin(), sizes.end());
+}
+
+/// Wrapper around `getStridesAndOffset` that returns only the offset and
+/// conforms to the function signature expected for `getAttributes` in
+/// `constifyIndexValues`.
+static SmallVector<int64_t> getConstantOffset(MemRefType memrefType) {
+ SmallVector<int64_t> strides;
+ int64_t offset;
+ LogicalResult hasStaticInformation =
+ getStridesAndOffset(memrefType, strides, offset);
+ if (failed(hasStaticInformation))
+ return SmallVector<int64_t>();
+ return SmallVector<int64_t>(1, offset);
+}
+
+/// Wrapper around `getStridesAndOffset` that returns only the strides and
+/// conforms to the function signature expected for `getAttributes` in
+/// `constifyIndexValues`.
+static SmallVector<int64_t> getConstantStrides(MemRefType memrefType) {
+ SmallVector<int64_t> strides;
+ int64_t offset;
+ LogicalResult hasStaticInformation =
+ getStridesAndOffset(memrefType, strides, offset);
+ if (failed(hasStaticInformation))
+ return SmallVector<int64_t>();
+ return strides;
+}
+
//===----------------------------------------------------------------------===//
// AllocOp / AllocaOp
//===----------------------------------------------------------------------===//
@@ -1293,18 +1392,22 @@ void ExtractStridedMetadataOp::getAsmResultNames(
template <typename Container>
static bool replaceConstantUsesOf(OpBuilder &rewriter, Location loc,
Container values,
- ArrayRef<int64_t> maybeConstants,
- llvm::function_ref<bool(int64_t)> isDynamic) {
+ ArrayRef<OpFoldResult> maybeConstants) {
assert(values.size() == maybeConstants.size() &&
" expected values and maybeConstants of the same size");
bool atLeastOneReplacement = false;
for (auto [maybeConstant, result] : llvm::zip(maybeConstants, values)) {
// Don't materialize a constant if there are no uses: this would indice
// infinite loops in the driver.
- if (isDynamic(maybeConstant) || result.use_empty())
+ if (result.use_empty() || maybeConstant == getAsOpFoldResult(result))
continue;
- Value constantVal =
- rewriter.create<arith::ConstantIndexOp>(loc, maybeConstant);
+ assert(maybeConstant.template is<Attribute>() &&
+ "The constified value should be either unchanged (i.e., == result) "
+ "or a constant");
+ Value constantVal = rewriter.create<arith::ConstantIndexOp>(
+ loc, maybeConstant.template get<Attribute>()
+ .template cast<IntegerAttr>()
+ .getInt());
for (Operation *op : llvm::make_early_inc_range(result.getUsers())) {
// updateRootInplace: lambda cannot capture structured bindings in C++17
// yet.
@@ -1319,26 +1422,41 @@ LogicalResult
ExtractStridedMetadataOp::fold(ArrayRef<Attribute> cstOperands,
SmallVectorImpl<OpFoldResult> &results) {
OpBuilder builder(*this);
- auto memrefType = getSource().getType().cast<MemRefType>();
- SmallVector<int64_t> strides;
- int64_t offset;
- LogicalResult res = getStridesAndOffset(memrefType, strides, offset);
- (void)res;
- assert(succeeded(res) && "must be a strided memref type");
bool atLeastOneReplacement = replaceConstantUsesOf(
builder, getLoc(), ArrayRef<TypedValue<IndexType>>(getOffset()),
- ArrayRef<int64_t>(offset), ShapedType::isDynamicStrideOrOffset);
- atLeastOneReplacement |=
- replaceConstantUsesOf(builder, getLoc(), getSizes(),
- memrefType.getShape(), ShapedType::isDynamic);
- atLeastOneReplacement |=
- replaceConstantUsesOf(builder, getLoc(), getStrides(), strides,
- ShapedType::isDynamicStrideOrOffset);
+ getConstifiedMixedOffset());
+ atLeastOneReplacement |= replaceConstantUsesOf(builder, getLoc(), getSizes(),
+ getConstifiedMixedSizes());
+ atLeastOneReplacement |= replaceConstantUsesOf(
+ builder, getLoc(), getStrides(), getConstifiedMixedStrides());
return success(atLeastOneReplacement);
}
+SmallVector<OpFoldResult> ExtractStridedMetadataOp::getConstifiedMixedSizes() {
+ SmallVector<OpFoldResult> values = getAsOpFoldResult(getSizes());
+ constifyIndexValues(values, getSource().getType(), getContext(),
+ getConstantSizes, ShapedType::isDynamic);
+ return values;
+}
+
+SmallVector<OpFoldResult>
+ExtractStridedMetadataOp::getConstifiedMixedStrides() {
+ SmallVector<OpFoldResult> values = getAsOpFoldResult(getStrides());
+ constifyIndexValues(values, getSource().getType(), getContext(),
+ getConstantStrides, ShapedType::isDynamicStrideOrOffset);
+ return values;
+}
+
+OpFoldResult ExtractStridedMetadataOp::getConstifiedMixedOffset() {
+ OpFoldResult offsetOfr = getAsOpFoldResult(getOffset());
+ SmallVector<OpFoldResult> values(1, offsetOfr);
+ constifyIndexValues(values, getSource().getType(), getContext(),
+ getConstantOffset, ShapedType::isDynamicStrideOrOffset);
+ return values[0];
+}
+
//===----------------------------------------------------------------------===//
// GenericAtomicRMWOp
//===----------------------------------------------------------------------===//
@@ -1781,8 +1899,67 @@ OpFoldResult ReinterpretCastOp::fold(ArrayRef<Attribute> /*operands*/) {
return nullptr;
}
+SmallVector<OpFoldResult> ReinterpretCastOp::getConstifiedMixedSizes() {
+ SmallVector<OpFoldResult> values = getMixedSizes();
+ constifyIndexValues(values, getType(), getContext(), getConstantSizes,
+ ShapedType::isDynamic);
+ return values;
+}
+
+SmallVector<OpFoldResult> ReinterpretCastOp::getConstifiedMixedStrides() {
+ SmallVector<OpFoldResult> values = getMixedStrides();
+ constifyIndexValues(values, getType(), getContext(), getConstantStrides,
+ ShapedType::isDynamicStrideOrOffset);
+ return values;
+}
+
+OpFoldResult ReinterpretCastOp::getConstifiedMixedOffset() {
+ SmallVector<OpFoldResult> values = getMixedOffsets();
+ assert(values.size() == 1 &&
+ "reinterpret_cast must have one and only one offset");
+ constifyIndexValues(values, getType(), getContext(), getConstantOffset,
+ ShapedType::isDynamicStrideOrOffset);
+ return values[0];
+}
+
namespace {
-/// Replace reinterpret_cast(extract_strided_metadata memref) -> memref.
+/// Replace the sequence:
+/// ```
+/// base, offset, sizes, strides = extract_strided_metadata src
+/// dst = reinterpret_cast base to offset, sizes, strides
+/// ```
+/// With
+///
+/// ```
+/// dst = memref.cast src
+/// ```
+///
+/// Note: The cast operation is only inserted when the type of dst and src
+/// are not the same. E.g., when going from <4xf32> to <?xf32>.
+///
+/// This pattern also matches when the offset, sizes, and strides don't come
+/// directly from the `extract_strided_metadata`'s results but it can be
+/// statically proven that they would hold the same values.
+///
+/// For instance, the following sequence would be replaced:
+/// ```
+/// base, offset, sizes, strides =
+/// extract_strided_metadata memref : memref<3x4xty>
+/// dst = reinterpret_cast base to 0, [3, 4], strides
+/// ```
+/// Because we know (thanks to the type of the input memref) that variable
+/// `offset` and `sizes` will respectively hold 0 and [3, 4].
+///
+/// Similarly, the following sequence would be replaced:
+/// ```
+/// c0 = arith.constant 0
+/// c4 = arith.constant 4
+/// base, offset, sizes, strides =
+/// extract_strided_metadata memref : memref<3x4xty>
+/// dst = reinterpret_cast base to c0, [3, c4], strides
+/// ```
+/// Because we know that `offset`and `c0` will hold 0
+/// and `c4` will hold 4.
struct ReinterpretCastOpExtractStridedMetadataFolder
: public OpRewritePattern<ReinterpretCastOp> {
public:
@@ -1798,24 +1975,39 @@ struct ReinterpretCastOpExtractStridedMetadataFolder
// properties as the extract strided metadata.
// First, check that the strides are the same.
- if (extractStridedMetadata.getStrides().size() != op.getStrides().size())
+ SmallVector<OpFoldResult> extractStridesOfr =
+ extractStridedMetadata.getConstifiedMixedStrides();
+ SmallVector<OpFoldResult> reinterpretStridesOfr =
+ op.getConstifiedMixedStrides();
+ if (extractStridesOfr.size() != reinterpretStridesOfr.size())
return failure();
- for (auto [extractStride, reinterpretStride] :
- llvm::zip(extractStridedMetadata.getStrides(), op.getStrides()))
- if (extractStride != reinterpretStride)
+
+ unsigned rank = op.getType().getRank();
+ for (unsigned i = 0; i < rank; ++i) {
+ if (extractStridesOfr[i] != reinterpretStridesOfr[i])
return failure();
+ }
// Second, check the sizes.
- if (extractStridedMetadata.getSizes().size() != op.getSizes().size())
- return failure();
- for (auto [extractSize, reinterpretSize] :
- llvm::zip(extractStridedMetadata.getSizes(), op.getSizes()))
- if (extractSize != reinterpretSize)
+ assert(extractStridedMetadata.getSizes().size() ==
+ op.getMixedSizes().size() &&
+ "Strides and sizes rank must match");
+ SmallVector<OpFoldResult> extractSizesOfr =
+ extractStridedMetadata.getConstifiedMixedSizes();
+ SmallVector<OpFoldResult> reinterpretSizesOfr =
+ op.getConstifiedMixedSizes();
+ for (unsigned i = 0; i < rank; ++i) {
+ if (extractSizesOfr[i] != reinterpretSizesOfr[i])
return failure();
-
+ }
// Finally, check the offset.
- if (op.getOffsets().size() != 1 &&
- extractStridedMetadata.getOffset() != *op.getOffsets().begin())
+ assert(op.getMixedOffsets().size() == 1 &&
+ "reinterpret_cast with more than one offset should have been "
+ "rejected by the verifier");
+ OpFoldResult extractOffsetOfr =
+ extractStridedMetadata.getConstifiedMixedOffset();
+ OpFoldResult reinterpretOffsetOfr = op.getConstifiedMixedOffset();
+ if (extractOffsetOfr != reinterpretOffsetOfr)
return failure();
// At this point, we know that the back and forth between extract strided
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 3fd4ae1c81c96..d34d0d95d0392 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -758,13 +758,9 @@ func.func @reinterpret_of_subview(%arg : memref<?xi8>, %size1: index, %size2: in
// Check that a reinterpret cast of an equivalent extract strided metadata
// is canonicalized to a plain cast when the destination type is
diff erent
// than the type of the original memref.
-// This pattern is currently defeated by the constant folding that happens
-// with extract_strided_metadata.
// CHECK-LABEL: func @reinterpret_of_extract_strided_metadata_w_type_mistach
// CHECK-SAME: (%[[ARG:.*]]: memref<8x2xf32>)
-// CHECK-DAG: %[[C0:.*]] = arith.constant 0
-// CHECK-DAG: %[[BASE:.*]], %{{.*}}, %{{.*}}:2, %{{.*}}:2 = memref.extract_strided_metadata %[[ARG]]
-// CHECK: %[[CAST:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[C0]]]
+// CHECK: %[[CAST:.*]] = memref.cast %[[ARG]] : memref<8x2xf32> to memref<?x?xf32,
// CHECK: return %[[CAST]]
func.func @reinterpret_of_extract_strided_metadata_w_type_mistach(%arg0 : memref<8x2xf32>) -> memref<?x?xf32, strided<[?, ?], offset: ?>> {
%base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<8x2xf32> -> memref<f32>, index, index, index, index, index
@@ -774,6 +770,23 @@ func.func @reinterpret_of_extract_strided_metadata_w_type_mistach(%arg0 : memref
// -----
+// Similar to reinterpret_of_extract_strided_metadata_w_type_mistach except that
+// we check that the match happen when the static information has been folded.
+// E.g., in this case, we know that size of dim 0 is 8 and size of dim 1 is 2.
+// So even if we don't use the values sizes#0, sizes#1, as long as they have the
+// same constant value, the match is valid.
+// CHECK-LABEL: func @reinterpret_of_extract_strided_metadata_w_constants
+// CHECK-SAME: (%[[ARG:.*]]: memref<8x2xf32>)
+// CHECK: %[[CAST:.*]] = memref.cast %[[ARG]] : memref<8x2xf32> to memref<?x?xf32,
+// CHECK: return %[[CAST]]
+func.func @reinterpret_of_extract_strided_metadata_w_constants(%arg0 : memref<8x2xf32>) -> memref<?x?xf32, strided<[?, ?], offset: ?>> {
+ %base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<8x2xf32> -> memref<f32>, index, index, index, index, index
+ %c8 = arith.constant 8: index
+ %m2 = memref.reinterpret_cast %base to offset: [0], sizes: [%c8, 2], strides: [2, %strides#1] : memref<f32> to memref<?x?xf32, strided<[?, ?], offset: ?>>
+ return %m2 : memref<?x?xf32, strided<[?, ?], offset: ?>>
+}
+// -----
+
// Check that a reinterpret cast of an equivalent extract strided metadata
// is completely removed when the original memref has the same type.
// CHECK-LABEL: func @reinterpret_of_extract_strided_metadata_same_type
More information about the Mlir-commits
mailing list