[Mlir-commits] [mlir] [mlir][vector] Extend TransferReadDropUnitDimsPattern to support partially-static memrefs (PR #72142)

Nicolas Vasilache llvmlistbot at llvm.org
Wed Nov 15 06:11:28 PST 2023


================
@@ -335,23 +348,50 @@ class TransferReadDropUnitDimsPattern
       return failure();
     // Check if the reduced vector shape matches the reduced source shape.
     // Otherwise, this case is not supported yet.
-    int vectorReducedRank = getReducedRank(vectorType.getShape());
-    if (reducedRank != vectorReducedRank)
+    auto reducedVectorType = trimUnitDims(vectorType);
+    if (reducedRank != reducedVectorType.getRank())
       return failure();
     if (llvm::any_of(transferReadOp.getIndices(), [](Value v) {
           return getConstantIntValue(v) != static_cast<int64_t>(0);
         }))
       return failure();
+
+    auto maskOp = transferReadOp.getMask();
+    if (maskOp) {
+      auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>();
+      if (!createMaskOp)
+        return failure();
+      auto maskType = maskOp.getType();
+      auto reducedMaskType = trimUnitDims(maskType);
+      if (reducedMaskType.getRank() == maskType.getRank())
+        return failure();
+      SmallVector<Value> maskOperands;
+      for (auto [dim, dimIsScalable, maskOperand] :
+           llvm::zip(maskType.getShape(), maskType.getScalableDims(),
+                     createMaskOp.getOperands())) {
+        if (dim == 1 && !dimIsScalable) {
+          // If the mask for the unit dim is not a constant of 1, do nothing.
+          auto constant = maskOperand.getDefiningOp<arith::ConstantIndexOp>();
+          if (!constant || (constant.value() != 1))
+            return failure();
+          continue;
----------------
nicolasvasilache wrote:

can we hoist this logic in a helper with a good name? It seems this is deep enough already.

https://github.com/llvm/llvm-project/pull/72142


More information about the Mlir-commits mailing list