[Mlir-commits] [mlir] 6774e5a - [mlir] Fix in_bounds attr handling in TransferReadPermutationLowering
Matthias Springer
llvmlistbot at llvm.org
Sun May 16 23:28:27 PDT 2021
Author: Matthias Springer
Date: 2021-05-17T15:28:16+09:00
New Revision: 6774e5a995fcc7e1f2360bbeaf8628ae88159430
URL: https://github.com/llvm/llvm-project/commit/6774e5a995fcc7e1f2360bbeaf8628ae88159430
DIFF: https://github.com/llvm/llvm-project/commit/6774e5a995fcc7e1f2360bbeaf8628ae88159430.diff
LOG: [mlir] Fix in_bounds attr handling in TransferReadPermutationLowering
The in_bounds attribute should also be transposed.
Differential Revision: https://reviews.llvm.org/D102572
Added:
mlir/test/Integration/Dialect/Vector/CPU/test-transfer-permutation-lowering.mlir
Modified:
mlir/lib/Dialect/Vector/VectorTransforms.cpp
mlir/test/Dialect/Vector/vector-transfer-lowering.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index 1dd04a2e4b213..1effce2f5679c 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -3040,6 +3040,18 @@ struct TransferWriteToVectorStoreLowering
}
};
+/// Transpose a vector transfer op's `in_bounds` attribute according to given
+/// indices.
+static ArrayAttr
+transposeInBoundsAttr(OpBuilder &builder, ArrayAttr attr,
+ const SmallVector<unsigned> &permutation) {
+ SmallVector<bool> newInBoundsValues;
+ for (unsigned pos : permutation)
+ newInBoundsValues.push_back(
+ attr.getValue()[pos].cast<BoolAttr>().getValue());
+ return builder.getBoolArrayAttr(newInBoundsValues);
+}
+
/// Lower transfer_read op with permutation into a transfer_read with a
/// permutation map composed of leading zeros followed by a minor identiy +
/// vector.transpose op.
@@ -3084,6 +3096,7 @@ struct TransferReadPermutationLowering
newVectorShape[pos.value()] = originalShape[pos.index()];
}
+ // Transpose mask operand.
Value newMask;
if (op.mask()) {
// Remove unused dims from the permutation map. E.g.:
@@ -3103,12 +3116,20 @@ struct TransferReadPermutationLowering
maskTransposeIndices);
}
+ // Transpose in_bounds attribute.
+ ArrayAttr newInBounds =
+ op.in_bounds() ? transposeInBoundsAttr(
+ rewriter, op.in_bounds().getValue(), permutation)
+ : ArrayAttr();
+
+ // Generate new transfer_read operation.
VectorType newReadType =
VectorType::get(newVectorShape, op.getVectorType().getElementType());
Value newRead = rewriter.create<vector::TransferReadOp>(
op.getLoc(), newReadType, op.source(), op.indices(), newMap,
- op.padding(), newMask, op.in_bounds() ? *op.in_bounds() : ArrayAttr());
+ op.padding(), newMask, newInBounds);
+ // Transpose result of transfer_read.
SmallVector<int64_t> transposePerm(permutation.begin(), permutation.end());
rewriter.replaceOpWithNewOp<vector::TransposeOp>(op, newRead,
transposePerm);
diff --git a/mlir/test/Dialect/Vector/vector-transfer-lowering.mlir b/mlir/test/Dialect/Vector/vector-transfer-lowering.mlir
index 98683de909904..28a267967942c 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-lowering.mlir
@@ -231,9 +231,9 @@ func @transfer_read_permutations(%arg0 : memref<?x?xf32>, %arg1 : memref<?x?x?x?
%m = constant 1 : i1
%mask0 = splat %m : vector<7x14xi1>
- %0 = vector.transfer_read %arg1[%c0, %c0, %c0, %c0], %cst, %mask0 {permutation_map = #map0} : memref<?x?x?x?xf32>, vector<7x14x8x16xf32>
+ %0 = vector.transfer_read %arg1[%c0, %c0, %c0, %c0], %cst, %mask0 {in_bounds = [true, false, true, true], permutation_map = #map0} : memref<?x?x?x?xf32>, vector<7x14x8x16xf32>
// CHECK: %[[MASK0:.*]] = vector.transpose {{.*}} : vector<7x14xi1> to vector<14x7xi1>
-// CHECK: vector.transfer_read {{.*}} %[[MASK0]] {permutation_map = #[[$MAP0]]} : memref<?x?x?x?xf32>, vector<14x7x8x16xf32>
+// CHECK: vector.transfer_read {{.*}} %[[MASK0]] {in_bounds = [false, true, true, true], permutation_map = #[[$MAP0]]} : memref<?x?x?x?xf32>, vector<14x7x8x16xf32>
// CHECK: vector.transpose %{{.*}}, [1, 0, 2, 3] : vector<14x7x8x16xf32> to vector<7x14x8x16xf32>
%mask1 = splat %m : vector<14x16xi1>
@@ -243,9 +243,9 @@ func @transfer_read_permutations(%arg0 : memref<?x?xf32>, %arg1 : memref<?x?x?x?
// CHECK: vector.transpose %{{.*}}, [2, 1, 3, 0] : vector<16x14x7x8xf32> to vector<7x14x8x16xf32>
%mask2 = splat %m : vector<7x14xi1>
- %2 = vector.transfer_read %arg1[%c0, %c0, %c0, %c0], %cst, %mask2 {in_bounds = [true, true, false, true], permutation_map = #map2} : memref<?x?x?x?xf32>, vector<7x14x8x16xf32>
+ %2 = vector.transfer_read %arg1[%c0, %c0, %c0, %c0], %cst, %mask2 {in_bounds = [true, false, true, true], permutation_map = #map2} : memref<?x?x?x?xf32>, vector<7x14x8x16xf32>
// CHECK: %[[MASK2:.*]] = vector.transpose {{.*}} : vector<7x14xi1> to vector<14x7xi1>
-// CHECK: vector.transfer_read {{.*}} %[[MASK2]] {in_bounds = [true, false, true], permutation_map = #[[$MAP1]]} : memref<?x?x?x?xf32>, vector<14x16x7xf32>
+// CHECK: vector.transfer_read {{.*}} %[[MASK2]] {in_bounds = [false, true, true], permutation_map = #[[$MAP1]]} : memref<?x?x?x?xf32>, vector<14x16x7xf32>
// CHECK: vector.broadcast %{{.*}} : vector<14x16x7xf32> to vector<8x14x16x7xf32>
// CHECK: vector.transpose %{{.*}}, [3, 1, 0, 2] : vector<8x14x16x7xf32> to vector<7x14x8x16xf32>
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-permutation-lowering.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-permutation-lowering.mlir
new file mode 100644
index 0000000000000..ad43b14a71d47
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-permutation-lowering.mlir
@@ -0,0 +1,41 @@
+// Run test with and without test-vector-transfer-lowering-patterns.
+
+// RUN: mlir-opt %s -convert-vector-to-scf -lower-affine -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \
+// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
+// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
+// RUN: FileCheck %s
+
+// RUN: mlir-opt %s -test-vector-transfer-lowering-patterns -convert-vector-to-scf -lower-affine -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \
+// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
+// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
+// RUN: FileCheck %s
+
+
+memref.global "private" @gv : memref<3x4xf32> = dense<[[0. , 1. , 2. , 3. ],
+ [10., 11., 12., 13.],
+ [20., 21., 22., 23.]]>
+
+// Vector load with transpose.
+func @transfer_read_2d(%A : memref<?x?xf32>, %base1: index, %base2: index) {
+ %fm42 = constant -42.0: f32
+ %f = vector.transfer_read %A[%base1, %base2], %fm42
+ {in_bounds = [true, false], permutation_map = affine_map<(d0, d1) -> (d1, d0)>} :
+ memref<?x?xf32>, vector<3x9xf32>
+ vector.print %f: vector<3x9xf32>
+ return
+}
+
+func @entry() {
+ %c0 = constant 0: index
+ %c1 = constant 1: index
+ %c2 = constant 2: index
+ %c3 = constant 3: index
+ %0 = memref.get_global @gv : memref<3x4xf32>
+ %A = memref.cast %0 : memref<3x4xf32> to memref<?x?xf32>
+
+ // 1. Read 2D vector from 2D memref with transpose.
+ call @transfer_read_2d(%A, %c1, %c2) : (memref<?x?xf32>, index, index) -> ()
+ // CHECK: ( ( 12, 22, -42, -42, -42, -42, -42, -42, -42 ), ( 13, 23, -42, -42, -42, -42, -42, -42, -42 ), ( 20, 0, -42, -42, -42, -42, -42, -42, -42 ) )
+
+ return
+}
More information about the Mlir-commits
mailing list