[Mlir-commits] [mlir] [mlir][vector] Fix a `target-rank=0` unrolling (PR #73365)
Rik Huijzer
llvmlistbot at llvm.org
Thu Nov 30 04:17:11 PST 2023
https://github.com/rikhuijzer updated https://github.com/llvm/llvm-project/pull/73365
>From 74c6b927c77b50042817dc7f46dccbc1995c721c Mon Sep 17 00:00:00 2001
From: Rik Huijzer <github at huijzer.xyz>
Date: Fri, 24 Nov 2023 20:55:10 +0100
Subject: [PATCH 1/4] [mlir][vector] Fix a `target-rank=0` unrolling
---
mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp | 17 ++++++++++++-----
.../Conversion/VectorToSCF/vector-to-scf.mlir | 14 ++++++++++++++
2 files changed, 26 insertions(+), 5 deletions(-)
diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index 6ba3a678b2da7eb..1483c88cc614a86 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -1207,23 +1207,30 @@ struct UnrollTransferWriteConversion
/// accesses, and broadcasts and transposes in permutation maps.
LogicalResult matchAndRewrite(TransferWriteOp xferOp,
PatternRewriter &rewriter) const override {
- if (xferOp.getVectorType().getRank() <= options.targetRank)
+ VectorType inputVectorTy = xferOp.getVectorType();
+
+ if (inputVectorTy.getRank() <= options.targetRank)
return failure();
+
+ // When target-rank=0, unrolling would cause the vector input argument
+ // into `transfer_write` to become a scalar.
+ if (inputVectorTy.getRank() == 1)
+ return failure();
+
if (isTensorOp(xferOp) && !options.lowerTensors)
return failure();
// Transfer ops that modify the element type are not supported atm.
- if (xferOp.getVectorType().getElementType() !=
+ if (inputVectorTy.getElementType() !=
xferOp.getShapedType().getElementType())
return failure();
auto vec = getDataVector(xferOp);
- auto xferVecType = xferOp.getVectorType();
- if (xferVecType.getScalableDims()[0]) {
+ if (inputVectorTy.getScalableDims()[0]) {
// Cannot unroll a scalable dimension at compile time.
return failure();
}
- int64_t dimSize = xferVecType.getShape()[0];
+ int64_t dimSize = inputVectorTy.getShape()[0];
Value source = xferOp.getSource(); // memref or tensor to be written to.
auto sourceType = isTensorOp(xferOp) ? xferOp.getShapedType() : Type();
diff --git a/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir b/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
index 597cc4f71a63961..318c6583c15af56 100644
--- a/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
+++ b/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
@@ -1,5 +1,6 @@
// RUN: mlir-opt %s -pass-pipeline="builtin.module(func.func(convert-vector-to-scf))" -split-input-file -allow-unregistered-dialect | FileCheck %s
// RUN: mlir-opt %s -pass-pipeline="builtin.module(func.func(convert-vector-to-scf{full-unroll=true}))" -split-input-file -allow-unregistered-dialect | FileCheck %s --check-prefix=FULL-UNROLL
+// RUN: mlir-opt %s "-convert-vector-to-scf=full-unroll target-rank=0" -split-input-file -allow-unregistered-dialect | FileCheck %s --check-prefix=TARGET-RANK-ZERO
// CHECK-LABEL: func @vector_transfer_ops_0d(
func.func @vector_transfer_ops_0d(%M: memref<f32>) {
@@ -748,3 +749,16 @@ func.func @cannot_fully_unroll_transfer_write_of_nd_scalable_vector(%vec: vector
vector.transfer_write %vec, %memref[%c0, %c0] {in_bounds = [true, true]} : vector<[4]x[4]xf32>, memref<?x?xf32>
return
}
+
+// -----
+
+// TARGET-RANK-ZERO-LABEL: func @cannot_further_unroll_transfer_write
+func.func @cannot_further_unroll_transfer_write(%vec : vector<2xi32>) {
+ // FULL-UNROLL-NOT: vector.extract
+ // FULL-UNROLL: vector.transfer_write
+ // FULL-UNROLL-NOT: vector.extract
+ %alloc = memref.alloc() : memref<4xi32>
+ %c0 = arith.constant 0 : index
+ vector.transfer_write %vec, %alloc[%c0] : vector<2xi32>, memref<4xi32>
+ return
+}
>From 7dba76e5bb8800bd08922e2b54bac14e808046c5 Mon Sep 17 00:00:00 2001
From: Rik Huijzer <github at huijzer.xyz>
Date: Fri, 24 Nov 2023 21:17:32 +0100
Subject: [PATCH 2/4] Fix label in test
---
mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir b/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
index 318c6583c15af56..d5b02006dbd12fb 100644
--- a/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
+++ b/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
@@ -754,9 +754,9 @@ func.func @cannot_fully_unroll_transfer_write_of_nd_scalable_vector(%vec: vector
// TARGET-RANK-ZERO-LABEL: func @cannot_further_unroll_transfer_write
func.func @cannot_further_unroll_transfer_write(%vec : vector<2xi32>) {
- // FULL-UNROLL-NOT: vector.extract
- // FULL-UNROLL: vector.transfer_write
- // FULL-UNROLL-NOT: vector.extract
+ // TARGET-RANK-ZERO-NOT: vector.extract
+ // TARGET-RANK-ZERO: vector.transfer_write
+ // TARGET-RANK-ZERO-NOT: vector.extract
%alloc = memref.alloc() : memref<4xi32>
%c0 = arith.constant 0 : index
vector.transfer_write %vec, %alloc[%c0] : vector<2xi32>, memref<4xi32>
>From 769f314209109353835c9fe252ab4a998f4b0753 Mon Sep 17 00:00:00 2001
From: Rik Huijzer <github at huijzer.xyz>
Date: Wed, 29 Nov 2023 19:43:58 +0100
Subject: [PATCH 3/4] Broadcast the scalar to a 0D vector
---
.../lib/Conversion/VectorToSCF/VectorToSCF.cpp | 18 ++++++++++++------
.../Conversion/VectorToSCF/vector-to-scf.mlir | 10 +++++-----
2 files changed, 17 insertions(+), 11 deletions(-)
diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index 1483c88cc614a86..33a77d7576ba70b 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -21,6 +21,7 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
#include "mlir/IR/Builders.h"
@@ -1212,11 +1213,6 @@ struct UnrollTransferWriteConversion
if (inputVectorTy.getRank() <= options.targetRank)
return failure();
- // When target-rank=0, unrolling would cause the vector input argument
- // into `transfer_write` to become a scalar.
- if (inputVectorTy.getRank() == 1)
- return failure();
-
if (isTensorOp(xferOp) && !options.lowerTensors)
return failure();
// Transfer ops that modify the element type are not supported atm.
@@ -1256,8 +1252,18 @@ struct UnrollTransferWriteConversion
auto extracted =
b.create<vector::ExtractOp>(loc, vec, extractionIndices);
auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
+ Value xferVec;
+ if (inputVectorTy.getRank() == 1) {
+ // When target-rank=0, unrolling would causes the vector input
+ // argument into `transfer_write` to become a scalar. We solve
+ // this by broadcasting the scalar to a 0D vector.
+ xferVec = b.create<vector::BroadcastOp>(
+ loc, VectorType::get({}, extracted.getType()), extracted);
+ } else {
+ xferVec = extracted;
+ }
auto newXferOp = b.create<vector::TransferWriteOp>(
- loc, sourceType, extracted, source, xferIndices,
+ loc, sourceType, xferVec, source, xferIndices,
AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), Value(),
inBoundsAttr);
diff --git a/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir b/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
index d5b02006dbd12fb..7ed85080581348f 100644
--- a/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
+++ b/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
@@ -752,11 +752,11 @@ func.func @cannot_fully_unroll_transfer_write_of_nd_scalable_vector(%vec: vector
// -----
-// TARGET-RANK-ZERO-LABEL: func @cannot_further_unroll_transfer_write
-func.func @cannot_further_unroll_transfer_write(%vec : vector<2xi32>) {
- // TARGET-RANK-ZERO-NOT: vector.extract
- // TARGET-RANK-ZERO: vector.transfer_write
- // TARGET-RANK-ZERO-NOT: vector.extract
+// TARGET-RANK-ZERO-LABEL: func @unroll_transfer_write_target_rank_zero
+func.func @unroll_transfer_write_target_rank_zero(%vec : vector<2xi32>) {
+ // TARGET-RANK-ZERO: %[[EXTRACTED:.*]] = vector.extract
+ // TARGET-RANK-ZERO: %[[BROADCASTED:.*]] = vector.broadcast %[[EXTRACTED]]
+ // TARGET-RANK-ZERO: vector.transfer_write %[[BROADCASTED]]
%alloc = memref.alloc() : memref<4xi32>
%c0 = arith.constant 0 : index
vector.transfer_write %vec, %alloc[%c0] : vector<2xi32>, memref<4xi32>
>From 16eeac7db3df9908d30ba4f33f342277bbee6c29 Mon Sep 17 00:00:00 2001
From: Rik Huijzer <github at huijzer.xyz>
Date: Thu, 30 Nov 2023 13:16:53 +0100
Subject: [PATCH 4/4] Check the types and both extract/broadcast/transfer_write
sequences
---
mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir | 10 +++++++---
1 file changed, 7 insertions(+), 3 deletions(-)
diff --git a/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir b/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
index 7ed85080581348f..ad78f0c945b24d9 100644
--- a/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
+++ b/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
@@ -754,11 +754,15 @@ func.func @cannot_fully_unroll_transfer_write_of_nd_scalable_vector(%vec: vector
// TARGET-RANK-ZERO-LABEL: func @unroll_transfer_write_target_rank_zero
func.func @unroll_transfer_write_target_rank_zero(%vec : vector<2xi32>) {
- // TARGET-RANK-ZERO: %[[EXTRACTED:.*]] = vector.extract
- // TARGET-RANK-ZERO: %[[BROADCASTED:.*]] = vector.broadcast %[[EXTRACTED]]
- // TARGET-RANK-ZERO: vector.transfer_write %[[BROADCASTED]]
%alloc = memref.alloc() : memref<4xi32>
%c0 = arith.constant 0 : index
vector.transfer_write %vec, %alloc[%c0] : vector<2xi32>, memref<4xi32>
return
}
+// TARGET-RANK-ZERO: %[[ALLOC:.*]] = memref.alloc() : memref<4xi32>
+// TARGET-RANK-ZERO: %[[EXTRACTED1:.*]] = vector.extract {{.*}} : i32 from vector<2xi32>
+// TARGET-RANK-ZERO: %[[BROADCASTED1:.*]] = vector.broadcast %[[EXTRACTED1]] : i32 to vector<i32>
+// TARGET-RANK-ZERO: vector.transfer_write %[[BROADCASTED1]], %[[ALLOC]]{{.*}} : vector<i32>, memref<4xi32>
+// TARGET-RANK-ZERO: %[[EXTRACTED2:.*]] = vector.extract {{.*}} : i32 from vector<2xi32>
+// TARGET-RANK-ZERO: %[[BROADCASTED2:.*]] = vector.broadcast %[[EXTRACTED2]] : i32 to vector<i32>
+// TARGET-RANK-ZERO: vector.transfer_write %[[BROADCASTED2]], %[[ALLOC]]{{.*}} : vector<i32>, memref<4xi32>
More information about the Mlir-commits
mailing list