[Mlir-commits] [mlir] 42263fb - [mlir][MemRef] Make reinterpret_cast(extract_strided_metadata) more robust

Quentin Colombet llvmlistbot at llvm.org
Mon Nov 14 10:03:21 PST 2022


Author: Quentin Colombet
Date: 2022-11-14T18:02:15Z
New Revision: 42263fb52d627e19a5d00fe57dcd2fe607732ff9

URL: https://github.com/llvm/llvm-project/commit/42263fb52d627e19a5d00fe57dcd2fe607732ff9
DIFF: https://github.com/llvm/llvm-project/commit/42263fb52d627e19a5d00fe57dcd2fe607732ff9.diff

LOG: [mlir][MemRef] Make reinterpret_cast(extract_strided_metadata) more robust

Prior to this patch the canonicalization pattern that turns
`reinterpret_cast(extract_strided_metadata)` into cast was only applied
when all the input operands of the `reinterpret_cast` are exactly all the
output results of the `extract_strided_metadata`.

This missed simplification opportunities when the values would have hold
the same constant values, but yet, come from different actual values.

E.g., prior to this patch, a pattern of the form:
```
%base, %offset = extract_strided_metadata %source : memref<i16>
reinterpret_cast %base to offset:[0]
```
Wouldn't have been simplified into a simple cast, because %offset is not
directly the same value object as 0.

This patch teaches this pattern how to check if the constant values
match what the results of the `extract_strided_metadata` operation would
have hold.

Differential Revision: https://reviews.llvm.org/D135736

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
    mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
    mlir/test/Dialect/MemRef/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index cfc9d2a773087..b364f68ca2541 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -913,6 +913,23 @@ def MemRef_ExtractStridedMetadataOp : MemRef_Op<"extract_strided_metadata", [
     $source `:` type($source) `->` type(results) attr-dict
   }];
 
+  let extraClassDeclaration = [{
+    /// Return a vector of all the static or dynamic sizes of the op, while
+    /// statically inferring the sizes of the dynamic sizes, when possible.
+    /// This is best effort.
+    /// E.g., if `getSizes` returns `[%dyn_size0, %dyn_size1]`, but the
+    /// source memref type is `memref<2x8xi16>`, this method will
+    /// return `[2, 8]`.
+    /// Similarly if the resulting memref type is `memref<2x?xi16>`, but
+    /// `%dyn_size1` can statically be pinned to a constant value, this
+    /// constant value is returned instead of `%dyn_size`.
+    SmallVector<OpFoldResult> getConstifiedMixedSizes();
+    /// Similar to `getConstifiedMixedSizes` but for strides.
+    SmallVector<OpFoldResult> getConstifiedMixedStrides();
+    /// Similar to `getConstifiedMixedSizes` but for the offset.
+    OpFoldResult getConstifiedMixedOffset();
+  }];
+
   let hasFolder = 1;
 }
 
@@ -1301,6 +1318,20 @@ def MemRef_ReinterpretCastOp
     /// Return the number of leading operands before the `offsets`, `sizes` and
     /// and `strides` operands.
     static unsigned getOffsetSizeAndStrideStartOperandIndex() { return 1; }
+
+    /// Return a vector of all the static or dynamic sizes of the op, while
+    /// statically inferring the sizes of the dynamic sizes, when possible.
+    /// This is best effort.
+    /// E.g., if `getMixedSizes` returns `[2, %dyn_size]`, but the resulting
+    /// memref type is `memref<2x8xi16>`, this method will return `[2, 8]`.
+    /// Similarly if the resulting memref type is `memref<2x?xi16>`, but
+    /// `%dyn_size` can statically be pinned to a constant value, this
+    /// constant value is returned instead of `%dyn_size`.
+    SmallVector<OpFoldResult> getConstifiedMixedSizes();
+    /// Similar to `getConstifiedMixedSizes` but for strides.
+    SmallVector<OpFoldResult> getConstifiedMixedStrides();
+    /// Similar to `getConstifiedMixedSizes` but for the offset.
+    OpFoldResult getConstifiedMixedOffset();
   }];
 
   let hasFolder = 1;

diff  --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index fb317374b4246..11aabd59af3bb 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -109,6 +109,105 @@ Type mlir::memref::getTensorTypeFromMemRefType(Type type) {
   return NoneType::get(type.getContext());
 }
 
+//===----------------------------------------------------------------------===//
+// Utility functions for propagating static information
+//===----------------------------------------------------------------------===//
+
+/// Helper function that infers the constant values from a list of \p values,
+/// a \p memRefTy, and another helper function \p getAttributes.
+/// The inferred constant values replace the related `OpFoldResult` in
+/// \p values.
+///
+/// \note This function shouldn't be used directly, instead, use the
+/// `getConstifiedMixedXXX` methods from the related operations.
+///
+/// \p getAttributes retuns a list of potentially constant values, as determined
+/// by \p isDynamic, from the given \p memRefTy. The returned list must have as
+/// many elements as \p values or be empty.
+///
+/// E.g., consider the following example:
+/// ```
+/// memref.reinterpret_cast %base to <...> strides: [2, %dyn_stride] :
+///     memref<f32> to memref<?x?xf32, strided<[?, 1], offset: ?>>
+/// ```
+/// `ReinterpretCastOp::getMixedStrides()` will return `[2, %dyn_stride]`.
+/// Now using this helper function with:
+/// - `values == [2, %dyn_stride]`,
+/// - `memRefTy == memref<?x?xf32, strided<[?, 1], offset: ?>>`
+/// - `getAttributes == getConstantStrides` (i.e., a wrapper around
+/// `getStridesAndOffset`), and
+/// - `isDynamic == isDynamicStrideOrOffset`
+/// Will yield: `values == [2, 1]`
+static void constifyIndexValues(
+    SmallVectorImpl<OpFoldResult> &values, MemRefType memRefTy,
+    MLIRContext *ctxt,
+    llvm::function_ref<SmallVector<int64_t>(MemRefType)> getAttributes,
+    llvm::function_ref<bool(int64_t)> isDynamic) {
+  SmallVector<int64_t> constValues = getAttributes(memRefTy);
+  Builder builder(ctxt);
+  for (const auto &it : llvm::enumerate(constValues)) {
+    int64_t constValue = it.value();
+    if (!isDynamic(constValue))
+      values[it.index()] = builder.getIndexAttr(constValue);
+  }
+  for (OpFoldResult &ofr : values) {
+    if (ofr.is<Attribute>()) {
+      // FIXME: We shouldn't need to do that, but right now, the static indices
+      // are created with the wrong type: `i64` instead of `index`.
+      // As a result, if we were to keep the attribute as is, we may fail to see
+      // that two attributes are equal because one would have the i64 type and
+      // the other the index type.
+      // The alternative would be to create constant indices with getI64Attr in
+      // this and the previous loop, but it doesn't logically make sense (we are
+      // dealing with indices here) and would only strenghten the inconsistency
+      // around how static indices are created (some places use getI64Attr,
+      // others use getIndexAttr).
+      // The workaround here is to stick to the IndexAttr type for all the
+      // values, hence we recreate the attribute even when it is already static
+      // to make sure the type is consistent.
+      ofr = builder.getIndexAttr(
+          ofr.get<Attribute>().cast<IntegerAttr>().getInt());
+      continue;
+    }
+    Optional<int64_t> maybeConstant = getConstantIntValue(ofr.get<Value>());
+    if (maybeConstant)
+      ofr = builder.getIndexAttr(*maybeConstant);
+  }
+}
+
+/// Wrapper around `getShape` that conforms to the function signature
+/// expected for `getAttributes` in `constifyIndexValues`.
+static SmallVector<int64_t> getConstantSizes(MemRefType memRefTy) {
+  ArrayRef<int64_t> sizes = memRefTy.getShape();
+  return SmallVector<int64_t>(sizes.begin(), sizes.end());
+}
+
+/// Wrapper around `getStridesAndOffset` that returns only the offset and
+/// conforms to the function signature expected for `getAttributes` in
+/// `constifyIndexValues`.
+static SmallVector<int64_t> getConstantOffset(MemRefType memrefType) {
+  SmallVector<int64_t> strides;
+  int64_t offset;
+  LogicalResult hasStaticInformation =
+      getStridesAndOffset(memrefType, strides, offset);
+  if (failed(hasStaticInformation))
+    return SmallVector<int64_t>();
+  return SmallVector<int64_t>(1, offset);
+}
+
+/// Wrapper around `getStridesAndOffset` that returns only the strides and
+/// conforms to the function signature expected for `getAttributes` in
+/// `constifyIndexValues`.
+static SmallVector<int64_t> getConstantStrides(MemRefType memrefType) {
+  SmallVector<int64_t> strides;
+  int64_t offset;
+  LogicalResult hasStaticInformation =
+      getStridesAndOffset(memrefType, strides, offset);
+  if (failed(hasStaticInformation))
+    return SmallVector<int64_t>();
+  return strides;
+}
+
 //===----------------------------------------------------------------------===//
 // AllocOp / AllocaOp
 //===----------------------------------------------------------------------===//
@@ -1293,18 +1392,22 @@ void ExtractStridedMetadataOp::getAsmResultNames(
 template <typename Container>
 static bool replaceConstantUsesOf(OpBuilder &rewriter, Location loc,
                                   Container values,
-                                  ArrayRef<int64_t> maybeConstants,
-                                  llvm::function_ref<bool(int64_t)> isDynamic) {
+                                  ArrayRef<OpFoldResult> maybeConstants) {
   assert(values.size() == maybeConstants.size() &&
          " expected values and maybeConstants of the same size");
   bool atLeastOneReplacement = false;
   for (auto [maybeConstant, result] : llvm::zip(maybeConstants, values)) {
     // Don't materialize a constant if there are no uses: this would indice
     // infinite loops in the driver.
-    if (isDynamic(maybeConstant) || result.use_empty())
+    if (result.use_empty() || maybeConstant == getAsOpFoldResult(result))
       continue;
-    Value constantVal =
-        rewriter.create<arith::ConstantIndexOp>(loc, maybeConstant);
+    assert(maybeConstant.template is<Attribute>() &&
+           "The constified value should be either unchanged (i.e., == result) "
+           "or a constant");
+    Value constantVal = rewriter.create<arith::ConstantIndexOp>(
+        loc, maybeConstant.template get<Attribute>()
+                 .template cast<IntegerAttr>()
+                 .getInt());
     for (Operation *op : llvm::make_early_inc_range(result.getUsers())) {
       // updateRootInplace: lambda cannot capture structured bindings in C++17
       // yet.
@@ -1319,26 +1422,41 @@ LogicalResult
 ExtractStridedMetadataOp::fold(ArrayRef<Attribute> cstOperands,
                                SmallVectorImpl<OpFoldResult> &results) {
   OpBuilder builder(*this);
-  auto memrefType = getSource().getType().cast<MemRefType>();
-  SmallVector<int64_t> strides;
-  int64_t offset;
-  LogicalResult res = getStridesAndOffset(memrefType, strides, offset);
-  (void)res;
-  assert(succeeded(res) && "must be a strided memref type");
 
   bool atLeastOneReplacement = replaceConstantUsesOf(
       builder, getLoc(), ArrayRef<TypedValue<IndexType>>(getOffset()),
-      ArrayRef<int64_t>(offset), ShapedType::isDynamicStrideOrOffset);
-  atLeastOneReplacement |=
-      replaceConstantUsesOf(builder, getLoc(), getSizes(),
-                            memrefType.getShape(), ShapedType::isDynamic);
-  atLeastOneReplacement |=
-      replaceConstantUsesOf(builder, getLoc(), getStrides(), strides,
-                            ShapedType::isDynamicStrideOrOffset);
+      getConstifiedMixedOffset());
+  atLeastOneReplacement |= replaceConstantUsesOf(builder, getLoc(), getSizes(),
+                                                 getConstifiedMixedSizes());
+  atLeastOneReplacement |= replaceConstantUsesOf(
+      builder, getLoc(), getStrides(), getConstifiedMixedStrides());
 
   return success(atLeastOneReplacement);
 }
 
+SmallVector<OpFoldResult> ExtractStridedMetadataOp::getConstifiedMixedSizes() {
+  SmallVector<OpFoldResult> values = getAsOpFoldResult(getSizes());
+  constifyIndexValues(values, getSource().getType(), getContext(),
+                      getConstantSizes, ShapedType::isDynamic);
+  return values;
+}
+
+SmallVector<OpFoldResult>
+ExtractStridedMetadataOp::getConstifiedMixedStrides() {
+  SmallVector<OpFoldResult> values = getAsOpFoldResult(getStrides());
+  constifyIndexValues(values, getSource().getType(), getContext(),
+                      getConstantStrides, ShapedType::isDynamicStrideOrOffset);
+  return values;
+}
+
+OpFoldResult ExtractStridedMetadataOp::getConstifiedMixedOffset() {
+  OpFoldResult offsetOfr = getAsOpFoldResult(getOffset());
+  SmallVector<OpFoldResult> values(1, offsetOfr);
+  constifyIndexValues(values, getSource().getType(), getContext(),
+                      getConstantOffset, ShapedType::isDynamicStrideOrOffset);
+  return values[0];
+}
+
 //===----------------------------------------------------------------------===//
 // GenericAtomicRMWOp
 //===----------------------------------------------------------------------===//
@@ -1781,8 +1899,67 @@ OpFoldResult ReinterpretCastOp::fold(ArrayRef<Attribute> /*operands*/) {
   return nullptr;
 }
 
+SmallVector<OpFoldResult> ReinterpretCastOp::getConstifiedMixedSizes() {
+  SmallVector<OpFoldResult> values = getMixedSizes();
+  constifyIndexValues(values, getType(), getContext(), getConstantSizes,
+                      ShapedType::isDynamic);
+  return values;
+}
+
+SmallVector<OpFoldResult> ReinterpretCastOp::getConstifiedMixedStrides() {
+  SmallVector<OpFoldResult> values = getMixedStrides();
+  constifyIndexValues(values, getType(), getContext(), getConstantStrides,
+                      ShapedType::isDynamicStrideOrOffset);
+  return values;
+}
+
+OpFoldResult ReinterpretCastOp::getConstifiedMixedOffset() {
+  SmallVector<OpFoldResult> values = getMixedOffsets();
+  assert(values.size() == 1 &&
+         "reinterpret_cast must have one and only one offset");
+  constifyIndexValues(values, getType(), getContext(), getConstantOffset,
+                      ShapedType::isDynamicStrideOrOffset);
+  return values[0];
+}
+
 namespace {
-/// Replace reinterpret_cast(extract_strided_metadata memref) -> memref.
+/// Replace the sequence:
+/// ```
+/// base, offset, sizes, strides = extract_strided_metadata src
+/// dst = reinterpret_cast base to offset, sizes, strides
+/// ```
+/// With
+///
+/// ```
+/// dst = memref.cast src
+/// ```
+///
+/// Note: The cast operation is only inserted when the type of dst and src
+/// are not the same. E.g., when going from <4xf32> to <?xf32>.
+///
+/// This pattern also matches when the offset, sizes, and strides don't come
+/// directly from the `extract_strided_metadata`'s results but it can be
+/// statically proven that they would hold the same values.
+///
+/// For instance, the following sequence would be replaced:
+/// ```
+/// base, offset, sizes, strides =
+///   extract_strided_metadata memref : memref<3x4xty>
+/// dst = reinterpret_cast base to 0, [3, 4], strides
+/// ```
+/// Because we know (thanks to the type of the input memref) that variable
+/// `offset` and `sizes` will respectively hold 0 and [3, 4].
+///
+/// Similarly, the following sequence would be replaced:
+/// ```
+/// c0 = arith.constant 0
+/// c4 = arith.constant 4
+/// base, offset, sizes, strides =
+///   extract_strided_metadata memref : memref<3x4xty>
+/// dst = reinterpret_cast base to c0, [3, c4], strides
+/// ```
+/// Because we know that `offset`and `c0` will hold 0
+/// and `c4` will hold 4.
 struct ReinterpretCastOpExtractStridedMetadataFolder
     : public OpRewritePattern<ReinterpretCastOp> {
 public:
@@ -1798,24 +1975,39 @@ struct ReinterpretCastOpExtractStridedMetadataFolder
     // properties as the extract strided metadata.
 
     // First, check that the strides are the same.
-    if (extractStridedMetadata.getStrides().size() != op.getStrides().size())
+    SmallVector<OpFoldResult> extractStridesOfr =
+        extractStridedMetadata.getConstifiedMixedStrides();
+    SmallVector<OpFoldResult> reinterpretStridesOfr =
+        op.getConstifiedMixedStrides();
+    if (extractStridesOfr.size() != reinterpretStridesOfr.size())
       return failure();
-    for (auto [extractStride, reinterpretStride] :
-         llvm::zip(extractStridedMetadata.getStrides(), op.getStrides()))
-      if (extractStride != reinterpretStride)
+
+    unsigned rank = op.getType().getRank();
+    for (unsigned i = 0; i < rank; ++i) {
+      if (extractStridesOfr[i] != reinterpretStridesOfr[i])
         return failure();
+    }
 
     // Second, check the sizes.
-    if (extractStridedMetadata.getSizes().size() != op.getSizes().size())
-      return failure();
-    for (auto [extractSize, reinterpretSize] :
-         llvm::zip(extractStridedMetadata.getSizes(), op.getSizes()))
-      if (extractSize != reinterpretSize)
+    assert(extractStridedMetadata.getSizes().size() ==
+               op.getMixedSizes().size() &&
+           "Strides and sizes rank must match");
+    SmallVector<OpFoldResult> extractSizesOfr =
+        extractStridedMetadata.getConstifiedMixedSizes();
+    SmallVector<OpFoldResult> reinterpretSizesOfr =
+        op.getConstifiedMixedSizes();
+    for (unsigned i = 0; i < rank; ++i) {
+      if (extractSizesOfr[i] != reinterpretSizesOfr[i])
         return failure();
-
+    }
     // Finally, check the offset.
-    if (op.getOffsets().size() != 1 &&
-        extractStridedMetadata.getOffset() != *op.getOffsets().begin())
+    assert(op.getMixedOffsets().size() == 1 &&
+           "reinterpret_cast with more than one offset should have been "
+           "rejected by the verifier");
+    OpFoldResult extractOffsetOfr =
+        extractStridedMetadata.getConstifiedMixedOffset();
+    OpFoldResult reinterpretOffsetOfr = op.getConstifiedMixedOffset();
+    if (extractOffsetOfr != reinterpretOffsetOfr)
       return failure();
 
     // At this point, we know that the back and forth between extract strided

diff  --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 3fd4ae1c81c96..d34d0d95d0392 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -758,13 +758,9 @@ func.func @reinterpret_of_subview(%arg : memref<?xi8>, %size1: index, %size2: in
 // Check that a reinterpret cast of an equivalent extract strided metadata
 // is canonicalized to a plain cast when the destination type is 
diff erent
 // than the type of the original memref.
-// This pattern is currently defeated by the constant folding that happens
-// with extract_strided_metadata.
 // CHECK-LABEL: func @reinterpret_of_extract_strided_metadata_w_type_mistach
 //  CHECK-SAME: (%[[ARG:.*]]: memref<8x2xf32>)
-//   CHECK-DAG: %[[C0:.*]] = arith.constant 0
-//   CHECK-DAG: %[[BASE:.*]], %{{.*}}, %{{.*}}:2, %{{.*}}:2 = memref.extract_strided_metadata %[[ARG]]
-//       CHECK: %[[CAST:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[C0]]]
+//       CHECK: %[[CAST:.*]] = memref.cast %[[ARG]] : memref<8x2xf32> to memref<?x?xf32,
 //       CHECK: return %[[CAST]]
 func.func @reinterpret_of_extract_strided_metadata_w_type_mistach(%arg0 : memref<8x2xf32>) -> memref<?x?xf32, strided<[?, ?], offset: ?>> {
   %base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<8x2xf32> -> memref<f32>, index, index, index, index, index
@@ -774,6 +770,23 @@ func.func @reinterpret_of_extract_strided_metadata_w_type_mistach(%arg0 : memref
 
 // -----
 
+// Similar to reinterpret_of_extract_strided_metadata_w_type_mistach except that
+// we check that the match happen when the static information has been folded.
+// E.g., in this case, we know that size of dim 0 is 8 and size of dim 1 is 2.
+// So even if we don't use the values sizes#0, sizes#1, as long as they have the
+// same constant value, the match is valid.
+// CHECK-LABEL: func @reinterpret_of_extract_strided_metadata_w_constants
+//  CHECK-SAME: (%[[ARG:.*]]: memref<8x2xf32>)
+//       CHECK: %[[CAST:.*]] = memref.cast %[[ARG]] : memref<8x2xf32> to memref<?x?xf32,
+//       CHECK: return %[[CAST]]
+func.func @reinterpret_of_extract_strided_metadata_w_constants(%arg0 : memref<8x2xf32>) -> memref<?x?xf32, strided<[?, ?], offset: ?>> {
+  %base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<8x2xf32> -> memref<f32>, index, index, index, index, index
+  %c8 = arith.constant 8: index
+  %m2 = memref.reinterpret_cast %base to offset: [0], sizes: [%c8, 2], strides: [2, %strides#1] : memref<f32> to memref<?x?xf32, strided<[?, ?], offset: ?>>
+  return %m2 : memref<?x?xf32, strided<[?, ?], offset: ?>>
+}
+// -----
+
 // Check that a reinterpret cast of an equivalent extract strided metadata
 // is completely removed when the original memref has the same type.
 // CHECK-LABEL: func @reinterpret_of_extract_strided_metadata_same_type


        


More information about the Mlir-commits mailing list