[Mlir-commits] [mlir] 6dc9725 - [mlir][vector] Fix lowering of permutation maps for transfer_write op

Thomas Raoux llvmlistbot at llvm.org
Tue Jan 17 09:04:13 PST 2023


Author: Thomas Raoux
Date: 2023-01-17T17:04:04Z
New Revision: 6dc9725471e05fe12bd72406f97daca49a47a0c0

URL: https://github.com/llvm/llvm-project/commit/6dc9725471e05fe12bd72406f97daca49a47a0c0
DIFF: https://github.com/llvm/llvm-project/commit/6dc9725471e05fe12bd72406f97daca49a47a0c0.diff

LOG: [mlir][vector] Fix lowering of permutation maps for transfer_write op

The lowering of transfer write permutation maps didn't match the op definition:
https://github.com/llvm/llvm-project/blob/93ccccb00d9717b58ba93f0942a243ba6dac4ef6/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td#L1476

Fix the lowering and add a case to the integration test in
order to enforce the correct semantic.

Differential Revision: https://reviews.llvm.org/D141801

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
    mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/lib/Dialect/Vector/Transforms/VectorTransferPermutationMapRewritePatterns.cpp
    mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
    mlir/test/Integration/Dialect/Vector/CPU/test-transfer-write.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index abf4f887f971..e0711a48b300 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -1330,6 +1330,18 @@ def Vector_TransferReadOp :
                memref<?x?x?xf32>, vector<32x256xf32>
     }}}
 
+    // or equivalently (rewrite with vector.transpose)
+    %f0 = arith.constant 0.0f : f32
+    for %i0 = 0 to %0 {
+      affine.for %i1 = 0 to %1 step 256 {
+        affine.for %i2 = 0 to %2 step 32 {
+          %v0 = vector.transfer_read %A[%i0, %i1, %i2], (%f0)
+               {permutation_map: (d0, d1, d2) -> (d1, d2)} :
+               memref<?x?x?xf32>, vector<256x32xf32>
+          %v = vector.transpose %v0, [1, 0] :
+              vector<256x32xf32> to vector<32x256f32>
+    }}}
+
     // Read the slice `%A[%i0, %i1]` (i.e. the element `%A[%i0, %i1]`) into
     // vector<128xf32>. The underlying implementation will require a 1-D vector
     // broadcast:
@@ -1485,6 +1497,19 @@ def Vector_TransferWriteOp :
               vector<16x32x64xf32>, memref<?x?x?x?xf32>
     }}}}
 
+    // or equivalently (rewrite with vector.transpose)
+    for %i0 = 0 to %0 {
+      affine.for %i1 = 0 to %1 step 32 {
+        affine.for %i2 = 0 to %2 step 64 {
+          affine.for %i3 = 0 to %3 step 16 {
+            %val = `ssa-value` : vector<16x32x64xf32>
+            %valt = vector.transpose %val, [1, 2, 0] :
+                  vector<16x32x64xf32> -> vector<32x64x16xf32>
+            vector.transfer_write %valt, %A[%i0, %i1, %i2, %i3]
+              {permutation_map: (d0, d1, d2, d3) -> (d1, d2, d3)} :
+              vector<32x64x16xf32>, memref<?x?x?x?xf32>
+    }}}}
+
     // write to a memref with vector element type.
     vector.transfer_write %4, %arg1[%c3, %c3]
       {permutation_map = (d0, d1)->(d0, d1)}

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 0bb31e970aa6..2bc0c2a0d777 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -456,7 +456,7 @@ static Value buildVectorWrite(RewriterBase &rewriter, Value value,
 
   Operation *write;
   if (vectorType.getRank() > 0) {
-    AffineMap writeMap = reindexIndexingMap(opOperandMap);
+    AffineMap writeMap = inversePermutation(reindexIndexingMap(opOperandMap));
     SmallVector<Value> indices(linalgOp.getRank(outputOperand),
                                rewriter.create<arith::ConstantIndexOp>(loc, 0));
     value = broadcastIfNeeded(rewriter, value, vectorType.getShape());

diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 933945233c88..db2cd8279000 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -3373,12 +3373,12 @@ void TransferReadOp::print(OpAsmPrinter &p) {
   p << " : " << getShapedType() << ", " << getVectorType();
 }
 
-/// Infers the mask type for a transfer read given its vector type and
-/// permutation map. The mask in a transfer read operation applies to the
-/// tensor/buffer reading part of it and its type should match the shape read
+/// Infers the mask type for a transfer op given its vector type and
+/// permutation map. The mask in a transfer op operation applies to the
+/// tensor/buffer part of it and its type should match the vector shape
 /// *before* any permutation or broadcasting.
-static VectorType inferTransferReadMaskType(VectorType vecType,
-                                            AffineMap permMap) {
+static VectorType inferTransferOpMaskType(VectorType vecType,
+                                          AffineMap permMap) {
   auto i1Type = IntegerType::get(permMap.getContext(), 1);
   AffineMap invPermMap = inversePermutation(compressUnusedDims(permMap));
   assert(invPermMap && "Inversed permutation map couldn't be computed");
@@ -3436,7 +3436,7 @@ ParseResult TransferReadOp::parse(OpAsmParser &parser, OperationState &result) {
           maskInfo.location, "does not support masks with vector element type");
     // Instead of adding the mask type as an op type, compute it based on the
     // vector type and the permutation map (to keep the type signature small).
-    auto maskType = inferTransferReadMaskType(vectorType, permMap);
+    auto maskType = inferTransferOpMaskType(vectorType, permMap);
     if (parser.resolveOperand(maskInfo, maskType, result.operands))
       return failure();
   }
@@ -3455,7 +3455,7 @@ LogicalResult TransferReadOp::verify() {
   auto paddingType = getPadding().getType();
   auto permutationMap = getPermutationMap();
   VectorType inferredMaskType =
-      maskType ? inferTransferReadMaskType(vectorType, permutationMap)
+      maskType ? inferTransferOpMaskType(vectorType, permutationMap)
                : VectorType();
   auto sourceElementType = shapedType.getElementType();
 
@@ -3495,7 +3495,7 @@ LogicalResult TransferReadOp::verify() {
 /// Returns the mask type expected by this operation. Mostly used for
 /// verification purposes. It requires the operation to be vectorized."
 Type TransferReadOp::getExpectedMaskType() {
-  return inferTransferReadMaskType(getVectorType(), getPermutationMap());
+  return inferTransferOpMaskType(getVectorType(), getPermutationMap());
 }
 
 template <typename TransferOp>
@@ -3836,18 +3836,6 @@ void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
   build(builder, result, vector, dest, indices, permutationMap, inBounds);
 }
 
-/// Infers the mask type for a transfer write given its vector type and
-/// permutation map. The mask in a transfer read operation applies to the
-/// tensor/buffer writing part of it and its type should match the shape written
-/// *after* any permutation.
-static VectorType inferTransferWriteMaskType(VectorType vecType,
-                                             AffineMap permMap) {
-  auto i1Type = IntegerType::get(permMap.getContext(), 1);
-  SmallVector<int64_t, 8> maskShape =
-      compressUnusedDims(permMap).compose(vecType.getShape());
-  return VectorType::get(maskShape, i1Type);
-}
-
 ParseResult TransferWriteOp::parse(OpAsmParser &parser,
                                    OperationState &result) {
   auto &builder = parser.getBuilder();
@@ -3892,7 +3880,7 @@ ParseResult TransferWriteOp::parse(OpAsmParser &parser,
     if (shapedType.getElementType().dyn_cast<VectorType>())
       return parser.emitError(
           maskInfo.location, "does not support masks with vector element type");
-    auto maskType = inferTransferWriteMaskType(vectorType, permMap);
+    auto maskType = inferTransferOpMaskType(vectorType, permMap);
     if (parser.resolveOperand(maskInfo, maskType, result.operands))
       return failure();
   }
@@ -3919,7 +3907,7 @@ LogicalResult TransferWriteOp::verify() {
   VectorType maskType = getMaskType();
   auto permutationMap = getPermutationMap();
   VectorType inferredMaskType =
-      maskType ? inferTransferWriteMaskType(vectorType, permutationMap)
+      maskType ? inferTransferOpMaskType(vectorType, permutationMap)
                : VectorType();
 
   if (llvm::size(getIndices()) != shapedType.getRank())
@@ -3945,7 +3933,7 @@ LogicalResult TransferWriteOp::verify() {
 /// Returns the mask type expected by this operation. Mostly used for
 /// verification purposes.
 Type TransferWriteOp::getExpectedMaskType() {
-  return inferTransferWriteMaskType(getVectorType(), getPermutationMap());
+  return inferTransferOpMaskType(getVectorType(), getPermutationMap());
 }
 
 /// Fold:

diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferPermutationMapRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferPermutationMapRewritePatterns.cpp
index d7ec87e95f4e..df8ba7b85534 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferPermutationMapRewritePatterns.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferPermutationMapRewritePatterns.cpp
@@ -20,15 +20,16 @@
 using namespace mlir;
 using namespace mlir::vector;
 
-/// Transpose a vector transfer op's `in_bounds` attribute according to given
-/// indices.
+/// Transpose a vector transfer op's `in_bounds` attribute by applying reverse
+/// permutation based on the given indices.
 static ArrayAttr
-transposeInBoundsAttr(OpBuilder &builder, ArrayAttr attr,
-                      const SmallVector<unsigned> &permutation) {
-  SmallVector<bool> newInBoundsValues;
+inverseTransposeInBoundsAttr(OpBuilder &builder, ArrayAttr attr,
+                             const SmallVector<unsigned> &permutation) {
+  SmallVector<bool> newInBoundsValues(permutation.size());
+  size_t index = 0;
   for (unsigned pos : permutation)
-    newInBoundsValues.push_back(
-        attr.getValue()[pos].cast<BoolAttr>().getValue());
+    newInBoundsValues[pos] =
+        attr.getValue()[index++].cast<BoolAttr>().getValue();
   return builder.getBoolArrayAttr(newInBoundsValues);
 }
 
@@ -85,7 +86,7 @@ struct TransferReadPermutationLowering
 
     // Transpose in_bounds attribute.
     ArrayAttr newInBoundsAttr =
-        op.getInBounds() ? transposeInBoundsAttr(
+        op.getInBounds() ? inverseTransposeInBoundsAttr(
                                rewriter, op.getInBounds().value(), permutation)
                          : ArrayAttr();
 
@@ -142,16 +143,17 @@ struct TransferWritePermutationLowering
     // E.g.:  (d0, d1, d2, d3, d4, d5) -> (d5, d3, d4)
     // comp = (d0, d1, d2) -> (d2, d0, d1)
     auto comp = compressUnusedDims(map);
+    AffineMap permutationMap = inversePermutation(comp);
     // Get positions of remaining result dims.
     SmallVector<int64_t> indices;
-    llvm::transform(comp.getResults(), std::back_inserter(indices),
+    llvm::transform(permutationMap.getResults(), std::back_inserter(indices),
                     [](AffineExpr expr) {
                       return expr.dyn_cast<AffineDimExpr>().getPosition();
                     });
 
     // Transpose in_bounds attribute.
     ArrayAttr newInBoundsAttr =
-        op.getInBounds() ? transposeInBoundsAttr(
+        op.getInBounds() ? inverseTransposeInBoundsAttr(
                                rewriter, op.getInBounds().value(), permutation)
                          : ArrayAttr();
 

diff  --git a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
index 0da64debdb6a..779b84f96a57 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
@@ -337,11 +337,11 @@ func.func @transfer_write_permutations(
   // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
   %c0 = arith.constant 0 : index
 
-  // CHECK: %[[MASK:.*]] = vector.splat %[[M]] : vector<8x14x16x7xi1>
-  %mask0 = vector.splat %m : vector<8x14x16x7xi1>
+  // CHECK: %[[MASK:.*]] = vector.splat %[[M]] : vector<16x14x7x8xi1>
+  %mask0 = vector.splat %m : vector<16x14x7x8xi1>
   %0 = vector.transfer_write %v1, %arg1[%c0, %c0, %c0, %c0], %mask0 {in_bounds = [true, false, false, true], permutation_map = affine_map<(d0, d1, d2, d3) -> (d2, d1, d3, d0)>} : vector<7x14x8x16xf32>, tensor<?x?x?x?xf32>
-  // CHECK: %[[NEW_VEC0:.*]] = vector.transpose %{{.*}} [2, 1, 3, 0] : vector<7x14x8x16xf32> to vector<8x14x16x7xf32>
-  // CHECK: %[[NEW_RES0:.*]] = vector.transfer_write %[[NEW_VEC0]], %[[ARG1]][%c0, %c0, %c0, %c0], %[[MASK]] {in_bounds = [false, false, true, true]} : vector<8x14x16x7xf32>, tensor<?x?x?x?xf32>
+  // CHECK: %[[NEW_VEC0:.*]] = vector.transpose %{{.*}} [3, 1, 0, 2] : vector<7x14x8x16xf32> to vector<16x14x7x8xf32>
+  // CHECK: %[[NEW_RES0:.*]] = vector.transfer_write %[[NEW_VEC0]], %[[ARG1]][%c0, %c0, %c0, %c0], %[[MASK]] {in_bounds = [true, false, true, false]} : vector<16x14x7x8xf32>, tensor<?x?x?x?xf32>
 
   vector.transfer_write %v2, %arg0[%c0, %c0, %c0, %c0] {permutation_map = affine_map<(d0, d1, d2, d3) -> (d3, d2)>} : vector<8x16xf32>, memref<?x?x?x?xf32>
   // CHECK: %[[NEW_VEC1:.*]] = vector.transpose %{{.*}} [1, 0] : vector<8x16xf32> to vector<16x8xf32>

diff  --git a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-write.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-write.mlir
index 91b07587a40c..cee90c74c465 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-write.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-write.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -convert-scf-to-cf -convert-vector-to-llvm -convert-memref-to-llvm -convert-func-to-llvm -reconcile-unrealized-casts | \
+// RUN: mlir-opt %s -convert-vector-to-scf -convert-scf-to-cf -convert-vector-to-llvm -convert-memref-to-llvm -convert-func-to-llvm -reconcile-unrealized-casts | \
 // RUN: mlir-cpu-runner -e entry -entry-point-result=void  \
 // RUN:   -shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext | \
 // RUN: FileCheck %s
@@ -39,6 +39,34 @@ func.func @transfer_read_1d(%A : memref<?xf32>) -> vector<32xf32> {
   return %r : vector<32xf32>
 }
 
+func.func @transfer_write_inbounds_3d(%A : memref<4x4x4xf32>) {
+  %c0 = arith.constant 0: index
+  %f = arith.constant 0.0 : f32
+  %v0 = vector.splat %f : vector<2x3x4xf32>
+  %f1 = arith.constant 1.0 : f32
+  %f2 = arith.constant 2.0 : f32
+  %f3 = arith.constant 3.0 : f32
+  %f4 = arith.constant 4.0 : f32
+  %f5 = arith.constant 5.0 : f32
+  %f6 = arith.constant 6.0 : f32
+  %f7 = arith.constant 7.0 : f32
+  %f8 = arith.constant 8.0 : f32
+
+  %v1 = vector.insert %f1, %v0[0, 0, 0] : f32 into vector<2x3x4xf32>
+  %v2 = vector.insert %f2, %v1[0, 0, 3] : f32 into vector<2x3x4xf32>
+  %v3 = vector.insert %f3, %v2[0, 2, 0] : f32 into vector<2x3x4xf32>
+  %v4 = vector.insert %f4, %v3[0, 2, 3] : f32 into vector<2x3x4xf32>
+  %v5 = vector.insert %f5, %v4[1, 0, 0] : f32 into vector<2x3x4xf32>
+  %v6 = vector.insert %f6, %v5[1, 0, 3] : f32 into vector<2x3x4xf32>
+  %v7 = vector.insert %f7, %v6[1, 2, 0] : f32 into vector<2x3x4xf32>
+  %v8 = vector.insert %f8, %v7[1, 2, 3] : f32 into vector<2x3x4xf32>
+  vector.transfer_write %v8, %A[%c0, %c0, %c0]
+    {permutation_map = affine_map<(d0, d1, d2) -> (d2, d0, d1)>,
+    in_bounds = [true, true, true]}
+    : vector<2x3x4xf32>, memref<4x4x4xf32>
+  return
+}
+
 func.func @entry() {
   %c0 = arith.constant 0: index
   %c1 = arith.constant 1: index
@@ -90,6 +118,24 @@ func.func @entry() {
   vector.print %6 : vector<32xf32>
 
   memref.dealloc %A : memref<?xf32>
+
+  // 3D case
+  %c4 = arith.constant 4: index
+  %A1 = memref.alloc() {alignment=64} : memref<4x4x4xf32>
+  scf.for %i = %c0 to %c4 step %c1 {
+    scf.for %j = %c0 to %c4 step %c1 {
+      scf.for %k = %c0 to %c4 step %c1 {
+        %f = arith.constant 0.0: f32
+        memref.store %f, %A1[%i, %j, %k] : memref<4x4x4xf32>
+      }
+    }
+  }
+  call @transfer_write_inbounds_3d(%A1) : (memref<4x4x4xf32>) -> ()
+  %f = arith.constant 0.0: f32
+  %r = vector.transfer_read %A1[%c0, %c0, %c0], %f
+    : memref<4x4x4xf32>, vector<4x4x4xf32>
+  vector.print %r : vector<4x4x4xf32>
+
   return
 }
 
@@ -100,3 +146,7 @@ func.func @entry() {
 // CHECK: ( 0, 0, 0, 17, 17, 17, 17, 17, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 )
 // CHECK: ( 0, 0, 0, 17, 17, 17, 17, 17, 13, 13, 13, 13, 13, 13, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 0 )
 // CHECK: ( 0, 0, 0, 17, 17, 17, 17, 17, 13, 13, 13, 13, 13, 13, 17, 17, 17, 17, 17, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13 )
+
+// 3D case.
+// CHECK: ( ( ( 1, 5, 0, 0 ), ( 0, 0, 0, 0 ), ( 0, 0, 0, 0 ), ( 2, 6, 0, 0 ) ), ( ( 0, 0, 0, 0 ), ( 0, 0, 0, 0 ), ( 0, 0, 0, 0 ), ( 0, 0, 0, 0 ) ),
+// CHECK-SAME: ( ( 3, 7, 0, 0 ), ( 0, 0, 0, 0 ), ( 0, 0, 0, 0 ), ( 4, 8, 0, 0 ) ), ( ( 0, 0, 0, 0 ), ( 0, 0, 0, 0 ), ( 0, 0, 0, 0 ), ( 0, 0, 0, 0 ) ) )


        


More information about the Mlir-commits mailing list