[Mlir-commits] [mlir] a7a5641 - [mlir][vector] Fix bug in `TransferWriteNonPermutationLowering`

Matthias Springer llvmlistbot at llvm.org
Mon Jul 10 08:25:46 PDT 2023


Author: Matthias Springer
Date: 2023-07-10T17:21:03+02:00
New Revision: a7a5641bdcfa92e95771ccfcc0a14d42611ac2f8

URL: https://github.com/llvm/llvm-project/commit/a7a5641bdcfa92e95771ccfcc0a14d42611ac2f8
DIFF: https://github.com/llvm/llvm-project/commit/a7a5641bdcfa92e95771ccfcc0a14d42611ac2f8.diff

LOG: [mlir][vector] Fix bug in `TransferWriteNonPermutationLowering`

This pattern expands the rank of the vector. However, the rank of the mask was not expanded.

Differential Revision: https://reviews.llvm.org/D154849

Added: 
    mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir

Modified: 
    mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
index 4f68526ac401ea..af591730ad963e 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
@@ -46,6 +46,21 @@ static Value extendVectorRank(OpBuilder &builder, Location loc, Value vec,
   return builder.create<vector::BroadcastOp>(loc, newVecType, vec);
 }
 
+/// Extend the rank of a vector Value by `addedRanks` by adding inner unit
+/// dimensions.
+static Value extendMaskRank(OpBuilder &builder, Location loc, Value vec,
+                            int64_t addedRank) {
+  Value broadcasted = extendVectorRank(builder, loc, vec, addedRank);
+  SmallVector<int64_t> permutation;
+  for (int64_t i = addedRank,
+               e = broadcasted.getType().cast<VectorType>().getRank();
+       i < e; ++i)
+    permutation.push_back(i);
+  for (int64_t i = 0; i < addedRank; ++i)
+    permutation.push_back(i);
+  return builder.create<vector::TransposeOp>(loc, broadcasted, permutation);
+}
+
 //===----------------------------------------------------------------------===//
 // populateVectorTransferPermutationMapLoweringPatterns
 //===----------------------------------------------------------------------===//
@@ -246,9 +261,14 @@ struct TransferWriteNonPermutationLowering
       missingInnerDim.push_back(i);
       exprs.push_back(rewriter.getAffineDimExpr(i));
     }
-    // Add unit dims at the beginning of the shape.
+    // Vector: add unit dims at the beginning of the shape.
     Value newVec = extendVectorRank(rewriter, op.getLoc(), op.getVector(),
                                     missingInnerDim.size());
+    // Mask: add unit dims at the end of the shape.
+    Value newMask;
+    if (op.getMask())
+      newMask = extendMaskRank(rewriter, op.getLoc(), op.getMask(),
+                               missingInnerDim.size());
     exprs.append(map.getResults().begin(), map.getResults().end());
     AffineMap newMap =
         AffineMap::get(map.getNumDims(), 0, exprs, op.getContext());
@@ -263,7 +283,7 @@ struct TransferWriteNonPermutationLowering
     }
     rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
         op, newVec, op.getSource(), op.getIndices(), AffineMapAttr::get(newMap),
-        op.getMask(), newInBoundsAttr);
+        newMask, newInBoundsAttr);
     return success();
   }
 };

diff  --git a/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir b/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
new file mode 100644
index 00000000000000..6ea53aa3f41b07
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
@@ -0,0 +1,27 @@
+// RUN: mlir-opt %s --test-transform-dialect-interpreter --split-input-file | FileCheck %s
+
+// CHECK-LABEL: func @lower_permutation_with_mask(
+//       CHECK:   %[[vec:.*]] = arith.constant dense<-2.000000e+00> : vector<7x1xf32>
+//       CHECK:   %[[mask:.*]] = arith.constant dense<[true, false, true, false, true, true, true]> : vector<7xi1>
+//       CHECK:   %[[b:.*]] = vector.broadcast %[[mask]] : vector<7xi1> to vector<1x7xi1>
+//       CHECK:   %[[tp:.*]] = vector.transpose %[[b]], [1, 0] : vector<1x7xi1> to vector<7x1xi1>
+//       CHECK:   vector.transfer_write %[[vec]], %{{.*}}[%{{.*}}, %{{.*}}], %[[tp]] {in_bounds = [false, true]} : vector<7x1xf32>, memref<?x?xf32>
+func.func @lower_permutation_with_mask(%A : memref<?x?xf32>, %base1 : index,
+                                       %base2 : index) {
+  %fn1 = arith.constant -2.0 : f32
+  %vf0 = vector.splat %fn1 : vector<7xf32>
+  %mask = arith.constant dense<[1, 0, 1, 0, 1, 1, 1]> : vector<7xi1>
+  vector.transfer_write %vf0, %A[%base1, %base2], %mask
+    {permutation_map = affine_map<(d0, d1) -> (d0)>, in_bounds = [false]}
+    : vector<7xf32>, memref<?x?xf32>
+  return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%module_op: !transform.any_op):
+  %f = transform.structured.match ops{["func.func"]} in %module_op
+    : (!transform.any_op) -> !transform.any_op
+  transform.apply_patterns to %f {
+    transform.apply_patterns.vector.transfer_permutation_patterns
+  } : !transform.any_op
+}


        


More information about the Mlir-commits mailing list