[Mlir-commits] [mlir] [mlir][vector] Fix off-by-one error in `getTransferChunkAccessed` (PR #70292)

Matthias Springer llvmlistbot at llvm.org
Fri Oct 27 01:36:00 PDT 2023


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/70292

>From 52f0cd36785d881dd03e32292b908afd750add89 Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Fri, 27 Oct 2023 17:35:00 +0900
Subject: [PATCH] [mlir][vector] Fix off-by-one error in
 `getTransferChunkAccessed`

If a dimension does not appear in the permutation map of a vector transfer op, the size of the accessed slice in that dimension is `1`. Before this fix, `getTransferChunkAccessed` used to return `0` for such dimensions, which would means that `0` elements in the underlying tensor/memref are accessed.
---
 .../mlir/Interfaces/VectorInterfaces.td       | 14 ++++----
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp      | 33 ++++++++++---------
 mlir/test/Dialect/Vector/canonicalize.mlir    | 24 ++++++++++++++
 3 files changed, 48 insertions(+), 23 deletions(-)

diff --git a/mlir/include/mlir/Interfaces/VectorInterfaces.td b/mlir/include/mlir/Interfaces/VectorInterfaces.td
index 346a409a3f3e0ef..026faf269f368de 100644
--- a/mlir/include/mlir/Interfaces/VectorInterfaces.td
+++ b/mlir/include/mlir/Interfaces/VectorInterfaces.td
@@ -257,22 +257,22 @@ def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> {
     >,
     InterfaceMethod<
       /*desc=*/[{
-      Return an upper-bound shape accessed by the transfer op within the
-      tensor/memref operand.
+      Return the shape of the hyperrectangular slice within the tensor/memref
+      operand that is accessed by the transfer op.
       For example:
       ```
-        vector.transfer %w0[%i, %j] {
-          permutation_map = affine_map<(d0, d1) -> (d1, d0, 0)>} :
-          tensor<?x?xf32>, vector<4x2x6xf32>
+        vector.transfer %w0[%i, %j, %k] {
+          permutation_map = affine_map<(d0, d1, d2) -> (d1, d0, 0)>} :
+          tensor<?x?x?xf32>, vector<4x2x6xf32>
       ```
-      returns a shape [2, 4].
+      returns a shape [2, 4, 1].
       }],
       /*retTy=*/"SmallVector<int64_t>",
       /*methodName=*/"getTransferChunkAccessed",
       /*args=*/(ins),
       /*methodBody=*/"",
       /*defaultImplementation=*/[{
-        SmallVector<int64_t> dimSizes($_op.getPermutationMap().getNumDims(), 0);
+        SmallVector<int64_t> dimSizes($_op.getPermutationMap().getNumDims(), 1);
         for (auto vecDims : llvm::zip($_op.getPermutationMap().getResults(),
                                       $_op.getVectorType().getShape())) {
           AffineExpr dim = std::get<0>(vecDims);
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index d77476c10908395..f7b15f98e166543 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -4004,35 +4004,36 @@ struct TransferReadAfterWriteToBroadcast
     auto defWrite = readOp.getSource().getDefiningOp<vector::TransferWriteOp>();
     if (!defWrite)
       return failure();
-
-    SmallVector<int64_t> readDims = readOp.getTransferChunkAccessed();
-    Value vec;
-    if (readOp.getIndices() == defWrite.getIndices() &&
-        readOp.getMask() == defWrite.getMask()) {
-      SmallVector<int64_t> writeDims = defWrite.getTransferChunkAccessed();
-      // TODO: If the writeDim is a superset of the read dims we could do an
-      // extract_strided_slice.
-      if (writeDims == readDims)
-        vec = defWrite.getVector();
-    }
+    // TODO: If the written transfer chunk is a superset of the read transfer
+    // chunk we could do an extract_strided_slice.
+    if (readOp.getTransferChunkAccessed() !=
+        defWrite.getTransferChunkAccessed())
+      return failure();
+    // TODO: Support cases where a dim is explicitly written but implicitly
+    // read (i.e., a unit dim that is rank reduced).
+    if (getUnusedDimsBitVector({readOp.getPermutationMap()}) !=
+        getUnusedDimsBitVector({defWrite.getPermutationMap()}))
+      return failure();
+    if (readOp.getIndices() != defWrite.getIndices() ||
+        readOp.getMask() != defWrite.getMask())
+      return failure();
+    Value vec = defWrite.getVector();
     // TODO: loop through the chain of transfer_write if we can prove that they
     // don't overlap with the transfer_read. This requires improving
     // `isDisjointTransferIndices` helper.
-    if (!vec)
-      return failure();
-    SmallVector<unsigned> permutation;
     AffineMap readMap = compressUnusedDims(readOp.getPermutationMap());
     AffineMap writeMap = compressUnusedDims(defWrite.getPermutationMap());
     AffineMap map = readMap.compose(writeMap);
     if (map.getNumResults() == 0)
       return failure();
-    // Calculate the permuation to apply to go from the vector stored to the
+    // Calculate the permutation to apply to go from the vector stored to the
     // vector read.
+    SmallVector<unsigned> permutation;
     if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation))
       return failure();
 
     Location loc = readOp.getLoc();
-    // Calculate the broadcast shape by applying the reverse permuation to the
+    // Calculate the broadcast shape by applying the reverse permutation to the
     // final shape we want.
     ArrayRef<int64_t> destShape = readOp.getVectorType().getShape();
     SmallVector<int64_t> broadcastShape(destShape.size());
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index dd2c78eb44e9f9e..d866c14fcbf2543 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -2400,3 +2400,27 @@ func.func @fold_shape_cast_with_constant_mask() -> vector<4xi1>{
   %2 = vector.shape_cast %1 : vector<4x1x1xi1> to vector<4xi1>
   return %2 : vector<4xi1>
 }
+
+// -----
+
+// TODO: This IR could be canonicalized but the canonicalization pattern is not
+// smart enough. For now, just make sure that we do not crash.
+
+// CHECK-LABEL: func.func @load_store_forwarding_rank_mismatch(
+//       CHECK:   vector.transfer_write
+//       CHECK:   vector.transfer_read
+func.func @load_store_forwarding_rank_mismatch(%v0: vector<4x1x1xf32>, %arg0: tensor<4x4x4xf32>) -> (vector<1x100x4x5xf32>) {
+  %c0 = arith.constant 0 : index
+  %cf0 = arith.constant 0.0 : f32
+  // d0 is explicitly written.
+  %w0 = vector.transfer_write %v0, %arg0[%c0, %c0, %c0]
+      {in_bounds = [true, true, true],
+      permutation_map = affine_map<(d0, d1, d2) -> (d2, d1, d0)>} :
+      vector<4x1x1xf32>, tensor<4x4x4xf32>
+  // d0 is implicitly read (rank-reduction of unit dim).
+  %r = vector.transfer_read %w0[%c0, %c0, %c0], %cf0
+      {in_bounds = [true, true, true, true],
+      permutation_map = affine_map<(d0, d1, d2) -> (d1, 0, d2, 0)>} :
+      tensor<4x4x4xf32>, vector<1x100x4x5xf32>
+  return %r : vector<1x100x4x5xf32>
+}
\ No newline at end of file



More information about the Mlir-commits mailing list