[Mlir-commits] [mlir] [mlir][vector] Fix a `target-rank=0` unrolling (PR #73365)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Nov 24 12:18:39 PST 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Rik Huijzer (rikhuijzer)
<details>
<summary>Changes</summary>
Fixes https://github.com/llvm/llvm-project/issues/64269.
I've also tried to proactively find other `target-rank=0` bugs, but couldn't find any. `options.targetRank` is only used 8 times throughout the `mlir` folder, all inside `VectorToSCF.cpp`. None of the other uses look like they could cause a crash. I've also tried
```mlir
func.func @<!-- -->main(%vec : vector<2xi32>) -> vector<2xi32> {
%alloc = memref.alloc() : memref<4xindex>
%c0 = arith.constant 0 : index
%out = vector.transfer_read %alloc[%c0], %c0 : memref<4xindex>, vector<2xi32>
return %out : vector<2xi32>
}
```
with `"--convert-vector-to-scf=full-unroll target-rank=0"` and that also didn't crash. (Maybe obvious. I have to admit that I'm not very familiar with these ops.)
---
Full diff: https://github.com/llvm/llvm-project/pull/73365.diff
2 Files Affected:
- (modified) mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp (+12-5)
- (modified) mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir (+14)
``````````diff
diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index 6ba3a678b2da7eb..1483c88cc614a86 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -1207,23 +1207,30 @@ struct UnrollTransferWriteConversion
/// accesses, and broadcasts and transposes in permutation maps.
LogicalResult matchAndRewrite(TransferWriteOp xferOp,
PatternRewriter &rewriter) const override {
- if (xferOp.getVectorType().getRank() <= options.targetRank)
+ VectorType inputVectorTy = xferOp.getVectorType();
+
+ if (inputVectorTy.getRank() <= options.targetRank)
return failure();
+
+ // When target-rank=0, unrolling would cause the vector input argument
+ // into `transfer_write` to become a scalar.
+ if (inputVectorTy.getRank() == 1)
+ return failure();
+
if (isTensorOp(xferOp) && !options.lowerTensors)
return failure();
// Transfer ops that modify the element type are not supported atm.
- if (xferOp.getVectorType().getElementType() !=
+ if (inputVectorTy.getElementType() !=
xferOp.getShapedType().getElementType())
return failure();
auto vec = getDataVector(xferOp);
- auto xferVecType = xferOp.getVectorType();
- if (xferVecType.getScalableDims()[0]) {
+ if (inputVectorTy.getScalableDims()[0]) {
// Cannot unroll a scalable dimension at compile time.
return failure();
}
- int64_t dimSize = xferVecType.getShape()[0];
+ int64_t dimSize = inputVectorTy.getShape()[0];
Value source = xferOp.getSource(); // memref or tensor to be written to.
auto sourceType = isTensorOp(xferOp) ? xferOp.getShapedType() : Type();
diff --git a/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir b/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
index 597cc4f71a63961..d5b02006dbd12fb 100644
--- a/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
+++ b/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
@@ -1,5 +1,6 @@
// RUN: mlir-opt %s -pass-pipeline="builtin.module(func.func(convert-vector-to-scf))" -split-input-file -allow-unregistered-dialect | FileCheck %s
// RUN: mlir-opt %s -pass-pipeline="builtin.module(func.func(convert-vector-to-scf{full-unroll=true}))" -split-input-file -allow-unregistered-dialect | FileCheck %s --check-prefix=FULL-UNROLL
+// RUN: mlir-opt %s "-convert-vector-to-scf=full-unroll target-rank=0" -split-input-file -allow-unregistered-dialect | FileCheck %s --check-prefix=TARGET-RANK-ZERO
// CHECK-LABEL: func @vector_transfer_ops_0d(
func.func @vector_transfer_ops_0d(%M: memref<f32>) {
@@ -748,3 +749,16 @@ func.func @cannot_fully_unroll_transfer_write_of_nd_scalable_vector(%vec: vector
vector.transfer_write %vec, %memref[%c0, %c0] {in_bounds = [true, true]} : vector<[4]x[4]xf32>, memref<?x?xf32>
return
}
+
+// -----
+
+// TARGET-RANK-ZERO-LABEL: func @cannot_further_unroll_transfer_write
+func.func @cannot_further_unroll_transfer_write(%vec : vector<2xi32>) {
+ // TARGET-RANK-ZERO-NOT: vector.extract
+ // TARGET-RANK-ZERO: vector.transfer_write
+ // TARGET-RANK-ZERO-NOT: vector.extract
+ %alloc = memref.alloc() : memref<4xi32>
+ %c0 = arith.constant 0 : index
+ vector.transfer_write %vec, %alloc[%c0] : vector<2xi32>, memref<4xi32>
+ return
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/73365
More information about the Mlir-commits
mailing list