[Mlir-commits] [mlir] [llvm] [mlir][Interfaces] LISH: Add helpers for hyperrectangular subsets (PR #70628)

Matthias Springer llvmlistbot at llvm.org
Tue Oct 31 19:22:29 PDT 2023


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/70628

>From 7566fe178f4f53e4fb019b63be861984d830e210 Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Wed, 1 Nov 2023 11:17:44 +0900
Subject: [PATCH] [mlir][Interfaces] `SubsetOpInterface`: Add helpers for
 hyperrectangular subsets

---
 .../Bufferization/IR/BufferizationOps.td      |   3 +-
 mlir/include/mlir/IR/OpDefinition.h           |   5 +
 .../mlir/Interfaces/SubsetOpInterface.h       |  16 ++-
 .../mlir/Interfaces/SubsetOpInterface.td      |  53 +++++++--
 .../mlir/Interfaces/ValueBoundsOpInterface.h  |  53 ++++++++-
 .../SubsetInsertionOpInterfaceImpl.cpp        |  82 +------------
 mlir/lib/Interfaces/CMakeLists.txt            |   2 +
 mlir/lib/Interfaces/SubsetOpInterface.cpp     |  49 ++++++++
 .../lib/Interfaces/ValueBoundsOpInterface.cpp | 109 ++++++++++++++++--
 .../loop-invariant-subset-hoisting.mlir       |   9 +-
 .../llvm-project-overlay/mlir/BUILD.bazel     |   1 +
 11 files changed, 285 insertions(+), 97 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
index e6b6d052df96a8c..9dc6afcaab31c86 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
@@ -220,7 +220,8 @@ def Bufferization_MaterializeInDestinationOp
          AllElementTypesMatch<["source", "dest"]>,
          BufferizableOpInterface, DestinationStyleOpInterface,
          DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
-         DeclareOpInterfaceMethods<SubsetOpInterface>,
+         DeclareOpInterfaceMethods<SubsetOpInterface,
+            ["operatesOnEquivalentSubset", "operatesOnDisjointSubset"]>,
          DeclareOpInterfaceMethods<SubsetInsertionOpInterface,
             ["getSourceOperand", "getValuesNeededToBuildSubsetExtraction",
              "buildSubsetExtraction", "isEquivalentSubset"]>,
diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h
index 8ab37c1d51d6b6c..bd68c27445744e3 100644
--- a/mlir/include/mlir/IR/OpDefinition.h
+++ b/mlir/include/mlir/IR/OpDefinition.h
@@ -268,6 +268,11 @@ class OpFoldResult : public PointerUnion<Attribute, Value> {
 
 public:
   void dump() const { llvm::errs() << *this << "\n"; }
+
+  MLIRContext *getContext() const {
+    return is<Attribute>() ? get<Attribute>().getContext()
+                           : get<Value>().getContext();
+  }
 };
 
 // Temporarily exit the MLIR namespace to add casting support as later code in
diff --git a/mlir/include/mlir/Interfaces/SubsetOpInterface.h b/mlir/include/mlir/Interfaces/SubsetOpInterface.h
index 049cf2456a9c842..98c33ec65012fca 100644
--- a/mlir/include/mlir/Interfaces/SubsetOpInterface.h
+++ b/mlir/include/mlir/Interfaces/SubsetOpInterface.h
@@ -10,6 +10,7 @@
 #define MLIR_INTERFACES_SUBSETOPINTERFACE_H_
 
 #include "mlir/IR/OpDefinition.h"
+#include "mlir/Interfaces/ValueBoundsOpInterface.h"
 
 namespace mlir {
 class SubsetOpInterface;
@@ -27,10 +28,23 @@ OpOperand &defaultGetDestinationOperand(Operation *op);
 /// `DestinationStyleOpInterface`.
 OpResult defaultGetUpdatedDestination(Operation *op);
 
-/// Default implementation of `isEquivalentSubset`.
+/// Default implementation of `SubsetInsertionOpInterface::isEquivalentSubset`.
 bool defaultIsEquivalentSubset(Operation *op, Value candidate,
                                function_ref<bool(Value, Value)> equivalenceFn);
 
+/// Default implementation of `SubsetOpInterface::operatesOnEquivalentSubset`.
+bool defaultOperatesOnEquivalentSubset(
+    Operation *op, SubsetOpInterface candidate,
+    function_ref<bool(Value, Value)> equivalenceFn);
+
+/// Default implementation of `SubsetOpInterface::operatesOnDisjointSubset`.
+bool defaultOperatesOnDisjointSubset(
+    Operation *op, SubsetOpInterface candidate,
+    function_ref<bool(Value, Value)> equivalenceFn);
+
+/// Return the container that the given subset op is operating on.
+Value getTensorContainer(Operation *op);
+
 /// Verify `SubsetOpInterface`.
 LogicalResult verifySubsetOpInterface(SubsetOpInterface op);
 
diff --git a/mlir/include/mlir/Interfaces/SubsetOpInterface.td b/mlir/include/mlir/Interfaces/SubsetOpInterface.td
index 9ebed2c94818d19..7000e7dfc89cdbe 100644
--- a/mlir/include/mlir/Interfaces/SubsetOpInterface.td
+++ b/mlir/include/mlir/Interfaces/SubsetOpInterface.td
@@ -32,11 +32,6 @@ def SubsetOpInterface : OpInterface<"SubsetOpInterface"> {
       hyperrectangular slice.
     - `tensor.gather/scatter` describe the subset as list of indices. (Not
       implemented yet.)
-
-    Note: This interface does not expose any interface methods to get a
-    description of the accessed subset. That is because there is currently no
-    efficient way to describe arbitrary subsets. This interface merely provides
-    interface methods to check if two subsets are equivalent or disjoint.
   }];
 
   let cppNamespace = "::mlir";
@@ -46,24 +41,59 @@ def SubsetOpInterface : OpInterface<"SubsetOpInterface"> {
           Return "true" if this op and the given candidate subset op operate on
           equivalent subsets. Return "false" if the two subsets are disjoint
           or cannot be proven to be equivalent.
+
+          This interface method does not have to be implemented if
+          `getAccessedHyperrectangularSlice` is implemented.
         }],
         /*retType=*/"bool",
         /*methodName=*/"operatesOnEquivalentSubset",
         /*args=*/(ins
             "::mlir::SubsetOpInterface":$candidate,
-            "::llvm::function_ref<bool(Value, Value)>":$equivalenceFn)
+            "::llvm::function_ref<bool(Value, Value)>":$equivalenceFn),
+        /*methodBody=*/"",
+        /*defaultImplementation=*/[{
+          return ::mlir::detail::defaultOperatesOnEquivalentSubset(
+              $_op, candidate, equivalenceFn);
+        }]
       >,
       InterfaceMethod<
         /*desc=*/[{
           Return "true" if this op and the given candidate subset op operate on
           disjoint subsets. Return "false" if the two subsets are equivalent,
           overlapping or cannot be proven to be disjoint.
+
+          This interface method does not have to be implemented if
+          `getAccessedHyperrectangularSlice` is implemented.
         }],
         /*retType=*/"bool",
         /*methodName=*/"operatesOnDisjointSubset",
         /*args=*/(ins
             "::mlir::SubsetOpInterface":$candidate,
-            "::llvm::function_ref<bool(Value, Value)>":$equivalenceFn)
+            "::llvm::function_ref<bool(Value, Value)>":$equivalenceFn),
+        /*methodBody=*/"",
+        /*defaultImplementation=*/[{
+          return ::mlir::detail::defaultOperatesOnDisjointSubset(
+              $_op, candidate, equivalenceFn);
+        }]
+      >,
+      InterfaceMethod<
+        /*desc=*/[{
+          If this op operates on a hyperrectangular subset, return a
+          description of the subset in terms of offsets, sizes and strides.
+          Otherwise, return "failure".
+
+          This interface method is a convenience method for the most common case
+          of hyperrectangular subset ops. It is optional. If it is implemented,
+          `operatesOnEquivalentSubset` and `operatesOnDisjointSubset` do not
+          have to be implemented.
+        }],
+        /*retType=*/"::mlir::FailureOr<::mlir::HyperrectangularSlice>",
+        /*methodName=*/"getAccessedHyperrectangularSlice",
+        /*args=*/(ins),
+        /*methodBody=*/"",
+        /*defaultImplementation=*/[{
+          return ::mlir::failure();
+        }]
       >,
   ];
 
@@ -71,6 +101,15 @@ def SubsetOpInterface : OpInterface<"SubsetOpInterface"> {
     return ::mlir::detail::verifySubsetOpInterface(
         ::mlir::cast<::mlir::SubsetOpInterface>($_op));
   }];
+
+  let extraClassDeclaration = [{
+    /// Return the container that this operation is operating on. In case of an
+    /// extraction op, the container is the source tensor. In case of an
+    /// insertion op, the container is the destination tensor.
+    Value getTensorContainer() {
+      return ::mlir::detail::getTensorContainer(getOperation());
+    }
+  }];
 }
 
 def SubsetExtractionOpInterface
diff --git a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
index 8e2986a2d1f05f6..28dadfb9ecf8688 100644
--- a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
+++ b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
@@ -21,6 +21,31 @@
 namespace mlir {
 class OffsetSizeAndStrideOpInterface;
 
+/// A hyperrectangular slice, represented as a list of offsets, sizes and
+/// strides.
+class HyperrectangularSlice {
+public:
+  HyperrectangularSlice(ArrayRef<OpFoldResult> offsets,
+                        ArrayRef<OpFoldResult> sizes,
+                        ArrayRef<OpFoldResult> strides);
+
+  /// Create a hyperrectangular slice with unit strides.
+  HyperrectangularSlice(ArrayRef<OpFoldResult> offsets,
+                        ArrayRef<OpFoldResult> sizes);
+
+  /// Infer a hyperrectangular slice from `OffsetSizeAndStrideOpInterface`.
+  HyperrectangularSlice(OffsetSizeAndStrideOpInterface op);
+
+  ArrayRef<OpFoldResult> getMixedOffsets() const { return mixedOffsets; }
+  ArrayRef<OpFoldResult> getMixedSizes() const { return mixedSizes; }
+  ArrayRef<OpFoldResult> getMixedStrides() const { return mixedStrides; }
+
+private:
+  SmallVector<OpFoldResult> mixedOffsets;
+  SmallVector<OpFoldResult> mixedSizes;
+  SmallVector<OpFoldResult> mixedStrides;
+};
+
 using ValueDimList = SmallVector<std::pair<Value, std::optional<int64_t>>>;
 
 /// A helper class to be used with `ValueBoundsOpInterface`. This class stores a
@@ -182,12 +207,34 @@ class ValueBoundsConstraintSet {
                                   std::optional<int64_t> dim1 = std::nullopt,
                                   std::optional<int64_t> dim2 = std::nullopt);
 
+  /// Compute whether the given values/attributes are equal. Return "failure" if
+  /// equality could not be determined.
+  ///
+  /// `ofr1`/`ofr2` must be of index type.
+  static FailureOr<bool> areEqual(OpFoldResult ofr1, OpFoldResult ofr2);
+
   /// Return "true" if the given slices are guaranteed to be overlapping.
   /// Return "false" if the given slices are guaranteed to be non-overlapping.
   /// Return "failure" if unknown.
-  static FailureOr<bool>
-  areOverlappingSlices(OffsetSizeAndStrideOpInterface slice1,
-                       OffsetSizeAndStrideOpInterface slice2);
+  ///
+  /// Slices are overlapping if for all dimensions:
+  /// *      offset1 + size1 * stride1 <= offset2
+  /// * and  offset2 + size2 * stride2 <= offset1
+  ///
+  /// Slice are non-overlapping if the above constraint is not satisfied for
+  /// at least one dimension.
+  static FailureOr<bool> areOverlappingSlices(MLIRContext *ctx,
+                                              HyperrectangularSlice slice1,
+                                              HyperrectangularSlice slice2);
+
+  /// Return "true" if the given slices are guaranteed to be equivalent.
+  /// Return "false" if the given slices are guaranteed to be non-equivalent.
+  /// Return "failure" if unknown.
+  ///
+  /// Slices are equivalent if their offsets, sizes and strices are equal.
+  static FailureOr<bool> areEquivalentSlices(MLIRContext *ctx,
+                                             HyperrectangularSlice slice1,
+                                             HyperrectangularSlice slice2);
 
   /// Add a bound for the given index-typed value or shaped value. This function
   /// returns a builder that adds the bound.
diff --git a/mlir/lib/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.cpp
index 7a1bafd409eea60..d50d7c62b789c8c 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.cpp
@@ -17,73 +17,12 @@ using namespace mlir::tensor;
 
 namespace {
 
-/// Return the tensor that the given subset op operates on.
-Value getContainerOperand(SubsetOpInterface op) {
-  if (auto extractionOp =
-          dyn_cast<SubsetExtractionOpInterface>(op.getOperation()))
-    return extractionOp.getSourceOperand().get();
-  if (auto insertionOp =
-          dyn_cast<SubsetInsertionOpInterface>(op.getOperation()))
-    return insertionOp.getDestinationOperand().get();
-  llvm_unreachable("expected SubsetExtraction/InsertionOpInterface");
-}
-
-/// Return "true" if the two ops operate on an equivalent subset.
-/// `equivalenceFn` is used to determine equivalence of tensors. Return "false"
-/// if the two ops operate non-equivalent subsets, if equivalence cannot be
-/// determined or if `op1` is not a subset op.
-template <typename OpTy>
-bool operateOnEquivalentSubsets(
-    OpTy op1, SubsetOpInterface op2,
-    function_ref<bool(Value, Value)> equivalenceFn) {
-  auto offsetsSizesAndStrides2 =
-      dyn_cast<OffsetSizeAndStrideOpInterface>(op2.getOperation());
-  if (!offsetsSizesAndStrides2)
-    return false;
-  if (!sameOffsetsSizesAndStrides(op1, offsetsSizesAndStrides2,
-                                  isEqualConstantIntOrValue))
-    return false;
-  return equivalenceFn(
-      getContainerOperand(cast<SubsetOpInterface>(op1.getOperation())),
-      getContainerOperand(op2));
-}
-
-/// Return "true" if the two ops operate on a disjoint subsets.
-/// `equivalenceFn` is used to determine equivalence of tensors. Return "false"
-/// if the two ops operate non-disjoint subsets, if disjointness cannot be
-/// determined or if `op1` is not a subset op.
-template <typename OpTy>
-bool operateOnDisjointSubsets(OpTy op1, SubsetOpInterface op2,
-                              function_ref<bool(Value, Value)> equivalenceFn) {
-  auto offsetsSizesAndStrides2 =
-      dyn_cast<OffsetSizeAndStrideOpInterface>(op2.getOperation());
-  if (!offsetsSizesAndStrides2)
-    return false;
-  FailureOr<bool> overlappingSlices =
-      ValueBoundsConstraintSet::areOverlappingSlices(op1,
-                                                     offsetsSizesAndStrides2);
-  if (failed(overlappingSlices) || *overlappingSlices)
-    return false;
-  return equivalenceFn(
-      getContainerOperand(cast<SubsetOpInterface>(op1.getOperation())),
-      getContainerOperand(op2));
-}
-
 struct ExtractSliceOpSubsetOpInterface
     : public SubsetOpInterface::ExternalModel<ExtractSliceOpSubsetOpInterface,
                                               tensor::ExtractSliceOp> {
-  bool operatesOnEquivalentSubset(
-      Operation *op, SubsetOpInterface candidate,
-      function_ref<bool(Value, Value)> equivalenceFn) const {
-    auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
-    return operateOnEquivalentSubsets(extractSliceOp, candidate, equivalenceFn);
-  }
-
-  bool operatesOnDisjointSubset(
-      Operation *op, SubsetOpInterface candidate,
-      function_ref<bool(Value, Value)> equivalenceFn) const {
-    auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
-    return operateOnDisjointSubsets(extractSliceOp, candidate, equivalenceFn);
+  FailureOr<HyperrectangularSlice>
+  getAccessedHyperrectangularSlice(Operation *op) const {
+    return HyperrectangularSlice(cast<OffsetSizeAndStrideOpInterface>(op));
   }
 };
 
@@ -99,18 +38,9 @@ template <typename OpTy>
 struct InsertSliceLikeOpSubsetOpInterface
     : public SubsetOpInterface::ExternalModel<
           InsertSliceLikeOpSubsetOpInterface<OpTy>, OpTy> {
-  bool operatesOnEquivalentSubset(
-      Operation *op, SubsetOpInterface candidate,
-      function_ref<bool(Value, Value)> equivalenceFn) const {
-    auto insertSliceOp = cast<OpTy>(op);
-    return operateOnEquivalentSubsets(insertSliceOp, candidate, equivalenceFn);
-  }
-
-  bool operatesOnDisjointSubset(
-      Operation *op, SubsetOpInterface candidate,
-      function_ref<bool(Value, Value)> equivalenceFn) const {
-    auto insertSliceOp = cast<OpTy>(op);
-    return operateOnDisjointSubsets(insertSliceOp, candidate, equivalenceFn);
+  FailureOr<HyperrectangularSlice>
+  getAccessedHyperrectangularSlice(Operation *op) const {
+    return HyperrectangularSlice(cast<OffsetSizeAndStrideOpInterface>(op));
   }
 };
 
diff --git a/mlir/lib/Interfaces/CMakeLists.txt b/mlir/lib/Interfaces/CMakeLists.txt
index 2652d261f480ba4..e7c76e70ed6b5d7 100644
--- a/mlir/lib/Interfaces/CMakeLists.txt
+++ b/mlir/lib/Interfaces/CMakeLists.txt
@@ -93,10 +93,12 @@ add_mlir_library(MLIRSubsetOpInterface
   DEPENDS
   MLIRDestinationStyleOpInterface
   MLIRSubsetOpInterfaceIncGen
+  MLIRValueBoundsOpInterface
 
   LINK_LIBS PUBLIC
   MLIRDestinationStyleOpInterface
   MLIRIR
+  MLIRValueBoundsOpInterface
   )
 
 add_mlir_interface_library(TilingInterface)
diff --git a/mlir/lib/Interfaces/SubsetOpInterface.cpp b/mlir/lib/Interfaces/SubsetOpInterface.cpp
index 7245ab20c499e20..d0bdadf500f6f6c 100644
--- a/mlir/lib/Interfaces/SubsetOpInterface.cpp
+++ b/mlir/lib/Interfaces/SubsetOpInterface.cpp
@@ -8,6 +8,7 @@
 
 #include "mlir/Interfaces/SubsetOpInterface.h"
 #include "mlir/Interfaces/DestinationStyleOpInterface.h"
+#include "mlir/Interfaces/ValueBoundsOpInterface.h"
 
 #include "mlir/Interfaces/SubsetOpInterface.cpp.inc"
 
@@ -40,6 +41,54 @@ bool detail::defaultIsEquivalentSubset(
       candidate.getDefiningOp<SubsetOpInterface>(), equivalenceFn);
 }
 
+bool detail::defaultOperatesOnEquivalentSubset(
+    Operation *op, SubsetOpInterface candidate,
+    function_ref<bool(Value, Value)> equivalenceFn) {
+  auto subsetOp = cast<SubsetOpInterface>(op);
+  FailureOr<HyperrectangularSlice> slice =
+      subsetOp.getAccessedHyperrectangularSlice();
+  assert(succeeded(slice) &&
+         "operatesOnEquivalentSubset must be implemented if "
+         "getAccessedHyperrectangularSlice is not implemented");
+  FailureOr<HyperrectangularSlice> otherSlice =
+      candidate.getAccessedHyperrectangularSlice();
+  if (failed(otherSlice))
+    return false;
+  if (!equivalenceFn(subsetOp.getTensorContainer(),
+                     candidate.getTensorContainer()))
+    return false;
+  FailureOr<bool> equivalent = ValueBoundsConstraintSet::areEquivalentSlices(
+      op->getContext(), *slice, *otherSlice);
+  return succeeded(equivalent) && *equivalent;
+}
+
+bool detail::defaultOperatesOnDisjointSubset(
+    Operation *op, SubsetOpInterface candidate,
+    function_ref<bool(Value, Value)> equivalenceFn) {
+  auto subsetOp = cast<SubsetOpInterface>(op);
+  FailureOr<HyperrectangularSlice> slice =
+      subsetOp.getAccessedHyperrectangularSlice();
+  assert(succeeded(slice) &&
+         "defaultOperatesOnDisjointSubset must be implemented if "
+         "getAccessedHyperrectangularSlice is not implemented");
+  FailureOr<HyperrectangularSlice> otherSlice =
+      candidate.getAccessedHyperrectangularSlice();
+  if (failed(otherSlice))
+    return false;
+  if (!equivalenceFn(subsetOp.getTensorContainer(),
+                     candidate.getTensorContainer()))
+    return false;
+  FailureOr<bool> overlapping = ValueBoundsConstraintSet::areOverlappingSlices(
+      op->getContext(), *slice, *otherSlice);
+  return succeeded(overlapping) && !*overlapping;
+}
+
+Value detail::getTensorContainer(Operation *op) {
+  if (auto insertionOp = dyn_cast<::mlir::SubsetInsertionOpInterface>(op))
+    return insertionOp.getDestinationOperand().get();
+  return cast<::mlir::SubsetExtractionOpInterface>(op).getSourceOperand().get();
+}
+
 LogicalResult detail::verifySubsetOpInterface(SubsetOpInterface op) {
   if (!(isa<SubsetExtractionOpInterface>(op.getOperation()) ^
         isa<SubsetInsertionOpInterface>(op.getOperation())))
diff --git a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
index f0c37c872e6d31d..62ba63402925e01 100644
--- a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
+++ b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
@@ -25,6 +25,32 @@ namespace mlir {
 #include "mlir/Interfaces/ValueBoundsOpInterface.cpp.inc"
 } // namespace mlir
 
+HyperrectangularSlice::HyperrectangularSlice(ArrayRef<OpFoldResult> offsets,
+                                             ArrayRef<OpFoldResult> sizes,
+                                             ArrayRef<OpFoldResult> strides)
+    : mixedOffsets(offsets), mixedSizes(sizes), mixedStrides(strides) {
+  assert(offsets.size() == sizes.size() &&
+         "expected same number of offsets, sizes, strides");
+  assert(offsets.size() == strides.size() &&
+         "expected same number of offsets, sizes, strides");
+}
+
+HyperrectangularSlice::HyperrectangularSlice(ArrayRef<OpFoldResult> offsets,
+                                             ArrayRef<OpFoldResult> sizes)
+    : mixedOffsets(offsets), mixedSizes(sizes) {
+  assert(offsets.size() == sizes.size() &&
+         "expected same number of offsets and sizes");
+  // Assume that all strides are 1.
+  if (offsets.empty())
+    return;
+  MLIRContext *ctx = offsets.front().getContext();
+  mixedStrides.append(offsets.size(), Builder(ctx).getIndexAttr(1));
+}
+
+HyperrectangularSlice::HyperrectangularSlice(OffsetSizeAndStrideOpInterface op)
+    : HyperrectangularSlice(op.getMixedOffsets(), op.getMixedSizes(),
+                            op.getMixedStrides()) {}
+
 /// If ofr is a constant integer or an IntegerAttr, return the integer.
 static std::optional<int64_t> getConstantIntValue(OpFoldResult ofr) {
   // Case 1: Check for Constant integer.
@@ -524,19 +550,44 @@ ValueBoundsConstraintSet::areEqual(Value value1, Value value2,
   return *delta == 0;
 }
 
-FailureOr<bool> ValueBoundsConstraintSet::areOverlappingSlices(
-    OffsetSizeAndStrideOpInterface slice1,
-    OffsetSizeAndStrideOpInterface slice2) {
-  assert(slice1.getStaticOffsets().size() == slice1.getStaticOffsets().size() &&
+FailureOr<bool> ValueBoundsConstraintSet::areEqual(OpFoldResult ofr1,
+                                                   OpFoldResult ofr2) {
+  Builder b(ofr1.getContext());
+  AffineMap map =
+      AffineMap::get(/*dimCount=*/0, /*symbolCount=*/2,
+                     b.getAffineSymbolExpr(0) - b.getAffineSymbolExpr(1));
+  SmallVector<OpFoldResult> ofrOperands;
+  ofrOperands.push_back(ofr1);
+  ofrOperands.push_back(ofr2);
+  SmallVector<Value> valueOperands;
+  AffineMap foldedMap =
+      foldAttributesIntoMap(b, map, ofrOperands, valueOperands);
+  ValueDimList valueDims;
+  for (Value v : valueOperands) {
+    assert(v.getType().isIndex() && "expected index type");
+    valueDims.emplace_back(v, std::nullopt);
+  }
+  FailureOr<int64_t> delta =
+      computeConstantBound(presburger::BoundType::EQ, foldedMap, valueDims);
+  if (failed(delta))
+    return failure();
+  return *delta == 0;
+}
+
+FailureOr<bool>
+ValueBoundsConstraintSet::areOverlappingSlices(MLIRContext *ctx,
+                                               HyperrectangularSlice slice1,
+                                               HyperrectangularSlice slice2) {
+  assert(slice1.getMixedOffsets().size() == slice1.getMixedOffsets().size() &&
          "expected slices of same rank");
-  assert(slice1.getStaticSizes().size() == slice1.getStaticSizes().size() &&
+  assert(slice1.getMixedSizes().size() == slice1.getMixedSizes().size() &&
          "expected slices of same rank");
-  assert(slice1.getStaticStrides().size() == slice1.getStaticStrides().size() &&
+  assert(slice1.getMixedStrides().size() == slice1.getMixedStrides().size() &&
          "expected slices of same rank");
 
-  Builder b(slice1.getContext());
+  Builder b(ctx);
   bool foundUnknownBound = false;
-  for (int64_t i = 0, e = slice1.getStaticOffsets().size(); i < e; ++i) {
+  for (int64_t i = 0, e = slice1.getMixedOffsets().size(); i < e; ++i) {
     AffineMap map =
         AffineMap::get(/*dimCount=*/0, /*symbolCount=*/4,
                        b.getAffineSymbolExpr(0) +
@@ -588,6 +639,48 @@ FailureOr<bool> ValueBoundsConstraintSet::areOverlappingSlices(
   return true;
 }
 
+FailureOr<bool>
+ValueBoundsConstraintSet::areEquivalentSlices(MLIRContext *ctx,
+                                              HyperrectangularSlice slice1,
+                                              HyperrectangularSlice slice2) {
+  assert(slice1.getMixedOffsets().size() == slice1.getMixedOffsets().size() &&
+         "expected slices of same rank");
+  assert(slice1.getMixedSizes().size() == slice1.getMixedSizes().size() &&
+         "expected slices of same rank");
+  assert(slice1.getMixedStrides().size() == slice1.getMixedStrides().size() &&
+         "expected slices of same rank");
+
+  // The two slices are equivalent if all of their offsets, sizes and strides
+  // are equal. If equality cannot be determined for at least one of those
+  // values, equivalence cannot be determined and this function returns
+  // "failure".
+  for (auto [offset1, offset2] :
+       llvm::zip_equal(slice1.getMixedOffsets(), slice2.getMixedOffsets())) {
+    FailureOr<bool> equal = areEqual(offset1, offset2);
+    if (failed(equal))
+      return failure();
+    if (!equal.value())
+      return false;
+  }
+  for (auto [size1, size2] :
+       llvm::zip_equal(slice1.getMixedSizes(), slice2.getMixedSizes())) {
+    FailureOr<bool> equal = areEqual(size1, size2);
+    if (failed(equal))
+      return failure();
+    if (!equal.value())
+      return false;
+  }
+  for (auto [stride1, stride2] :
+       llvm::zip_equal(slice1.getMixedStrides(), slice2.getMixedStrides())) {
+    FailureOr<bool> equal = areEqual(stride1, stride2);
+    if (failed(equal))
+      return failure();
+    if (!equal.value())
+      return false;
+  }
+  return true;
+}
+
 ValueBoundsConstraintSet::BoundBuilder &
 ValueBoundsConstraintSet::BoundBuilder::operator[](int64_t dim) {
   assert(!this->dim.has_value() && "dim was already set");
diff --git a/mlir/test/Transforms/loop-invariant-subset-hoisting.mlir b/mlir/test/Transforms/loop-invariant-subset-hoisting.mlir
index b9161f4e20d1927..bb60eeaba52455c 100644
--- a/mlir/test/Transforms/loop-invariant-subset-hoisting.mlir
+++ b/mlir/test/Transforms/loop-invariant-subset-hoisting.mlir
@@ -7,6 +7,11 @@ func.func @hoist_matching_extract_insert(%arg: tensor<?xf32>) -> tensor<?xf32> {
   %ub = "test.foo"() : () -> (index)
   %step = "test.foo"() : () -> (index)
 
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %add = arith.addi %c0, %c1 : index
+  %sub = arith.subi %add, %c1 : index
+
   // CHECK: %[[extract:.*]] = tensor.extract_slice %[[arg]]
   // CHECK: %[[for:.*]]:2 = scf.for {{.*}} iter_args(%[[t:.*]] = %[[arg]], %[[hoisted:.*]] = %[[extract]])
   %0 = scf.for %iv = %lb to %ub step %step iter_args(%t = %arg) -> (tensor<?xf32>) {
@@ -17,7 +22,9 @@ func.func @hoist_matching_extract_insert(%arg: tensor<?xf32>) -> tensor<?xf32> {
     %1 = tensor.extract_slice %t[0][5][1] : tensor<?xf32> to tensor<5xf32>
     // CHECK: %[[foo:.*]] = "test.foo"(%[[hoisted]])
     %2 = "test.foo"(%1) : (tensor<5xf32>) -> (tensor<5xf32>)
-    %3 = tensor.insert_slice %2 into %t[0][5][1] : tensor<5xf32> into tensor<?xf32>
+    // Obfuscate the IR by inserting at offset %sub instead of 0; both of them
+    // have the same value.
+    %3 = tensor.insert_slice %2 into %t[%sub][5][1] : tensor<5xf32> into tensor<?xf32>
     // CHECK: scf.yield %[[t]], %[[foo]]
     scf.yield %3 : tensor<?xf32>
   }
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 0a2ae427169a99a..7109b6c439057fb 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -10230,6 +10230,7 @@ cc_library(
         ":IR",
         ":SubsetOpInterfaceIncGen",
         ":Support",
+        ":ValueBoundsOpInterface",
         "//llvm:Support",
     ],
 )



More information about the Mlir-commits mailing list