[Mlir-commits] [mlir] c84061f - [mlir][vector] Fix a `target-rank=0` unrolling (#73365)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Nov 30 04:29:16 PST 2023


Author: Rik Huijzer
Date: 2023-11-30T13:29:09+01:00
New Revision: c84061fd343cdd647dd18321aa555c5d358c2d65

URL: https://github.com/llvm/llvm-project/commit/c84061fd343cdd647dd18321aa555c5d358c2d65
DIFF: https://github.com/llvm/llvm-project/commit/c84061fd343cdd647dd18321aa555c5d358c2d65.diff

LOG: [mlir][vector] Fix a `target-rank=0` unrolling (#73365)

Fixes https://github.com/llvm/llvm-project/issues/64269.

With this patch, calling `mlir-opt "-convert-vector-to-scf=full-unroll
target-rank=0"` on
```mlir
func.func @main(%vec : vector<2xi32>) {
  %alloc = memref.alloc() : memref<4xi32>
  %c0 = arith.constant 0 : index
  vector.transfer_write %vec, %alloc[%c0] : vector<2xi32>, memref<4xi32>
  return
}
```
will result in
```mlir
module {
  func.func @main(%arg0: vector<2xi32>) {
    %c0 = arith.constant 0 : index
    %c1 = arith.constant 1 : index
    %alloc = memref.alloc() : memref<4xi32>
    %0 = vector.extract %arg0[0] : i32 from vector<2xi32>
    %1 = vector.broadcast %0 : i32 to vector<i32>
    vector.transfer_write %1, %alloc[%c0] : vector<i32>, memref<4xi32>
    %2 = vector.extract %arg0[1] : i32 from vector<2xi32>
    %3 = vector.broadcast %2 : i32 to vector<i32>
    vector.transfer_write %3, %alloc[%c1] : vector<i32>, memref<4xi32>
    return
  }
}
```

I've also tried to proactively find other `target-rank=0` bugs, but
couldn't find any. `options.targetRank` is only used 8 times throughout
the `mlir` folder, all inside `VectorToSCF.cpp`. None of the other uses
look like they could cause a crash. I've also tried

```mlir
func.func @main(%vec : vector<2xi32>) -> vector<2xi32> {
  %alloc = memref.alloc() : memref<4xindex>
  %c0 = arith.constant 0 : index
  %out = vector.transfer_read %alloc[%c0], %c0 : memref<4xindex>, vector<2xi32>
  return %out : vector<2xi32>
}
```
with `"--convert-vector-to-scf=full-unroll target-rank=0"` and that also
didn't crash. (Maybe obvious. I have to admit that I'm not very familiar
with these ops.)

Added: 
    

Modified: 
    mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
    mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index 6ba3a678b2da7eb..33a77d7576ba70b 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -21,6 +21,7 @@
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
 #include "mlir/IR/Builders.h"
@@ -1207,23 +1208,25 @@ struct UnrollTransferWriteConversion
   /// accesses, and broadcasts and transposes in permutation maps.
   LogicalResult matchAndRewrite(TransferWriteOp xferOp,
                                 PatternRewriter &rewriter) const override {
-    if (xferOp.getVectorType().getRank() <= options.targetRank)
+    VectorType inputVectorTy = xferOp.getVectorType();
+
+    if (inputVectorTy.getRank() <= options.targetRank)
       return failure();
+
     if (isTensorOp(xferOp) && !options.lowerTensors)
       return failure();
     // Transfer ops that modify the element type are not supported atm.
-    if (xferOp.getVectorType().getElementType() !=
+    if (inputVectorTy.getElementType() !=
         xferOp.getShapedType().getElementType())
       return failure();
 
     auto vec = getDataVector(xferOp);
-    auto xferVecType = xferOp.getVectorType();
-    if (xferVecType.getScalableDims()[0]) {
+    if (inputVectorTy.getScalableDims()[0]) {
       // Cannot unroll a scalable dimension at compile time.
       return failure();
     }
 
-    int64_t dimSize = xferVecType.getShape()[0];
+    int64_t dimSize = inputVectorTy.getShape()[0];
     Value source = xferOp.getSource(); // memref or tensor to be written to.
     auto sourceType = isTensorOp(xferOp) ? xferOp.getShapedType() : Type();
 
@@ -1249,8 +1252,18 @@ struct UnrollTransferWriteConversion
             auto extracted =
                 b.create<vector::ExtractOp>(loc, vec, extractionIndices);
             auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
+            Value xferVec;
+            if (inputVectorTy.getRank() == 1) {
+              // When target-rank=0, unrolling would causes the vector input
+              // argument into `transfer_write` to become a scalar. We solve
+              // this by broadcasting the scalar to a 0D vector.
+              xferVec = b.create<vector::BroadcastOp>(
+                  loc, VectorType::get({}, extracted.getType()), extracted);
+            } else {
+              xferVec = extracted;
+            }
             auto newXferOp = b.create<vector::TransferWriteOp>(
-                loc, sourceType, extracted, source, xferIndices,
+                loc, sourceType, xferVec, source, xferIndices,
                 AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), Value(),
                 inBoundsAttr);
 

diff  --git a/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir b/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
index 597cc4f71a63961..ad78f0c945b24d9 100644
--- a/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
+++ b/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
@@ -1,5 +1,6 @@
 // RUN: mlir-opt %s -pass-pipeline="builtin.module(func.func(convert-vector-to-scf))" -split-input-file -allow-unregistered-dialect | FileCheck %s
 // RUN: mlir-opt %s -pass-pipeline="builtin.module(func.func(convert-vector-to-scf{full-unroll=true}))" -split-input-file -allow-unregistered-dialect | FileCheck %s --check-prefix=FULL-UNROLL
+// RUN: mlir-opt %s "-convert-vector-to-scf=full-unroll target-rank=0" -split-input-file -allow-unregistered-dialect | FileCheck %s --check-prefix=TARGET-RANK-ZERO
 
 // CHECK-LABEL: func @vector_transfer_ops_0d(
 func.func @vector_transfer_ops_0d(%M: memref<f32>) {
@@ -748,3 +749,20 @@ func.func @cannot_fully_unroll_transfer_write_of_nd_scalable_vector(%vec: vector
   vector.transfer_write %vec, %memref[%c0, %c0] {in_bounds = [true, true]} : vector<[4]x[4]xf32>, memref<?x?xf32>
   return
 }
+
+//  -----
+
+// TARGET-RANK-ZERO-LABEL: func @unroll_transfer_write_target_rank_zero
+func.func @unroll_transfer_write_target_rank_zero(%vec : vector<2xi32>) {
+  %alloc = memref.alloc() : memref<4xi32>
+  %c0 = arith.constant 0 : index
+  vector.transfer_write %vec, %alloc[%c0] : vector<2xi32>, memref<4xi32>
+  return
+}
+// TARGET-RANK-ZERO: %[[ALLOC:.*]] = memref.alloc() : memref<4xi32>
+// TARGET-RANK-ZERO: %[[EXTRACTED1:.*]] = vector.extract {{.*}} : i32 from vector<2xi32>
+// TARGET-RANK-ZERO: %[[BROADCASTED1:.*]] = vector.broadcast %[[EXTRACTED1]] : i32 to vector<i32>
+// TARGET-RANK-ZERO: vector.transfer_write %[[BROADCASTED1]], %[[ALLOC]]{{.*}} : vector<i32>, memref<4xi32>
+// TARGET-RANK-ZERO: %[[EXTRACTED2:.*]] = vector.extract {{.*}} : i32 from vector<2xi32>
+// TARGET-RANK-ZERO: %[[BROADCASTED2:.*]] = vector.broadcast %[[EXTRACTED2]] : i32 to vector<i32>
+// TARGET-RANK-ZERO: vector.transfer_write %[[BROADCASTED2]], %[[ALLOC]]{{.*}} : vector<i32>, memref<4xi32>


        


More information about the Mlir-commits mailing list