[Mlir-commits] [mlir] 847048f - [mlir][Vector] Fix bug in vector xfer op flattening transformation (#81964)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Feb 22 12:37:36 PST 2024


Author: Diego Caballero
Date: 2024-02-22T12:37:32-08:00
New Revision: 847048f497bcdfcfe52f36cba49f07bdbd63cd24

URL: https://github.com/llvm/llvm-project/commit/847048f497bcdfcfe52f36cba49f07bdbd63cd24
DIFF: https://github.com/llvm/llvm-project/commit/847048f497bcdfcfe52f36cba49f07bdbd63cd24.diff

LOG: [mlir][Vector] Fix bug in vector xfer op flattening transformation (#81964)

It looks like the affine map generated to compute the indices of the
collapsed dimensions used the wrong dim size. For indices `[idx0][idx1]`
we computed the collapsed index as `idx0*size0 + idx1` instead of
`idx0*size1 + idx1`. This led to correctness issues in convolution tests
when enabling this transformation internally.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Utils/IndexingUtils.h
    mlir/lib/Dialect/Utils/IndexingUtils.cpp
    mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
    mlir/test/Dialect/Vector/vector-transfer-flatten.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
index 2453d841f633e4..9892253df2bff1 100644
--- a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
@@ -257,6 +257,9 @@ SmallVector<int64_t> getI64SubArray(ArrayAttr arrayAttr, unsigned dropFront = 0,
 std::pair<AffineExpr, SmallVector<OpFoldResult>>
 computeLinearIndex(OpFoldResult sourceOffset, ArrayRef<OpFoldResult> strides,
                    ArrayRef<OpFoldResult> indices);
+std::pair<AffineExpr, SmallVector<OpFoldResult>>
+computeLinearIndex(OpFoldResult sourceOffset, ArrayRef<int64_t> strides,
+                   ArrayRef<Value> indices);
 
 //===----------------------------------------------------------------------===//
 // Utilities for decomposing larger shapes

diff  --git a/mlir/lib/Dialect/Utils/IndexingUtils.cpp b/mlir/lib/Dialect/Utils/IndexingUtils.cpp
index baaa581ab6f225..4c960659d80cb7 100644
--- a/mlir/lib/Dialect/Utils/IndexingUtils.cpp
+++ b/mlir/lib/Dialect/Utils/IndexingUtils.cpp
@@ -7,13 +7,12 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Utils/IndexingUtils.h"
-
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/MLIRContext.h"
 #include "llvm/ADT/STLExtras.h"
-
 #include <numeric>
 #include <optional>
 
@@ -306,6 +305,14 @@ mlir::computeLinearIndex(OpFoldResult sourceOffset,
   return {expr, values};
 }
 
+std::pair<AffineExpr, SmallVector<OpFoldResult>>
+mlir::computeLinearIndex(OpFoldResult sourceOffset, ArrayRef<int64_t> strides,
+                         ArrayRef<Value> indices) {
+  return computeLinearIndex(
+      sourceOffset, getAsIndexOpFoldResult(sourceOffset.getContext(), strides),
+      getAsOpFoldResult(ValueRange(indices)));
+}
+
 //===----------------------------------------------------------------------===//
 // TileOffsetRange
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index 04e5a816dd91e6..0ffef6aabccc18 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -15,6 +15,7 @@
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
@@ -577,7 +578,6 @@ class FlattenContiguousRowMajorTransferReadPattern
     if (transferReadOp.getMask())
       return failure();
 
-    SmallVector<Value> collapsedIndices;
     int64_t firstDimToCollapse = sourceType.getRank() - vectorType.getRank();
 
     // 1. Collapse the source memref
@@ -599,12 +599,14 @@ class FlattenContiguousRowMajorTransferReadPattern
     // 2.2 New indices
     // If all the collapsed indices are zero then no extra logic is needed.
     // Otherwise, a new offset/index has to be computed.
+    SmallVector<Value> collapsedIndices;
     if (failed(checkAndCollapseInnerZeroIndices(transferReadOp.getIndices(),
                                                 firstDimToCollapse,
                                                 collapsedIndices))) {
-      // Copy all the leading indices
-      collapsedIndices = transferReadOp.getIndices();
-      collapsedIndices.resize(firstDimToCollapse);
+      // Copy all the leading indices.
+      SmallVector<Value> indices = transferReadOp.getIndices();
+      collapsedIndices.append(indices.begin(),
+                              indices.begin() + firstDimToCollapse);
 
       // Compute the remaining trailing index/offset required for reading from
       // the collapsed memref:
@@ -621,24 +623,26 @@ class FlattenContiguousRowMajorTransferReadPattern
       //      memref<1x86xi32>, vector<2xi32>
       // one would get the following offset:
       //    %offset = %arg0 * 43
-      AffineExpr offsetExpr, idxExpr;
-      bindSymbols(rewriter.getContext(), offsetExpr, idxExpr);
-
-      int64_t outputRank = transferReadOp.getIndices().size();
-      OpFoldResult offset =
+      OpFoldResult collapsedOffset =
           rewriter.create<arith::ConstantIndexOp>(loc, 0).getResult();
 
-      for (int64_t i = firstDimToCollapse; i < outputRank; ++i) {
-        int64_t dim = dyn_cast<ShapedType>(source.getType()).getDimSize(i);
-        offset = affine::makeComposedFoldedAffineApply(
-            rewriter, loc, offsetExpr + dim * idxExpr,
-            {offset, transferReadOp.getIndices()[i]});
-      }
-      if (offset.is<Value>()) {
-        collapsedIndices.push_back(offset.get<Value>());
+      auto sourceShape = sourceType.getShape();
+      auto collapsedStrides = computeSuffixProduct(ArrayRef<int64_t>(
+          sourceShape.begin() + firstDimToCollapse, sourceShape.end()));
+
+      // Compute the collapsed offset.
+      ArrayRef<Value> indicesToCollapse(indices.begin() + firstDimToCollapse,
+                                        indices.end());
+      auto &&[collapsedExpr, collapsedVals] = computeLinearIndex(
+          collapsedOffset, collapsedStrides, indicesToCollapse);
+      collapsedOffset = affine::makeComposedFoldedAffineApply(
+          rewriter, loc, collapsedExpr, collapsedVals);
+
+      if (collapsedOffset.is<Value>()) {
+        collapsedIndices.push_back(collapsedOffset.get<Value>());
       } else {
         collapsedIndices.push_back(rewriter.create<arith::ConstantIndexOp>(
-            loc, *getConstantIntValue(offset)));
+            loc, *getConstantIntValue(collapsedOffset)));
       }
     }
 
@@ -710,6 +714,7 @@ class FlattenContiguousRowMajorTransferWritePattern
                                                 firstContiguousInnerDim,
                                                 collapsedIndices)))
       return failure();
+
     Value collapsedSource =
         collapseInnerDims(rewriter, loc, source, firstContiguousInnerDim);
     MemRefType collapsedSourceType =

diff  --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index 1775b5fa4a346a..3b6441d0c9560c 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -83,7 +83,7 @@ func.func @transfer_read_dims_mismatch_non_zero_indices(
   return
 }
 
-// CHECK: #[[$ATTR_0:.+]] = affine_map<()[s0, s1] -> (s0 * 4 + s1 * 43)>
+// CHECK: #[[$ATTR_0:.+]] = affine_map<()[s0, s1] -> (s0 * 24 + s1 * 6)>
 
 // CHECK-LABEL:   func.func @transfer_read_dims_mismatch_non_zero_indices(
 // CHECK-SAME:      %[[IDX_1:.*]]: index, %[[IDX_2:.*]]: index,
@@ -92,7 +92,7 @@ func.func @transfer_read_dims_mismatch_non_zero_indices(
 // CHECK:           %[[C_0:.*]] = arith.constant 0 : i32
 // CHECK:           %[[C_0_IDX:.*]] = arith.constant 0 : index
 // CHECK:           %[[COLLAPSED_IN:.*]] = memref.collapse_shape %[[M_IN]] {{\[}}[0], [1, 2, 3]] : memref<1x43x4x6xi32> into memref<1x1032xi32>
-// CHECK:           %[[COLLAPSED_IDX:.*]] = affine.apply #[[$ATTR_0]]()[%[[IDX_2]], %[[IDX_1]]]
+// CHECK:           %[[COLLAPSED_IDX:.*]] = affine.apply #[[$ATTR_0]]()[%[[IDX_1]], %[[IDX_2]]]
 // CHECK:           %[[READ:.*]] = vector.transfer_read %[[COLLAPSED_IN]][%[[C_0_IDX]], %[[COLLAPSED_IDX]]], %[[C_0]] {in_bounds = [true]} : memref<1x1032xi32>, vector<12xi32>
 // CHECK:           %[[COLLAPSED_OUT:.*]] = memref.collapse_shape %[[M_OUT]] {{\[}}[0, 1, 2]] : memref<1x2x6xi32> into memref<12xi32>
 // CHECK:           vector.transfer_write %[[READ]], %[[COLLAPSED_OUT]][%[[C_0_IDX]]] {in_bounds = [true]} : vector<12xi32>, memref<12xi32>
@@ -459,3 +459,31 @@ func.func @fold_unit_dims_entirely(%arg0 : vector<8xi32>,
 // CHECK-128B-LABEL: func @fold_unit_dims_entirely(
 //   CHECK-128B-NOT:   memref.collapse_shape
 
+
+// -----
+
+func.func @regression_non_contiguous_dim_read(%subview : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>>,
+                                              %idx0 : index, %idx1 : index) -> vector<2x2xf32> {
+  %c0 = arith.constant 0 : index
+  %cst_1 = arith.constant 0.000000e+00 : f32
+  %8 = vector.transfer_read %subview[%c0, %idx0, %idx1, %c0], %cst_1 {in_bounds = [true, true]} : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>>, vector<2x2xf32>
+  return %8 : vector<2x2xf32>
+}
+
+//       CHECK:  #[[$MAP:.+]] = affine_map<()[s0] -> (s0 * 2)>
+// CHECK-LABEL:    func.func @regression_non_contiguous_dim_read(
+//       CHECK:      %[[COLLAPSE:.+]] = memref.collapse_shape %{{.*}} {{\[}}[0], [1], [2, 3]] : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>> into memref<1x3x6xf32, strided<[40, 10, 1], offset: ?>>
+//       CHECK:     %[[APPLY:.*]] = affine.apply #[[$MAP]]()
+
+// -----
+
+func.func @unsupported_non_contiguous_dim_write(%value : vector<2x2xf32>,
+                                                %subview : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>>,
+                                                %idx0 : index, %idx1 : index) {
+  %c0 = arith.constant 0 : index
+  vector.transfer_write %value, %subview[%c0, %idx0, %idx1, %c0] {in_bounds = [true, true]} : vector<2x2xf32>, memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>>
+  return
+}
+
+// CHECK-LABEL:  func.func @unsupported_non_contiguous_dim_write(
+//   CHECK-NOT:    memref.collapse_shape


        


More information about the Mlir-commits mailing list