[Mlir-commits] [mlir] [mlir] Add direct vectorization lowering for `tensor.pack` ops (PR #78660)

Han-Chung Wang llvmlistbot at llvm.org
Sun Feb 4 21:45:08 PST 2024


================
@@ -1393,6 +1401,169 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
   return success();
 }
 
+/// Given a tensor::PackOp, return the `dest` shape before any packing
+/// permutations.
+static SmallVector<int64_t> getTiledPackShape(tensor::PackOp packOp,
+                                              ArrayRef<int64_t> destShape) {
+  return applyPermutation(destShape,
+                          tensor::getPackInverseDestPermutation(packOp));
+}
+
+/// Create a TransferReadOp from `source` with static shape `readShape`. If the
+/// vector type for the read is not the same as the type of `source`, then a
+/// mask is created on the read.
+static Value createReadOrMaskedRead(OpBuilder &builder, Location loc,
+                                    Value source, ArrayRef<int64_t> readShape,
+                                    Value padValue) {
+  assert(llvm::none_of(readShape,
+                       [](int64_t s) { return s == ShapedType::kDynamic; }));
+  auto maskType = VectorType::get(readShape, builder.getI1Type());
+  auto vectorType = VectorType::get(readShape, padValue.getType());
+  int64_t readRank = readShape.size();
+  auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
+  auto transferReadOp = builder.create<vector::TransferReadOp>(
+      loc,
+      /*vectorType=*/vectorType,
+      /*source=*/source,
+      /*indices=*/SmallVector<Value>(readRank, zero),
+      /*padding=*/padValue,
+      /*inBounds=*/SmallVector<bool>(readRank, true));
+  auto sourceShape = llvm::dyn_cast<ShapedType>(source.getType()).getShape();
+  if (sourceShape.size() == readShape.size() &&
+      llvm::all_of(llvm::zip_equal(readShape, sourceShape), [](auto it) {
+        return std::get<0>(it) != ShapedType::kDynamic &&
+               std::get<0>(it) == std::get<1>(it);
+      })) {
+    return transferReadOp;
+  }
----------------
hanhanW wrote:

few nits here:

1. Drop `llvm::` for `dyn_cast`. It is more common in the MLIR codebase.
2. Should `sourceShape.size() == readShape.size()` be an assertion? It looks wrong to me when it is happening.
3. `std::get<0>(it) != ShapedType::kDynamic` is asserted in the entry point, so we don't need them. In this case, we can probably update it to `llvm::zip_eqaul(readShape, sourceShape, std::equal_to<int64_t>()) { ... }`

https://github.com/llvm/llvm-project/pull/78660


More information about the Mlir-commits mailing list