[Mlir-commits] [mlir] [mlir][Vector] Add load, store, etc. to dropleadunitdim (PR #195686)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon May 4 09:38:30 PDT 2026
llvmorg-github-actions[bot] wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Krzysztof Drewniak (krzysz00)
<details>
<summary>Changes</summary>
Discussions on improvements to fold-memref-alias-ops changes revealed that the patterns meant to drop leading unit dimensions from vector operations weren't handling load, store, and other "terminal" vector dialect operations. This PR adds the patterns to fix that.
Assisted-by: Claude 4.7
---
Full diff: https://github.com/llvm/llvm-project/pull/195686.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp (+105-1)
- (modified) mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir (+95)
``````````diff
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
index fa95f96b88177..26a702ef0f512 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
@@ -537,6 +537,101 @@ class CastAwayElementwiseLeadingOneDim : public RewritePattern {
return success();
}
};
+} // namespace
+
+// Drops `dropDim` leading dimensions from `operand` using vector.extract when
+// those dims are all non-scalable units (the cheap, structural rewrite); falls
+// back to vector.shape_cast otherwise.
+static Value dropLeadingOneDimsFromOperand(OpBuilder &b, Location loc,
+ Value operand, int64_t nDropped) {
+ auto oldType = cast<VectorType>(operand.getType());
+ ArrayRef<int64_t> leadingShape = oldType.getShape().take_front(nDropped);
+ ArrayRef<bool> leadingScalable =
+ oldType.getScalableDims().take_front(nDropped);
+ bool extractable =
+ llvm::all_of(leadingShape, [](int64_t d) { return d == 1; }) &&
+ llvm::none_of(leadingScalable, [](bool s) { return s; });
+ if (extractable)
+ return vector::ExtractOp::create(b, loc, operand, splatZero(nDropped));
+ VectorType newType = VectorType::get(
+ oldType.getShape().drop_front(nDropped), oldType.getElementType(),
+ oldType.getScalableDims().drop_front(nDropped));
+ return vector::ShapeCastOp::create(b, loc, newType, operand);
+}
+
+namespace {
+
+// Drops leading 1 dimensions from load-like memory operaitons. REmoves leading
+// unit dimensions from the result types and then broadcasts back in those 1s,
+// while also extracting (or shape_cast-ing) any leading unit dimensions on
+// the input operands.
+template <typename OpTy>
+struct CastAwayLoadLikeLeadingOneDim : public OpRewritePattern<OpTy> {
+ using OpRewritePattern<OpTy>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(OpTy op,
+ PatternRewriter &rewriter) const override {
+ VectorType oldResultType = op.getVectorType();
+ VectorType newResultType = trimLeadingOneDims(oldResultType);
+ if (newResultType == oldResultType)
+ return failure();
+ int64_t nDropped = oldResultType.getRank() - newResultType.getRank();
+
+ Location loc = op.getLoc();
+ SmallVector<Value> newOperands;
+ newOperands.reserve(op->getNumOperands());
+ for (Value operand : op->getOperands()) {
+ if (isa<VectorType>(operand.getType())) {
+ newOperands.push_back(
+ dropLeadingOneDimsFromOperand(rewriter, loc, operand, nDropped));
+ } else {
+ newOperands.push_back(operand);
+ }
+ }
+
+ Operation *newOp =
+ rewriter.create(loc, op->getName().getIdentifier(), newOperands,
+ TypeRange{newResultType}, op->getAttrs());
+ rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, oldResultType,
+ newOp->getResult(0));
+ return success();
+ }
+};
+
+// Drops leading 1 dimensions from store-like memory ops. Extracts or
+// `shape_cast`s away those leading unit dimensions and leaves any scalar
+// operands alone.
+template <typename OpTy>
+struct CastAwayStoreLikeLeadingOneDim : public OpRewritePattern<OpTy> {
+ using OpRewritePattern<OpTy>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(OpTy op,
+ PatternRewriter &rewriter) const override {
+ VectorType oldVecType = op.getVectorType();
+ VectorType newVecType = trimLeadingOneDims(oldVecType);
+ if (newVecType == oldVecType)
+ return failure();
+ int64_t nDropped = oldVecType.getRank() - newVecType.getRank();
+
+ Location loc = op.getLoc();
+ SmallVector<Value> newOperands;
+ newOperands.reserve(op->getNumOperands());
+ for (Value operand : op->getOperands()) {
+ if (isa<VectorType>(operand.getType())) {
+ newOperands.push_back(
+ dropLeadingOneDimsFromOperand(rewriter, loc, operand, nDropped));
+ } else {
+ newOperands.push_back(operand);
+ }
+ }
+
+ Operation *newOp =
+ rewriter.create(loc, op->getName().getIdentifier(), newOperands,
+ op->getResultTypes(), op->getAttrs());
+ rewriter.replaceOp(op, newOp->getResults());
+ return success();
+ }
+};
// Drops leading 1 dimensions from vector.constant_mask and inserts a
// vector.broadcast back to the original shape.
@@ -578,5 +673,14 @@ void mlir::vector::populateCastAwayVectorLeadingOneDimPatterns(
CastAwayInsertStridedSliceLeadingOneDim, CastAwayInsertLeadingOneDim,
CastAwayConstantMaskLeadingOneDim, CastAwayTransferReadLeadingOneDim,
CastAwayTransferWriteLeadingOneDim, CastAwayElementwiseLeadingOneDim,
- CastAwayContractionLeadingOneDim>(patterns.getContext(), benefit);
+ CastAwayContractionLeadingOneDim,
+ CastAwayLoadLikeLeadingOneDim<vector::LoadOp>,
+ CastAwayLoadLikeLeadingOneDim<vector::MaskedLoadOp>,
+ CastAwayLoadLikeLeadingOneDim<vector::ExpandLoadOp>,
+ CastAwayLoadLikeLeadingOneDim<vector::GatherOp>,
+ CastAwayStoreLikeLeadingOneDim<vector::StoreOp>,
+ CastAwayStoreLikeLeadingOneDim<vector::MaskedStoreOp>,
+ CastAwayStoreLikeLeadingOneDim<vector::CompressStoreOp>,
+ CastAwayStoreLikeLeadingOneDim<vector::ScatterOp>>(
+ patterns.getContext(), benefit);
}
diff --git a/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir b/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
index aee77ce3da553..bf01c8a8589d9 100644
--- a/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
@@ -693,3 +693,98 @@ func.func @drop_unit_dims_scalar_cond_select(%cond: i1, %arg0: vector<1x16xi1>,
%sel = arith.select %cond, %arg0, %arg1 : vector<1x16xi1>
return %sel : vector<1x16xi1>
}
+
+// -----
+
+// CHECK-LABEL: func.func @cast_away_load_leading_one_dims
+// CHECK: %[[L:.+]] = vector.load %{{.*}}[%{{.*}}, %{{.*}}] : memref<8x16xf32>, vector<4xf32>
+// CHECK: %[[B:.+]] = vector.broadcast %[[L]] : vector<4xf32> to vector<1x4xf32>
+// CHECK: return %[[B]] : vector<1x4xf32>
+func.func @cast_away_load_leading_one_dims(%base: memref<8x16xf32>, %i: index, %j: index) -> vector<1x4xf32> {
+ %0 = vector.load %base[%i, %j] : memref<8x16xf32>, vector<1x4xf32>
+ return %0 : vector<1x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @cast_away_maskedload_leading_one_dims
+// CHECK: %[[M:.+]] = vector.extract %{{.*}}[0] : vector<4xi1> from vector<1x4xi1>
+// CHECK: %[[P:.+]] = vector.extract %{{.*}}[0] : vector<4xf32> from vector<1x4xf32>
+// CHECK: %[[L:.+]] = vector.maskedload %{{.*}}[%{{.*}}], %[[M]], %[[P]] : memref<16xf32>, vector<4xi1>, vector<4xf32> into vector<4xf32>
+// CHECK: %[[B:.+]] = vector.broadcast %[[L]] : vector<4xf32> to vector<1x4xf32>
+// CHECK: return %[[B]] : vector<1x4xf32>
+func.func @cast_away_maskedload_leading_one_dims(%base: memref<16xf32>, %i: index, %mask: vector<1x4xi1>, %pass: vector<1x4xf32>) -> vector<1x4xf32> {
+ %0 = vector.maskedload %base[%i], %mask, %pass : memref<16xf32>, vector<1x4xi1>, vector<1x4xf32> into vector<1x4xf32>
+ return %0 : vector<1x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @cast_away_expandload_leading_one_dims
+// CHECK: %[[M:.+]] = vector.extract %{{.*}}[0] : vector<4xi1> from vector<1x4xi1>
+// CHECK: %[[P:.+]] = vector.extract %{{.*}}[0] : vector<4xf32> from vector<1x4xf32>
+// CHECK: %[[L:.+]] = vector.expandload %{{.*}}[%{{.*}}], %[[M]], %[[P]] : memref<16xf32>, vector<4xi1>, vector<4xf32> into vector<4xf32>
+// CHECK: %[[B:.+]] = vector.broadcast %[[L]] : vector<4xf32> to vector<1x4xf32>
+// CHECK: return %[[B]] : vector<1x4xf32>
+func.func @cast_away_expandload_leading_one_dims(%base: memref<16xf32>, %i: index, %mask: vector<1x4xi1>, %pass: vector<1x4xf32>) -> vector<1x4xf32> {
+ %0 = vector.expandload %base[%i], %mask, %pass : memref<16xf32>, vector<1x4xi1>, vector<1x4xf32> into vector<1x4xf32>
+ return %0 : vector<1x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @cast_away_gather_leading_one_dims
+// CHECK: %[[I:.+]] = vector.extract %{{.*}}[0] : vector<4xi32> from vector<1x4xi32>
+// CHECK: %[[M:.+]] = vector.extract %{{.*}}[0] : vector<4xi1> from vector<1x4xi1>
+// CHECK: %[[P:.+]] = vector.extract %{{.*}}[0] : vector<4xf32> from vector<1x4xf32>
+// CHECK: %[[G:.+]] = vector.gather %{{.*}}[%{{.*}}] [%[[I]]], %[[M]], %[[P]] : memref<16xf32>, vector<4xi32>, vector<4xi1>, vector<4xf32> into vector<4xf32>
+// CHECK: %[[B:.+]] = vector.broadcast %[[G]] : vector<4xf32> to vector<1x4xf32>
+// CHECK: return %[[B]] : vector<1x4xf32>
+func.func @cast_away_gather_leading_one_dims(%base: memref<16xf32>, %i: index, %idx: vector<1x4xi32>, %mask: vector<1x4xi1>, %pass: vector<1x4xf32>) -> vector<1x4xf32> {
+ %0 = vector.gather %base[%i] [%idx], %mask, %pass : memref<16xf32>, vector<1x4xi32>, vector<1x4xi1>, vector<1x4xf32> into vector<1x4xf32>
+ return %0 : vector<1x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @cast_away_store_leading_one_dims
+// CHECK: %[[V:.+]] = vector.extract %{{.*}}[0] : vector<4xf32> from vector<1x4xf32>
+// CHECK: vector.store %[[V]], %{{.*}}[%{{.*}}, %{{.*}}] : memref<8x16xf32>, vector<4xf32>
+func.func @cast_away_store_leading_one_dims(%val: vector<1x4xf32>, %base: memref<8x16xf32>, %i: index, %j: index) {
+ vector.store %val, %base[%i, %j] : memref<8x16xf32>, vector<1x4xf32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func.func @cast_away_maskedstore_leading_one_dims
+// CHECK: %[[M:.+]] = vector.extract %{{.*}}[0] : vector<4xi1> from vector<1x4xi1>
+// CHECK: %[[V:.+]] = vector.extract %{{.*}}[0] : vector<4xf32> from vector<1x4xf32>
+// CHECK: vector.maskedstore %{{.*}}[%{{.*}}], %[[M]], %[[V]] : memref<16xf32>, vector<4xi1>, vector<4xf32>
+func.func @cast_away_maskedstore_leading_one_dims(%base: memref<16xf32>, %i: index, %mask: vector<1x4xi1>, %val: vector<1x4xf32>) {
+ vector.maskedstore %base[%i], %mask, %val : memref<16xf32>, vector<1x4xi1>, vector<1x4xf32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func.func @cast_away_compressstore_leading_one_dims
+// CHECK: %[[M:.+]] = vector.extract %{{.*}}[0] : vector<4xi1> from vector<1x4xi1>
+// CHECK: %[[V:.+]] = vector.extract %{{.*}}[0] : vector<4xf32> from vector<1x4xf32>
+// CHECK: vector.compressstore %{{.*}}[%{{.*}}], %[[M]], %[[V]] : memref<16xf32>, vector<4xi1>, vector<4xf32>
+func.func @cast_away_compressstore_leading_one_dims(%base: memref<16xf32>, %i: index, %mask: vector<1x4xi1>, %val: vector<1x4xf32>) {
+ vector.compressstore %base[%i], %mask, %val : memref<16xf32>, vector<1x4xi1>, vector<1x4xf32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func.func @cast_away_scatter_leading_one_dims
+// CHECK: %[[I:.+]] = vector.extract %{{.*}}[0] : vector<4xi32> from vector<1x4xi32>
+// CHECK: %[[M:.+]] = vector.extract %{{.*}}[0] : vector<4xi1> from vector<1x4xi1>
+// CHECK: %[[V:.+]] = vector.extract %{{.*}}[0] : vector<4xf32> from vector<1x4xf32>
+// CHECK: vector.scatter %{{.*}}[%{{.*}}] [%[[I]]], %[[M]], %[[V]] : memref<16xf32>, vector<4xi32>, vector<4xi1>, vector<4xf32>
+func.func @cast_away_scatter_leading_one_dims(%base: memref<16xf32>, %i: index, %idx: vector<1x4xi32>, %mask: vector<1x4xi1>, %val: vector<1x4xf32>) {
+ vector.scatter %base[%i] [%idx], %mask, %val : memref<16xf32>, vector<1x4xi32>, vector<1x4xi1>, vector<1x4xf32>
+ return
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/195686
More information about the Mlir-commits
mailing list