[Mlir-commits] [mlir] e01ff82 - [mlir][vector] Fix scalability issues in drop innermost unit dims transfer patterns (#92402)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri May 17 09:14:03 PDT 2024


Author: Benjamin Maxwell
Date: 2024-05-17T17:13:59+01:00
New Revision: e01ff8238cf62c7149de7b8046bccec9adefbe67

URL: https://github.com/llvm/llvm-project/commit/e01ff8238cf62c7149de7b8046bccec9adefbe67
DIFF: https://github.com/llvm/llvm-project/commit/e01ff8238cf62c7149de7b8046bccec9adefbe67.diff

LOG: [mlir][vector] Fix scalability issues in drop innermost unit dims transfer patterns (#92402)

Previously, these rewrites would drop scalable dimensions and treated
`[1]` (scalable one dim) as a unit dimension. This patch propagates
scalable dimensions and ensures `[1]` is not treated as a unit
dimension.

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
    mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 69c497264fd1e..f29eba90c3ceb 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1237,6 +1237,10 @@ getTransferFoldableInnerUnitDims(MemRefType srcType, VectorType vectorType) {
   if (failed(getStridesAndOffset(srcType, srcStrides, srcOffset)))
     return failure();
 
+  auto isUnitDim = [](VectorType type, int dim) {
+    return type.getDimSize(dim) == 1 && !type.getScalableDims()[dim];
+  };
+
   // According to vector.transfer_read/write semantics, the vector can be a
   // slice. Thus, we have to offset the check index with `rankDiff` in
   // `srcStrides` and source dim sizes.
@@ -1247,8 +1251,7 @@ getTransferFoldableInnerUnitDims(MemRefType srcType, VectorType vectorType) {
     // It can be folded only if they are 1 and the stride is 1.
     int dim = vectorType.getRank() - i - 1;
     if (srcStrides[dim + rankDiff] != 1 ||
-        srcType.getDimSize(dim + rankDiff) != 1 ||
-        vectorType.getDimSize(dim) != 1)
+        srcType.getDimSize(dim + rankDiff) != 1 || !isUnitDim(vectorType, dim))
       break;
     result++;
   }
@@ -1292,7 +1295,8 @@ class DropInnerMostUnitDimsTransferRead
 
     auto resultTargetVecType =
         VectorType::get(targetType.getShape().drop_back(dimsToDrop),
-                        targetType.getElementType());
+                        targetType.getElementType(),
+                        targetType.getScalableDims().drop_back(dimsToDrop));
 
     auto loc = readOp.getLoc();
     SmallVector<OpFoldResult> sizes =
@@ -1378,7 +1382,8 @@ class DropInnerMostUnitDimsTransferWrite
 
     auto resultTargetVecType =
         VectorType::get(targetType.getShape().drop_back(dimsToDrop),
-                        targetType.getElementType());
+                        targetType.getElementType(),
+                        targetType.getScalableDims().drop_back(dimsToDrop));
 
     Location loc = writeOp.getLoc();
     SmallVector<OpFoldResult> sizes =

diff  --git a/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir b/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir
index 477755b66c020..b4cb640108bae 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir
@@ -174,3 +174,59 @@ func.func @non_unit_strides(%arg0: memref<512x16x1xf32, strided<[8192, 16, 4], o
 // The inner most unit dims can not be dropped if the strides are not ones.
 // CHECK:     func.func @non_unit_strides
 // CHECK-NOT:   memref.subview
+
+// -----
+
+func.func @leading_scalable_dimension_transfer_read(%dest : memref<24x1xf32>) -> vector<[4]x1xf32> {
+  %c0 = arith.constant 0 : index
+  %pad = arith.constant 0.0 : f32
+  %0 = vector.transfer_read %dest[%c0, %c0], %pad {in_bounds = [true, true]} : memref<24x1xf32>, vector<[4]x1xf32>
+  return %0 : vector<[4]x1xf32>
+}
+// CHECK:      func.func @leading_scalable_dimension_transfer_read
+// CHECK-SAME:   %[[DEST:[a-zA-Z0-9]+]]
+// CHECK:        %[[SUBVIEW:.+]] = memref.subview %[[DEST]][0, 0] [24, 1] [1, 1] : memref<24x1xf32> to memref<24xf32, strided<[1]>>
+// CHECK:        %[[READ:.+]] = vector.transfer_read %[[SUBVIEW]]{{.*}} {in_bounds = [true]} : memref<24xf32, strided<[1]>>, vector<[4]xf32>
+// CHECK:        %[[CAST:.+]] = vector.shape_cast %[[READ]] : vector<[4]xf32> to vector<[4]x1xf32>
+// CHECK:        return %[[CAST]]
+
+// -----
+
+// Negative test: [1] (scalable 1) is _not_ a unit dimension.
+func.func @trailing_scalable_one_dim_transfer_read(%dest : memref<24x1xf32>) -> vector<4x[1]xf32> {
+  %c0 = arith.constant 0 : index
+  %pad = arith.constant 0.0 : f32
+  %0 = vector.transfer_read %dest[%c0, %c0], %pad {in_bounds = [true, true]} : memref<24x1xf32>, vector<4x[1]xf32>
+  return %0 : vector<4x[1]xf32>
+}
+// CHECK:      func.func @trailing_scalable_one_dim_transfer_read
+// CHECK-NOT:    vector.shape_cast
+// CHECK:        vector.transfer_read {{.*}} : memref<24x1xf32>, vector<4x[1]xf32>
+// CHECK-NOT:    vector.shape_cast
+
+// -----
+
+func.func @leading_scalable_dimension_transfer_write(%dest : memref<24x1xf32>, %vec: vector<[4]x1xf32>) {
+  %c0 = arith.constant 0 : index
+  vector.transfer_write %vec, %dest[%c0, %c0] {in_bounds = [true, true]} : vector<[4]x1xf32>,  memref<24x1xf32>
+  return
+}
+// CHECK:      func.func @leading_scalable_dimension_transfer_write
+// CHECK-SAME:   %[[DEST:[a-zA-Z0-9]+]]
+// CHECK-SAME:   %[[VEC:[a-zA-Z0-9]+]]
+// CHECK:        %[[SUBVIEW:.+]] = memref.subview %[[DEST]][0, 0] [24, 1] [1, 1] : memref<24x1xf32> to memref<24xf32, strided<[1]>>
+// CHECK:        %[[CAST:.+]] = vector.shape_cast %[[VEC]] : vector<[4]x1xf32> to vector<[4]xf32>
+// CHECK:        vector.transfer_write %[[CAST]], %[[SUBVIEW]]{{.*}} {in_bounds = [true]} : vector<[4]xf32>, memref<24xf32, strided<[1]>>
+
+// -----
+
+// Negative test: [1] (scalable 1) is _not_ a unit dimension.
+func.func @trailing_scalable_one_dim_transfer_write(%dest : memref<24x1xf32>, %vec: vector<4x[1]xf32>, %index: index) {
+  %c0 = arith.constant 0 : index
+  vector.transfer_write %vec, %dest[%index, %c0] {in_bounds = [true, true]} : vector<4x[1]xf32>,  memref<24x1xf32>
+  return
+}
+// CHECK:      func.func @trailing_scalable_one_dim_transfer_write
+// CHECK-NOT:    vector.shape_cast
+// CHECK:        vector.transfer_write {{.*}} : vector<4x[1]xf32>,  memref<24x1xf32>
+// CHECK-NOT:    vector.shape_cast


        


More information about the Mlir-commits mailing list