[Mlir-commits] [mlir] e87aa0c - [mlir][vector] Sink vector.extract/splat into load/store ops (#134389)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Apr 22 07:18:58 PDT 2025


Author: Ivan Butygin
Date: 2025-04-22T17:18:54+03:00
New Revision: e87aa0c6ab7b9d1abbf86e8df84053cd4de92656

URL: https://github.com/llvm/llvm-project/commit/e87aa0c6ab7b9d1abbf86e8df84053cd4de92656
DIFF: https://github.com/llvm/llvm-project/commit/e87aa0c6ab7b9d1abbf86e8df84053cd4de92656.diff

LOG: [mlir][vector] Sink vector.extract/splat into load/store ops (#134389)

```
vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32>
vector.extract %0[1] : f32 from vector<4xf32>
```
Gets converted to:
```
%c1 = arith.constant 1 : index
%0 = arith.addi %arg1, %c1 overflow<nsw> : index
%1 = memref.load %arg0[%0] : memref<?xf32>
```

```
%0 = vector.splat %arg2 : vector<1xf32>
vector.store %0, %arg0[%arg1] : memref<?xf32>, vector<1xf32>
```
Gets converted to:
```
memref.store %arg2, %arg0[%arg1] : memref<?xf32>
```

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Arith/Utils/Utils.h
    mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
    mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
    mlir/lib/Dialect/Arith/Utils/Utils.cpp
    mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
    mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
    mlir/test/Dialect/Vector/vector-sink-transform.mlir
    mlir/test/Dialect/Vector/vector-sink.mlir
    mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Arith/Utils/Utils.h b/mlir/include/mlir/Dialect/Arith/Utils/Utils.h
index d759299cbf762..c0b286494996b 100644
--- a/mlir/include/mlir/Dialect/Arith/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Arith/Utils/Utils.h
@@ -101,7 +101,10 @@ Type getType(OpFoldResult ofr);
 /// Helper struct to build simple arithmetic quantities with minimal type
 /// inference support.
 struct ArithBuilder {
-  ArithBuilder(OpBuilder &b, Location loc) : b(b), loc(loc) {}
+  ArithBuilder(
+      OpBuilder &b, Location loc,
+      arith::IntegerOverflowFlags ovf = arith::IntegerOverflowFlags::none)
+      : b(b), loc(loc), ovf(ovf) {}
 
   Value _and(Value lhs, Value rhs);
   Value add(Value lhs, Value rhs);
@@ -114,6 +117,15 @@ struct ArithBuilder {
 private:
   OpBuilder &b;
   Location loc;
+  arith::IntegerOverflowFlags ovf;
+};
+
+/// ArithBuilder specialized specifically for tensor/memref indexing
+/// calculations. Those calculations generally should never signed overflow and
+/// always use signed integers, so we can set oveflow flags accordingly.
+struct ArithIndexingBuilder : public ArithBuilder {
+  ArithIndexingBuilder(OpBuilder &b, Location loc)
+      : ArithBuilder(b, loc, arith::IntegerOverflowFlags::nsw) {}
 };
 
 namespace arith {

diff  --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index f46aa0428f12f..14cbbac99d9ae 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -458,7 +458,9 @@ def ApplySinkVectorPatternsOp : Op<Transform_Dialect,
     [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
   let description = [{
     Patterns that remove redundant Vector Ops by re-ordering them with
-    e.g. elementwise Ops:
+    e.g. elementwise Ops.
+
+    Example:
     ```
     %at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
     %bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
@@ -469,8 +471,32 @@ def ApplySinkVectorPatternsOp : Op<Transform_Dialect,
     %0 = arith.addf %a, %b : vector<4x2xf32>
     %r = vector.transpose %0, [1, 0] : vector<2x4xf32>
     ```
-    At the moment, these patterns are limited to vector.broadcast and
-    vector.transpose.
+    At the moment, these patterns are limited to vector.broadcast,
+    vector.transpose and vector.extract.
+  }];
+
+  let assemblyFormat = "attr-dict";
+}
+
+def ApplySinkVectorMemPatternsOp : Op<Transform_Dialect,
+    "apply_patterns.vector.sink_mem_ops",
+    [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+  let description = [{
+    Patterns that replace redundant Vector Ops (followed by
+    `vector.load`/`vector.store`) with either vector.load/vector.store or
+    `memref.load`/`memref.store`. Currently limited to 1-element vectors.
+
+    Example:
+    ```
+    vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32>
+    vector.extract %0[1] : f32 from vector<4xf32>
+    ```
+    Gets converted to:
+    ```
+    %c1 = arith.constant 1 : index
+    %0 = arith.addi %arg1, %c1 overflow<nsw> : index
+    %1 = memref.load %arg0[%0] : memref<?xf32>
+    ```
   }];
 
   let assemblyFormat = "attr-dict";

diff  --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index ce97847172197..7a079dcc6affc 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -161,6 +161,20 @@ void populateVectorTransferCollapseInnerMostContiguousDimsPatterns(
 void populateSinkVectorOpsPatterns(RewritePatternSet &patterns,
                                    PatternBenefit benefit = 1);
 
+/// Patterns that remove redundant Vector Ops by merging them with load/store
+/// ops
+/// ```
+/// vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32>
+/// vector.extract %0[1] : f32 from vector<4xf32>
+/// ```
+/// Gets converted to:
+/// ```
+/// %c1 = arith.constant 1 : index
+/// %0 = arith.addi %arg1, %c1 overflow<nsw> : index
+/// %1 = memref.load %arg0[%0] : memref<?xf32>
+void populateSinkVectorMemOpsPatterns(RewritePatternSet &patterns,
+                                      PatternBenefit benefit = 1);
+
 /// Patterns that fold chained vector reductions. These patterns assume that
 /// elementwise operations (e.g., `arith.addf` with vector operands) are
 /// cheaper than vector reduction.

diff  --git a/mlir/lib/Dialect/Arith/Utils/Utils.cpp b/mlir/lib/Dialect/Arith/Utils/Utils.cpp
index 8dde9866b22b3..6b1074e454bd5 100644
--- a/mlir/lib/Dialect/Arith/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Arith/Utils/Utils.cpp
@@ -315,17 +315,17 @@ Value ArithBuilder::_and(Value lhs, Value rhs) {
 Value ArithBuilder::add(Value lhs, Value rhs) {
   if (isa<FloatType>(lhs.getType()))
     return b.create<arith::AddFOp>(loc, lhs, rhs);
-  return b.create<arith::AddIOp>(loc, lhs, rhs);
+  return b.create<arith::AddIOp>(loc, lhs, rhs, ovf);
 }
 Value ArithBuilder::sub(Value lhs, Value rhs) {
   if (isa<FloatType>(lhs.getType()))
     return b.create<arith::SubFOp>(loc, lhs, rhs);
-  return b.create<arith::SubIOp>(loc, lhs, rhs);
+  return b.create<arith::SubIOp>(loc, lhs, rhs, ovf);
 }
 Value ArithBuilder::mul(Value lhs, Value rhs) {
   if (isa<FloatType>(lhs.getType()))
     return b.create<arith::MulFOp>(loc, lhs, rhs);
-  return b.create<arith::MulIOp>(loc, lhs, rhs);
+  return b.create<arith::MulIOp>(loc, lhs, rhs, ovf);
 }
 Value ArithBuilder::sgt(Value lhs, Value rhs) {
   if (isa<FloatType>(lhs.getType()))

diff  --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index 12dcf768dd928..a888d745be443 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -212,6 +212,11 @@ void transform::ApplySinkVectorPatternsOp::populatePatterns(
   vector::populateSinkVectorOpsPatterns(patterns);
 }
 
+void transform::ApplySinkVectorMemPatternsOp::populatePatterns(
+    RewritePatternSet &patterns) {
+  vector::populateSinkVectorMemOpsPatterns(patterns);
+}
+
 //===----------------------------------------------------------------------===//
 // Transform op registration
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 89839d0440d3c..b94c5fce64f83 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -902,6 +902,8 @@ struct BreakDownVectorBitCast : public OpRewritePattern<vector::BitCastOp> {
 };
 
 /// Reorders elementwise(broadcast/splat) to broadcast(elementwise). Ex:
+///
+/// Example:
 /// ```
 /// %a = vector.broadcast %arg1 : index to vector<1x4xindex>
 /// %b = vector.broadcast %arg2 : index to vector<1x4xindex>
@@ -987,6 +989,8 @@ struct ReorderElementwiseOpsOnBroadcast final
 /// This may result in cleaner code when extracting a single value
 /// from multi-element vector and also to help canonicalize 1-element vectors to
 /// scalars.
+///
+/// Example:
 /// ```
 ///  %0 = arith.addf %arg0, %arg1 : vector<4xf32>
 ///  %1 = vector.extract %0[1] : f32 from vector<4xf32>
@@ -1043,6 +1047,150 @@ class ExtractOpFromElementwise final
   }
 };
 
+/// Check if the element type is suitable for vector.load/store sinking.
+/// Element type must be index or byte-aligned integer or floating-point type.
+static bool isSupportedMemSinkElementType(Type type) {
+  if (isa<IndexType>(type))
+    return true;
+
+  return type.isIntOrFloat() && type.getIntOrFloatBitWidth() % 8 == 0;
+}
+
+/// Pattern to rewrite `vector.extract(vector.load) -> vector/memref.load.
+/// Only index and byte-aligned integer and floating-point element types are
+/// supported for now.
+///
+/// Example:
+/// ```
+///  vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32>
+///  vector.extract %0[1] : f32 from vector<4xf32>
+/// ```
+/// Gets converted to:
+/// ```
+/// %c1 = arith.constant 1 : index
+/// %0 = arith.addi %arg1, %c1 overflow<nsw> : index
+/// %1 = memref.load %arg0[%0] : memref<?xf32>
+/// ```
+class ExtractOpFromLoad final : public OpRewritePattern<vector::ExtractOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::ExtractOp op,
+                                PatternRewriter &rewriter) const override {
+    auto loadOp = op.getVector().getDefiningOp<vector::LoadOp>();
+    if (!loadOp)
+      return rewriter.notifyMatchFailure(op, "expected a load op");
+
+    // Checking for single use so we won't duplicate load ops.
+    if (!loadOp->hasOneUse())
+      return rewriter.notifyMatchFailure(op, "expected single op use");
+
+    VectorType loadVecType = loadOp.getVectorType();
+    if (loadVecType.isScalable())
+      return rewriter.notifyMatchFailure(op,
+                                         "scalable vectors are not supported");
+
+    MemRefType memType = loadOp.getMemRefType();
+
+    // Non-byte-aligned types are tricky and may require special handling,
+    // ignore them for now.
+    if (!isSupportedMemSinkElementType(memType.getElementType()))
+      return rewriter.notifyMatchFailure(op, "unsupported element type");
+
+    int64_t rankOffset = memType.getRank() - loadVecType.getRank();
+    if (rankOffset < 0)
+      return rewriter.notifyMatchFailure(op, "unsupported ranks combination");
+
+    auto extractVecType = dyn_cast<VectorType>(op.getResult().getType());
+    int64_t finalRank = 0;
+    if (extractVecType)
+      finalRank = extractVecType.getRank();
+
+    SmallVector<Value> indices = loadOp.getIndices();
+    SmallVector<OpFoldResult> extractPos = op.getMixedPosition();
+
+    // There may be memory stores between the load and the extract op, so we
+    // need to make sure that the new load op is inserted at the same place as
+    // the original load op.
+    OpBuilder::InsertionGuard g(rewriter);
+    rewriter.setInsertionPoint(loadOp);
+    Location loc = loadOp.getLoc();
+    ArithIndexingBuilder idxBuilderf(rewriter, loc);
+    for (auto i : llvm::seq<int64_t>(rankOffset, indices.size() - finalRank)) {
+      OpFoldResult pos = extractPos[i - rankOffset];
+      if (isConstantIntValue(pos, 0))
+        continue;
+
+      Value offset = getValueOrCreateConstantIndexOp(rewriter, loc, pos);
+      indices[i] = idxBuilderf.add(indices[i], offset);
+    }
+
+    Value base = loadOp.getBase();
+    if (extractVecType) {
+      rewriter.replaceOpWithNewOp<vector::LoadOp>(op, extractVecType, base,
+                                                  indices);
+    } else {
+      rewriter.replaceOpWithNewOp<memref::LoadOp>(op, base, indices);
+    }
+    // We checked for single use so we can safely erase the load op.
+    rewriter.eraseOp(loadOp);
+    return success();
+  }
+};
+
+/// Pattern to rewrite vector.store(vector.splat) -> vector/memref.store.
+///
+/// Example:
+/// ```
+/// %0 = vector.splat %arg2 : vector<1xf32>
+/// vector.store %0, %arg0[%arg1] : memref<?xf32>, vector<1xf32>
+/// ```
+/// Gets converted to:
+/// ```
+/// memref.store %arg2, %arg0[%arg1] : memref<?xf32>
+/// ```
+class StoreOpFromSplatOrBroadcast final
+    : public OpRewritePattern<vector::StoreOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::StoreOp op,
+                                PatternRewriter &rewriter) const override {
+    VectorType vecType = op.getVectorType();
+    if (vecType.isScalable())
+      return rewriter.notifyMatchFailure(op,
+                                         "scalable vectors are not supported");
+
+    if (isa<VectorType>(op.getMemRefType().getElementType()))
+      return rewriter.notifyMatchFailure(
+          op, "memrefs of vectors are not supported");
+
+    if (vecType.getNumElements() != 1)
+      return rewriter.notifyMatchFailure(
+          op, "only 1-element vectors are supported");
+
+    Operation *splat = op.getValueToStore().getDefiningOp();
+    if (!isa_and_present<vector::BroadcastOp, vector::SplatOp>(splat))
+      return rewriter.notifyMatchFailure(op, "neither a splat nor a broadcast");
+
+    // Checking for single use so we can remove splat.
+    if (!splat->hasOneUse())
+      return rewriter.notifyMatchFailure(op, "expected single op use");
+
+    Value source = splat->getOperand(0);
+    Value base = op.getBase();
+    ValueRange indices = op.getIndices();
+
+    if (isa<VectorType>(source.getType())) {
+      rewriter.replaceOpWithNewOp<vector::StoreOp>(op, source, base, indices);
+    } else {
+      rewriter.replaceOpWithNewOp<memref::StoreOp>(op, source, base, indices);
+    }
+    rewriter.eraseOp(splat);
+    return success();
+  }
+};
+
 // Helper that returns a vector comparison that constructs a mask:
 //     mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b]
 //
@@ -2109,6 +2257,13 @@ void mlir::vector::populateSinkVectorOpsPatterns(RewritePatternSet &patterns,
       patterns.getContext(), benefit);
 }
 
+void mlir::vector::populateSinkVectorMemOpsPatterns(RewritePatternSet &patterns,
+                                                    PatternBenefit benefit) {
+  // TODO: Consider converting these patterns to canonicalizations.
+  patterns.add<ExtractOpFromLoad, StoreOpFromSplatOrBroadcast>(
+      patterns.getContext(), benefit);
+}
+
 void mlir::vector::populateChainedVectorReductionFoldingPatterns(
     RewritePatternSet &patterns, PatternBenefit benefit) {
   patterns.add<ChainedReduction>(patterns.getContext(), benefit);

diff  --git a/mlir/test/Dialect/Vector/vector-sink-transform.mlir b/mlir/test/Dialect/Vector/vector-sink-transform.mlir
index ef17b69b2444c..4d04276742164 100644
--- a/mlir/test/Dialect/Vector/vector-sink-transform.mlir
+++ b/mlir/test/Dialect/Vector/vector-sink-transform.mlir
@@ -7,6 +7,7 @@ module attributes {transform.with_named_sequence} {
     %func = transform.structured.match ops{["func.func"]} in %module_op : (!transform.any_op) -> !transform.any_op
     transform.apply_patterns to %func {
       transform.apply_patterns.vector.sink_ops
+      transform.apply_patterns.vector.sink_mem_ops
     } : !transform.any_op
     transform.yield
   }

diff  --git a/mlir/test/Dialect/Vector/vector-sink.mlir b/mlir/test/Dialect/Vector/vector-sink.mlir
index 8c8f1797aaab6..900ad99bb4a4c 100644
--- a/mlir/test/Dialect/Vector/vector-sink.mlir
+++ b/mlir/test/Dialect/Vector/vector-sink.mlir
@@ -513,3 +513,203 @@ func.func @negative_extract_vec_fma(%arg0: vector<4xf32>, %arg1: vector<4xf32>,
   %1 = vector.extract %0[1] : f32 from vector<4xf32>
   return %1 : f32
 }
+
+//-----------------------------------------------------------------------------
+// [Pattern: ExtractOpFromLoad]
+//-----------------------------------------------------------------------------
+
+// CHECK-LABEL: @extract_load_scalar
+//  CHECK-SAME:   (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index)
+func.func @extract_load_scalar(%arg0: memref<?xf32>, %arg1: index) -> f32 {
+// CHECK:   %[[RES:.*]] = memref.load %[[ARG0]][%[[ARG1]]] : memref<?xf32>
+// CHECK:   return %[[RES]] : f32
+  %0 = vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32>
+  %1 = vector.extract %0[0] : f32 from vector<4xf32>
+  return %1 : f32
+}
+
+// CHECK-LABEL: @extract_load_index
+//  CHECK-SAME:   (%[[ARG0:.*]]: memref<?xindex>, %[[ARG1:.*]]: index)
+func.func @extract_load_index(%arg0: memref<?xindex>, %arg1: index) -> index {
+// CHECK:   %[[RES:.*]] = memref.load %[[ARG0]][%[[ARG1]]] : memref<?xindex>
+// CHECK:   return %[[RES]] : index
+  %0 = vector.load %arg0[%arg1] : memref<?xindex>, vector<4xindex>
+  %1 = vector.extract %0[0] : index from vector<4xindex>
+  return %1 : index
+}
+
+// CHECK-LABEL: @extract_load_scalar_non_zero_off
+//  CHECK-SAME:   (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index)
+func.func @extract_load_scalar_non_zero_off(%arg0: memref<?xf32>, %arg1: index) -> f32 {
+// CHECK:   %[[C1:.*]] = arith.constant 1 : index
+// CHECK:   %[[OFF:.*]] = arith.addi %[[ARG1]], %[[C1]] overflow<nsw> : index
+// CHECK:   %[[RES:.*]] = memref.load %[[ARG0]][%[[OFF]]] : memref<?xf32>
+// CHECK:   return %[[RES]] : f32
+  %0 = vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32>
+  %1 = vector.extract %0[1] : f32 from vector<4xf32>
+  return %1 : f32
+}
+
+// CHECK-LABEL: @extract_load_scalar_dyn_off
+//  CHECK-SAME:   (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
+func.func @extract_load_scalar_dyn_off(%arg0: memref<?xf32>, %arg1: index, %arg2: index) -> f32 {
+// CHECK:   %[[OFF:.*]] = arith.addi %[[ARG1]], %[[ARG2]] overflow<nsw> : index
+// CHECK:   %[[RES:.*]] = memref.load %[[ARG0]][%[[OFF]]] : memref<?xf32>
+// CHECK:   return %[[RES]] : f32
+  %0 = vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32>
+  %1 = vector.extract %0[%arg2] : f32 from vector<4xf32>
+  return %1 : f32
+}
+
+// CHECK-LABEL: @extract_load_vec_non_zero_off
+//  CHECK-SAME:   (%[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
+func.func @extract_load_vec_non_zero_off(%arg0: memref<?x?xf32>, %arg1: index, %arg2: index) -> vector<4xf32> {
+// CHECK:   %[[C1:.*]] = arith.constant 1 : index
+// CHECK:   %[[OFF:.*]] = arith.addi %[[ARG1]], %[[C1]] overflow<nsw> : index
+// CHECK:   %[[RES:.*]] = vector.load %[[ARG0]][%[[OFF]], %[[ARG2]]] : memref<?x?xf32>, vector<4xf32>
+// CHECK:   return %[[RES]] : vector<4xf32>
+  %0 = vector.load %arg0[%arg1, %arg2] : memref<?x?xf32>, vector<2x4xf32>
+  %1 = vector.extract %0[1] : vector<4xf32> from vector<2x4xf32>
+  return %1 : vector<4xf32>
+}
+
+// CHECK-LABEL: @extract_load_scalar_non_zero_off_2d_src_memref
+//  CHECK-SAME:   (%[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
+func.func @extract_load_scalar_non_zero_off_2d_src_memref(%arg0: memref<?x?xf32>, %arg1: index, %arg2: index) -> f32 {
+// CHECK:   %[[C1:.*]] = arith.constant 1 : index
+// CHECK:   %[[OFF:.*]] = arith.addi %[[ARG2]], %[[C1]] overflow<nsw> : index
+// CHECK:   %[[RES:.*]] = memref.load %[[ARG0]][%[[ARG1]], %[[OFF]]] : memref<?x?xf32>
+// CHECK:   return %[[RES]] : f32
+  %0 = vector.load %arg0[%arg1, %arg2] : memref<?x?xf32>, vector<4xf32>
+  %1 = vector.extract %0[1] : f32 from vector<4xf32>
+  return %1 : f32
+}
+
+// CHECK-LABEL: @extract_load_vec_high_rank
+//  CHECK-SAME:   (%[[ARG0:.*]]: memref<?x?x?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index)
+func.func @extract_load_vec_high_rank(%arg0: memref<?x?x?xf32>, %arg1: index, %arg2: index, %arg3: index) -> vector<4xf32> {
+// CHECK:   %[[C1:.*]] = arith.constant 1 : index
+// CHECK:   %[[OFF:.*]] = arith.addi %[[ARG2]], %[[C1]] overflow<nsw> : index
+// CHECK:   %[[RES:.*]] = vector.load %[[ARG0]][%[[ARG1]], %[[OFF]], %[[ARG3]]] : memref<?x?x?xf32>, vector<4xf32>
+// CHECK:   return %[[RES]] : vector<4xf32>
+  %0 = vector.load %arg0[%arg1, %arg2, %arg3] : memref<?x?x?xf32>, vector<2x4xf32>
+  %1 = vector.extract %0[1] : vector<4xf32> from vector<2x4xf32>
+  return %1 : vector<4xf32>
+}
+
+// CHECK-LABEL: @negative_extract_load_scalar_from_memref_of_vec
+//  CHECK-SAME:   (%[[ARG0:.*]]: memref<?xvector<4xf32>>, %[[ARG1:.*]]: index)
+func.func @negative_extract_load_scalar_from_memref_of_vec(%arg0: memref<?xvector<4xf32>>, %arg1: index) -> f32 {
+// CHECK:   %[[RES:.*]] = vector.load %[[ARG0]][%[[ARG1]]] : memref<?xvector<4xf32>>, vector<4xf32>
+// CHECK:   %[[EXT:.*]] = vector.extract %[[RES]][0] : f32 from vector<4xf32>
+// CHECK:   return %[[EXT]] : f32
+  %0 = vector.load %arg0[%arg1] : memref<?xvector<4xf32>>, vector<4xf32>
+  %1 = vector.extract %0[0] : f32 from vector<4xf32>
+  return %1 : f32
+}
+
+// CHECK-LABEL: @negative_extract_load_scalar_from_memref_of_i1
+//  CHECK-SAME:   (%[[ARG0:.*]]: memref<?xi1>, %[[ARG1:.*]]: index)
+func.func @negative_extract_load_scalar_from_memref_of_i1(%arg0: memref<?xi1>, %arg1: index) -> i1 {
+// Subbyte types are tricky, ignore them for now.
+// CHECK:   %[[RES:.*]] = vector.load %[[ARG0]][%[[ARG1]]] : memref<?xi1>, vector<8xi1>
+// CHECK:   %[[EXT:.*]] = vector.extract %[[RES]][0] : i1 from vector<8xi1>
+// CHECK:   return %[[EXT]] : i1
+  %0 = vector.load %arg0[%arg1] : memref<?xi1>, vector<8xi1>
+  %1 = vector.extract %0[0] : i1 from vector<8xi1>
+  return %1 : i1
+}
+
+// CHECK-LABEL: @negative_extract_load_no_single_use
+//  CHECK-SAME:   (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index)
+func.func @negative_extract_load_no_single_use(%arg0: memref<?xf32>, %arg1: index) -> (f32, vector<4xf32>) {
+// CHECK:   %[[RES:.*]] = vector.load %[[ARG0]][%[[ARG1]]] : memref<?xf32>, vector<4xf32>
+// CHECK:   %[[EXT:.*]] = vector.extract %[[RES]][0] : f32 from vector<4xf32>
+// CHECK:   return %[[EXT]], %[[RES]] : f32, vector<4xf32>
+  %0 = vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32>
+  %1 = vector.extract %0[0] : f32 from vector<4xf32>
+  return %1, %0 : f32, vector<4xf32>
+}
+
+// CHECK-LABEL: @negative_extract_load_scalable
+//  CHECK-SAME:   (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index)
+func.func @negative_extract_load_scalable(%arg0: memref<?xf32>, %arg1: index) -> f32 {
+// CHECK:   %[[RES:.*]] = vector.load %[[ARG0]][%[[ARG1]]] : memref<?xf32>, vector<[1]xf32>
+// CHECK:   %[[EXT:.*]] = vector.extract %[[RES]][0] : f32 from vector<[1]xf32>
+// CHECK:   return %[[EXT]] : f32
+  %0 = vector.load %arg0[%arg1] : memref<?xf32>, vector<[1]xf32>
+  %1 = vector.extract %0[0] : f32 from vector<[1]xf32>
+  return %1 : f32
+}
+
+//-----------------------------------------------------------------------------
+// [Pattern: StoreOpFromSplatOrBroadcast]
+//-----------------------------------------------------------------------------
+
+// CHECK-LABEL: @store_splat
+//  CHECK-SAME:   (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: f32)
+func.func @store_splat(%arg0: memref<?xf32>, %arg1: index, %arg2: f32) {
+// CHECK:   memref.store %[[ARG2]], %[[ARG0]][%[[ARG1]]] : memref<?xf32>
+  %0 = vector.splat %arg2 : vector<1xf32>
+  vector.store %0, %arg0[%arg1] : memref<?xf32>, vector<1xf32>
+  return
+}
+
+// CHECK-LABEL: @store_broadcast
+//  CHECK-SAME:   (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: f32)
+func.func @store_broadcast(%arg0: memref<?xf32>, %arg1: index, %arg2: f32) {
+// CHECK:   memref.store %[[ARG2]], %[[ARG0]][%[[ARG1]]] : memref<?xf32>
+  %0 = vector.broadcast %arg2 : f32 to vector<1xf32>
+  vector.store %0, %arg0[%arg1] : memref<?xf32>, vector<1xf32>
+  return
+}
+
+// CHECK-LABEL: @store_broadcast_1d_to_2d
+//  CHECK-SAME:   (%[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: vector<1xf32>)
+func.func @store_broadcast_1d_to_2d(%arg0: memref<?x?xf32>, %arg1: index, %arg2: index, %arg3: vector<1xf32>) {
+// CHECK:   vector.store %[[ARG3]], %[[ARG0]][%[[ARG1]], %[[ARG2]]] : memref<?x?xf32>, vector<1xf32>
+  %0 = vector.broadcast %arg3 : vector<1xf32> to vector<1x1xf32>
+  vector.store %0, %arg0[%arg1, %arg2] : memref<?x?xf32>, vector<1x1xf32>
+  return
+}
+
+// CHECK-LABEL: @negative_store_scalable
+//  CHECK-SAME:   (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: f32)
+func.func @negative_store_scalable(%arg0: memref<?xf32>, %arg1: index, %arg2: f32) {
+// CHECK:   %[[RES:.*]] = vector.splat %[[ARG2]] : vector<[1]xf32>
+// CHECK:   vector.store %[[RES]], %[[ARG0]][%[[ARG1]]] : memref<?xf32>, vector<[1]xf32>
+  %0 = vector.splat %arg2 : vector<[1]xf32>
+  vector.store %0, %arg0[%arg1] : memref<?xf32>, vector<[1]xf32>
+  return
+}
+
+// CHECK-LABEL: @negative_store_memref_of_vec
+//  CHECK-SAME:   (%[[ARG0:.*]]: memref<?xvector<1xf32>>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: f32)
+func.func @negative_store_memref_of_vec(%arg0: memref<?xvector<1xf32>>, %arg1: index, %arg2: f32) {
+// CHECK:   %[[RES:.*]] = vector.splat %[[ARG2]] : vector<1xf32>
+// CHECK:   vector.store %[[RES]], %[[ARG0]][%[[ARG1]]] : memref<?xvector<1xf32>>, vector<1xf32>
+  %0 = vector.splat %arg2 : vector<1xf32>
+  vector.store %0, %arg0[%arg1] : memref<?xvector<1xf32>>, vector<1xf32>
+  return
+}
+
+// CHECK-LABEL: @negative_store_more_than_one_element
+//  CHECK-SAME:   (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: f32)
+func.func @negative_store_more_than_one_element(%arg0: memref<?xf32>, %arg1: index, %arg2: f32) {
+// CHECK:   %[[RES:.*]] = vector.splat %[[ARG2]] : vector<4xf32>
+// CHECK:   vector.store %[[RES]], %[[ARG0]][%[[ARG1]]] : memref<?xf32>, vector<4xf32>
+  %0 = vector.splat %arg2 : vector<4xf32>
+  vector.store %0, %arg0[%arg1] : memref<?xf32>, vector<4xf32>
+  return
+}
+
+// CHECK-LABEL: @negative_store_no_single_use
+//  CHECK-SAME:   (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: f32)
+func.func @negative_store_no_single_use(%arg0: memref<?xf32>, %arg1: index, %arg2: f32) -> vector<1xf32> {
+// CHECK:   %[[RES:.*]] = vector.splat %[[ARG2]] : vector<1xf32>
+// CHECK:   vector.store %[[RES]], %[[ARG0]][%[[ARG1]]] : memref<?xf32>, vector<1xf32>
+// CHECK:   return %[[RES:.*]] : vector<1xf32>
+  %0 = vector.splat %arg2 : vector<1xf32>
+  vector.store %0, %arg0[%arg1] : memref<?xf32>, vector<1xf32>
+  return %0 : vector<1xf32>
+}

diff  --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index a54ae816570a8..03f907e46c2c6 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -395,6 +395,7 @@ struct TestVectorSinkPatterns
   void runOnOperation() override {
     RewritePatternSet patterns(&getContext());
     populateSinkVectorOpsPatterns(patterns);
+    populateSinkVectorMemOpsPatterns(patterns);
     (void)applyPatternsGreedily(getOperation(), std::move(patterns));
   }
 };


        


More information about the Mlir-commits mailing list