[Mlir-commits] [mlir] [mlir][vector] Support scalable vec in `TransferReadAfterWriteToBroadcast` (PR #79162)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Jan 23 08:19:28 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Andrzej WarzyĆski (banach-space)
<details>
<summary>Changes</summary>
Makes `TransferReadAfterWriteToBroadcast` correctly propagate
scalability flags.
---
Full diff: https://github.com/llvm/llvm-project/pull/79162.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+7-2)
- (modified) mlir/test/Dialect/Vector/canonicalize.mlir (+18)
``````````diff
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 791924f92e8ad40..5f9860ef2311e7a 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -4079,10 +4079,15 @@ struct TransferReadAfterWriteToBroadcast
// final shape we want.
ArrayRef<int64_t> destShape = readOp.getVectorType().getShape();
SmallVector<int64_t> broadcastShape(destShape.size());
- for (const auto &pos : llvm::enumerate(permutation))
+ SmallVector<bool> broadcastScalableFlags(destShape.size());
+ for (const auto &pos : llvm::enumerate(permutation)) {
broadcastShape[pos.value()] = destShape[pos.index()];
+ broadcastScalableFlags[pos.value()] =
+ readOp.getVectorType().getScalableDims()[pos.index()];
+ }
VectorType broadcastedType = VectorType::get(
- broadcastShape, defWrite.getVectorType().getElementType());
+ broadcastShape, defWrite.getVectorType().getElementType(),
+ broadcastScalableFlags);
vec = rewriter.create<vector::BroadcastOp>(loc, broadcastedType, vec);
SmallVector<int64_t> transposePerm(permutation.begin(), permutation.end());
rewriter.replaceOpWithNewOp<vector::TransposeOp>(readOp, vec,
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index feefb0c174aab4e..e6f045e12e51973 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1302,6 +1302,24 @@ func.func @store_to_load_tensor_broadcast(%arg0 : tensor<4x4xf32>,
// -----
+// CHECK-LABEL: func @store_to_load_tensor_broadcast_scalable
+// CHECK-SAME: (%[[ARG:.*]]: tensor<?xf32>, %[[V0:.*]]: vector<[4]xf32>)
+// CHECK: %[[B:.*]] = vector.broadcast %[[V0]] : vector<[4]xf32> to vector<6x[4]xf32>
+// CHECK: return %[[B]] : vector<6x[4]xf32>
+func.func @store_to_load_tensor_broadcast_scalable(%arg0 : tensor<?xf32>,
+ %v0 : vector<[4]xf32>) -> vector<6x[4]xf32> {
+ %c0 = arith.constant 0 : index
+ %cf0 = arith.constant 0.0 : f32
+ %w0 = vector.transfer_write %v0, %arg0[%c0] {in_bounds = [true]} :
+ vector<[4]xf32>, tensor<?xf32>
+ %0 = vector.transfer_read %w0[%c0], %cf0 {in_bounds = [true, true],
+ permutation_map = affine_map<(d0) -> (0, d0)>} :
+ tensor<?xf32>, vector<6x[4]xf32>
+ return %0 : vector<6x[4]xf32>
+}
+
+// -----
+
// CHECK-LABEL: func @store_to_load_tensor_perm_broadcast
// CHECK-SAME: (%[[ARG:.*]]: tensor<4x4x4xf32>, %[[V0:.*]]: vector<4x1xf32>)
// CHECK: %[[B:.*]] = vector.broadcast %[[V0]] : vector<4x1xf32> to vector<100x5x4x1xf32>
``````````
</details>
https://github.com/llvm/llvm-project/pull/79162
More information about the Mlir-commits
mailing list