[Mlir-commits] [mlir] [mlir][vector] Relax the requirements on broadcast dims (PR #99341)

Andrzej WarzyƄski llvmlistbot at llvm.org
Mon Sep 30 10:03:15 PDT 2024


https://github.com/banach-space updated https://github.com/llvm/llvm-project/pull/99341

>From c6797cb6b9d1bef13fe10507e60766be80e3a8d1 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Tue, 16 Jul 2024 21:28:01 +0100
Subject: [PATCH 1/2] [mlir][vector] Relax the requirements on broadcast dims

NOTE: This is a follow-up for #97049 in which the `in_bounds` attribute
was made mandatory.

This PR updates the semantics of the `in_bounds` attribute so that
broadcast dimensions are no longer required to be "in bounds".
Specifically, these xfer_read/xfer_write Ops become valid after this
change:

```mlir
  %read = vector.transfer_read %A[%base1, %base2], %pad
      {in_bounds = [false], permutation_map = affine_map<(d0, d1) -> (0)>}
      {permutation_map = affine_map<(d0, d1) -> (0)>}
      : memref<?x?xf32>, vector<9xf32>

  vector.transfer_write %vec, %A[%base1, %base2],
      {in_bounds = [false], permutation_map = affine_map<(d0, d1) -> (0)>}
      {permutation_map = affine_map<(d0, d1) -> (0)>}
      : vector<9xf32>, memref<?x?xf32>
```

Note that the value `false` merely means "may run out-of-bounds", i.e.,
the corresponding access can still be "in bounds". In fact, the folder
for xfer Ops is also updated (*) and will update the attribute value
corresponding to broadcast dims to `true`. Indeed, such dims would
never be out-of-bounds in practice. Still, there's no need to require
Op "users" to always set the corresponding `in_bounds` flag to `true.

Note that this PR doesn't change any of the lowerings. The changes in
"SuperVectorize.cpp", "Vectorization.cpp" and "AffineMap.cpp" are simple
reverts of recent changes in #97049. Those were only meant to facilitate
making `in_bounds` mandatory and to work around the extra requirements
for broadcast dims (those requirements ere removed in this PR). All
changes in tests are also reverts of changes from #97049.

For context, here's a PR in which "broadcast" dims where forced to
always be "in-bounds":
  * https://reviews.llvm.org/D102566

(*) See `foldTransferInBoundsAttribute`.
---
 .../mlir/Dialect/Vector/IR/VectorOps.td       | 24 +++++++++----------
 mlir/include/mlir/IR/AffineMap.h              |  8 -------
 .../mlir/Interfaces/VectorInterfaces.td       |  7 ++----
 .../Affine/Transforms/SuperVectorize.cpp      | 13 +---------
 .../Linalg/Transforms/Vectorization.cpp       | 11 +--------
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp      | 23 ++++++++++--------
 mlir/lib/IR/AffineMap.cpp                     | 13 ----------
 .../Conversion/VectorToSCF/vector-to-scf.mlir |  6 ++---
 .../Affine/SuperVectorize/vectorize_1d.mlir   |  6 ++---
 .../Affine/SuperVectorize/vectorize_2d.mlir   |  4 ++--
 .../vectorize_affine_apply.mlir               |  2 +-
 mlir/test/Dialect/Linalg/hoisting.mlir        |  2 +-
 mlir/test/Dialect/Linalg/vectorization.mlir   |  2 +-
 mlir/test/Dialect/Vector/invalid.mlir         | 10 --------
 mlir/test/Dialect/Vector/ops.mlir             |  2 +-
 .../vector-transfer-permutation-lowering.mlir |  2 --
 .../Vector/vector-transfer-unroll.mlir        |  4 ++--
 .../Dialect/Vector/CPU/transfer-read-1d.mlir  |  2 +-
 .../Dialect/Vector/CPU/transfer-read-2d.mlir  |  6 ++---
 .../Dialect/Vector/CPU/transfer-read-3d.mlir  |  4 ++--
 20 files changed, 49 insertions(+), 102 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index b6d2cc29cd1bf8..45fd1c6e3f9384 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -1290,12 +1290,12 @@ def Vector_TransferReadOp :
     specifies if the transfer is guaranteed to be within the source bounds. If
     set to "false", accesses (including the starting point) may run
     out-of-bounds along the respective vector dimension as the index increases.
-    Non-vector and broadcast dimensions *must* always be in-bounds. The
-    `in_bounds` array length has to be equal to the vector rank. This attribute
-    has a default value: `false` (i.e. "out-of-bounds"). When skipped in the
-    textual IR, the default value is assumed. Similarly, the OP printer will
-    omit this attribute when all dimensions are out-of-bounds (i.e. the default
-    value is used).
+    Non-vector dimensions *must* always be in-bounds. The `in_bounds` array
+    length has to be equal to the vector rank. This attribute has a default
+    value: `false` (i.e. "out-of-bounds"). When skipped in the textual IR, the
+    default value is assumed. Similarly, the OP printer will omit this
+    attribute when all dimensions are out-of-bounds (i.e. the default value is
+    used).
 
     A `vector.transfer_read` can be lowered to a simple load if all dimensions
     are specified to be within bounds and no `mask` was specified.
@@ -1535,12 +1535,12 @@ def Vector_TransferWriteOp :
     specifies if the transfer is guaranteed to be within the source bounds. If
     set to "false", accesses (including the starting point) may run
     out-of-bounds along the respective vector dimension as the index increases.
-    Non-vector and broadcast dimensions *must* always be in-bounds. The
-    `in_bounds` array length has to be equal to the vector rank. This attribute
-    has a default value: `false` (i.e. "out-of-bounds"). When skipped in the
-    textual IR, the default value is assumed. Similarly, the OP printer will
-    omit this attribute when all dimensions are out-of-bounds (i.e. the default
-    value is used).
+    Non-vector dimensions *must* always be in-bounds. The `in_bounds` array
+    length has to be equal to the vector rank. This attribute has a default
+    value: `false` (i.e. "out-of-bounds"). When skipped in the textual IR, the
+    default value is assumed. Similarly, the OP printer will omit this
+    attribute when all dimensions are out-of-bounds (i.e. the default value is
+    used).
 
      A `vector.transfer_write` can be lowered to a simple store if all
      dimensions are specified to be within bounds and no `mask` was specified.
diff --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h
index e30950bbf292d6..f74fc9c3fe7dbd 100644
--- a/mlir/include/mlir/IR/AffineMap.h
+++ b/mlir/include/mlir/IR/AffineMap.h
@@ -146,14 +146,6 @@ class AffineMap {
   /// affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
   bool isMinorIdentity() const;
 
-  /// Returns the list of broadcast dimensions (i.e. dims indicated by value 0
-  /// in the result).
-  /// Ex:
-  ///  * (d0, d1, d2) -> (0, d1) gives [0]
-  ///  * (d0, d1, d2) -> (d2, d1) gives []
-  ///  * (d0, d1, d2, d4) -> (d0, 0, d1, 0) gives [1, 3]
-  SmallVector<unsigned> getBroadcastDims() const;
-
   /// Returns true if this affine map is a minor identity up to broadcasted
   /// dimensions which are indicated by value 0 in the result. If
   /// `broadcastedDims` is not null, it will be populated with the indices of
diff --git a/mlir/include/mlir/Interfaces/VectorInterfaces.td b/mlir/include/mlir/Interfaces/VectorInterfaces.td
index 7ea62c2ae2ab13..be939bad14b7bf 100644
--- a/mlir/include/mlir/Interfaces/VectorInterfaces.td
+++ b/mlir/include/mlir/Interfaces/VectorInterfaces.td
@@ -234,12 +234,9 @@ def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> {
       return constExpr && constExpr.getValue() == 0;
     }
 
-    /// Return "true" if the vector transfer dimension `dim` is in-bounds. Also
-    /// return "true" if the dimension is a broadcast dimension. Return "false"
-    /// otherwise.
+    /// Return "true" if the vector transfer dimension `dim` is in-bounds.
+    /// Return "false" otherwise.
     bool isDimInBounds(unsigned dim) {
-      if ($_op.isBroadcastDim(dim))
-        return true;
       auto inBounds = $_op.getInBounds();
       return ::llvm::cast<::mlir::BoolAttr>(inBounds[dim]).getValue();
     }
diff --git a/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp b/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp
index 6bb8dfecba0ec5..71e9648a5e00fa 100644
--- a/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp
@@ -1223,19 +1223,8 @@ static Operation *vectorizeAffineLoad(AffineLoadOp loadOp,
   LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ permutationMap: ");
   LLVM_DEBUG(permutationMap.print(dbgs()));
 
-  // Make sure that the in_bounds attribute corresponding to a broadcast dim
-  // is set to `true` - that's required by the xfer Op.
-  // FIXME: We're not veryfying whether the corresponding access is in bounds.
-  // TODO: Use masking instead.
-  SmallVector<unsigned> broadcastedDims = permutationMap.getBroadcastDims();
-  SmallVector<bool> inBounds(vectorType.getRank(), false);
-
-  for (auto idx : broadcastedDims)
-    inBounds[idx] = true;
-
   auto transfer = state.builder.create<vector::TransferReadOp>(
-      loadOp.getLoc(), vectorType, loadOp.getMemRef(), indices, permutationMap,
-      inBounds);
+      loadOp.getLoc(), vectorType, loadOp.getMemRef(), indices, permutationMap);
 
   // Register replacement for future uses in the scope.
   state.registerOpVectorReplacement(loadOp, transfer);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index ca85f4b9b9c156..427f2059b5137c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1380,17 +1380,8 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
 
     SmallVector<Value> indices(linalgOp.getShape(opOperand).size(), zero);
 
-    // Make sure that the in_bounds attribute corresponding to a broadcast dim
-    // is `true`
-    SmallVector<unsigned> broadcastedDims = readMap.getBroadcastDims();
-    SmallVector<bool> inBounds(readType.getRank(), false);
-
-    for (auto idx : broadcastedDims)
-      inBounds[idx] = true;
-
     Operation *read = rewriter.create<vector::TransferReadOp>(
-        loc, readType, opOperand->get(), indices, readMap,
-        ArrayRef<bool>(inBounds));
+        loc, readType, opOperand->get(), indices, readMap);
     read = state.maskOperation(rewriter, read, linalgOp, indexingMap);
     Value readValue = read->getResult(0);
 
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index cac6b955457049..8b3bb7f1024880 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -3947,10 +3947,6 @@ verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType,
                            "as permutation_map results: ")
            << AffineMapAttr::get(permutationMap)
            << " vs inBounds of size: " << inBounds.size();
-  for (unsigned int i = 0, e = permutationMap.getNumResults(); i < e; ++i)
-    if (isa<AffineConstantExpr>(permutationMap.getResult(i)) &&
-        !llvm::cast<BoolAttr>(inBounds.getValue()[i]).getValue())
-      return op->emitOpError("requires broadcast dimensions to be in-bounds");
 
   return success();
 }
@@ -4139,17 +4135,24 @@ static LogicalResult foldTransferInBoundsAttribute(TransferOp op) {
   SmallVector<bool, 4> newInBounds;
   newInBounds.reserve(op.getTransferRank());
   for (unsigned i = 0; i < op.getTransferRank(); ++i) {
-    // Already marked as in-bounds, nothing to see here.
+    // 1. Already marked as in-bounds, nothing to see here.
     if (op.isDimInBounds(i)) {
       newInBounds.push_back(true);
       continue;
     }
-    // Currently out-of-bounds, check whether we can statically determine it is
-    // inBounds.
+    // 2. Currently out-of-bounds, check whether we can statically determine it
+    // is inBounds.
+    bool inBounds = false;
     auto dimExpr = dyn_cast<AffineDimExpr>(permutationMap.getResult(i));
-    assert(dimExpr && "Broadcast dims must be in-bounds");
-    auto inBounds =
-        isInBounds(op, /*resultIdx=*/i, /*indicesIdx=*/dimExpr.getPosition());
+    if (dimExpr) {
+      // 2.a Non-broadcast dim
+      inBounds = isInBounds(op, /*resultIdx=*/i,
+                            /*indicesIdx=*/dimExpr.getPosition());
+    } else {
+      // 2.b Broadcast dim
+      inBounds = true;
+    }
+
     newInBounds.push_back(inBounds);
     // We commit the pattern if it is "more inbounds".
     changed |= inBounds;
diff --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp
index ea3c0723b07759..7221f4943eaaf2 100644
--- a/mlir/lib/IR/AffineMap.cpp
+++ b/mlir/lib/IR/AffineMap.cpp
@@ -158,19 +158,6 @@ bool AffineMap::isMinorIdentity() const {
              getMinorIdentityMap(getNumDims(), getNumResults(), getContext());
 }
 
-SmallVector<unsigned> AffineMap::getBroadcastDims() const {
-  SmallVector<unsigned> broadcastedDims;
-  for (const auto &[resIdx, expr] : llvm::enumerate(getResults())) {
-    if (auto constExpr = dyn_cast<AffineConstantExpr>(expr)) {
-      if (constExpr.getValue() != 0)
-        continue;
-      broadcastedDims.push_back(resIdx);
-    }
-  }
-
-  return broadcastedDims;
-}
-
 /// Returns true if this affine map is a minor identity up to broadcasted
 /// dimensions which are indicated by value 0 in the result.
 bool AffineMap::isMinorIdentityWithBroadcasting(
diff --git a/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir b/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
index 30df419822994c..c55a0c558bc2f1 100644
--- a/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
+++ b/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
@@ -133,7 +133,7 @@ func.func @materialize_read(%M: index, %N: index, %O: index, %P: index) {
     affine.for %i1 = 0 to %N {
       affine.for %i2 = 0 to %O {
         affine.for %i3 = 0 to %P step 5 {
-          %f = vector.transfer_read %A[%i0, %i1, %i2, %i3], %f0 {in_bounds = [false, true, false], permutation_map = affine_map<(d0, d1, d2, d3) -> (d3, 0, d0)>} : memref<?x?x?x?xf32>, vector<5x4x3xf32>
+          %f = vector.transfer_read %A[%i0, %i1, %i2, %i3], %f0 {permutation_map = affine_map<(d0, d1, d2, d3) -> (d3, 0, d0)>} : memref<?x?x?x?xf32>, vector<5x4x3xf32>
           // Add a dummy use to prevent dead code elimination from removing
           // transfer read ops.
           "dummy_use"(%f) : (vector<5x4x3xf32>) -> ()
@@ -507,7 +507,7 @@ func.func @transfer_read_with_tensor(%arg: tensor<f32>) -> vector<1xf32> {
     // CHECK-NEXT: %[[RESULT:.*]] = vector.broadcast %[[EXTRACTED]] : f32 to vector<1xf32>
     // CHECK-NEXT: return %[[RESULT]] : vector<1xf32>
     %f0 = arith.constant 0.0 : f32
-    %0 = vector.transfer_read %arg[], %f0 {in_bounds = [true], permutation_map = affine_map<()->(0)>} :
+    %0 = vector.transfer_read %arg[], %f0 {permutation_map = affine_map<()->(0)>} :
       tensor<f32>, vector<1xf32>
     return %0: vector<1xf32>
 }
@@ -746,7 +746,7 @@ func.func @cannot_lower_transfer_read_with_leading_scalable(%arg0: memref<?x4xf3
 func.func @does_not_crash_on_unpack_one_dim(%subview:  memref<1x1x1x1xi32>, %mask: vector<1x1xi1>) -> vector<1x1x1x1xi32> {
   %c0 = arith.constant 0 : index
   %c0_i32 = arith.constant 0 : i32
-  %3 = vector.transfer_read %subview[%c0, %c0, %c0, %c0], %c0_i32, %mask {in_bounds = [false, true, true, false], permutation_map = #map1}
+  %3 = vector.transfer_read %subview[%c0, %c0, %c0, %c0], %c0_i32, %mask {permutation_map = #map1}
           : memref<1x1x1x1xi32>, vector<1x1x1x1xi32>
   return %3 : vector<1x1x1x1xi32>
 }
diff --git a/mlir/test/Dialect/Affine/SuperVectorize/vectorize_1d.mlir b/mlir/test/Dialect/Affine/SuperVectorize/vectorize_1d.mlir
index 0a077624d18f88..9244604128cb72 100644
--- a/mlir/test/Dialect/Affine/SuperVectorize/vectorize_1d.mlir
+++ b/mlir/test/Dialect/Affine/SuperVectorize/vectorize_1d.mlir
@@ -22,7 +22,7 @@ func.func @vec1d_1(%A : memref<?x?xf32>, %B : memref<?x?x?xf32>) {
 // CHECK-NEXT: %{{.*}} = affine.apply #[[$map_id1]](%[[C0]])
 // CHECK-NEXT: %{{.*}} = affine.apply #[[$map_id1]](%[[C0]])
 // CHECK-NEXT: %{{.*}} = arith.constant 0.0{{.*}}: f32
-// CHECK-NEXT: {{.*}} = vector.transfer_read %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}} {in_bounds = [true], permutation_map = #[[$map_proj_d0d1_0]]} : memref<?x?xf32>, vector<128xf32>
+// CHECK-NEXT: {{.*}} = vector.transfer_read %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}} {permutation_map = #[[$map_proj_d0d1_0]]} : memref<?x?xf32>, vector<128xf32>
    affine.for %i0 = 0 to %M { // vectorized due to scalar -> vector
      %a0 = affine.load %A[%c0, %c0] : memref<?x?xf32>
    }
@@ -425,7 +425,7 @@ func.func @vec_rejected_8(%A : memref<?x?xf32>, %B : memref<?x?x?xf32>) {
 // CHECK:     %{{.*}} = affine.apply #[[$map_id1]](%{{.*}})
 // CHECK:     %{{.*}} = affine.apply #[[$map_id1]](%{{.*}})
 // CHECK:     %{{.*}} = arith.constant 0.0{{.*}}: f32
-// CHECK:     {{.*}} = vector.transfer_read %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}} {in_bounds = [true], permutation_map = #[[$map_proj_d0d1_0]]} : memref<?x?xf32>, vector<128xf32>
+// CHECK:     {{.*}} = vector.transfer_read %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}} {permutation_map = #[[$map_proj_d0d1_0]]} : memref<?x?xf32>, vector<128xf32>
    affine.for %i17 = 0 to %M { // not vectorized, the 1-D pattern that matched %{{.*}} in DFS post-order prevents vectorizing %{{.*}}
      affine.for %i18 = 0 to %M { // vectorized due to scalar -> vector
        %a18 = affine.load %A[%c0, %c0] : memref<?x?xf32>
@@ -459,7 +459,7 @@ func.func @vec_rejected_9(%A : memref<?x?xf32>, %B : memref<?x?x?xf32>) {
 // CHECK:      %{{.*}} = affine.apply #[[$map_id1]](%{{.*}})
 // CHECK-NEXT: %{{.*}} = affine.apply #[[$map_id1]](%{{.*}})
 // CHECK-NEXT: %{{.*}} = arith.constant 0.0{{.*}}: f32
-// CHECK-NEXT: {{.*}} = vector.transfer_read %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}} {in_bounds = [true], permutation_map = #[[$map_proj_d0d1_0]]} : memref<?x?xf32>, vector<128xf32>
+// CHECK-NEXT: {{.*}} = vector.transfer_read %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}} {permutation_map = #[[$map_proj_d0d1_0]]} : memref<?x?xf32>, vector<128xf32>
    affine.for %i17 = 0 to %M { // not vectorized, the 1-D pattern that matched %i18 in DFS post-order prevents vectorizing %{{.*}}
      affine.for %i18 = 0 to %M { // vectorized due to scalar -> vector
        %a18 = affine.load %A[%c0, %c0] : memref<?x?xf32>
diff --git a/mlir/test/Dialect/Affine/SuperVectorize/vectorize_2d.mlir b/mlir/test/Dialect/Affine/SuperVectorize/vectorize_2d.mlir
index eb5120a49e3d4b..83916e755363ba 100644
--- a/mlir/test/Dialect/Affine/SuperVectorize/vectorize_2d.mlir
+++ b/mlir/test/Dialect/Affine/SuperVectorize/vectorize_2d.mlir
@@ -123,8 +123,8 @@ func.func @vectorize_matmul(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, %arg
   //      VECT:  affine.for %[[I2:.*]] = #[[$map_id1]](%[[C0]]) to #[[$map_id1]](%[[M]]) step 4 {
   // VECT-NEXT:    affine.for %[[I3:.*]] = #[[$map_id1]](%[[C0]]) to #[[$map_id1]](%[[N]]) step 8 {
   // VECT-NEXT:      affine.for %[[I4:.*]] = #[[$map_id1]](%[[C0]]) to #[[$map_id1]](%[[K]]) {
-  //      VECT:        %[[A:.*]] = vector.transfer_read %{{.*}}[%[[I4]], %[[I3]]], %{{.*}} {in_bounds = [true, false], permutation_map = #[[$map_proj_d0d1_zerod1]]} : memref<?x?xf32>, vector<4x8xf32>
-  //      VECT:        %[[B:.*]] = vector.transfer_read %{{.*}}[%[[I2]], %[[I4]]], %{{.*}} {in_bounds = [false, true], permutation_map = #[[$map_proj_d0d1_d0zero]]} : memref<?x?xf32>, vector<4x8xf32>
+  //      VECT:        %[[A:.*]] = vector.transfer_read %{{.*}}[%[[I4]], %[[I3]]], %{{.*}} {permutation_map = #[[$map_proj_d0d1_zerod1]]} : memref<?x?xf32>, vector<4x8xf32>
+  //      VECT:        %[[B:.*]] = vector.transfer_read %{{.*}}[%[[I2]], %[[I4]]], %{{.*}} {permutation_map = #[[$map_proj_d0d1_d0zero]]} : memref<?x?xf32>, vector<4x8xf32>
   // VECT-NEXT:        %[[C:.*]] = arith.mulf %[[B]], %[[A]] : vector<4x8xf32>
   //      VECT:        %[[D:.*]] = vector.transfer_read %{{.*}}[%[[I2]], %[[I3]]], %{{.*}} : memref<?x?xf32>, vector<4x8xf32>
   // VECT-NEXT:        %[[E:.*]] = arith.addf %[[D]], %[[C]] : vector<4x8xf32>
diff --git a/mlir/test/Dialect/Affine/SuperVectorize/vectorize_affine_apply.mlir b/mlir/test/Dialect/Affine/SuperVectorize/vectorize_affine_apply.mlir
index 16ade6455d6974..15a7133cf0f65f 100644
--- a/mlir/test/Dialect/Affine/SuperVectorize/vectorize_affine_apply.mlir
+++ b/mlir/test/Dialect/Affine/SuperVectorize/vectorize_affine_apply.mlir
@@ -141,7 +141,7 @@ func.func @affine_map_with_expr_2(%arg0: memref<8x12x16xf32>, %arg1: memref<8x24
 // CHECK-NEXT:       %[[S1:.*]] = affine.apply #[[$MAP_ID4]](%[[ARG3]], %[[ARG4]], %[[I0]])
 // CHECK-NEXT:       %[[S2:.*]] = affine.apply #[[$MAP_ID5]](%[[ARG3]], %[[ARG4]], %[[I0]])
 // CHECK-NEXT:       %[[CST:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK-NEXT:       %[[S3:.*]] = vector.transfer_read %[[ARG0]][%[[S0]], %[[S1]], %[[S2]]], %[[CST]] {in_bounds = [true], permutation_map = #[[$MAP_ID6]]} : memref<8x12x16xf32>, vector<8xf32>
+// CHECK-NEXT:       %[[S3:.*]] = vector.transfer_read %[[ARG0]][%[[S0]], %[[S1]], %[[S2]]], %[[CST]] {permutation_map = #[[$MAP_ID6]]} : memref<8x12x16xf32>, vector<8xf32>
 // CHECK-NEXT:       vector.transfer_write %[[S3]], %[[ARG1]][%[[ARG3]], %[[ARG4]], %[[ARG5]]] : vector<8xf32>, memref<8x24x48xf32>
 // CHECK-NEXT:     }
 // CHECK-NEXT:   }
diff --git a/mlir/test/Dialect/Linalg/hoisting.mlir b/mlir/test/Dialect/Linalg/hoisting.mlir
index 44c15c272bb3ef..241b8a486c012e 100644
--- a/mlir/test/Dialect/Linalg/hoisting.mlir
+++ b/mlir/test/Dialect/Linalg/hoisting.mlir
@@ -200,7 +200,7 @@ func.func @hoist_vector_transfer_pairs_in_affine_loops(%memref0: memref<64x64xi3
   affine.for %arg3 = 0 to 64 {
     affine.for %arg4 = 0 to 64 step 16 {
       affine.for %arg5 = 0 to 64 {
-        %0 = vector.transfer_read %memref0[%arg3, %arg5], %c0_i32 {in_bounds = [true], permutation_map = affine_map<(d0, d1) -> (0)>} : memref<64x64xi32>, vector<16xi32>
+        %0 = vector.transfer_read %memref0[%arg3, %arg5], %c0_i32 {permutation_map = affine_map<(d0, d1) -> (0)>} : memref<64x64xi32>, vector<16xi32>
         %1 = vector.transfer_read %memref1[%arg5, %arg4], %c0_i32 : memref<64x64xi32>, vector<16xi32>
         %2 = vector.transfer_read %memref2[%arg3, %arg4], %c0_i32 : memref<64x64xi32>, vector<16xi32>
         %3 = arith.muli %0, %1 : vector<16xi32>
diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index 0e2b2458d29cdb..2464759522c0f8 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -130,7 +130,7 @@ func.func @vectorize_dynamic_1d_broadcast(%arg0: tensor<?xf32>,
 // CHECK-LABEL:   @vectorize_dynamic_1d_broadcast
 // CHECK:           %[[VAL_3:.*]] = arith.constant 0 : index
 // CHECK:           %[[VAL_4:.*]] = tensor.dim %{{.*}}, %[[VAL_3]] : tensor<?xf32>
-// CHECK:           %[[VAL_7:.*]] = vector.transfer_read %{{.*}} {in_bounds = {{.*}}, permutation_map = #{{.*}}} : tensor<?xf32>, vector<4xf32>
+// CHECK:           %[[VAL_7:.*]] = vector.transfer_read %{{.*}} {permutation_map = #{{.*}}} : tensor<?xf32>, vector<4xf32>
 // CHECK:           %[[VAL_9:.*]] = vector.create_mask %[[VAL_4]] : vector<4xi1>
 // CHECK:           %[[VAL_10:.*]] = vector.mask %[[VAL_9]] { vector.transfer_read %{{.*}} {in_bounds = [true]} : tensor<?xf32>, vector<4xf32> } : vector<4xi1> -> vector<4xf32>
 // CHECK:           %[[VAL_12:.*]] = vector.mask %[[VAL_9]] { vector.transfer_read %{{.*}} {in_bounds = [true]} : tensor<?xf32>, vector<4xf32> } : vector<4xi1> -> vector<4xf32>
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index e2bc5ef6128e7d..9319aa0debdd12 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -505,16 +505,6 @@ func.func @test_vector.transfer_read(%arg0: memref<?x?xvector<2x3xf32>>) {
 
 // -----
 
-func.func @test_vector.transfer_read(%arg0: memref<?x?xvector<2x3xf32>>) {
-  %c3 = arith.constant 3 : index
-  %f0 = arith.constant 0.0 : f32
-  %vf0 = vector.splat %f0 : vector<2x3xf32>
-  // expected-error at +1 {{requires broadcast dimensions to be in-bounds}}
-  %0 = vector.transfer_read %arg0[%c3, %c3], %vf0 {in_bounds = [false, true], permutation_map = affine_map<(d0, d1)->(0, d1)>} : memref<?x?xvector<2x3xf32>>, vector<1x1x2x3xf32>
-}
-
-// -----
-
 func.func @test_vector.transfer_read(%arg0: memref<?x?xvector<2x3xf32>>) {
   %c3 = arith.constant 3 : index
   %f0 = arith.constant 0.0 : f32
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 08d1a189231bcc..3baacba9b61243 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -70,7 +70,7 @@ func.func @vector_transfer_ops(%arg0: memref<?x?xf32>,
   // CHECK: vector.transfer_read %{{.*}}[%[[C3]], %[[C3]]], %{{.*}}, %{{.*}} : memref<?x?xf32>, vector<5xf32>
   %8 = vector.transfer_read %arg0[%c3, %c3], %f0, %m : memref<?x?xf32>, vector<5xf32>
   // CHECK: vector.transfer_read %{{.*}}[%[[C3]], %[[C3]], %[[C3]]], %{{.*}}, %{{.*}} : memref<?x?x?xf32>, vector<5x4x8xf32>
-  %9 = vector.transfer_read %arg4[%c3, %c3, %c3], %f0, %m2 {in_bounds = [false, false, true], permutation_map = affine_map<(d0, d1, d2)->(d1, d0, 0)>} : memref<?x?x?xf32>, vector<5x4x8xf32>
+  %9 = vector.transfer_read %arg4[%c3, %c3, %c3], %f0, %m2 {permutation_map = affine_map<(d0, d1, d2)->(d1, d0, 0)>} : memref<?x?x?xf32>, vector<5x4x8xf32>
 
   // CHECK: vector.transfer_write
   vector.transfer_write %0, %arg0[%c3, %c3] {permutation_map = affine_map<(d0, d1)->(d0)>} : vector<128xf32>, memref<?x?xf32>
diff --git a/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir b/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
index 15000d706adfc8..0feaf690af2510 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
@@ -327,13 +327,11 @@ func.func @masked_permutation_xfer_read_fixed_width(
   %c0 = arith.constant 0 : index
   %3 = vector.mask %mask {
     vector.transfer_read %dest[%c0, %c0], %cst {
-      in_bounds = [false, true, false],
       permutation_map = affine_map<(d0, d1) -> (d1, 0, d0)>
     } : tensor<?x1xf32>, vector<1x4x4xf32>
   } : vector<4x1xi1> -> vector<1x4x4xf32>
 
   "test.some_use"(%3) : (vector<1x4x4xf32>) -> ()
-
   return
 }
 
diff --git a/mlir/test/Dialect/Vector/vector-transfer-unroll.mlir b/mlir/test/Dialect/Vector/vector-transfer-unroll.mlir
index 75c5ad26fcf231..9be1a95d0aa01c 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-unroll.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-unroll.mlir
@@ -207,7 +207,7 @@ func.func @transfer_read_unroll_permutation(%mem : memref<6x4xf32>) -> vector<4x
 func.func @transfer_read_unroll_broadcast(%mem : memref<6x4xf32>) -> vector<6x4xf32> {
   %c0 = arith.constant 0 : index
   %cf0 = arith.constant 0.0 : f32
-  %res = vector.transfer_read %mem[%c0, %c0], %cf0 {in_bounds = [true, false], permutation_map = #map0} : memref<6x4xf32>, vector<6x4xf32>
+  %res = vector.transfer_read %mem[%c0, %c0], %cf0 permutation_map = #map0} : memref<6x4xf32>, vector<6x4xf32>
   return %res : vector<6x4xf32>
 }
 
@@ -234,7 +234,7 @@ func.func @transfer_read_unroll_broadcast(%mem : memref<6x4xf32>) -> vector<6x4x
 func.func @transfer_read_unroll_broadcast_permuation(%mem : memref<6x4xf32>) -> vector<4x6xf32> {
   %c0 = arith.constant 0 : index
   %cf0 = arith.constant 0.0 : f32
-  %res = vector.transfer_read %mem[%c0, %c0], %cf0 {in_bounds = [true, false], permutation_map = #map0} : memref<6x4xf32>, vector<4x6xf32>
+  %res = vector.transfer_read %mem[%c0, %c0], %cf0 permutation_map = #map0} : memref<6x4xf32>, vector<4x6xf32>
   return %res : vector<4x6xf32>
 }
 
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/transfer-read-1d.mlir b/mlir/test/Integration/Dialect/Vector/CPU/transfer-read-1d.mlir
index 12b0511d486ea0..8a98d39e657f2c 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/transfer-read-1d.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/transfer-read-1d.mlir
@@ -82,7 +82,7 @@ func.func @transfer_read_1d_broadcast(
     %A : memref<?x?xf32>, %base1 : index, %base2 : index) {
   %fm42 = arith.constant -42.0: f32
   %f = vector.transfer_read %A[%base1, %base2], %fm42
-      {in_bounds = [true], permutation_map = affine_map<(d0, d1) -> (0)>}
+      {permutation_map = affine_map<(d0, d1) -> (0)>}
       : memref<?x?xf32>, vector<9xf32>
   vector.print %f: vector<9xf32>
   return
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/transfer-read-2d.mlir b/mlir/test/Integration/Dialect/Vector/CPU/transfer-read-2d.mlir
index 9f8849fa9a1489..cb8a8ce8ab0b0e 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/transfer-read-2d.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/transfer-read-2d.mlir
@@ -57,7 +57,7 @@ func.func @transfer_read_2d_mask_broadcast(
   %fm42 = arith.constant -42.0: f32
   %mask = arith.constant dense<[1, 0, 1, 0, 1, 1, 1, 0, 1]> : vector<9xi1>
   %f = vector.transfer_read %A[%base1, %base2], %fm42, %mask
-      {in_bounds = [true, false], permutation_map = affine_map<(d0, d1) -> (0, d1)>} :
+      {permutation_map = affine_map<(d0, d1) -> (0, d1)>} :
     memref<?x?xf32>, vector<4x9xf32>
   vector.print %f: vector<4x9xf32>
   return
@@ -69,7 +69,7 @@ func.func @transfer_read_2d_mask_transpose_broadcast_last_dim(
   %fm42 = arith.constant -42.0: f32
   %mask = arith.constant dense<[1, 0, 1, 1]> : vector<4xi1>
   %f = vector.transfer_read %A[%base1, %base2], %fm42, %mask
-      {in_bounds = [false, true], permutation_map = affine_map<(d0, d1) -> (d1, 0)>} :
+      {permutation_map = affine_map<(d0, d1) -> (d1, 0)>} :
     memref<?x?xf32>, vector<4x9xf32>
   vector.print %f: vector<4x9xf32>
   return
@@ -91,7 +91,7 @@ func.func @transfer_read_2d_broadcast(
     %A : memref<?x?xf32>, %base1: index, %base2: index) {
   %fm42 = arith.constant -42.0: f32
   %f = vector.transfer_read %A[%base1, %base2], %fm42
-      {in_bounds = [false, true], permutation_map = affine_map<(d0, d1) -> (d1, 0)>} :
+      {permutation_map = affine_map<(d0, d1) -> (d1, 0)>} :
     memref<?x?xf32>, vector<4x9xf32>
   vector.print %f: vector<4x9xf32>
   return
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/transfer-read-3d.mlir b/mlir/test/Integration/Dialect/Vector/CPU/transfer-read-3d.mlir
index 466afeec459b43..4aecca3d6891eb 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/transfer-read-3d.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/transfer-read-3d.mlir
@@ -32,7 +32,7 @@ func.func @transfer_read_3d_broadcast(%A : memref<?x?x?x?xf32>,
                                  %o: index, %a: index, %b: index, %c: index) {
   %fm42 = arith.constant -42.0: f32
   %f = vector.transfer_read %A[%o, %a, %b, %c], %fm42
-      {in_bounds = [false, true, false], permutation_map = affine_map<(d0, d1, d2, d3) -> (d1, 0, d3)>}
+      {permutation_map = affine_map<(d0, d1, d2, d3) -> (d1, 0, d3)>}
       : memref<?x?x?x?xf32>, vector<2x5x3xf32>
   vector.print %f: vector<2x5x3xf32>
   return
@@ -43,7 +43,7 @@ func.func @transfer_read_3d_mask_broadcast(
   %fm42 = arith.constant -42.0: f32
   %mask = arith.constant dense<[0, 1]> : vector<2xi1>
   %f = vector.transfer_read %A[%o, %a, %b, %c], %fm42, %mask
-      {in_bounds = [false, true, true], permutation_map = affine_map<(d0, d1, d2, d3) -> (d1, 0, 0)>}
+      {permutation_map = affine_map<(d0, d1, d2, d3) -> (d1, 0, 0)>}
       : memref<?x?x?x?xf32>, vector<2x5x3xf32>
   vector.print %f: vector<2x5x3xf32>
   return

>From cfec2c78894dfd008d999b5fe798dcf4b0d2a5b1 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Tue, 23 Jul 2024 20:13:33 +0100
Subject: [PATCH 2/2] Refine how bcast dims are handled

Only mark bcast dims as "in bounds" when all non-bcast dims are "in
bounds".
---
 mlir/include/mlir/IR/AffineMap.h              |  8 ++++++++
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp      | 19 ++++++++++++++++---
 mlir/lib/IR/AffineMap.cpp                     | 13 +++++++++++++
 .../Vector/vector-transfer-unroll.mlir        |  4 ++--
 4 files changed, 39 insertions(+), 5 deletions(-)

diff --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h
index f74fc9c3fe7dbd..e30950bbf292d6 100644
--- a/mlir/include/mlir/IR/AffineMap.h
+++ b/mlir/include/mlir/IR/AffineMap.h
@@ -146,6 +146,14 @@ class AffineMap {
   /// affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
   bool isMinorIdentity() const;
 
+  /// Returns the list of broadcast dimensions (i.e. dims indicated by value 0
+  /// in the result).
+  /// Ex:
+  ///  * (d0, d1, d2) -> (0, d1) gives [0]
+  ///  * (d0, d1, d2) -> (d2, d1) gives []
+  ///  * (d0, d1, d2, d4) -> (d0, 0, d1, 0) gives [1, 3]
+  SmallVector<unsigned> getBroadcastDims() const;
+
   /// Returns true if this affine map is a minor identity up to broadcasted
   /// dimensions which are indicated by value 0 in the result. If
   /// `broadcastedDims` is not null, it will be populated with the indices of
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 8b3bb7f1024880..cfb897039b7eff 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -4134,6 +4134,7 @@ static LogicalResult foldTransferInBoundsAttribute(TransferOp op) {
   bool changed = false;
   SmallVector<bool, 4> newInBounds;
   newInBounds.reserve(op.getTransferRank());
+  SmallVector<unsigned> nonBcastDims;
   for (unsigned i = 0; i < op.getTransferRank(); ++i) {
     // 1. Already marked as in-bounds, nothing to see here.
     if (op.isDimInBounds(i)) {
@@ -4148,15 +4149,27 @@ static LogicalResult foldTransferInBoundsAttribute(TransferOp op) {
       // 2.a Non-broadcast dim
       inBounds = isInBounds(op, /*resultIdx=*/i,
                             /*indicesIdx=*/dimExpr.getPosition());
-    } else {
-      // 2.b Broadcast dim
-      inBounds = true;
+      // 2.b Broadcast dims are handled after processing non-bcast dims
+      // FIXME: constant expr != 0 are not broadcasts - should such
+      // constants be allowed at all?
+      nonBcastDims.push_back(i);
     }
 
     newInBounds.push_back(inBounds);
     // We commit the pattern if it is "more inbounds".
     changed |= inBounds;
   }
+
+  // Handle broadcast dims: if all non-broadcast dims are "in
+  // bounds", then all bcast dims should be "in bounds" as well.
+  bool allNonBcastDimsInBounds = llvm::all_of(
+      nonBcastDims, [&newInBounds](unsigned idx) { return newInBounds[idx]; });
+  if (allNonBcastDimsInBounds)
+    llvm::for_each(permutationMap.getBroadcastDims(), [&](unsigned idx) {
+      changed |= !newInBounds[idx];
+      newInBounds[idx] = true;
+    });
+
   if (!changed)
     return failure();
   // OpBuilder is only used as a helper to build an I64ArrayAttr.
diff --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp
index 7221f4943eaaf2..ea3c0723b07759 100644
--- a/mlir/lib/IR/AffineMap.cpp
+++ b/mlir/lib/IR/AffineMap.cpp
@@ -158,6 +158,19 @@ bool AffineMap::isMinorIdentity() const {
              getMinorIdentityMap(getNumDims(), getNumResults(), getContext());
 }
 
+SmallVector<unsigned> AffineMap::getBroadcastDims() const {
+  SmallVector<unsigned> broadcastedDims;
+  for (const auto &[resIdx, expr] : llvm::enumerate(getResults())) {
+    if (auto constExpr = dyn_cast<AffineConstantExpr>(expr)) {
+      if (constExpr.getValue() != 0)
+        continue;
+      broadcastedDims.push_back(resIdx);
+    }
+  }
+
+  return broadcastedDims;
+}
+
 /// Returns true if this affine map is a minor identity up to broadcasted
 /// dimensions which are indicated by value 0 in the result.
 bool AffineMap::isMinorIdentityWithBroadcasting(
diff --git a/mlir/test/Dialect/Vector/vector-transfer-unroll.mlir b/mlir/test/Dialect/Vector/vector-transfer-unroll.mlir
index 9be1a95d0aa01c..5dd65ea132d080 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-unroll.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-unroll.mlir
@@ -207,7 +207,7 @@ func.func @transfer_read_unroll_permutation(%mem : memref<6x4xf32>) -> vector<4x
 func.func @transfer_read_unroll_broadcast(%mem : memref<6x4xf32>) -> vector<6x4xf32> {
   %c0 = arith.constant 0 : index
   %cf0 = arith.constant 0.0 : f32
-  %res = vector.transfer_read %mem[%c0, %c0], %cf0 permutation_map = #map0} : memref<6x4xf32>, vector<6x4xf32>
+  %res = vector.transfer_read %mem[%c0, %c0], %cf0 {permutation_map = #map0} : memref<6x4xf32>, vector<6x4xf32>
   return %res : vector<6x4xf32>
 }
 
@@ -234,7 +234,7 @@ func.func @transfer_read_unroll_broadcast(%mem : memref<6x4xf32>) -> vector<6x4x
 func.func @transfer_read_unroll_broadcast_permuation(%mem : memref<6x4xf32>) -> vector<4x6xf32> {
   %c0 = arith.constant 0 : index
   %cf0 = arith.constant 0.0 : f32
-  %res = vector.transfer_read %mem[%c0, %c0], %cf0 permutation_map = #map0} : memref<6x4xf32>, vector<4x6xf32>
+  %res = vector.transfer_read %mem[%c0, %c0], %cf0 {permutation_map = #map0} : memref<6x4xf32>, vector<4x6xf32>
   return %res : vector<4x6xf32>
 }
 



More information about the Mlir-commits mailing list