[Mlir-commits] [mlir] 6dc9725 - [mlir][vector] Fix lowering of permutation maps for transfer_write op
Thomas Raoux
llvmlistbot at llvm.org
Tue Jan 17 09:04:13 PST 2023
Author: Thomas Raoux
Date: 2023-01-17T17:04:04Z
New Revision: 6dc9725471e05fe12bd72406f97daca49a47a0c0
URL: https://github.com/llvm/llvm-project/commit/6dc9725471e05fe12bd72406f97daca49a47a0c0
DIFF: https://github.com/llvm/llvm-project/commit/6dc9725471e05fe12bd72406f97daca49a47a0c0.diff
LOG: [mlir][vector] Fix lowering of permutation maps for transfer_write op
The lowering of transfer write permutation maps didn't match the op definition:
https://github.com/llvm/llvm-project/blob/93ccccb00d9717b58ba93f0942a243ba6dac4ef6/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td#L1476
Fix the lowering and add a case to the integration test in
order to enforce the correct semantic.
Differential Revision: https://reviews.llvm.org/D141801
Added:
Modified:
mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/lib/Dialect/Vector/Transforms/VectorTransferPermutationMapRewritePatterns.cpp
mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
mlir/test/Integration/Dialect/Vector/CPU/test-transfer-write.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index abf4f887f971..e0711a48b300 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -1330,6 +1330,18 @@ def Vector_TransferReadOp :
memref<?x?x?xf32>, vector<32x256xf32>
}}}
+ // or equivalently (rewrite with vector.transpose)
+ %f0 = arith.constant 0.0f : f32
+ for %i0 = 0 to %0 {
+ affine.for %i1 = 0 to %1 step 256 {
+ affine.for %i2 = 0 to %2 step 32 {
+ %v0 = vector.transfer_read %A[%i0, %i1, %i2], (%f0)
+ {permutation_map: (d0, d1, d2) -> (d1, d2)} :
+ memref<?x?x?xf32>, vector<256x32xf32>
+ %v = vector.transpose %v0, [1, 0] :
+ vector<256x32xf32> to vector<32x256f32>
+ }}}
+
// Read the slice `%A[%i0, %i1]` (i.e. the element `%A[%i0, %i1]`) into
// vector<128xf32>. The underlying implementation will require a 1-D vector
// broadcast:
@@ -1485,6 +1497,19 @@ def Vector_TransferWriteOp :
vector<16x32x64xf32>, memref<?x?x?x?xf32>
}}}}
+ // or equivalently (rewrite with vector.transpose)
+ for %i0 = 0 to %0 {
+ affine.for %i1 = 0 to %1 step 32 {
+ affine.for %i2 = 0 to %2 step 64 {
+ affine.for %i3 = 0 to %3 step 16 {
+ %val = `ssa-value` : vector<16x32x64xf32>
+ %valt = vector.transpose %val, [1, 2, 0] :
+ vector<16x32x64xf32> -> vector<32x64x16xf32>
+ vector.transfer_write %valt, %A[%i0, %i1, %i2, %i3]
+ {permutation_map: (d0, d1, d2, d3) -> (d1, d2, d3)} :
+ vector<32x64x16xf32>, memref<?x?x?x?xf32>
+ }}}}
+
// write to a memref with vector element type.
vector.transfer_write %4, %arg1[%c3, %c3]
{permutation_map = (d0, d1)->(d0, d1)}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 0bb31e970aa6..2bc0c2a0d777 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -456,7 +456,7 @@ static Value buildVectorWrite(RewriterBase &rewriter, Value value,
Operation *write;
if (vectorType.getRank() > 0) {
- AffineMap writeMap = reindexIndexingMap(opOperandMap);
+ AffineMap writeMap = inversePermutation(reindexIndexingMap(opOperandMap));
SmallVector<Value> indices(linalgOp.getRank(outputOperand),
rewriter.create<arith::ConstantIndexOp>(loc, 0));
value = broadcastIfNeeded(rewriter, value, vectorType.getShape());
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 933945233c88..db2cd8279000 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -3373,12 +3373,12 @@ void TransferReadOp::print(OpAsmPrinter &p) {
p << " : " << getShapedType() << ", " << getVectorType();
}
-/// Infers the mask type for a transfer read given its vector type and
-/// permutation map. The mask in a transfer read operation applies to the
-/// tensor/buffer reading part of it and its type should match the shape read
+/// Infers the mask type for a transfer op given its vector type and
+/// permutation map. The mask in a transfer op operation applies to the
+/// tensor/buffer part of it and its type should match the vector shape
/// *before* any permutation or broadcasting.
-static VectorType inferTransferReadMaskType(VectorType vecType,
- AffineMap permMap) {
+static VectorType inferTransferOpMaskType(VectorType vecType,
+ AffineMap permMap) {
auto i1Type = IntegerType::get(permMap.getContext(), 1);
AffineMap invPermMap = inversePermutation(compressUnusedDims(permMap));
assert(invPermMap && "Inversed permutation map couldn't be computed");
@@ -3436,7 +3436,7 @@ ParseResult TransferReadOp::parse(OpAsmParser &parser, OperationState &result) {
maskInfo.location, "does not support masks with vector element type");
// Instead of adding the mask type as an op type, compute it based on the
// vector type and the permutation map (to keep the type signature small).
- auto maskType = inferTransferReadMaskType(vectorType, permMap);
+ auto maskType = inferTransferOpMaskType(vectorType, permMap);
if (parser.resolveOperand(maskInfo, maskType, result.operands))
return failure();
}
@@ -3455,7 +3455,7 @@ LogicalResult TransferReadOp::verify() {
auto paddingType = getPadding().getType();
auto permutationMap = getPermutationMap();
VectorType inferredMaskType =
- maskType ? inferTransferReadMaskType(vectorType, permutationMap)
+ maskType ? inferTransferOpMaskType(vectorType, permutationMap)
: VectorType();
auto sourceElementType = shapedType.getElementType();
@@ -3495,7 +3495,7 @@ LogicalResult TransferReadOp::verify() {
/// Returns the mask type expected by this operation. Mostly used for
/// verification purposes. It requires the operation to be vectorized."
Type TransferReadOp::getExpectedMaskType() {
- return inferTransferReadMaskType(getVectorType(), getPermutationMap());
+ return inferTransferOpMaskType(getVectorType(), getPermutationMap());
}
template <typename TransferOp>
@@ -3836,18 +3836,6 @@ void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
build(builder, result, vector, dest, indices, permutationMap, inBounds);
}
-/// Infers the mask type for a transfer write given its vector type and
-/// permutation map. The mask in a transfer read operation applies to the
-/// tensor/buffer writing part of it and its type should match the shape written
-/// *after* any permutation.
-static VectorType inferTransferWriteMaskType(VectorType vecType,
- AffineMap permMap) {
- auto i1Type = IntegerType::get(permMap.getContext(), 1);
- SmallVector<int64_t, 8> maskShape =
- compressUnusedDims(permMap).compose(vecType.getShape());
- return VectorType::get(maskShape, i1Type);
-}
-
ParseResult TransferWriteOp::parse(OpAsmParser &parser,
OperationState &result) {
auto &builder = parser.getBuilder();
@@ -3892,7 +3880,7 @@ ParseResult TransferWriteOp::parse(OpAsmParser &parser,
if (shapedType.getElementType().dyn_cast<VectorType>())
return parser.emitError(
maskInfo.location, "does not support masks with vector element type");
- auto maskType = inferTransferWriteMaskType(vectorType, permMap);
+ auto maskType = inferTransferOpMaskType(vectorType, permMap);
if (parser.resolveOperand(maskInfo, maskType, result.operands))
return failure();
}
@@ -3919,7 +3907,7 @@ LogicalResult TransferWriteOp::verify() {
VectorType maskType = getMaskType();
auto permutationMap = getPermutationMap();
VectorType inferredMaskType =
- maskType ? inferTransferWriteMaskType(vectorType, permutationMap)
+ maskType ? inferTransferOpMaskType(vectorType, permutationMap)
: VectorType();
if (llvm::size(getIndices()) != shapedType.getRank())
@@ -3945,7 +3933,7 @@ LogicalResult TransferWriteOp::verify() {
/// Returns the mask type expected by this operation. Mostly used for
/// verification purposes.
Type TransferWriteOp::getExpectedMaskType() {
- return inferTransferWriteMaskType(getVectorType(), getPermutationMap());
+ return inferTransferOpMaskType(getVectorType(), getPermutationMap());
}
/// Fold:
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferPermutationMapRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferPermutationMapRewritePatterns.cpp
index d7ec87e95f4e..df8ba7b85534 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferPermutationMapRewritePatterns.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferPermutationMapRewritePatterns.cpp
@@ -20,15 +20,16 @@
using namespace mlir;
using namespace mlir::vector;
-/// Transpose a vector transfer op's `in_bounds` attribute according to given
-/// indices.
+/// Transpose a vector transfer op's `in_bounds` attribute by applying reverse
+/// permutation based on the given indices.
static ArrayAttr
-transposeInBoundsAttr(OpBuilder &builder, ArrayAttr attr,
- const SmallVector<unsigned> &permutation) {
- SmallVector<bool> newInBoundsValues;
+inverseTransposeInBoundsAttr(OpBuilder &builder, ArrayAttr attr,
+ const SmallVector<unsigned> &permutation) {
+ SmallVector<bool> newInBoundsValues(permutation.size());
+ size_t index = 0;
for (unsigned pos : permutation)
- newInBoundsValues.push_back(
- attr.getValue()[pos].cast<BoolAttr>().getValue());
+ newInBoundsValues[pos] =
+ attr.getValue()[index++].cast<BoolAttr>().getValue();
return builder.getBoolArrayAttr(newInBoundsValues);
}
@@ -85,7 +86,7 @@ struct TransferReadPermutationLowering
// Transpose in_bounds attribute.
ArrayAttr newInBoundsAttr =
- op.getInBounds() ? transposeInBoundsAttr(
+ op.getInBounds() ? inverseTransposeInBoundsAttr(
rewriter, op.getInBounds().value(), permutation)
: ArrayAttr();
@@ -142,16 +143,17 @@ struct TransferWritePermutationLowering
// E.g.: (d0, d1, d2, d3, d4, d5) -> (d5, d3, d4)
// comp = (d0, d1, d2) -> (d2, d0, d1)
auto comp = compressUnusedDims(map);
+ AffineMap permutationMap = inversePermutation(comp);
// Get positions of remaining result dims.
SmallVector<int64_t> indices;
- llvm::transform(comp.getResults(), std::back_inserter(indices),
+ llvm::transform(permutationMap.getResults(), std::back_inserter(indices),
[](AffineExpr expr) {
return expr.dyn_cast<AffineDimExpr>().getPosition();
});
// Transpose in_bounds attribute.
ArrayAttr newInBoundsAttr =
- op.getInBounds() ? transposeInBoundsAttr(
+ op.getInBounds() ? inverseTransposeInBoundsAttr(
rewriter, op.getInBounds().value(), permutation)
: ArrayAttr();
diff --git a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
index 0da64debdb6a..779b84f96a57 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
@@ -337,11 +337,11 @@ func.func @transfer_write_permutations(
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
%c0 = arith.constant 0 : index
- // CHECK: %[[MASK:.*]] = vector.splat %[[M]] : vector<8x14x16x7xi1>
- %mask0 = vector.splat %m : vector<8x14x16x7xi1>
+ // CHECK: %[[MASK:.*]] = vector.splat %[[M]] : vector<16x14x7x8xi1>
+ %mask0 = vector.splat %m : vector<16x14x7x8xi1>
%0 = vector.transfer_write %v1, %arg1[%c0, %c0, %c0, %c0], %mask0 {in_bounds = [true, false, false, true], permutation_map = affine_map<(d0, d1, d2, d3) -> (d2, d1, d3, d0)>} : vector<7x14x8x16xf32>, tensor<?x?x?x?xf32>
- // CHECK: %[[NEW_VEC0:.*]] = vector.transpose %{{.*}} [2, 1, 3, 0] : vector<7x14x8x16xf32> to vector<8x14x16x7xf32>
- // CHECK: %[[NEW_RES0:.*]] = vector.transfer_write %[[NEW_VEC0]], %[[ARG1]][%c0, %c0, %c0, %c0], %[[MASK]] {in_bounds = [false, false, true, true]} : vector<8x14x16x7xf32>, tensor<?x?x?x?xf32>
+ // CHECK: %[[NEW_VEC0:.*]] = vector.transpose %{{.*}} [3, 1, 0, 2] : vector<7x14x8x16xf32> to vector<16x14x7x8xf32>
+ // CHECK: %[[NEW_RES0:.*]] = vector.transfer_write %[[NEW_VEC0]], %[[ARG1]][%c0, %c0, %c0, %c0], %[[MASK]] {in_bounds = [true, false, true, false]} : vector<16x14x7x8xf32>, tensor<?x?x?x?xf32>
vector.transfer_write %v2, %arg0[%c0, %c0, %c0, %c0] {permutation_map = affine_map<(d0, d1, d2, d3) -> (d3, d2)>} : vector<8x16xf32>, memref<?x?x?x?xf32>
// CHECK: %[[NEW_VEC1:.*]] = vector.transpose %{{.*}} [1, 0] : vector<8x16xf32> to vector<16x8xf32>
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-write.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-write.mlir
index 91b07587a40c..cee90c74c465 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-write.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-write.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -convert-scf-to-cf -convert-vector-to-llvm -convert-memref-to-llvm -convert-func-to-llvm -reconcile-unrealized-casts | \
+// RUN: mlir-opt %s -convert-vector-to-scf -convert-scf-to-cf -convert-vector-to-llvm -convert-memref-to-llvm -convert-func-to-llvm -reconcile-unrealized-casts | \
// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
// RUN: -shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext | \
// RUN: FileCheck %s
@@ -39,6 +39,34 @@ func.func @transfer_read_1d(%A : memref<?xf32>) -> vector<32xf32> {
return %r : vector<32xf32>
}
+func.func @transfer_write_inbounds_3d(%A : memref<4x4x4xf32>) {
+ %c0 = arith.constant 0: index
+ %f = arith.constant 0.0 : f32
+ %v0 = vector.splat %f : vector<2x3x4xf32>
+ %f1 = arith.constant 1.0 : f32
+ %f2 = arith.constant 2.0 : f32
+ %f3 = arith.constant 3.0 : f32
+ %f4 = arith.constant 4.0 : f32
+ %f5 = arith.constant 5.0 : f32
+ %f6 = arith.constant 6.0 : f32
+ %f7 = arith.constant 7.0 : f32
+ %f8 = arith.constant 8.0 : f32
+
+ %v1 = vector.insert %f1, %v0[0, 0, 0] : f32 into vector<2x3x4xf32>
+ %v2 = vector.insert %f2, %v1[0, 0, 3] : f32 into vector<2x3x4xf32>
+ %v3 = vector.insert %f3, %v2[0, 2, 0] : f32 into vector<2x3x4xf32>
+ %v4 = vector.insert %f4, %v3[0, 2, 3] : f32 into vector<2x3x4xf32>
+ %v5 = vector.insert %f5, %v4[1, 0, 0] : f32 into vector<2x3x4xf32>
+ %v6 = vector.insert %f6, %v5[1, 0, 3] : f32 into vector<2x3x4xf32>
+ %v7 = vector.insert %f7, %v6[1, 2, 0] : f32 into vector<2x3x4xf32>
+ %v8 = vector.insert %f8, %v7[1, 2, 3] : f32 into vector<2x3x4xf32>
+ vector.transfer_write %v8, %A[%c0, %c0, %c0]
+ {permutation_map = affine_map<(d0, d1, d2) -> (d2, d0, d1)>,
+ in_bounds = [true, true, true]}
+ : vector<2x3x4xf32>, memref<4x4x4xf32>
+ return
+}
+
func.func @entry() {
%c0 = arith.constant 0: index
%c1 = arith.constant 1: index
@@ -90,6 +118,24 @@ func.func @entry() {
vector.print %6 : vector<32xf32>
memref.dealloc %A : memref<?xf32>
+
+ // 3D case
+ %c4 = arith.constant 4: index
+ %A1 = memref.alloc() {alignment=64} : memref<4x4x4xf32>
+ scf.for %i = %c0 to %c4 step %c1 {
+ scf.for %j = %c0 to %c4 step %c1 {
+ scf.for %k = %c0 to %c4 step %c1 {
+ %f = arith.constant 0.0: f32
+ memref.store %f, %A1[%i, %j, %k] : memref<4x4x4xf32>
+ }
+ }
+ }
+ call @transfer_write_inbounds_3d(%A1) : (memref<4x4x4xf32>) -> ()
+ %f = arith.constant 0.0: f32
+ %r = vector.transfer_read %A1[%c0, %c0, %c0], %f
+ : memref<4x4x4xf32>, vector<4x4x4xf32>
+ vector.print %r : vector<4x4x4xf32>
+
return
}
@@ -100,3 +146,7 @@ func.func @entry() {
// CHECK: ( 0, 0, 0, 17, 17, 17, 17, 17, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 )
// CHECK: ( 0, 0, 0, 17, 17, 17, 17, 17, 13, 13, 13, 13, 13, 13, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 0 )
// CHECK: ( 0, 0, 0, 17, 17, 17, 17, 17, 13, 13, 13, 13, 13, 13, 17, 17, 17, 17, 17, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13 )
+
+// 3D case.
+// CHECK: ( ( ( 1, 5, 0, 0 ), ( 0, 0, 0, 0 ), ( 0, 0, 0, 0 ), ( 2, 6, 0, 0 ) ), ( ( 0, 0, 0, 0 ), ( 0, 0, 0, 0 ), ( 0, 0, 0, 0 ), ( 0, 0, 0, 0 ) ),
+// CHECK-SAME: ( ( 3, 7, 0, 0 ), ( 0, 0, 0, 0 ), ( 0, 0, 0, 0 ), ( 4, 8, 0, 0 ) ), ( ( 0, 0, 0, 0 ), ( 0, 0, 0, 0 ), ( 0, 0, 0, 0 ), ( 0, 0, 0, 0 ) ) )
More information about the Mlir-commits
mailing list