[Mlir-commits] [llvm] [mlir] [MLIR] Generalize expand_shape to take shape as explicit input (PR #69267)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Apr 17 16:36:55 PDT 2024
================
@@ -1102,50 +1097,103 @@ def Tensor_ExpandShapeOp : Tensor_ReassociativeReshapeOp<"expand_shape"> {
rank than the operand `src` whose dimension sizes are a reassociation of
`src`.
- A reassociation is defined as a continuous grouping of dimensions. It is
- represented with an array of DenseI64ArrayAttr attribute. Entries in the
- array are referred to as reassociation maps.
+ A reassociation is defined as a continuous grouping of dimensions and is
+ represented with an array of DenseI64ArrayAttr attribute. The reassociation
+ maps applied to the result tensor with the higher rank must result in the
+ operand tensor with the smaller rank.
- The reassociation maps are applied to the result shape to obtain the operand
- shape.
+ The representation for the output shape supports a partially-static
+ specification via attributes specified through the `static_output_shape`
+ argument. A special sentinel value `ShapedType::kDynamic` encodes that the
+ corresponding entry has a dynamic value. There must be exactly as many SSA
+ inputs in `output_shape` as there are `ShapedType::kDynamic` entries in
+ `static_output_shape`.
Example:
```mlir
// Dimension expansion i -> (i', j') and (k) -> (k')
- %b = tensor.expand_shape %a [[0, 1], [2]]
- : tensor<?x?xf32> into tensor<?x?x?xf32>
+ %b = tensor.expand_shape %a [[0, 1], [2]] [%sz0, %sz1, 32]
+ : tensor<?x32xf32> into tensor<?x?x32xf32>
```
}];
+
+ let arguments = (ins AnyTensor:$src, IndexListArrayAttr:$reassociation,
+ Variadic<Index>:$output_shape,
+ DenseI64ArrayAttr:$static_output_shape);
+
+ let assemblyFormat = [{
+ $src $reassociation `output_shape`
+ custom<DynamicIndexList>($output_shape, $static_output_shape) attr-dict `:`
+ type($src) `into` type($result)
+ }];
+
let builders = [
// Builders using ReassociationIndices.
+ OpBuilder<(ins "Type":$resultType, "Value":$src,
+ "ArrayRef<ReassociationIndices>":$reassociation),
+ [{
+ SmallVector<OpFoldResult> inputShape =
+ getMixedSizes($_builder, $_state.location, src);
+ std::pair<SmallVector<int64_t>, SmallVector<Value>> outputShape;
+ auto status =
+ inferOutputShape($_builder, $_state.location,
+ resultType.cast<RankedTensorType>(),
+ reassociation, inputShape, outputShape);
+ (void) status;
+ assert(succeeded(status) && "unable to infer output shape");
+ build($_builder, $_state, resultType.cast<RankedTensorType>(), src,
+ getReassociationIndicesAttribute($_builder, reassociation),
+ outputShape.second, outputShape.first);
+ }]>,
OpBuilder<(ins "Type":$resultType, "Value":$src,
"ArrayRef<ReassociationIndices>":$reassociation,
- CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
+ "ArrayRef<OpFoldResult>":$outputShape),
[{
- build($_builder, $_state, resultType, src, attrs);
- $_state.addAttribute("reassociation",
- getReassociationIndicesAttribute($_builder, reassociation));
+ auto [staticOutputShape, dynamicOutputShape] =
----------------
MaheshRavishankar wrote:
While here, maybe move these to c++ file as well.
https://github.com/llvm/llvm-project/pull/69267
More information about the Mlir-commits
mailing list