[Mlir-commits] [mlir] [mlir][linalg] Split GenericPadOpVectorizationPattern into two patterns (PR #111349)

Han-Chung Wang llvmlistbot at llvm.org
Fri Oct 25 16:17:54 PDT 2024


================
@@ -2604,6 +2495,177 @@ struct PadOpVectorizationWithTransferWritePattern
   }
 };
 
+/// Given an ArrayRef of OpFoldResults, return a vector of Values.
+/// IntegerAttrs are converted to ConstantIndexOps. Other attribute types are
+/// not supported.
+static SmallVector<Value> ofrToIndexValues(RewriterBase &rewriter, Location loc,
+                                           ArrayRef<OpFoldResult> ofrs) {
+  SmallVector<Value> result;
+  for (auto o : ofrs) {
+    if (auto val = llvm::dyn_cast_if_present<Value>(o)) {
+      result.push_back(val);
+    } else {
+      result.push_back(rewriter.create<arith::ConstantIndexOp>(
+          loc, cast<IntegerAttr>(cast<Attribute>(o)).getInt()));
+    }
+  }
+  return result;
+}
+
+/// Returns the effective Pad value for the input op, provided it's a scalar.
+///
+/// Many Ops exhibit pad-like behaviour, but this isn't always explicit. If
+/// this Op performs padding, retrieve the padding value provided that it's
+/// a scalar and static/fixed for all the padded values. Returns an empty value
+/// otherwise.
+static Value getStaticPadVl(Operation *op) {
+  if (!op)
+    return {};
+
+  // 1. vector.broadcast - return the value that's being broadcast,
+  // provided that it's a scalar.
+  if (auto bcast = llvm::dyn_cast<vector::BroadcastOp>(op)) {
+    auto source = bcast.getSource();
+    if (llvm::dyn_cast<VectorType>(source.getType()))
+      return {};
+
+    return source;
+  }
+
+  // 1. linalg.fill - use the scalar input value that used to fill the output
+  // tensor.
+  if (auto fill = llvm::dyn_cast<linalg::FillOp>(op)) {
+    return fill.getInputs()[0];
+  }
+
+  // 2. tensor.generateOp - can't guarantee the value is fixed without
+  // analysing, bail out.
+  if (auto generate = llvm::dyn_cast<tensor::GenerateOp>(op)) {
+    return {};
+  }
+
+  // 3. vector.transfer_write - inspect the input vector that's written from. If
+  // if contains a single value that has been broadcast (e.g. via
+  // vector.broadcast), extract it, fail otherwise.
+  if (auto xferWrite = llvm::dyn_cast<vector::TransferWriteOp>(op))
+    return getStaticPadVl(xferWrite.getVector().getDefiningOp());
+
+  // 4. tensor.insert_slice - inspect the destination tensor. If it's larger
+  // than the input tensor, then, provided it's constant, we'll extract the
+  // value that was used to generate it (via e.g. linalg.fill), fail otherwise.
+  // TODO: Clarify the semantics when the input tensor is larger than the
+  // destination.
+  if (auto slice = llvm::dyn_cast<tensor::InsertSliceOp>(op))
+    return getStaticPadVl(slice.getDest().getDefiningOp());
+
+  return {};
+}
+
+/// Rewrite tensor.insert.slice as a vector.transfer_read +
+/// vector.transfer_write pair. The vector size is inferred from the static
+/// dims in the input and output tensors. If a dim is dynamic in both the input
+/// and output tensors, bails out.
+///
+/// Before:
+///     !t_in_type = tensor<1x2x3xf32>
+///     !t_out_type = tensor<9x8x7x1x2x3xf32>
+///     !v_type = vector<1x2x3xf32>
+///     %inserted_slice = tensor.insert_slice %src into %dest ... : !t_in_type
+///     into !t_out_type
+/// After:
+///     %read = vector.transfer_read %src[...], %pad ... : !t_in_type, !v_type
+///     %write = vector.transfer_write %read, %dest ... : !v_type, !t_out_type
+///
+/// TODO: Support masking
+struct InsertSliceVectorizePattern
----------------
hanhanW wrote:

It can be done in a follow-up. I think we want to have an unified API and vectorization path, which helps us manage the "future" extensibility feature/design better. So it'd be a plus if we move the implementation to `vectorize()` method.

(There was a long and old discussion in https://reviews.llvm.org/D150495 though it's not happening...but I think the point is still hold...)

https://github.com/llvm/llvm-project/pull/111349


More information about the Mlir-commits mailing list