[Mlir-commits] [mlir] [mlir][vector] Support scalable vec in `TransferReadAfterWriteToBroadcast` (PR #79162)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Jan 23 08:19:28 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Andrzej WarzyƄski (banach-space)

<details>
<summary>Changes</summary>

Makes `TransferReadAfterWriteToBroadcast` correctly propagate
scalability flags.


---
Full diff: https://github.com/llvm/llvm-project/pull/79162.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+7-2) 
- (modified) mlir/test/Dialect/Vector/canonicalize.mlir (+18) 


``````````diff
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 791924f92e8ad40..5f9860ef2311e7a 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -4079,10 +4079,15 @@ struct TransferReadAfterWriteToBroadcast
     // final shape we want.
     ArrayRef<int64_t> destShape = readOp.getVectorType().getShape();
     SmallVector<int64_t> broadcastShape(destShape.size());
-    for (const auto &pos : llvm::enumerate(permutation))
+    SmallVector<bool> broadcastScalableFlags(destShape.size());
+    for (const auto &pos : llvm::enumerate(permutation)) {
       broadcastShape[pos.value()] = destShape[pos.index()];
+      broadcastScalableFlags[pos.value()] =
+          readOp.getVectorType().getScalableDims()[pos.index()];
+    }
     VectorType broadcastedType = VectorType::get(
-        broadcastShape, defWrite.getVectorType().getElementType());
+        broadcastShape, defWrite.getVectorType().getElementType(),
+        broadcastScalableFlags);
     vec = rewriter.create<vector::BroadcastOp>(loc, broadcastedType, vec);
     SmallVector<int64_t> transposePerm(permutation.begin(), permutation.end());
     rewriter.replaceOpWithNewOp<vector::TransposeOp>(readOp, vec,
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index feefb0c174aab4e..e6f045e12e51973 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1302,6 +1302,24 @@ func.func @store_to_load_tensor_broadcast(%arg0 : tensor<4x4xf32>,
 
 // -----
 
+// CHECK-LABEL: func @store_to_load_tensor_broadcast_scalable
+//  CHECK-SAME: (%[[ARG:.*]]: tensor<?xf32>, %[[V0:.*]]: vector<[4]xf32>)
+//       CHECK:   %[[B:.*]] = vector.broadcast %[[V0]] : vector<[4]xf32> to vector<6x[4]xf32>
+//       CHECK:   return %[[B]] : vector<6x[4]xf32>
+func.func @store_to_load_tensor_broadcast_scalable(%arg0 : tensor<?xf32>,
+  %v0 : vector<[4]xf32>) -> vector<6x[4]xf32> {
+  %c0 = arith.constant 0 : index
+  %cf0 = arith.constant 0.0 : f32
+  %w0 = vector.transfer_write %v0, %arg0[%c0] {in_bounds = [true]} :
+    vector<[4]xf32>, tensor<?xf32>
+  %0 = vector.transfer_read %w0[%c0], %cf0 {in_bounds = [true, true],
+  permutation_map = affine_map<(d0) -> (0, d0)>} :
+    tensor<?xf32>, vector<6x[4]xf32>
+  return %0 : vector<6x[4]xf32>
+}
+
+// -----
+
 // CHECK-LABEL: func @store_to_load_tensor_perm_broadcast
 //  CHECK-SAME: (%[[ARG:.*]]: tensor<4x4x4xf32>, %[[V0:.*]]: vector<4x1xf32>)
 //       CHECK:   %[[B:.*]] = vector.broadcast %[[V0]] : vector<4x1xf32> to vector<100x5x4x1xf32>

``````````

</details>


https://github.com/llvm/llvm-project/pull/79162


More information about the Mlir-commits mailing list