[llvm] Allow empty dimension arrays in `linalg::inferContractionDims` (PR #69496)

via llvm-commits llvm-commits at lists.llvm.org
Wed Oct 18 11:39:26 PDT 2023


https://github.com/bjacob created https://github.com/llvm/llvm-project/pull/69496

This function was returning failure when any of the intersection sets was empty, but this is actually legitimate in "matrix times vector" cases, where some of the operands have lower dimensionality, implying unit-dimension semantics for the "missing" dimensions.

Example:

```mlir
func.func @transpose_extend_batch_matmul(
    %vec: tensor<32x128xi16>,
    %mat: tensor<11008x32x128xi4>) -> tensor<11008x32xi32> {
  %c0_i32 = arith.constant 0 : i32
  %cst_0 = arith.constant 0.000000e+00 : f32
  %0 = tensor.empty() : tensor<11008x32xi32>
  %1 = linalg.fill ins(%c0_i32 : i32) outs(%0 : tensor<11008x32xi32>) -> tensor<11008x32xi32>
  %2 = tensor.empty() : tensor<11008xf32>
  %3 = linalg.fill ins(%cst_0 : f32) outs(%2 : tensor<11008xf32>) -> tensor<11008xf32>
  %batch_matmul_result = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d2)>, 
                                                          affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 
                                                          affine_map<(d0, d1, d2) -> (d0, d1)>], 
                                         iterator_types = ["parallel", "parallel", "reduction"]} 
                                         ins(%vec, %mat : tensor<32x128xi16>, tensor<11008x32x128xi4>) 
                                         outs(%1 : tensor<11008x32xi32>) {
  ^bb0(%in: i16, %in_3: i4, %out: i32):
      %19 = arith.extsi %in : i16 to i32
      %20 = arith.extui %in_3 : i4 to i32
      %21 = arith.muli %19, %20 : i32
      %22 = arith.addi %21, %out : i32
      linalg.yield %22 : i32
  } -> tensor<11008x32xi32>
  return %batch_matmul_result : tensor<11008x32xi32>
}
```

Here, we were returning failure because `ac` is empty. With this PR, we return this useful information:

```
batch: [ 1 ]
m: [ ]
n: [ 0 ]
k: [ 2 ]
```

>From 015cacd87c618ce63f3016ae61db0cf483e82935 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Thu, 12 Oct 2023 18:17:37 -0700
Subject: [PATCH 1/5] [mlir][affine] ValueBoundsConstraintSet: Fully compose
 affine.apply (#68899)

Fully compose `affine.apply` ops before adding them to the underlying
`FlatLinearConstraints`. This works around a limitation of
`FlatLinearConstraints`, which cannot deduce a constant bound if it
involves two identical local variables.

Details for future improvements of `FlatLinearConstraints`: The
constraint set infrastructure fails to compute a constant bound of -8
for the first variable.
```
Domain: 0, Range: 1, Symbols: 4, Locals: 2
8 constraints
(None    None    None    None    None    Local    Local    const)
 1    -1    0    0    0    0    0    0    = 0
 0    1    -1    1    0    0    0    0    = 0
 0    0    1    0    0    0    -16    0    = 0
 0    0    0    1    0    -16    0    -8    = 0
 0    0    0    0    -1    0    32    31    >= 0
 0    0    0    0    1    0    -32    0    >= 0
 0    0    0    0    -1    32    0    31    >= 0
 0    0    0    0    1    -32    0    0    >= 0
```
---
 .../Affine/IR/ValueBoundsOpInterfaceImpl.h    | 14 ++++++
 .../Affine/IR/ValueBoundsOpInterfaceImpl.cpp  | 47 +++++++++++++++++--
 .../value-bounds-op-interface-impl.mlir       | 32 +++++++++++++
 .../Dialect/Affine/TestReifyValueBounds.cpp   | 11 ++++-
 4 files changed, 97 insertions(+), 7 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.h b/mlir/include/mlir/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.h
index 2abbabc5bb2868c..5d4774861bdfd37 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.h
+++ b/mlir/include/mlir/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.h
@@ -9,11 +9,25 @@
 #ifndef MLIR_DIALECT_AFFINE_IR_VALUEBOUNDSOPINTERFACEIMPL_H
 #define MLIR_DIALECT_AFFINE_IR_VALUEBOUNDSOPINTERFACEIMPL_H
 
+#include "mlir/Support/LogicalResult.h"
+
 namespace mlir {
 class DialectRegistry;
+class Value;
 
 namespace affine {
 void registerValueBoundsOpInterfaceExternalModels(DialectRegistry &registry);
+
+/// Compute whether the given values are equal. Return "failure" if equality
+/// could not be determined. `value1`/`value2` must be index-typed.
+///
+/// This function is similar to `ValueBoundsConstraintSet::areEqual`. To work
+/// around limitations in `FlatLinearConstraints`, this function fully composes
+/// `value1` and `value2` (if they are the result of affine.apply ops) before
+/// populating the constraint set. The folding/composing logic can see
+/// opportunities for simplifications that the constraint set implementation
+/// cannot see.
+FailureOr<bool> fullyComposeAndCheckIfEqual(Value value1, Value value2);
 } // namespace affine
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp
index 97dd70e4f1d2b7e..d47c8eb8ccb4272 100644
--- a/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp
@@ -27,12 +27,22 @@ struct AffineApplyOpInterface
     assert(applyOp.getAffineMap().getNumResults() == 1 &&
            "expected single result");
 
+    // Fully compose this affine.apply with other ops because the folding logic
+    // can see opportunities for simplifying the affine map that
+    // `FlatLinearConstraints` can currently not see.
+    AffineMap map = applyOp.getAffineMap();
+    SmallVector<Value> operands = llvm::to_vector(applyOp.getOperands());
+    fullyComposeAffineMapAndOperands(&map, &operands);
+
     // Align affine map result with dims/symbols in the constraint set.
-    AffineExpr expr = applyOp.getAffineMap().getResult(0);
-    SmallVector<AffineExpr> dimReplacements = llvm::to_vector(llvm::map_range(
-        applyOp.getDimOperands(), [&](Value v) { return cstr.getExpr(v); }));
-    SmallVector<AffineExpr> symReplacements = llvm::to_vector(llvm::map_range(
-        applyOp.getSymbolOperands(), [&](Value v) { return cstr.getExpr(v); }));
+    AffineExpr expr = map.getResult(0);
+    SmallVector<AffineExpr> dimReplacements, symReplacements;
+    for (int64_t i = 0, e = map.getNumDims(); i < e; ++i)
+      dimReplacements.push_back(cstr.getExpr(operands[i]));
+    for (int64_t i = map.getNumDims(),
+                 e = map.getNumDims() + map.getNumSymbols();
+         i < e; ++i)
+      symReplacements.push_back(cstr.getExpr(operands[i]));
     AffineExpr bound =
         expr.replaceDimsAndSymbols(dimReplacements, symReplacements);
     cstr.bound(value) == bound;
@@ -92,3 +102,30 @@ void mlir::affine::registerValueBoundsOpInterfaceExternalModels(
     AffineMinOp::attachInterface<AffineMinOpInterface>(*ctx);
   });
 }
+
+FailureOr<bool> mlir::affine::fullyComposeAndCheckIfEqual(Value value1,
+                                                          Value value2) {
+  assert(value1.getType().isIndex() && "expected index type");
+  assert(value2.getType().isIndex() && "expected index type");
+
+  // Subtract the two values/dimensions from each other. If the result is 0,
+  // both are equal.
+  Builder b(value1.getContext());
+  AffineMap map = AffineMap::get(/*dimCount=*/2, /*symbolCount=*/0,
+                                 b.getAffineDimExpr(0) - b.getAffineDimExpr(1));
+  // Fully compose the affine map with other ops because the folding logic
+  // can see opportunities for simplifying the affine map that
+  // `FlatLinearConstraints` can currently not see.
+  SmallVector<Value> mapOperands;
+  mapOperands.push_back(value1);
+  mapOperands.push_back(value2);
+  affine::fullyComposeAffineMapAndOperands(&map, &mapOperands);
+  ValueDimList valueDims;
+  for (Value v : mapOperands)
+    valueDims.push_back({v, std::nullopt});
+  FailureOr<int64_t> bound = ValueBoundsConstraintSet::computeConstantBound(
+      presburger::BoundType::EQ, map, valueDims);
+  if (failed(bound))
+    return failure();
+  return *bound == 0;
+}
diff --git a/mlir/test/Dialect/Affine/value-bounds-op-interface-impl.mlir b/mlir/test/Dialect/Affine/value-bounds-op-interface-impl.mlir
index 338c48c5b210bc1..8acf358c887a987 100644
--- a/mlir/test/Dialect/Affine/value-bounds-op-interface-impl.mlir
+++ b/mlir/test/Dialect/Affine/value-bounds-op-interface-impl.mlir
@@ -58,3 +58,35 @@ func.func @affine_min_lb(%a: index) -> (index) {
   %2 = "test.reify_bound"(%1) {type = "LB"}: (index) -> (index)
   return %2 : index
 }
+
+// -----
+
+// CHECK-LABEL: func @composed_affine_apply(
+//       CHECK:   %[[cst:.*]] = arith.constant -8 : index
+//       CHECK:   return %[[cst]]
+func.func @composed_affine_apply(%i1 : index) -> (index) {
+  // The ValueBoundsOpInterface implementation of affine.apply fully composes
+  // the affine map (and its operands) with other affine.apply ops drawn from
+  // its operands before adding it to the constraint set. This is to work
+  // around a limitation in `FlatLinearConstraints`, which can currently not
+  // compute a constant bound for %s. (The affine map simplification logic can
+  // simplify %s to -8.)
+  %i2 = affine.apply affine_map<(d0) -> ((d0 floordiv 32) * 16)>(%i1)
+  %i3 = affine.apply affine_map<(d0) -> ((d0 floordiv 32) * 16 + 8)>(%i1)
+  %s = affine.apply affine_map<()[s0, s1] -> (s0 - s1)>()[%i2, %i3]
+  %reified = "test.reify_constant_bound"(%s) {type = "EQ"} : (index) -> (index)
+  return %reified : index
+}
+
+
+// -----
+
+// Test for affine::fullyComposeAndCheckIfEqual
+func.func @composed_are_equal(%i1 : index) {
+  %i2 = affine.apply affine_map<(d0) -> ((d0 floordiv 32) * 16)>(%i1)
+  %i3 = affine.apply affine_map<(d0) -> ((d0 floordiv 32) * 16 + 8)>(%i1)
+  %s = affine.apply affine_map<()[s0, s1] -> (s0 - s1)>()[%i2, %i3]
+  // expected-remark @below{{different}}
+   "test.are_equal"(%i2, %i3) {compose} : (index, index) -> ()
+  return
+}
diff --git a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
index ad017cef1b9bace..6e3c3dff759a2ed 100644
--- a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
+++ b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.h"
 #include "mlir/Dialect/Affine/Transforms/Transforms.h"
 #include "mlir/Dialect/Arith/Transforms/Transforms.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -186,8 +187,14 @@ static LogicalResult testEquality(func::FuncOp funcOp) {
         op->emitOpError("invalid op");
         return WalkResult::skip();
       }
-      FailureOr<bool> equal = ValueBoundsConstraintSet::areEqual(
-          op->getOperand(0), op->getOperand(1));
+      FailureOr<bool> equal = failure();
+      if (op->hasAttr("compose")) {
+        equal = affine::fullyComposeAndCheckIfEqual(op->getOperand(0),
+                                                    op->getOperand(1));
+      } else {
+        equal = ValueBoundsConstraintSet::areEqual(op->getOperand(0),
+                                                   op->getOperand(1));
+      }
       if (failed(equal)) {
         op->emitError("could not determine equality");
       } else if (*equal) {

>From 5ebd8216a35eb7a3de019519756342e9ba4c7db9 Mon Sep 17 00:00:00 2001
From: Lei Zhang <antiagainst at gmail.com>
Date: Sun, 15 Oct 2023 19:37:54 -0400
Subject: [PATCH 2/5] [mlir][vector] Enable transfer op hoisting with dynamic
 indices (#68500)

Recent changes (https://github.com/llvm/llvm-project/pull/66930)
disabled vector transfer ops hoisting with view-like intermediate ops.
The recommended way is to fold subview ops into transfer op indices
before invoking hoisting. That would mean now we see transfer op indices
involving dynamic values, instead of static constant values before with
subview ops. Therefore hoisting won't kick in anymore. This breaks
downstream users.

To fix it, this commit enables hoisting transfer ops with dynamic
indices by using `ValueBoundsConstraintSet` to prove ranges are disjoint
in `isDisjointTransferIndices`. Given that utility is used in many
places including op folders, right now we introduce a flag to it and
only set as true for "heavy" transforms in hoisting and load-store
forwarding.
---
 .../Affine/IR/ValueBoundsOpInterfaceImpl.h    |  12 +-
 .../mlir/Dialect/Vector/IR/VectorOps.h        |  19 ++-
 .../mlir/Interfaces/ValueBoundsOpInterface.h  |  10 ++
 .../Affine/IR/ValueBoundsOpInterfaceImpl.cpp  |   9 +-
 .../Dialect/Linalg/Transforms/Hoisting.cpp    |  12 +-
 mlir/lib/Dialect/Vector/IR/CMakeLists.txt     |   2 +
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp      |  65 +++++++--
 .../Transforms/VectorTransferOpTransforms.cpp |   6 +-
 .../lib/Interfaces/ValueBoundsOpInterface.cpp |  27 ++--
 mlir/test/Dialect/Linalg/hoisting.mlir        | 132 ++++++++++++++++++
 .../Dialect/Vector/vector-transferop-opt.mlir | 104 ++++++++++++++
 .../Dialect/Affine/TestReifyValueBounds.cpp   |  30 ++--
 .../llvm-project-overlay/mlir/BUILD.bazel     |   2 +
 13 files changed, 370 insertions(+), 60 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.h b/mlir/include/mlir/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.h
index 5d4774861bdfd37..6e617ef40a53d7d 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.h
+++ b/mlir/include/mlir/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.h
@@ -18,16 +18,18 @@ class Value;
 namespace affine {
 void registerValueBoundsOpInterfaceExternalModels(DialectRegistry &registry);
 
-/// Compute whether the given values are equal. Return "failure" if equality
-/// could not be determined. `value1`/`value2` must be index-typed.
+/// Compute a constant delta of the given two values. Return "failure" if we
+/// cannot determine a constant delta. `value1`/`value2` must be index-typed.
 ///
-/// This function is similar to `ValueBoundsConstraintSet::areEqual`. To work
-/// around limitations in `FlatLinearConstraints`, this function fully composes
+/// This function is similar to
+/// `ValueBoundsConstraintSet::computeConstantDistance`. To work around
+/// limitations in `FlatLinearConstraints`, this function fully composes
 /// `value1` and `value2` (if they are the result of affine.apply ops) before
 /// populating the constraint set. The folding/composing logic can see
 /// opportunities for simplifications that the constraint set implementation
 /// cannot see.
-FailureOr<bool> fullyComposeAndCheckIfEqual(Value value1, Value value2);
+FailureOr<int64_t> fullyComposeAndComputeConstantDelta(Value value1,
+                                                       Value value2);
 } // namespace affine
 } // namespace mlir
 
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
index fc0c80036ff79ad..9ab20e20d975429 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
@@ -105,16 +105,23 @@ bool checkSameValueRAW(TransferWriteOp defWrite, TransferReadOp read);
 /// op.
 bool checkSameValueWAW(TransferWriteOp write, TransferWriteOp priorWrite);
 
-/// Same behavior as `isDisjointTransferSet` but doesn't require the operations
-/// to have the same tensor/memref. This allows comparing operations accessing
-/// different tensors.
+/// Return true if we can prove that the transfer operations access disjoint
+/// memory, without requring the accessed tensor/memref to be the same.
+///
+/// If `testDynamicValueUsingBounds` is true, tries to test dynamic values
+/// via ValueBoundsOpInterface.
 bool isDisjointTransferIndices(VectorTransferOpInterface transferA,
-                               VectorTransferOpInterface transferB);
+                               VectorTransferOpInterface transferB,
+                               bool testDynamicValueUsingBounds = false);
 
 /// Return true if we can prove that the transfer operations access disjoint
-/// memory.
+/// memory, requiring the operations to access the same tensor/memref.
+///
+/// If `testDynamicValueUsingBounds` is true, tries to test dynamic values
+/// via ValueBoundsOpInterface.
 bool isDisjointTransferSet(VectorTransferOpInterface transferA,
-                           VectorTransferOpInterface transferB);
+                           VectorTransferOpInterface transferB,
+                           bool testDynamicValueUsingBounds = false);
 
 /// Return the result value of reducing two scalar/vector values with the
 /// corresponding arith operation.
diff --git a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
index 2687d79aec68ebb..8f11c563e0cbd91 100644
--- a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
+++ b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
@@ -176,6 +176,16 @@ class ValueBoundsConstraintSet {
       presburger::BoundType type, AffineMap map, ValueDimList mapOperands,
       StopConditionFn stopCondition = nullptr, bool closedUB = false);
 
+  /// Compute a constant delta between the given two values. Return "failure"
+  /// if a constant delta could not be determined.
+  ///
+  /// `dim1`/`dim2` must be `nullopt` if and only if `value1`/`value2` are
+  /// index-typed.
+  static FailureOr<int64_t>
+  computeConstantDelta(Value value1, Value value2,
+                       std::optional<int64_t> dim1 = std::nullopt,
+                       std::optional<int64_t> dim2 = std::nullopt);
+
   /// Compute whether the given values/dimensions are equal. Return "failure" if
   /// equality could not be determined.
   ///
diff --git a/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp
index d47c8eb8ccb4272..e0c3abe7a0f71d1 100644
--- a/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp
@@ -103,8 +103,8 @@ void mlir::affine::registerValueBoundsOpInterfaceExternalModels(
   });
 }
 
-FailureOr<bool> mlir::affine::fullyComposeAndCheckIfEqual(Value value1,
-                                                          Value value2) {
+FailureOr<int64_t>
+mlir::affine::fullyComposeAndComputeConstantDelta(Value value1, Value value2) {
   assert(value1.getType().isIndex() && "expected index type");
   assert(value2.getType().isIndex() && "expected index type");
 
@@ -123,9 +123,6 @@ FailureOr<bool> mlir::affine::fullyComposeAndCheckIfEqual(Value value1,
   ValueDimList valueDims;
   for (Value v : mapOperands)
     valueDims.push_back({v, std::nullopt});
-  FailureOr<int64_t> bound = ValueBoundsConstraintSet::computeConstantBound(
+  return ValueBoundsConstraintSet::computeConstantBound(
       presburger::BoundType::EQ, map, valueDims);
-  if (failed(bound))
-    return failure();
-  return *bound == 0;
 }
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
index 221bec713b38aa3..cbb2c507de69f9e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
@@ -173,16 +173,16 @@ void mlir::linalg::hoistRedundantVectorTransfers(func::FuncOp func) {
         if (auto transferWriteUse =
                 dyn_cast<vector::TransferWriteOp>(use.getOwner())) {
           if (!vector::isDisjointTransferSet(
-                  cast<VectorTransferOpInterface>(transferWrite.getOperation()),
-                  cast<VectorTransferOpInterface>(
-                      transferWriteUse.getOperation())))
+                  cast<VectorTransferOpInterface>(*transferWrite),
+                  cast<VectorTransferOpInterface>(*transferWriteUse),
+                  /*testDynamicValueUsingBounds=*/true))
             return WalkResult::advance();
         } else if (auto transferReadUse =
                        dyn_cast<vector::TransferReadOp>(use.getOwner())) {
           if (!vector::isDisjointTransferSet(
-                  cast<VectorTransferOpInterface>(transferWrite.getOperation()),
-                  cast<VectorTransferOpInterface>(
-                      transferReadUse.getOperation())))
+                  cast<VectorTransferOpInterface>(*transferWrite),
+                  cast<VectorTransferOpInterface>(*transferReadUse),
+                  /*testDynamicValueUsingBounds=*/true))
             return WalkResult::advance();
         } else {
           // Unknown use, we cannot prove that it doesn't alias with the
diff --git a/mlir/lib/Dialect/Vector/IR/CMakeLists.txt b/mlir/lib/Dialect/Vector/IR/CMakeLists.txt
index 9ec919423b3428f..70f3fa8c297d4bc 100644
--- a/mlir/lib/Dialect/Vector/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/IR/CMakeLists.txt
@@ -11,6 +11,7 @@ add_mlir_dialect_library(MLIRVectorDialect
   MLIRVectorAttributesIncGen
 
   LINK_LIBS PUBLIC
+  MLIRAffineDialect
   MLIRArithDialect
   MLIRControlFlowInterfaces
   MLIRDataLayoutInterfaces
@@ -22,5 +23,6 @@ add_mlir_dialect_library(MLIRVectorDialect
   MLIRMemRefDialect
   MLIRSideEffectInterfaces
   MLIRTensorDialect
+  MLIRValueBoundsOpInterface
   MLIRVectorInterfaces
   )
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 044b6cc07d3d629..68a5cf209f2fb49 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -13,6 +13,7 @@
 
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 
+#include "mlir/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
@@ -30,6 +31,7 @@
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
+#include "mlir/Interfaces/ValueBoundsOpInterface.h"
 #include "mlir/Support/LLVM.h"
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/STLExtras.h"
@@ -168,39 +170,76 @@ bool mlir::vector::checkSameValueWAW(vector::TransferWriteOp write,
 }
 
 bool mlir::vector::isDisjointTransferIndices(
-    VectorTransferOpInterface transferA, VectorTransferOpInterface transferB) {
+    VectorTransferOpInterface transferA, VectorTransferOpInterface transferB,
+    bool testDynamicValueUsingBounds) {
   // For simplicity only look at transfer of same type.
   if (transferA.getVectorType() != transferB.getVectorType())
     return false;
   unsigned rankOffset = transferA.getLeadingShapedRank();
   for (unsigned i = 0, e = transferA.indices().size(); i < e; i++) {
-    auto indexA = getConstantIntValue(transferA.indices()[i]);
-    auto indexB = getConstantIntValue(transferB.indices()[i]);
-    // If any of the indices are dynamic we cannot prove anything.
-    if (!indexA.has_value() || !indexB.has_value())
-      continue;
+    Value indexA = transferA.indices()[i];
+    Value indexB = transferB.indices()[i];
+    std::optional<int64_t> cstIndexA = getConstantIntValue(indexA);
+    std::optional<int64_t> cstIndexB = getConstantIntValue(indexB);
 
     if (i < rankOffset) {
       // For leading dimensions, if we can prove that index are different we
       // know we are accessing disjoint slices.
-      if (*indexA != *indexB)
-        return true;
+      if (cstIndexA.has_value() && cstIndexB.has_value()) {
+        if (*cstIndexA != *cstIndexB)
+          return true;
+        continue;
+      }
+      if (testDynamicValueUsingBounds) {
+        // First try to see if we can fully compose and simplify the affine
+        // expression as a fast track.
+        FailureOr<uint64_t> delta =
+            affine::fullyComposeAndComputeConstantDelta(indexA, indexB);
+        if (succeeded(delta) && *delta != 0)
+          return true;
+
+        FailureOr<bool> testEqual =
+            ValueBoundsConstraintSet::areEqual(indexA, indexB);
+        if (succeeded(testEqual) && !testEqual.value())
+          return true;
+      }
     } else {
       // For this dimension, we slice a part of the memref we need to make sure
       // the intervals accessed don't overlap.
-      int64_t distance = std::abs(*indexA - *indexB);
-      if (distance >= transferA.getVectorType().getDimSize(i - rankOffset))
-        return true;
+      int64_t vectorDim = transferA.getVectorType().getDimSize(i - rankOffset);
+      if (cstIndexA.has_value() && cstIndexB.has_value()) {
+        int64_t distance = std::abs(*cstIndexA - *cstIndexB);
+        if (distance >= vectorDim)
+          return true;
+        continue;
+      }
+      if (testDynamicValueUsingBounds) {
+        // First try to see if we can fully compose and simplify the affine
+        // expression as a fast track.
+        FailureOr<int64_t> delta =
+            affine::fullyComposeAndComputeConstantDelta(indexA, indexB);
+        if (succeeded(delta) && std::abs(*delta) >= vectorDim)
+          return true;
+
+        FailureOr<int64_t> computeDelta =
+            ValueBoundsConstraintSet::computeConstantDelta(indexA, indexB);
+        if (succeeded(computeDelta)) {
+          if (std::abs(computeDelta.value()) >= vectorDim)
+            return true;
+        }
+      }
     }
   }
   return false;
 }
 
 bool mlir::vector::isDisjointTransferSet(VectorTransferOpInterface transferA,
-                                         VectorTransferOpInterface transferB) {
+                                         VectorTransferOpInterface transferB,
+                                         bool testDynamicValueUsingBounds) {
   if (transferA.source() != transferB.source())
     return false;
-  return isDisjointTransferIndices(transferA, transferB);
+  return isDisjointTransferIndices(transferA, transferB,
+                                   testDynamicValueUsingBounds);
 }
 
 // Helper to iterate over n-D vector slice elements. Calculate the next
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index 603b88f11c8e007..a5f1b28152b9bde 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -142,7 +142,8 @@ void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) {
       // Don't need to consider disjoint accesses.
       if (vector::isDisjointTransferSet(
               cast<VectorTransferOpInterface>(write.getOperation()),
-              cast<VectorTransferOpInterface>(transferOp.getOperation())))
+              cast<VectorTransferOpInterface>(transferOp.getOperation()),
+              /*testDynamicValueUsingBounds=*/true))
         continue;
     }
     blockingAccesses.push_back(user);
@@ -217,7 +218,8 @@ void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
       // the write.
       if (vector::isDisjointTransferSet(
               cast<VectorTransferOpInterface>(write.getOperation()),
-              cast<VectorTransferOpInterface>(read.getOperation())))
+              cast<VectorTransferOpInterface>(read.getOperation()),
+              /*testDynamicValueUsingBounds=*/true))
         continue;
       if (write.getSource() == read.getSource() &&
           dominators.dominates(write, read) && checkSameValueRAW(write, read)) {
diff --git a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
index c00ee0315a9639a..ff941115219f68b 100644
--- a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
+++ b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
@@ -484,25 +484,32 @@ FailureOr<int64_t> ValueBoundsConstraintSet::computeConstantBound(
   return failure();
 }
 
-FailureOr<bool>
-ValueBoundsConstraintSet::areEqual(Value value1, Value value2,
-                                   std::optional<int64_t> dim1,
-                                   std::optional<int64_t> dim2) {
+FailureOr<int64_t>
+ValueBoundsConstraintSet::computeConstantDelta(Value value1, Value value2,
+                                               std::optional<int64_t> dim1,
+                                               std::optional<int64_t> dim2) {
 #ifndef NDEBUG
   assertValidValueDim(value1, dim1);
   assertValidValueDim(value2, dim2);
 #endif // NDEBUG
 
-  // Subtract the two values/dimensions from each other. If the result is 0,
-  // both are equal.
   Builder b(value1.getContext());
   AffineMap map = AffineMap::get(/*dimCount=*/2, /*symbolCount=*/0,
                                  b.getAffineDimExpr(0) - b.getAffineDimExpr(1));
-  FailureOr<int64_t> bound = computeConstantBound(
-      presburger::BoundType::EQ, map, {{value1, dim1}, {value2, dim2}});
-  if (failed(bound))
+  return computeConstantBound(presburger::BoundType::EQ, map,
+                              {{value1, dim1}, {value2, dim2}});
+}
+
+FailureOr<bool>
+ValueBoundsConstraintSet::areEqual(Value value1, Value value2,
+                                   std::optional<int64_t> dim1,
+                                   std::optional<int64_t> dim2) {
+  // Subtract the two values/dimensions from each other. If the result is 0,
+  // both are equal.
+  FailureOr<int64_t> delta = computeConstantDelta(value1, value2, dim1, dim2);
+  if (failed(delta))
     return failure();
-  return *bound == 0;
+  return *delta == 0;
 }
 
 ValueBoundsConstraintSet::BoundBuilder &
diff --git a/mlir/test/Dialect/Linalg/hoisting.mlir b/mlir/test/Dialect/Linalg/hoisting.mlir
index 7d0c3648c344b1d..11bf4b58b95c82e 100644
--- a/mlir/test/Dialect/Linalg/hoisting.mlir
+++ b/mlir/test/Dialect/Linalg/hoisting.mlir
@@ -872,3 +872,135 @@ transform.sequence failures(propagate) {
   transform.structured.hoist_redundant_vector_transfers %0
     : (!transform.any_op) -> !transform.any_op
 }
+
+// -----
+
+// Test that we can hoist out 1-D read-write pairs whose indices are dynamic values.
+
+// CHECK: #[[$MAP1:.+]] = affine_map<()[s0] -> (s0 + 1)>
+// CHECK: #[[$MAP4:.+]] = affine_map<()[s0] -> (s0 + 4)>
+
+//   CHECK-LABEL: func.func @hoist_vector_transfer_pairs_disjoint_dynamic
+//    CHECK-SAME: (%[[BUFFER:.+]]: memref<?x?xf32>, %{{.+}}: index, %{{.+}}: index, %{{.+}}: index, %[[I0:.+]]: index)
+
+//         CHECK:   %[[PLUS1:.+]] = affine.apply #[[$MAP1]]()[%[[I0]]]
+//         CHECK:   %[[PLUS4:.+]] = affine.apply #[[$MAP4]]()[%[[I0]]]
+//         CHECK:   %2 = vector.transfer_read %[[BUFFER]][%[[I0]], %[[I0]]]
+//         CHECK:   %3 = vector.transfer_read %[[BUFFER]][%[[PLUS1]], %[[I0]]]
+//         CHECK:   %4 = vector.transfer_read %[[BUFFER]][%[[PLUS1]], %[[PLUS4]]]
+// CHECK-COUNT-2:   scf.for %{{.+}} = {{.+}} -> (vector<4xf32>, vector<4xf32>, vector<4xf32>)
+// CHECK-COUNT-3:     "some_use"
+// CHECK-COUNT-2:   scf.yield {{.+}} : vector<4xf32>, vector<4xf32>, vector<4xf32>
+//         CHECK:   vector.transfer_write %{{.+}}, %[[BUFFER]][%[[PLUS1]], %[[PLUS4]]]
+//         CHECK:   vector.transfer_write %{{.+}}, %[[BUFFER]][%[[PLUS1]], %[[I0]]]
+//         CHECK:   vector.transfer_write %{{.+}}, %[[BUFFER]][%[[I0]], %[[I0]]]
+
+func.func @hoist_vector_transfer_pairs_disjoint_dynamic(
+    %buffer: memref<?x?xf32>, %lb : index, %ub : index, %step: index, %i0 : index) {
+  %cst = arith.constant 0.0 : f32
+  %i1 = affine.apply affine_map<(d0) -> (d0 + 1)>(%i0)
+  %i2 = affine.apply affine_map<(d0) -> (d0 + 4)>(%i0)
+
+  scf.for %i = %lb to %ub step %step {
+    scf.for %j = %lb to %ub step %step {
+      %r0 = vector.transfer_read %buffer[%i0, %i0], %cst: memref<?x?xf32>, vector<4xf32>
+      // Disjoint leading dim
+      %r1 = vector.transfer_read %buffer[%i1, %i0], %cst: memref<?x?xf32>, vector<4xf32>
+      // Non-overlap trailing dim
+      %r2 = vector.transfer_read %buffer[%i1, %i2], %cst: memref<?x?xf32>, vector<4xf32>
+      %u0 = "some_use"(%r0) : (vector<4xf32>) -> vector<4xf32>
+      %u1 = "some_use"(%r1) : (vector<4xf32>) -> vector<4xf32>
+      %u2 = "some_use"(%r2) : (vector<4xf32>) -> vector<4xf32>
+      vector.transfer_write %u0, %buffer[%i0, %i0] : vector<4xf32>, memref<?x?xf32>
+      vector.transfer_write %u1, %buffer[%i1, %i0] : vector<4xf32>, memref<?x?xf32>
+      vector.transfer_write %u2, %buffer[%i1, %i2] : vector<4xf32>, memref<?x?xf32>
+    }
+  }
+  return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+  %0 = transform.structured.match ops{["func.func"]} in %arg1
+    : (!transform.any_op) -> !transform.any_op
+  transform.structured.hoist_redundant_vector_transfers %0
+    : (!transform.any_op) -> !transform.any_op
+}
+
+// -----
+
+// Test that we cannot hoist out read-write pairs whose indices are overlapping.
+
+//   CHECK-LABEL: func.func @hoist_vector_transfer_pairs_overlapping_dynamic
+// CHECK-COUNT-2:   scf.for
+// CHECK-COUNT-2:     vector.transfer_read
+// CHECK-COUNT-2:     vector.transfer_write
+
+func.func @hoist_vector_transfer_pairs_overlapping_dynamic(
+    %buffer: memref<?x?xf32>, %lb : index, %ub : index, %step: index, %i0 : index) {
+  %cst = arith.constant 0.0 : f32
+  %i1 = affine.apply affine_map<(d0) -> (d0 + 3)>(%i0)
+
+  scf.for %i = %lb to %ub step %step {
+    scf.for %j = %lb to %ub step %step {
+      %r0 = vector.transfer_read %buffer[%i0, %i0], %cst: memref<?x?xf32>, vector<4xf32>
+      // Overlapping range with the above
+      %r1 = vector.transfer_read %buffer[%i0, %i1], %cst: memref<?x?xf32>, vector<4xf32>
+      %u0 = "some_use"(%r0) : (vector<4xf32>) -> vector<4xf32>
+      %u1 = "some_use"(%r1) : (vector<4xf32>) -> vector<4xf32>
+      vector.transfer_write %u0, %buffer[%i0, %i0] : vector<4xf32>, memref<?x?xf32>
+      vector.transfer_write %u1, %buffer[%i0, %i1] : vector<4xf32>, memref<?x?xf32>
+    }
+  }
+  return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+  %0 = transform.structured.match ops{["func.func"]} in %arg1
+    : (!transform.any_op) -> !transform.any_op
+  transform.structured.hoist_redundant_vector_transfers %0
+    : (!transform.any_op) -> !transform.any_op
+}
+
+// -----
+
+// Test that we can hoist out 2-D read-write pairs whose indices are dynamic values.
+
+//   CHECK-LABEL: func.func @hoist_vector_transfer_pairs_disjoint_dynamic
+// CHECK-COUNT-3:   vector.transfer_read
+// CHECK-COUNT-2:   %{{.+}}:3 = scf.for {{.+}} -> (vector<16x8xf32>, vector<16x8xf32>, vector<16x8xf32>)
+// CHECK-COUNT-2:   scf.yield {{.+}} : vector<16x8xf32>, vector<16x8xf32>, vector<16x8xf32>
+// CHECK-COUNT-3:   vector.transfer_write
+//         CHECK:   return
+
+func.func @hoist_vector_transfer_pairs_disjoint_dynamic(
+    %buffer: memref<?x?xf32>, %lb : index, %ub : index, %step: index, %i0 : index, %i1 : index) {
+  %cst = arith.constant 0.0 : f32
+  %i2 = affine.apply affine_map<(d0) -> ((d0 floordiv 32) * 16)>(%i1)
+  %i3 = affine.apply affine_map<(d0) -> ((d0 floordiv 32) * 16 + 8)>(%i1)
+  %i4 = affine.apply affine_map<(d0) -> ((d0 floordiv 32) * 16 + 16)>(%i1)
+
+  scf.for %i = %lb to %ub step %step {
+    scf.for %j = %lb to %ub step %step {
+      %r0 = vector.transfer_read %buffer[%i0, %i2], %cst: memref<?x?xf32>, vector<16x8xf32>
+      %r1 = vector.transfer_read %buffer[%i0, %i3], %cst: memref<?x?xf32>, vector<16x8xf32>
+      %r2 = vector.transfer_read %buffer[%i0, %i4], %cst: memref<?x?xf32>, vector<16x8xf32>
+      %u0 = "some_use"(%r0) : (vector<16x8xf32>) -> vector<16x8xf32>
+      %u1 = "some_use"(%r1) : (vector<16x8xf32>) -> vector<16x8xf32>
+      %u2 = "some_use"(%r2) : (vector<16x8xf32>) -> vector<16x8xf32>
+      vector.transfer_write %u2, %buffer[%i0, %i4] : vector<16x8xf32>, memref<?x?xf32>
+      vector.transfer_write %u1, %buffer[%i0, %i3] : vector<16x8xf32>, memref<?x?xf32>
+      vector.transfer_write %u0, %buffer[%i0, %i2] : vector<16x8xf32>, memref<?x?xf32>
+    }
+  }
+  return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+  %0 = transform.structured.match ops{["func.func"]} in %arg1
+    : (!transform.any_op) -> !transform.any_op
+  transform.structured.hoist_redundant_vector_transfers %0
+    : (!transform.any_op) -> !transform.any_op
+}
diff --git a/mlir/test/Dialect/Vector/vector-transferop-opt.mlir b/mlir/test/Dialect/Vector/vector-transferop-opt.mlir
index f43367ab4aeba7d..13957af014b89ed 100644
--- a/mlir/test/Dialect/Vector/vector-transferop-opt.mlir
+++ b/mlir/test/Dialect/Vector/vector-transferop-opt.mlir
@@ -256,3 +256,107 @@ func.func @collapse_shape(%in_0: memref<1x20x1xi32>, %vec: vector<4xi32>) {
   }
   return
 }
+
+// CHECK-LABEL: func @forward_dead_store_dynamic_same_index
+//   CHECK-NOT:   vector.transfer_write
+//   CHECK-NOT:   vector.transfer_read
+//       CHECK:   scf.for
+//       CHECK:   }
+//       CHECK:   vector.transfer_write
+//       CHECK:   return
+func.func @forward_dead_store_dynamic_same_index(
+    %buffer : memref<?x?xf32>, %v0 : vector<4xf32>, %v1 : vector<4xf32>, %i : index) {
+  %c1 = arith.constant 1 : index
+  %c4 = arith.constant 4 : index
+  %c0 = arith.constant 0 : index
+  %cf0 = arith.constant 0.0 : f32
+  vector.transfer_write %v0, %buffer[%i, %i] {in_bounds = [true]} : vector<4xf32>, memref<?x?xf32>
+  // The following transfer op reads/writes to the same address so that we can forward.
+  %0 = vector.transfer_read %buffer[%i, %i], %cf0 {in_bounds = [true]} : memref<?x?xf32>, vector<4xf32>
+  %x = scf.for %i0 = %c0 to %c4 step %c1 iter_args(%acc = %0) -> (vector<4xf32>) {
+    %1 = arith.addf %acc, %acc : vector<4xf32>
+    scf.yield %1 : vector<4xf32>
+  }
+  vector.transfer_write %x, %buffer[%i, %i] {in_bounds = [true]} : vector<4xf32>, memref<?x?xf32>
+  return
+}
+
+//   CHECK-LABEL: func @dont_forward_dead_store_dynamic_overlap
+// CHECK-COUNT-2:   vector.transfer_write
+//         CHECK:   vector.transfer_read
+//         CHECK:   scf.for
+//         CHECK:   }
+//         CHECK:   vector.transfer_write
+//         CHECK:   return
+func.func @dont_forward_dead_store_dynamic_overlap(
+    %buffer : memref<?x?xf32>, %v0 : vector<4xf32>, %v1 : vector<4xf32>, %i0 : index) {
+  %c1 = arith.constant 1 : index
+  %c4 = arith.constant 4 : index
+  %c0 = arith.constant 0 : index
+  %cf0 = arith.constant 0.0 : f32
+  %i1 = affine.apply affine_map<(d0) -> (d0 + 3)>(%i0)
+  vector.transfer_write %v0, %buffer[%i0, %i0] {in_bounds = [true]} : vector<4xf32>, memref<?x?xf32>
+  // The following transfer op writes to an overlapping range so we cannot forward.
+  vector.transfer_write %v0, %buffer[%i0, %i1] {in_bounds = [true]} : vector<4xf32>, memref<?x?xf32>
+  %0 = vector.transfer_read %buffer[%i0, %i0], %cf0 {in_bounds = [true]} : memref<?x?xf32>, vector<4xf32>
+  %x = scf.for %iv = %c0 to %c4 step %c1 iter_args(%acc = %0) -> (vector<4xf32>) {
+    %1 = arith.addf %acc, %acc : vector<4xf32>
+    scf.yield %1 : vector<4xf32>
+  }
+  vector.transfer_write %x, %buffer[%i0, %i0] {in_bounds = [true]} : vector<4xf32>, memref<?x?xf32>
+  return
+}
+
+// CHECK-LABEL: func @forward_dead_store_dynamic_non_overlap_leading_dim
+//       CHECK:   vector.transfer_write
+//   CHECK-NOT:   vector.transfer_write
+//   CHECK-NOT:   vector.transfer_read
+//       CHECK:   scf.for
+//       CHECK:   }
+//       CHECK:   vector.transfer_write
+//       CHECK:   return
+func.func @forward_dead_store_dynamic_non_overlap_leading_dim(
+    %buffer : memref<?x?xf32>, %v0 : vector<4xf32>, %v1 : vector<4xf32>, %i0 : index) {
+  %c1 = arith.constant 1 : index
+  %c4 = arith.constant 4 : index
+  %c0 = arith.constant 0 : index
+  %cf0 = arith.constant 0.0 : f32
+  %i1 = affine.apply affine_map<(d0) -> (d0 + 1)>(%i0)
+  vector.transfer_write %v0, %buffer[%i0, %i0] {in_bounds = [true]} : vector<4xf32>, memref<?x?xf32>
+  // The following transfer op writes to an non-overlapping range so we can forward.
+  vector.transfer_write %v0, %buffer[%i1, %i0] {in_bounds = [true]} : vector<4xf32>, memref<?x?xf32>
+  %0 = vector.transfer_read %buffer[%i0, %i0], %cf0 {in_bounds = [true]} : memref<?x?xf32>, vector<4xf32>
+  %x = scf.for %iv = %c0 to %c4 step %c1 iter_args(%acc = %0) -> (vector<4xf32>) {
+    %1 = arith.addf %acc, %acc : vector<4xf32>
+    scf.yield %1 : vector<4xf32>
+  }
+  vector.transfer_write %x, %buffer[%i0, %i0] {in_bounds = [true]} : vector<4xf32>, memref<?x?xf32>
+  return
+}
+
+// CHECK-LABEL: func @forward_dead_store_dynamic_non_overlap_trailing_dim
+//       CHECK:   vector.transfer_write
+//   CHECK-NOT:   vector.transfer_write
+//   CHECK-NOT:   vector.transfer_read
+//       CHECK:   scf.for
+//       CHECK:   }
+//       CHECK:   vector.transfer_write
+//       CHECK:   return
+func.func @forward_dead_store_dynamic_non_overlap_trailing_dim(
+    %buffer : memref<?x?xf32>, %v0 : vector<4xf32>, %v1 : vector<4xf32>, %i0 : index) {
+  %c1 = arith.constant 1 : index
+  %c4 = arith.constant 4 : index
+  %c0 = arith.constant 0 : index
+  %cf0 = arith.constant 0.0 : f32
+  %i1 = affine.apply affine_map<(d0) -> (d0 + 4)>(%i0)
+  vector.transfer_write %v0, %buffer[%i0, %i0] {in_bounds = [true]} : vector<4xf32>, memref<?x?xf32>
+  // The following transfer op writes to an non-overlapping range so we can forward.
+  vector.transfer_write %v0, %buffer[%i0, %i1] {in_bounds = [true]} : vector<4xf32>, memref<?x?xf32>
+  %0 = vector.transfer_read %buffer[%i0, %i0], %cf0 {in_bounds = [true]} : memref<?x?xf32>, vector<4xf32>
+  %x = scf.for %iv = %c0 to %c4 step %c1 iter_args(%acc = %0) -> (vector<4xf32>) {
+    %1 = arith.addf %acc, %acc : vector<4xf32>
+    scf.yield %1 : vector<4xf32>
+  }
+  vector.transfer_write %x, %buffer[%i0, %i0] {in_bounds = [true]} : vector<4xf32>, memref<?x?xf32>
+  return
+}
diff --git a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
index 6e3c3dff759a2ed..2f1631cbdb02e01 100644
--- a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
+++ b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
@@ -187,20 +187,26 @@ static LogicalResult testEquality(func::FuncOp funcOp) {
         op->emitOpError("invalid op");
         return WalkResult::skip();
       }
-      FailureOr<bool> equal = failure();
       if (op->hasAttr("compose")) {
-        equal = affine::fullyComposeAndCheckIfEqual(op->getOperand(0),
-                                                    op->getOperand(1));
-      } else {
-        equal = ValueBoundsConstraintSet::areEqual(op->getOperand(0),
-                                                   op->getOperand(1));
-      }
-      if (failed(equal)) {
-        op->emitError("could not determine equality");
-      } else if (*equal) {
-        op->emitRemark("equal");
+        FailureOr<int64_t> equal = affine::fullyComposeAndComputeConstantDelta(
+            op->getOperand(0), op->getOperand(1));
+        if (failed(equal)) {
+          op->emitError("could not determine equality");
+        } else if (*equal == 0) {
+          op->emitRemark("equal");
+        } else {
+          op->emitRemark("different");
+        }
       } else {
-        op->emitRemark("different");
+        FailureOr<bool> equal = ValueBoundsConstraintSet::areEqual(
+            op->getOperand(0), op->getOperand(1));
+        if (failed(equal)) {
+          op->emitError("could not determine equality");
+        } else if (*equal) {
+          op->emitRemark("equal");
+        } else {
+          op->emitRemark("different");
+        }
       }
     }
     return WalkResult::advance();
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 51ea4a28cc8fa0b..60252d2ef7efc09 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -4422,6 +4422,7 @@ cc_library(
     ]),
     includes = ["include"],
     deps = [
+        ":AffineDialect",
         ":ArithDialect",
         ":ArithUtils",
         ":ControlFlowInterfaces",
@@ -4436,6 +4437,7 @@ cc_library(
         ":SideEffectInterfaces",
         ":Support",
         ":TensorDialect",
+        ":ValueBoundsOpInterface",
         ":VectorAttributesIncGen",
         ":VectorDialectIncGen",
         ":VectorInterfaces",

>From 36dfa7d88c99446c51e00951d8346e8912504bf3 Mon Sep 17 00:00:00 2001
From: Stella Laurenzo <stellaraccident at gmail.com>
Date: Mon, 16 Oct 2023 17:59:39 -0700
Subject: [PATCH 3/5] Revert "[mlir][tosa][linalg] Apply direct tosa -> linalg
 Conv2D lowering (#68304)"

This reverts commit e29a253c9ebaded53a823def985364392c4ba4ec.

Breaking TFLite mobilenet test. Needs triage.
---
 .../Linalg/IR/LinalgNamedStructuredOps.yaml   | 137 ------------------
 .../TosaToLinalg/TosaToLinalgNamed.cpp        |  43 +++---
 .../linalg/opdsl/ops/core_named_ops.py        |  30 ----
 .../TosaToLinalg/tosa-to-linalg-named.mlir    |  20 ++-
 4 files changed, 34 insertions(+), 196 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index cd64b813c11e532..44bcbbab2bbe9de 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -2575,143 +2575,6 @@ structured_op: !LinalgStructuredOpConfig
                     - !ScalarExpression
                       scalar_arg: KZp
 --- !LinalgOpConfig
-metadata: !LinalgOpMetadata
-  name: conv_2d_nhwc_fhwc_q
-  cpp_class_name: Conv2DNhwcFhwcQOp
-  doc: |-
-    Performs 2-D convolution with zero point offsets.
-
-    Layout:
-      * Input: NHWC.
-      * Kernel: FHWC.
-
-    Numeric casting is performed on the operands to the inner multiply, promoting
-    them to the same data type as the accumulator/output. This includes the zero
-    point offsets common to quantized operations.
-  implements:
-  - LinalgConvolutionOpInterface
-structured_op: !LinalgStructuredOpConfig
-  args:
-  - !LinalgOperandDefConfig
-    name: I
-    kind: input_tensor
-    type_var: T1
-    shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> (s0,
-      s1 * s2 + s3 * s4, s5 * s6 + s7 * s8, s9)>
-  - !LinalgOperandDefConfig
-    name: K
-    kind: input_tensor
-    type_var: T2
-    shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> (s10,
-      s3, s7, s9)>
-  - !LinalgOperandDefConfig
-    name: IZp
-    kind: scalar
-    type_var: I32
-  - !LinalgOperandDefConfig
-    name: KZp
-    kind: scalar
-    type_var: I32
-  - !LinalgOperandDefConfig
-    name: O
-    kind: output_tensor
-    type_var: U
-    shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> (s0,
-      s1, s5, s10)>
-  - !LinalgOperandDefConfig
-    name: strides
-    kind: index_attr
-    index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] ->
-      (s2, s6)>
-    default_indices:
-    - 1
-    - 1
-  - !LinalgOperandDefConfig
-    name: dilations
-    kind: index_attr
-    index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] ->
-      (s4, s8)>
-    default_indices:
-    - 1
-    - 1
-  indexing_maps: !LinalgIndexingMapsConfig
-    static_indexing_maps:
-    - affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8,
-      s9, s10] -> (d0, d1 * s2 + d4 * s4, d2 * s6 + d5 * s8, d6)>
-    - affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8,
-      s9, s10] -> (d3, d4, d5, d6)>
-    - affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8,
-      s9, s10] -> ()>
-    - affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8,
-      s9, s10] -> ()>
-    - affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8,
-      s9, s10] -> (d0, d1, d2, d3)>
-  iterator_types:
-  - parallel
-  - parallel
-  - parallel
-  - parallel
-  - reduction
-  - reduction
-  - reduction
-  assignments:
-  - !ScalarAssign
-    arg: O
-    value: !ScalarExpression
-      scalar_fn:
-        kind: binary
-        fn_name: add
-        operands:
-        - !ScalarExpression
-          scalar_arg: O
-        - !ScalarExpression
-          scalar_fn:
-            kind: binary
-            fn_name: mul
-            operands:
-            - !ScalarExpression
-              scalar_fn:
-                kind: binary
-                fn_name: sub
-                operands:
-                - !ScalarExpression
-                  scalar_fn:
-                    kind: type
-                    fn_name: cast_signed
-                    type_var: U
-                    operands:
-                    - !ScalarExpression
-                      scalar_arg: I
-                - !ScalarExpression
-                  scalar_fn:
-                    kind: type
-                    fn_name: cast_signed
-                    type_var: U
-                    operands:
-                    - !ScalarExpression
-                      scalar_arg: IZp
-            - !ScalarExpression
-              scalar_fn:
-                kind: binary
-                fn_name: sub
-                operands:
-                - !ScalarExpression
-                  scalar_fn:
-                    kind: type
-                    fn_name: cast_signed
-                    type_var: U
-                    operands:
-                    - !ScalarExpression
-                      scalar_arg: K
-                - !ScalarExpression
-                  scalar_fn:
-                    kind: type
-                    fn_name: cast_signed
-                    type_var: U
-                    operands:
-                    - !ScalarExpression
-                      scalar_arg: KZp
---- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: conv_2d_nchw_fchw
   cpp_class_name: Conv2DNchwFchwOp
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index 4214bb57563285c..62ec44bf9c1e1e1 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -248,28 +248,25 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
     pad.resize(pad.size() + 2, 0);
     input = applyPad(loc, input, pad, zeroAttr, rewriter);
 
-    // For Conv3D transpose the kernel to match dimension ordering of the linalg
-    // convolution operation. Conv2D has a 1-1 mapping in linalg so better to
-    // map directly and then transpose later if desired.
-    if (5 == inputTy.getRank()) {
-      // TODO(suderman): See if this can be efficiently folded - check whether
-      // the input is used anywhere else, if not fold the constant.
-      SmallVector<int64_t> weightPerm;
-      for (int i = 1; i < resultTy.getRank(); i++)
-        weightPerm.push_back(i);
-      weightPerm.push_back(0);
-
-      SmallVector<int64_t> newWeightShape;
-      for (auto dim : weightPerm)
-        newWeightShape.push_back(weightShape[dim]);
-      auto weightPermAttr = rewriter.getI64TensorAttr(weightPerm);
-      Value weightPermValue =
-          rewriter.create<arith::ConstantOp>(loc, weightPermAttr);
-      Type newWeightTy =
-          RankedTensorType::get(newWeightShape, weightTy.getElementType());
-      weight = rewriter.create<tosa::TransposeOp>(loc, newWeightTy, weight,
-                                                  weightPermValue);
-    }
+    // Transpose the kernel to match dimension ordering of the linalg
+    // convolution operation.
+    // TODO(suderman): See if this can be efficiently folded - check whether
+    // the input is used anywhere else, if not fold the constant.
+    SmallVector<int64_t> weightPerm;
+    for (int i = 1; i < resultTy.getRank(); i++)
+      weightPerm.push_back(i);
+    weightPerm.push_back(0);
+
+    SmallVector<int64_t> newWeightShape;
+    for (auto dim : weightPerm)
+      newWeightShape.push_back(weightShape[dim]);
+    auto weightPermAttr = rewriter.getI64TensorAttr(weightPerm);
+    Value weightPermValue =
+        rewriter.create<arith::ConstantOp>(loc, weightPermAttr);
+    Type newWeightTy =
+        RankedTensorType::get(newWeightShape, weightTy.getElementType());
+    weight = rewriter.create<tosa::TransposeOp>(loc, newWeightTy, weight,
+                                                weightPermValue);
 
     auto resultZeroAttr = rewriter.getZeroAttr(resultETy);
     Value emptyTensor = rewriter.create<tensor::EmptyOp>(
@@ -980,7 +977,7 @@ void mlir::tosa::populateTosaToLinalgNamedConversionPatterns(
     RewritePatternSet *patterns) {
   patterns->add<
       // clang-format off
-      ConvConverter<tosa::Conv2DOp, linalg::Conv2DNhwcFhwcOp, linalg::Conv2DNhwcFhwcQOp>,
+      ConvConverter<tosa::Conv2DOp, linalg::Conv2DNhwcHwcfOp, linalg::Conv2DNhwcHwcfQOp>,
       ConvConverter<tosa::Conv3DOp, linalg::Conv3DNdhwcDhwcfOp, linalg::Conv3DNdhwcDhwcfQOp>,
       DepthwiseConvConverter,
       MatMulConverter,
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
index a8f8f8e0fbd68b4..6eae3d916c92882 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
@@ -693,36 +693,6 @@ def conv_2d_nhwc_hwcf_q(
     ) * (TypeFn.cast_signed(U, K[D.kh, D.kw, D.c, D.f]) - TypeFn.cast_signed(U, KZp))
 
 
- at linalg_structured_op
-def conv_2d_nhwc_fhwc_q(
-    I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.C),
-    K=TensorDef(T2, S.F, S.KH, S.KW, S.C),
-    IZp=ScalarDef(I32),
-    KZp=ScalarDef(I32),
-    O=TensorDef(U, S.N, S.OH, S.OW, S.F, output=True),
-    strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
-    dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]),
-):
-    """Performs 2-D convolution with zero point offsets.
-
-    Layout:
-      * Input: NHWC.
-      * Kernel: FHWC.
-
-    Numeric casting is performed on the operands to the inner multiply, promoting
-    them to the same data type as the accumulator/output. This includes the zero
-    point offsets common to quantized operations.
-    """
-    implements(ConvolutionOpInterface)
-    domain(D.n, D.oh, D.ow, D.f, D.kh, D.kw, D.c)
-    O[D.n, D.oh, D.ow, D.f] += (
-        TypeFn.cast_signed(
-            U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]
-        )
-        - TypeFn.cast_signed(U, IZp)
-    ) * (TypeFn.cast_signed(U, K[D.f, D.kh, D.kw, D.c]) - TypeFn.cast_signed(U, KZp))
-
-
 @linalg_structured_op
 def conv_2d_nchw_fchw(
     I=TensorDef(T1, S.N, S.C, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW),
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
index b601bfb28a4f280..bf970c84832e9e5 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
@@ -363,11 +363,13 @@ func.func @avg_pool_dyn(%arg0: tensor<?x6x34x62xf32>) -> (tensor<?x5x33x62xf32>)
 
 // CHECK-LABEL: @conv2d_i8
 func.func @conv2d_i8(%input: tensor<1x49x42x27xi8>, %weights: tensor<28x1x1x27xi8>, %bias: tensor<28xi8>) -> () {
+  // CHECK: %[[PERM:.+]] = arith.constant dense<[1, 2, 3, 0]>
+  // CHECK: %[[W:.+]] = tosa.transpose %arg1, %[[PERM]]
   // CHECK: %[[M_IN:.+]] = tensor.empty()
   // CHECK: %[[CST:.+]] = arith.constant 0
   // CHECK: %[[FILL:.+]] = linalg.fill
   // CHECK: %[[B_IN:.+]] = tensor.empty()
-  // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_fhwc_q {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1, %c0_i32_0, %c0_i32_1 : tensor<1x49x42x27xi8>, tensor<28x1x1x27xi8>, i32, i32) outs(%[[FILL]] : tensor<1x45x40x28xi32>) -> tensor<1x45x40x28xi32>
+  // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf_q {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %[[W]], %c0_i32_0, %c0_i32_1 : tensor<1x49x42x27xi8>, tensor<1x1x27x28xi8>, i32, i32) outs(%[[FILL]] : tensor<1x45x40x28xi32>) -> tensor<1x45x40x28xi32>
   // CHECK: %[[B:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, %[[CONV]] : tensor<28xi8>, tensor<1x45x40x28xi32>) outs(%[[B_IN]] : tensor<1x45x40x28xi32>)
   // CHECK:   arith.extsi
   // CHECK:   arith.addi
@@ -383,11 +385,13 @@ func.func @conv2d_i8(%input: tensor<1x49x42x27xi8>, %weights: tensor<28x1x1x27xi
 
 // CHECK-LABEL: @conv2d_f32
 func.func @conv2d_f32(%input: tensor<1x49x42x27xf32>, %weights: tensor<28x3x3x27xf32>, %bias: tensor<28xf32>) -> () {
+  // CHECK: %[[PERM:.+]] = arith.constant dense<[1, 2, 3, 0]>
+  // CHECK: %[[W:.+]] = tosa.transpose %arg1, %[[PERM]]
   // CHECK: %[[M_IN:.+]] = tensor.empty()
   // CHECK: %[[CST:.+]] = arith.constant 0
   // CHECK: %[[FILL:.+]] = linalg.fill
   // CHECK: %[[B_IN:.+]] = tensor.empty()
-  // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_fhwc {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<1x49x42x27xf32>, tensor<28x3x3x27xf32>) outs(%[[FILL]] : tensor<1x45x40x28xf32>)
+  // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %[[W]] : tensor<1x49x42x27xf32>, tensor<3x3x27x28xf32>) outs(%[[FILL]] : tensor<1x45x40x28xf32>)
   // CHECK: %[[B:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, %[[CONV]] : tensor<28xf32>, tensor<1x45x40x28xf32>) outs(%[[B_IN]] : tensor<1x45x40x28xf32>)
   // CHECK:   arith.addf
   // CHECK:   linalg.yield
@@ -404,11 +408,13 @@ func.func @conv2d_f32(%input: tensor<1x49x42x27xf32>, %weights: tensor<28x3x3x27
 func.func @conv2d_dyn(%input: tensor<?x49x42x27xf32>, %weights: tensor<28x3x3x27xf32>, %bias: tensor<28xf32>) -> () {
   // CHECK: %[[C0:.+]] = arith.constant 0
   // CHECK: %[[BATCH:.+]] = tensor.dim %arg0, %[[C0]]
+  // CHECK: %[[PERM:.+]] = arith.constant dense<[1, 2, 3, 0]>
+  // CHECK: %[[W:.+]] = tosa.transpose %arg1, %[[PERM]]
   // CHECK: %[[M_IN:.+]] = tensor.empty(%[[BATCH]])
   // CHECK: %[[CST:.+]] = arith.constant 0
   // CHECK: %[[FILL:.+]] = linalg.fill
   // CHECK: %[[B_IN:.+]] = tensor.empty(%[[BATCH]])
-  // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_fhwc {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<?x49x42x27xf32>, tensor<28x3x3x27xf32>) outs(%[[FILL]] : tensor<?x45x40x28xf32>)
+  // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %[[W]] : tensor<?x49x42x27xf32>, tensor<3x3x27x28xf32>) outs(%[[FILL]] : tensor<?x45x40x28xf32>)
   // CHECK: %[[B:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, %[[CONV]] : tensor<28xf32>, tensor<?x45x40x28xf32>) outs(%[[B_IN]] : tensor<?x45x40x28xf32>)
   // CHECK:   %[[ADD:.+]] = arith.addf
   // CHECK:   linalg.yield %[[ADD]] : f32
@@ -462,11 +468,13 @@ func.func @conv2d_dyn_w_h(%input: tensor<1x?x?x27xf32>, %weights: tensor<28x3x3x
   // CHECK: %[[W_OUT:.+]] = arith.addi %[[DIVIDED_0]], %[[ONE_0]] : index
 
   // Running convolution
+  // CHECK: %[[PERM:.+]] = arith.constant dense<[1, 2, 3, 0]>
+  // CHECK: %[[WEIGHT:.+]] = tosa.transpose %arg1, %[[PERM]]
   // CHECK: %[[M_IN:.+]] = tensor.empty(%[[H_OUT]], %[[W_OUT]])
   // CHECK: %[[CST:.+]] = arith.constant 0
   // CHECK: %[[FILL:.+]] = linalg.fill
   // CHECK: %[[B_IN:.+]] = tensor.empty(%[[H_OUT]], %[[W_OUT]])
-  // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_fhwc {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<1x?x?x27xf32>, tensor<28x3x3x27xf32>) outs(%[[FILL]] : tensor<1x?x?x28xf32>)
+  // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %[[WEIGHT]] : tensor<1x?x?x27xf32>, tensor<3x3x27x28xf32>) outs(%[[FILL]] : tensor<1x?x?x28xf32>)
   // CHECK: %[[B:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, %[[CONV]] : tensor<28xf32>, tensor<1x?x?x28xf32>) outs(%[[B_IN]] : tensor<1x?x?x28xf32>)
   // CHECK:   %[[ADD:.+]] = arith.addf
   // CHECK:   linalg.yield %[[ADD]] : f32
@@ -481,7 +489,7 @@ func.func @conv2d_padded_f32(%input: tensor<1x47x40x28xf32>, %weights: tensor<28
   // CHECK: %[[C0:.+]] = arith.constant 0
   // CHECK: tensor.pad %arg0 low[0, 1, 1, 0] high[0, 1, 1, 0]
   // CHECK:   tensor.yield %[[C0]]
-  // CHECK: linalg.conv_2d_nhwc_fhwc
+  // CHECK: linalg.conv_2d_nhwc_hwcf
   %0 = tosa.conv2d %input, %weights, %bias {pad = array<i64: 1, 1, 1, 1>, stride = array<i64: 1, 1>, dilation = array<i64: 2, 1>} : (tensor<1x47x40x28xf32>, tensor<28x3x3x28xf32>, tensor<28xf32>) -> tensor<1x45x40x28xf32>
   return
 }
@@ -493,7 +501,7 @@ func.func @conv2d_quant(%arg0 : tensor<1x12x12x1xi8>, %arg1 : tensor<1024x3x3x1x
   // CHECK:   %[[C22:.+]] = arith.constant -22
   // CHECK: tensor.pad %arg0 low[0, 1, 1, 0] high[0, 1, 1, 0]
   // CHECK:   tensor.yield %[[C22]]
-  // CHECK: linalg.conv_2d_nhwc_fhwc_q
+  // CHECK: linalg.conv_2d_nhwc_hwcf_q
   %0 = tosa.conv2d %arg0, %arg1, %arg2 {dilation = array<i64: 1, 1>, pad = array<i64: 1, 1, 1, 1>, quantization_info = #tosa.conv_quant<input_zp = -22, weight_zp = 42>, stride = array<i64: 1, 1>} : (tensor<1x12x12x1xi8>, tensor<1024x3x3x1xi8>, tensor<1024xi32>) -> tensor<1x12x12x1024xi32>
   return
 }

>From b7677049c54b6ecf34fbe1659cac56cd3750bfc1 Mon Sep 17 00:00:00 2001
From: "Balaji V. Iyer" <43187390+bviyer at users.noreply.github.com>
Date: Fri, 13 Oct 2023 00:47:36 -0500
Subject: [PATCH 4/5] [ArmSVE][NVVM][Bazel] Added Features to BUILD.bazel file
 (#68949)

Added VectorOps support for ArmSVE in BUILD.bazel
Added BasicPtxBuilderInterface support for NVVM in build.bazel
---
 .../llvm-project-overlay/mlir/BUILD.bazel     | 80 +++++++++++++++----
 1 file changed, 63 insertions(+), 17 deletions(-)

diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 60252d2ef7efc09..63f9cdafce88b90 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -2097,6 +2097,7 @@ cc_library(
         ":IR",
         ":LLVMDialect",
         ":SideEffectInterfaces",
+        ":VectorDialect",
         "//llvm:Core",
         "//llvm:Support",
     ],
@@ -2109,13 +2110,12 @@ cc_library(
     includes = ["include"],
     deps = [
         ":ArmSVEDialect",
+        ":DialectUtils",
         ":FuncDialect",
         ":IR",
         ":LLVMCommonConversion",
         ":LLVMDialect",
-        ":TransformUtils",
-        "//llvm:Core",
-        "//llvm:Support",
+        ":VectorDialect",
     ],
 )
 
@@ -4818,6 +4818,7 @@ cc_library(
             "lib/Dialect/LLVMIR/IR/NVVM*.cpp",
             "lib/Dialect/LLVMIR/IR/NVVM*.h",
             "lib/Dialect/LLVMIR/IR/ROCDL*.cpp",
+            "lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp",
             "lib/Dialect/LLVMIR/IR/ROCDL*.h",
             "lib/Dialect/LLVMIR/IR/*X86Vector*.cpp",
             "lib/Dialect/LLVMIR/IR/*X86Vector*.h",
@@ -4829,6 +4830,7 @@ cc_library(
             "include/mlir/Dialect/LLVMIR/*AMX*.h",
             "include/mlir/Dialect/LLVMIR/*ArmSVE*.h",
             "include/mlir/Dialect/LLVMIR/NVVM*.h",
+            "include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h",
             "include/mlir/Dialect/LLVMIR/ROCDL*.h",
             "include/mlir/Dialect/LLVMIR/*X86Vector*.h",
         ],
@@ -5770,6 +5772,7 @@ cc_library(
     hdrs = ["include/mlir/Dialect/LLVMIR/NVVMDialect.h"],
     includes = ["include"],
     deps = [
+        ":BasicPtxBuilderInterface",
         ":ConvertToLLVM",
         ":DialectUtils",
         ":GPUDialect",
@@ -5824,11 +5827,25 @@ cc_library(
     ],
 )
 
+td_library(
+    name = "BasicPtxBuilderIntTdFiles",
+    srcs = [
+        "include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.td",
+    ],
+    includes = ["include"],
+    deps = [
+        ":GPUOpsTdFiles",
+        ":LLVMOpsTdFiles",
+        ":OpBaseTdFiles",
+    ],
+)
+
 td_library(
     name = "NVVMOpsTdFiles",
     srcs = ["include/mlir/Dialect/LLVMIR/NVVMOps.td"],
     includes = ["include"],
     deps = [
+        ":BasicPtxBuilderIntTdFiles",
         ":GPUOpsTdFiles",
         ":LLVMOpsTdFiles",
         ":OpBaseTdFiles",
@@ -5836,6 +5853,31 @@ td_library(
     ],
 )
 
+gentbl_cc_library(
+    name = "BasicPtxBuilderIntGen",
+    tbl_outs = [
+        (
+            [
+                "-gen-op-interface-decls",
+            ],
+            "include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h.inc",
+        ),
+        (
+            [
+                "-gen-op-interface-defs",
+            ],
+            "include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.cpp.inc",
+        ),
+    ],
+    tblgen = ":mlir-tblgen",
+    td_file = "include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.td",
+    deps = [
+        ":BasicPtxBuilderIntTdFiles",
+        ":GPUOpsTdFiles",
+        ":LLVMOpsTdFiles",
+    ],
+)
+
 gentbl_cc_library(
     name = "NVVMOpsIncGen",
     tbl_outs = [
@@ -5883,20 +5925,6 @@ gentbl_cc_library(
             ],
             "include/mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc",
         ),
-        (
-            [
-                "-gen-op-interface-decls",
-                "-attrdefs-dialect=nvvm",
-            ],
-            "include/mlir/Dialect/LLVMIR/NVVMOpsInterface.h.inc",
-        ),
-        (
-            [
-                "-gen-op-interface-defs",
-                "-attrdefs-dialect=nvvm",
-            ],
-            "include/mlir/Dialect/LLVMIR/NVVMOpsInterface.cpp.inc",
-        ),
     ],
     tblgen = ":mlir-tblgen",
     td_file = "include/mlir/Dialect/LLVMIR/NVVMOps.td",
@@ -5916,6 +5944,22 @@ gentbl_cc_library(
     deps = [":NVVMOpsTdFiles"],
 )
 
+cc_library(
+    name = "BasicPtxBuilderInterface",
+    srcs = ["lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp"],
+    hdrs = [
+        "include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h",
+    ],
+    includes = ["include"],
+    deps = [
+        ":BasicPtxBuilderIntGen",
+        ":IR",
+        ":LLVMDialect",
+        ":Support",
+    ],
+)
+
+
 cc_library(
     name = "NVVMToLLVM",
     srcs = glob(["lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp"]),
@@ -7998,6 +8042,7 @@ cc_library(
         ":LLVMIntrinsicConversionIncGen",
         ":OpenMPDialect",
         ":Support",
+        ":TransformUtils",
         "//llvm:Core",
         "//llvm:FrontendOpenMP",
         "//llvm:Support",
@@ -8203,6 +8248,7 @@ cc_library(
         ":OpenMPCommon",
         ":Support",
         ":ToLLVMIRTranslation",
+        ":TransformUtils",
         "//llvm:Core",
         "//llvm:FrontendOpenMP",
         "//llvm:Support",

>From 7d0cb451374ff89367dd20a5a5191994629a7841 Mon Sep 17 00:00:00 2001
From: Benoit Jacob <jacob.benoit.1 at gmail.com>
Date: Wed, 18 Oct 2023 14:36:14 -0400
Subject: [PATCH 5/5] allow-empty

---
 mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp | 3 ---
 1 file changed, 3 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index ea50e1232a4c74a..5fde8d71cac3e75 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -227,9 +227,6 @@ mlir::linalg::inferContractionDims(LinalgOp linalgOp) {
       linalgOp, linalgOp.getDpsInputOperand(1), red);
   llvm::set_intersect(ra, rb);
 
-  if (ac.empty() || bc.empty() || ra.empty())
-    return failure();
-
   // Return each set in sorted order.
   ContractionDimensions dimensions{
       SmallVector<unsigned, 2>(batches.begin(), batches.end()),



More information about the llvm-commits mailing list