[Mlir-commits] [mlir] 6774e5a - [mlir] Fix in_bounds attr handling in TransferReadPermutationLowering

Matthias Springer llvmlistbot at llvm.org
Sun May 16 23:28:27 PDT 2021


Author: Matthias Springer
Date: 2021-05-17T15:28:16+09:00
New Revision: 6774e5a995fcc7e1f2360bbeaf8628ae88159430

URL: https://github.com/llvm/llvm-project/commit/6774e5a995fcc7e1f2360bbeaf8628ae88159430
DIFF: https://github.com/llvm/llvm-project/commit/6774e5a995fcc7e1f2360bbeaf8628ae88159430.diff

LOG: [mlir] Fix in_bounds attr handling in TransferReadPermutationLowering

The in_bounds attribute should also be transposed.

Differential Revision: https://reviews.llvm.org/D102572

Added: 
    mlir/test/Integration/Dialect/Vector/CPU/test-transfer-permutation-lowering.mlir

Modified: 
    mlir/lib/Dialect/Vector/VectorTransforms.cpp
    mlir/test/Dialect/Vector/vector-transfer-lowering.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index 1dd04a2e4b213..1effce2f5679c 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -3040,6 +3040,18 @@ struct TransferWriteToVectorStoreLowering
   }
 };
 
+/// Transpose a vector transfer op's `in_bounds` attribute according to given
+/// indices.
+static ArrayAttr
+transposeInBoundsAttr(OpBuilder &builder, ArrayAttr attr,
+                      const SmallVector<unsigned> &permutation) {
+  SmallVector<bool> newInBoundsValues;
+  for (unsigned pos : permutation)
+    newInBoundsValues.push_back(
+        attr.getValue()[pos].cast<BoolAttr>().getValue());
+  return builder.getBoolArrayAttr(newInBoundsValues);
+}
+
 /// Lower transfer_read op with permutation into a transfer_read with a
 /// permutation map composed of leading zeros followed by a minor identiy +
 /// vector.transpose op.
@@ -3084,6 +3096,7 @@ struct TransferReadPermutationLowering
       newVectorShape[pos.value()] = originalShape[pos.index()];
     }
 
+    // Transpose mask operand.
     Value newMask;
     if (op.mask()) {
       // Remove unused dims from the permutation map. E.g.:
@@ -3103,12 +3116,20 @@ struct TransferReadPermutationLowering
                                                      maskTransposeIndices);
     }
 
+    // Transpose in_bounds attribute.
+    ArrayAttr newInBounds =
+        op.in_bounds() ? transposeInBoundsAttr(
+                             rewriter, op.in_bounds().getValue(), permutation)
+                       : ArrayAttr();
+
+    // Generate new transfer_read operation.
     VectorType newReadType =
         VectorType::get(newVectorShape, op.getVectorType().getElementType());
     Value newRead = rewriter.create<vector::TransferReadOp>(
         op.getLoc(), newReadType, op.source(), op.indices(), newMap,
-        op.padding(), newMask, op.in_bounds() ? *op.in_bounds() : ArrayAttr());
+        op.padding(), newMask, newInBounds);
 
+    // Transpose result of transfer_read.
     SmallVector<int64_t> transposePerm(permutation.begin(), permutation.end());
     rewriter.replaceOpWithNewOp<vector::TransposeOp>(op, newRead,
                                                      transposePerm);

diff  --git a/mlir/test/Dialect/Vector/vector-transfer-lowering.mlir b/mlir/test/Dialect/Vector/vector-transfer-lowering.mlir
index 98683de909904..28a267967942c 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-lowering.mlir
@@ -231,9 +231,9 @@ func @transfer_read_permutations(%arg0 : memref<?x?xf32>, %arg1 : memref<?x?x?x?
   %m = constant 1 : i1
 
   %mask0 = splat %m : vector<7x14xi1>
-  %0 = vector.transfer_read %arg1[%c0, %c0, %c0, %c0], %cst, %mask0 {permutation_map = #map0} : memref<?x?x?x?xf32>, vector<7x14x8x16xf32>
+  %0 = vector.transfer_read %arg1[%c0, %c0, %c0, %c0], %cst, %mask0 {in_bounds = [true, false, true, true], permutation_map = #map0} : memref<?x?x?x?xf32>, vector<7x14x8x16xf32>
 // CHECK: %[[MASK0:.*]] = vector.transpose {{.*}} : vector<7x14xi1> to vector<14x7xi1>
-// CHECK: vector.transfer_read {{.*}} %[[MASK0]] {permutation_map = #[[$MAP0]]} : memref<?x?x?x?xf32>, vector<14x7x8x16xf32>
+// CHECK: vector.transfer_read {{.*}} %[[MASK0]] {in_bounds = [false, true, true, true], permutation_map = #[[$MAP0]]} : memref<?x?x?x?xf32>, vector<14x7x8x16xf32>
 // CHECK: vector.transpose %{{.*}}, [1, 0, 2, 3] : vector<14x7x8x16xf32> to vector<7x14x8x16xf32>
 
   %mask1 = splat %m : vector<14x16xi1>
@@ -243,9 +243,9 @@ func @transfer_read_permutations(%arg0 : memref<?x?xf32>, %arg1 : memref<?x?x?x?
 // CHECK: vector.transpose %{{.*}}, [2, 1, 3, 0] : vector<16x14x7x8xf32> to vector<7x14x8x16xf32>
 
   %mask2 = splat %m : vector<7x14xi1>
-  %2 = vector.transfer_read %arg1[%c0, %c0, %c0, %c0], %cst, %mask2 {in_bounds = [true, true, false, true], permutation_map = #map2} : memref<?x?x?x?xf32>, vector<7x14x8x16xf32>
+  %2 = vector.transfer_read %arg1[%c0, %c0, %c0, %c0], %cst, %mask2 {in_bounds = [true, false, true, true], permutation_map = #map2} : memref<?x?x?x?xf32>, vector<7x14x8x16xf32>
 // CHECK: %[[MASK2:.*]] = vector.transpose {{.*}} : vector<7x14xi1> to vector<14x7xi1>
-// CHECK: vector.transfer_read {{.*}} %[[MASK2]] {in_bounds = [true, false, true], permutation_map = #[[$MAP1]]} : memref<?x?x?x?xf32>, vector<14x16x7xf32>
+// CHECK: vector.transfer_read {{.*}} %[[MASK2]] {in_bounds = [false, true, true], permutation_map = #[[$MAP1]]} : memref<?x?x?x?xf32>, vector<14x16x7xf32>
 // CHECK: vector.broadcast %{{.*}} : vector<14x16x7xf32> to vector<8x14x16x7xf32>
 // CHECK: vector.transpose %{{.*}}, [3, 1, 0, 2] : vector<8x14x16x7xf32> to vector<7x14x8x16xf32>
 

diff  --git a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-permutation-lowering.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-permutation-lowering.mlir
new file mode 100644
index 0000000000000..ad43b14a71d47
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-permutation-lowering.mlir
@@ -0,0 +1,41 @@
+// Run test with and without test-vector-transfer-lowering-patterns.
+
+// RUN: mlir-opt %s -convert-vector-to-scf -lower-affine -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \
+// RUN: mlir-cpu-runner -e entry -entry-point-result=void  \
+// RUN:   -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
+// RUN: FileCheck %s
+
+// RUN: mlir-opt %s -test-vector-transfer-lowering-patterns -convert-vector-to-scf -lower-affine -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \
+// RUN: mlir-cpu-runner -e entry -entry-point-result=void  \
+// RUN:   -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
+// RUN: FileCheck %s
+
+
+memref.global "private" @gv : memref<3x4xf32> = dense<[[0. , 1. , 2. , 3. ],
+                                                       [10., 11., 12., 13.],
+                                                       [20., 21., 22., 23.]]>
+
+// Vector load with transpose.
+func @transfer_read_2d(%A : memref<?x?xf32>, %base1: index, %base2: index) {
+  %fm42 = constant -42.0: f32
+  %f = vector.transfer_read %A[%base1, %base2], %fm42
+      {in_bounds = [true, false], permutation_map = affine_map<(d0, d1) -> (d1, d0)>} :
+    memref<?x?xf32>, vector<3x9xf32>
+  vector.print %f: vector<3x9xf32>
+  return
+}
+
+func @entry() {
+  %c0 = constant 0: index
+  %c1 = constant 1: index
+  %c2 = constant 2: index
+  %c3 = constant 3: index
+  %0 = memref.get_global @gv : memref<3x4xf32>
+  %A = memref.cast %0 : memref<3x4xf32> to memref<?x?xf32>
+
+  // 1. Read 2D vector from 2D memref with transpose.
+  call @transfer_read_2d(%A, %c1, %c2) : (memref<?x?xf32>, index, index) -> ()
+  // CHECK: ( ( 12, 22, -42, -42, -42, -42, -42, -42, -42 ), ( 13, 23, -42, -42, -42, -42, -42, -42, -42 ), ( 20, 0, -42, -42, -42, -42, -42, -42, -42 ) )
+
+  return
+}


        


More information about the Mlir-commits mailing list