[llvm] [mlir] [MLIR] Generalize expand_shape to take shape as explicit input (PR #69267)
via llvm-commits
llvm-commits at lists.llvm.org
Wed Apr 17 16:36:54 PDT 2024
================
@@ -30,6 +30,28 @@ using ReassociationExprs = SmallVector<AffineExpr, 2>;
/// Attribute name for the ArrayAttr which encodes reassociation indices.
constexpr StringRef getReassociationAttrName() { return "reassociation"; }
+// Infer the output shape for a {memref|tensor}.expand_shape when it is possible
+// to do so.
+//
+// Note: This should *only* be used to implement
+// `ExpandShapeOp::inferOutputShape` in both the memref and tensor namespaces.
+// If you need to infer the output shape you should use the static method of
+// `ExpandShapeOp` instead of calling this.
+//
+// `inputShape` is the shape of the tensor or memref being expanded as a
+// sequence of SSA values or constants. `expandedType` is the output shape of
+// the expand_shape operation. `reassociation` is the reassociation denoting
+// the output dims each input dim is mapped to.
+//
+// Returns the output shape in `outputShape` and `staticOutputShape`, following
+// the conventions for the output_shape and static_output_shape inputs to the
+// expand_shape ops.
+LogicalResult inferExpandShapeOutputShape(
+ OpBuilder &b, Location loc, RankedTensorType expandedType,
----------------
MaheshRavishankar wrote:
I see later that you are using the same method for `MemRefType` and `RankedTensorType`. Maybe make the `expandedType` be `ShapedType` that way you dont need to create a tensor type just to call this method.
https://github.com/llvm/llvm-project/pull/69267
More information about the llvm-commits
mailing list