[Mlir-commits] [mlir] a9733b8 - [MLIR] Adopt `DenseI64ArrayAttr` in tensor, memref and linalg transform
Lorenzo Chelini
llvmlistbot at llvm.org
Fri Nov 25 00:43:36 PST 2022
Author: Lorenzo Chelini
Date: 2022-11-25T09:43:30+01:00
New Revision: a9733b8a5eed441d6378d0fc88630233e00a6395
URL: https://github.com/llvm/llvm-project/commit/a9733b8a5eed441d6378d0fc88630233e00a6395
DIFF: https://github.com/llvm/llvm-project/commit/a9733b8a5eed441d6378d0fc88630233e00a6395.diff
LOG: [MLIR] Adopt `DenseI64ArrayAttr` in tensor, memref and linalg transform
This commit is a first step toward removing inconsistencies between dynamic
and static attributes (i64 v. index) by dropping `I64ArrayAttr` and
using `DenseI64ArrayAttr` in Tensor, Memref and Linalg Transform ops.
In Linalg Transform ops only `TileToScfForOp` and `TileOp` have been updated.
See related discussion: https://discourse.llvm.org/t/rfc-inconsistency-between-dynamic-and-static-attributes-i64-v-index/66612/1
Reviewed By: nicolasvasilache
Differential Revision: https://reviews.llvm.org/D138567
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
mlir/include/mlir/Interfaces/ViewLikeInterface.h
mlir/include/mlir/Interfaces/ViewLikeInterface.td
mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
mlir/lib/Dialect/Utils/StaticValueUtils.cpp
mlir/lib/Interfaces/ViewLikeInterface.cpp
mlir/python/mlir/dialects/_structured_transform_ops_ext.py
mlir/test/Dialect/Linalg/transform-patterns.mlir
mlir/test/python/dialects/transform_structured_ext.py
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index b7c02a205f5e6..dbb803bcb1e1e 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -839,8 +839,8 @@ def TileOp : Op<Transform_Dialect, "structured.tile",
let arguments = (ins PDL_Operation:$target,
Variadic<PDL_Operation>:$dynamic_sizes,
- DefaultValuedAttr<I64ArrayAttr, "{}">:$static_sizes,
- DefaultValuedAttr<I64ArrayAttr, "{}">:$interchange);
+ DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sizes,
+ DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$interchange);
let results = (outs PDL_Operation:$tiled_linalg_op,
Variadic<PDL_Operation>:$loops);
@@ -917,8 +917,8 @@ def TileToForeachThreadOp :
let arguments = (ins PDL_Operation:$target,
Variadic<PDL_Operation>:$num_threads,
Variadic<PDL_Operation>:$tile_sizes,
- DefaultValuedAttr<I64ArrayAttr, "{}">:$static_num_threads,
- DefaultValuedAttr<I64ArrayAttr, "{}">:$static_tile_sizes,
+ DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_num_threads,
+ DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_tile_sizes,
OptionalAttr<DeviceMappingArrayAttr>:$mapping);
let results = (outs PDL_Operation:$foreach_thread_op,
PDL_Operation:$tiled_op);
@@ -1009,8 +1009,8 @@ def TileToScfForOp : Op<Transform_Dialect, "structured.tile_to_scf_for",
let arguments = (ins PDL_Operation:$target,
Variadic<PDL_Operation>:$dynamic_sizes,
- DefaultValuedAttr<I64ArrayAttr, "{}">:$static_sizes,
- DefaultValuedAttr<I64ArrayAttr, "{}">:$interchange);
+ DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sizes,
+ DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$interchange);
let results = (outs PDL_Operation:$tiled_linalg_op,
Variadic<PDL_Operation>:$loops);
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index ccf01be858680..4a567b40e2e5e 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -1260,9 +1260,9 @@ def MemRef_ReinterpretCastOp
Variadic<Index>:$offsets,
Variadic<Index>:$sizes,
Variadic<Index>:$strides,
- I64ArrayAttr:$static_offsets,
- I64ArrayAttr:$static_sizes,
- I64ArrayAttr:$static_strides);
+ DenseI64ArrayAttr:$static_offsets,
+ DenseI64ArrayAttr:$static_sizes,
+ DenseI64ArrayAttr:$static_strides);
let results = (outs AnyMemRef:$result);
let assemblyFormat = [{
@@ -1476,7 +1476,7 @@ def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
or copies.
A reassociation is defined as a grouping of dimensions and is represented
- with an array of I64ArrayAttr attributes.
+ with an array of DenseI64ArrayAttr attributes.
Example:
@@ -1563,7 +1563,7 @@ def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape", [
type.
A reassociation is defined as a continuous grouping of dimensions and is
- represented with an array of I64ArrayAttr attribute.
+ represented with an array of DenseI64ArrayAttr attribute.
Note: Only the dimensions within a reassociation group must be contiguous.
The remaining dimensions may be non-contiguous.
@@ -1855,9 +1855,9 @@ def SubViewOp : MemRef_OpWithOffsetSizesAndStrides<"subview", [
Variadic<Index>:$offsets,
Variadic<Index>:$sizes,
Variadic<Index>:$strides,
- I64ArrayAttr:$static_offsets,
- I64ArrayAttr:$static_sizes,
- I64ArrayAttr:$static_strides);
+ DenseI64ArrayAttr:$static_offsets,
+ DenseI64ArrayAttr:$static_sizes,
+ DenseI64ArrayAttr:$static_strides);
let results = (outs AnyMemRef:$result);
let assemblyFormat = [{
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 14060075b2340..0af5811638a85 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -326,9 +326,9 @@ def Tensor_ExtractSliceOp : Tensor_OpWithOffsetSizesAndStrides<"extract_slice",
Variadic<Index>:$offsets,
Variadic<Index>:$sizes,
Variadic<Index>:$strides,
- I64ArrayAttr:$static_offsets,
- I64ArrayAttr:$static_sizes,
- I64ArrayAttr:$static_strides
+ DenseI64ArrayAttr:$static_offsets,
+ DenseI64ArrayAttr:$static_sizes,
+ DenseI64ArrayAttr:$static_strides
);
let results = (outs AnyRankedTensor:$result);
@@ -807,9 +807,9 @@ def Tensor_InsertSliceOp : Tensor_OpWithOffsetSizesAndStrides<"insert_slice", [
Variadic<Index>:$offsets,
Variadic<Index>:$sizes,
Variadic<Index>:$strides,
- I64ArrayAttr:$static_offsets,
- I64ArrayAttr:$static_sizes,
- I64ArrayAttr:$static_strides
+ DenseI64ArrayAttr:$static_offsets,
+ DenseI64ArrayAttr:$static_sizes,
+ DenseI64ArrayAttr:$static_strides
);
let results = (outs AnyRankedTensor:$result);
@@ -1013,7 +1013,7 @@ def Tensor_ExpandShapeOp : Tensor_ReassociativeReshapeOp<"expand_shape"> {
rank whose sizes are a reassociation of the original `src`.
A reassociation is defined as a continuous grouping of dimensions and is
- represented with an array of I64ArrayAttr attribute.
+ represented with an array of DenseI64ArrayAttr attribute.
The verification rule is that the reassociation maps are applied to the
result tensor with the higher rank to obtain the operand tensor with the
@@ -1065,7 +1065,7 @@ def Tensor_CollapseShapeOp : Tensor_ReassociativeReshapeOp<"collapse_shape"> {
rank whose sizes are a reassociation of the original `src`.
A reassociation is defined as a continuous grouping of dimensions and is
- represented with an array of I64ArrayAttr attribute.
+ represented with an array of DenseI64ArrayAttr attribute.
The verification rule is that the reassociation maps are applied to the
operand tensor with the higher rank to obtain the result tensor with the
@@ -1206,8 +1206,8 @@ def Tensor_PadOp : Tensor_Op<"pad", [
AnyTensor:$source,
Variadic<Index>:$low,
Variadic<Index>:$high,
- I64ArrayAttr:$static_low,
- I64ArrayAttr:$static_high,
+ DenseI64ArrayAttr:$static_low,
+ DenseI64ArrayAttr:$static_high,
UnitAttr:$nofold);
let regions = (region SizedRegion<1>:$region);
@@ -1254,16 +1254,17 @@ def Tensor_PadOp : Tensor_Op<"pad", [
// Return a vector of all the static or dynamic values (low/high padding) of
// the op.
- inline SmallVector<OpFoldResult> getMixedPadImpl(ArrayAttr staticAttrs,
+ inline SmallVector<OpFoldResult> getMixedPadImpl(ArrayRef<int64_t> staticAttrs,
ValueRange values) {
+ Builder builder(*this);
SmallVector<OpFoldResult> res;
unsigned numDynamic = 0;
unsigned count = staticAttrs.size();
for (unsigned idx = 0; idx < count; ++idx) {
- if (ShapedType::isDynamic(staticAttrs[idx].cast<IntegerAttr>().getInt()))
+ if (ShapedType::isDynamic(staticAttrs[idx]))
res.push_back(values[numDynamic++]);
else
- res.push_back(staticAttrs[idx]);
+ res.push_back(builder.getI64IntegerAttr(staticAttrs[idx]));
}
return res;
}
@@ -1400,9 +1401,9 @@ def Tensor_ParallelInsertSliceOp : Tensor_Op<"parallel_insert_slice", [
Variadic<Index>:$offsets,
Variadic<Index>:$sizes,
Variadic<Index>:$strides,
- I64ArrayAttr:$static_offsets,
- I64ArrayAttr:$static_sizes,
- I64ArrayAttr:$static_strides
+ DenseI64ArrayAttr:$static_offsets,
+ DenseI64ArrayAttr:$static_sizes,
+ DenseI64ArrayAttr:$static_strides
);
let assemblyFormat = [{
$source `into` $dest ``
@@ -1748,7 +1749,7 @@ def Tensor_PackOp : Tensor_RelayoutOp<"pack", [
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$outer_dims_perm,
DenseI64ArrayAttr:$inner_dims_pos,
Variadic<Index>:$inner_tiles,
- I64ArrayAttr:$static_inner_tiles);
+ DenseI64ArrayAttr:$static_inner_tiles);
let results = (outs AnyRankedTensor:$result);
let assemblyFormat = [{
$source
@@ -1803,7 +1804,7 @@ def Tensor_UnPackOp : Tensor_RelayoutOp<"unpack"> {
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$outer_dims_perm,
DenseI64ArrayAttr:$inner_dims_pos,
Variadic<Index>:$inner_tiles,
- I64ArrayAttr:$static_inner_tiles);
+ DenseI64ArrayAttr:$static_inner_tiles);
let results = (outs AnyRankedTensor:$result);
let assemblyFormat = [{
$source
diff --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
index f09cf88afaab3..e72f7095b6da0 100644
--- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
@@ -87,6 +87,18 @@ bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2);
SmallVector<Value> getAsValues(OpBuilder &b, Location loc,
ArrayRef<OpFoldResult> valueOrAttrVec);
+/// Return a vector of OpFoldResults with the same size a staticValues, but all
+/// elements for which ShapedType::isDynamic is true, will be replaced by
+/// dynamicValues.
+SmallVector<OpFoldResult> getMixedValues(ArrayRef<int64_t> staticValues,
+ ValueRange dynamicValues, Builder &b);
+
+/// Decompose a vector of mixed static or dynamic values into the corresponding
+/// pair of arrays. This is the inverse function of `getMixedValues`.
+std::pair<ArrayAttr, SmallVector<Value>>
+decomposeMixedValues(Builder &b,
+ const SmallVectorImpl<OpFoldResult> &mixedValues);
+
} // namespace mlir
#endif // MLIR_DIALECT_UTILS_STATICVALUEUTILS_H
diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.h b/mlir/include/mlir/Interfaces/ViewLikeInterface.h
index 700546d082e6f..f950933b23c7a 100644
--- a/mlir/include/mlir/Interfaces/ViewLikeInterface.h
+++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.h
@@ -21,18 +21,6 @@
namespace mlir {
-/// Return a vector of OpFoldResults with the same size a staticValues, but all
-/// elements for which ShapedType::isDynamic is true, will be replaced by
-/// dynamicValues.
-SmallVector<OpFoldResult, 4> getMixedValues(ArrayAttr staticValues,
- ValueRange dynamicValues);
-
-/// Decompose a vector of mixed static or dynamic values into the corresponding
-/// pair of arrays. This is the inverse function of `getMixedValues`.
-std::pair<ArrayAttr, SmallVector<Value>>
-decomposeMixedValues(Builder &b,
- const SmallVectorImpl<OpFoldResult> &mixedValues);
-
class OffsetSizeAndStrideOpInterface;
namespace detail {
@@ -61,7 +49,7 @@ namespace mlir {
/// idiomatic printing of mixed value and integer attributes in a list. E.g.
/// `[%arg0, 7, 42, %arg42]`.
void printDynamicIndexList(OpAsmPrinter &printer, Operation *op,
- OperandRange values, ArrayAttr integers);
+ OperandRange values, ArrayRef<int64_t> integers);
/// Pasrer hook for custom directive in assemblyFormat.
///
@@ -79,13 +67,14 @@ void printDynamicIndexList(OpAsmPrinter &printer, Operation *op,
ParseResult
parseDynamicIndexList(OpAsmParser &parser,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
- ArrayAttr &integers);
+ DenseI64ArrayAttr &integers);
/// Verify that a the `values` has as many elements as the number of entries in
/// `attr` for which `isDynamic` evaluates to true.
LogicalResult verifyListOfOperandsOrIntegers(Operation *op, StringRef name,
unsigned expectedNumElements,
- ArrayAttr attr, ValueRange values);
+ ArrayRef<int64_t> attr,
+ ValueRange values);
} // namespace mlir
diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.td b/mlir/include/mlir/Interfaces/ViewLikeInterface.td
index aca01262134c4..b5870af8c7936 100644
--- a/mlir/include/mlir/Interfaces/ViewLikeInterface.td
+++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.td
@@ -124,7 +124,7 @@ def OffsetSizeAndStrideOpInterface : OpInterface<"OffsetSizeAndStrideOpInterface
/*desc=*/[{
Return the static offset attributes.
}],
- /*retTy=*/"::mlir::ArrayAttr",
+ /*retTy=*/"::llvm::ArrayRef<int64_t>",
/*methodName=*/"static_offsets",
/*args=*/(ins),
/*methodBody=*/"",
@@ -136,7 +136,7 @@ def OffsetSizeAndStrideOpInterface : OpInterface<"OffsetSizeAndStrideOpInterface
/*desc=*/[{
Return the static size attributes.
}],
- /*retTy=*/"::mlir::ArrayAttr",
+ /*retTy=*/"::llvm::ArrayRef<int64_t>",
/*methodName=*/"static_sizes",
/*args=*/(ins),
/*methodBody=*/"",
@@ -148,7 +148,7 @@ def OffsetSizeAndStrideOpInterface : OpInterface<"OffsetSizeAndStrideOpInterface
/*desc=*/[{
Return the dynamic stride attributes.
}],
- /*retTy=*/"::mlir::ArrayAttr",
+ /*retTy=*/"::llvm::ArrayRef<int64_t>",
/*methodName=*/"static_strides",
/*args=*/(ins),
/*methodBody=*/"",
@@ -165,8 +165,9 @@ def OffsetSizeAndStrideOpInterface : OpInterface<"OffsetSizeAndStrideOpInterface
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
+ Builder b($_op->getContext());
return ::mlir::getMixedValues($_op.getStaticOffsets(),
- $_op.getOffsets());
+ $_op.getOffsets(), b);
}]
>,
InterfaceMethod<
@@ -178,7 +179,8 @@ def OffsetSizeAndStrideOpInterface : OpInterface<"OffsetSizeAndStrideOpInterface
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
- return ::mlir::getMixedValues($_op.getStaticSizes(), $_op.sizes());
+ Builder b($_op->getContext());
+ return ::mlir::getMixedValues($_op.getStaticSizes(), $_op.sizes(), b);
}]
>,
InterfaceMethod<
@@ -190,8 +192,9 @@ def OffsetSizeAndStrideOpInterface : OpInterface<"OffsetSizeAndStrideOpInterface
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
+ Builder b($_op->getContext());
return ::mlir::getMixedValues($_op.getStaticStrides(),
- $_op.getStrides());
+ $_op.getStrides(), b);
}]
>,
@@ -202,9 +205,7 @@ def OffsetSizeAndStrideOpInterface : OpInterface<"OffsetSizeAndStrideOpInterface
/*args=*/(ins "unsigned":$idx),
/*methodBody=*/"",
/*defaultImplementation=*/[{
- ::llvm::APInt v = *(static_offsets()
- .template getAsValueRange<::mlir::IntegerAttr>().begin() + idx);
- return ::mlir::ShapedType::isDynamic(v.getSExtValue());
+ return ::mlir::ShapedType::isDynamic(static_offsets()[idx]);
}]
>,
InterfaceMethod<
@@ -214,9 +215,7 @@ def OffsetSizeAndStrideOpInterface : OpInterface<"OffsetSizeAndStrideOpInterface
/*args=*/(ins "unsigned":$idx),
/*methodBody=*/"",
/*defaultImplementation=*/[{
- ::llvm::APInt v = *(static_sizes()
- .template getAsValueRange<::mlir::IntegerAttr>().begin() + idx);
- return ::mlir::ShapedType::isDynamic(v.getSExtValue());
+ return ::mlir::ShapedType::isDynamic(static_sizes()[idx]);
}]
>,
InterfaceMethod<
@@ -226,9 +225,7 @@ def OffsetSizeAndStrideOpInterface : OpInterface<"OffsetSizeAndStrideOpInterface
/*args=*/(ins "unsigned":$idx),
/*methodBody=*/"",
/*defaultImplementation=*/[{
- ::llvm::APInt v = *(static_strides()
- .template getAsValueRange<::mlir::IntegerAttr>().begin() + idx);
- return ::mlir::ShapedType::isDynamic(v.getSExtValue());
+ return ::mlir::ShapedType::isDynamic(static_strides()[idx]);
}]
>,
InterfaceMethod<
@@ -241,9 +238,7 @@ def OffsetSizeAndStrideOpInterface : OpInterface<"OffsetSizeAndStrideOpInterface
/*methodBody=*/"",
/*defaultImplementation=*/[{
assert(!$_op.isDynamicOffset(idx) && "expected static offset");
- ::llvm::APInt v = *(static_offsets().
- template getAsValueRange<::mlir::IntegerAttr>().begin() + idx);
- return v.getSExtValue();
+ return static_offsets()[idx];
}]
>,
InterfaceMethod<
@@ -256,9 +251,7 @@ def OffsetSizeAndStrideOpInterface : OpInterface<"OffsetSizeAndStrideOpInterface
/*methodBody=*/"",
/*defaultImplementation=*/[{
assert(!$_op.isDynamicSize(idx) && "expected static size");
- ::llvm::APInt v = *(static_sizes().
- template getAsValueRange<::mlir::IntegerAttr>().begin() + idx);
- return v.getSExtValue();
+ return static_sizes()[idx];
}]
>,
InterfaceMethod<
@@ -271,9 +264,7 @@ def OffsetSizeAndStrideOpInterface : OpInterface<"OffsetSizeAndStrideOpInterface
/*methodBody=*/"",
/*defaultImplementation=*/[{
assert(!$_op.isDynamicStride(idx) && "expected static stride");
- ::llvm::APInt v = *(static_strides().
- template getAsValueRange<::mlir::IntegerAttr>().begin() + idx);
- return v.getSExtValue();
+ return static_strides()[idx];
}]
>,
@@ -289,7 +280,7 @@ def OffsetSizeAndStrideOpInterface : OpInterface<"OffsetSizeAndStrideOpInterface
/*defaultImplementation=*/[{
assert($_op.isDynamicOffset(idx) && "expected dynamic offset");
auto numDynamic = getNumDynamicEntriesUpToIdx(
- static_offsets().template cast<::mlir::ArrayAttr>(),
+ static_offsets(),
::mlir::ShapedType::isDynamic,
idx);
return $_op.getOffsetSizeAndStrideStartOperandIndex() + numDynamic;
@@ -307,7 +298,7 @@ def OffsetSizeAndStrideOpInterface : OpInterface<"OffsetSizeAndStrideOpInterface
/*defaultImplementation=*/[{
assert($_op.isDynamicSize(idx) && "expected dynamic size");
auto numDynamic = getNumDynamicEntriesUpToIdx(
- static_sizes().template cast<::mlir::ArrayAttr>(), ::mlir::ShapedType::isDynamic, idx);
+ static_sizes(), ::mlir::ShapedType::isDynamic, idx);
return $_op.getOffsetSizeAndStrideStartOperandIndex() +
offsets().size() + numDynamic;
}]
@@ -324,7 +315,7 @@ def OffsetSizeAndStrideOpInterface : OpInterface<"OffsetSizeAndStrideOpInterface
/*defaultImplementation=*/[{
assert($_op.isDynamicStride(idx) && "expected dynamic stride");
auto numDynamic = getNumDynamicEntriesUpToIdx(
- static_strides().template cast<::mlir::ArrayAttr>(),
+ static_strides(),
::mlir::ShapedType::isDynamic,
idx);
return $_op.getOffsetSizeAndStrideStartOperandIndex() +
@@ -333,20 +324,20 @@ def OffsetSizeAndStrideOpInterface : OpInterface<"OffsetSizeAndStrideOpInterface
>,
InterfaceMethod<
/*desc=*/[{
- Helper method to compute the number of dynamic entries of `attr`, up to
+ Helper method to compute the number of dynamic entries of `staticVals`, up to
`idx` using `isDynamic` to determine whether an entry is dynamic.
}],
/*retTy=*/"unsigned",
/*methodName=*/"getNumDynamicEntriesUpToIdx",
- /*args=*/(ins "::mlir::ArrayAttr":$attr,
+ /*args=*/(ins "::llvm::ArrayRef<int64_t>":$staticVals,
"::llvm::function_ref<bool(int64_t)>":$isDynamic,
"unsigned":$idx),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return std::count_if(
- attr.getValue().begin(), attr.getValue().begin() + idx,
- [&](::mlir::Attribute attr) {
- return isDynamic(attr.cast<::mlir::IntegerAttr>().getInt());
+ staticVals.begin(), staticVals.begin() + idx,
+ [&](int64_t val) {
+ return isDynamic(val);
});
}]
>,
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index 42d2d9a1b3097..7da3c3693bb69 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -1705,10 +1705,8 @@ struct SubViewOpLowering : public ConvertOpToLLVMPattern<memref::SubViewOp> {
auto viewMemRefType = subViewOp.getType();
auto inferredType =
memref::SubViewOp::inferResultType(
- subViewOp.getSourceType(),
- extractFromI64ArrayAttr(subViewOp.getStaticOffsets()),
- extractFromI64ArrayAttr(subViewOp.getStaticSizes()),
- extractFromI64ArrayAttr(subViewOp.getStaticStrides()))
+ subViewOp.getSourceType(), subViewOp.getStaticOffsets(),
+ subViewOp.getStaticSizes(), subViewOp.getStaticStrides())
.cast<MemRefType>();
auto targetElementTy =
typeConverter->convertType(viewMemRefType.getElementType());
diff --git a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
index 92bb30eefa5de..cb2eea2960e3d 100644
--- a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
+++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
@@ -30,8 +30,8 @@ class SliceOpConverter : public OpRewritePattern<tosa::SliceOp> {
PatternRewriter &rewriter) const final {
Location loc = sliceOp.getLoc();
Value input = sliceOp.getInput();
- SmallVector<int64_t> strides, sizes;
- auto starts = sliceOp.getStart();
+ SmallVector<int64_t> strides, sizes, starts;
+ starts = extractFromI64ArrayAttr(sliceOp.getStart());
strides.resize(sliceOp.getType().template cast<ShapedType>().getRank(), 1);
SmallVector<Value> dynSizes;
@@ -44,15 +44,15 @@ class SliceOpConverter : public OpRewritePattern<tosa::SliceOp> {
auto dim = rewriter.create<tensor::DimOp>(loc, input, index);
auto offset = rewriter.create<arith::ConstantOp>(
- loc,
- rewriter.getIndexAttr(starts[index].cast<IntegerAttr>().getInt()));
+ loc, rewriter.getIndexAttr(starts[index]));
dynSizes.push_back(rewriter.create<arith::SubIOp>(loc, dim, offset));
}
auto newSliceOp = rewriter.create<tensor::ExtractSliceOp>(
sliceOp.getLoc(), sliceOp.getType(), input, ValueRange({}), dynSizes,
- ValueRange({}), starts, rewriter.getI64ArrayAttr(sizes),
- rewriter.getI64ArrayAttr(strides));
+ ValueRange({}), rewriter.getDenseI64ArrayAttr(starts),
+ rewriter.getDenseI64ArrayAttr(sizes),
+ rewriter.getDenseI64ArrayAttr(strides));
rewriter.replaceOp(sliceOp, newSliceOp.getResult());
return success();
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index f02ccfae68934..e6123a4f17749 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -40,16 +40,6 @@ static SmallVector<unsigned> extractUIntArray(ArrayAttr attr) {
return result;
}
-/// Extracts a vector of int64_t from an array attribute. Asserts if the
-/// attribute contains values other than integers.
-static SmallVector<int64_t> extractI64Array(ArrayAttr attr) {
- SmallVector<int64_t> result;
- result.reserve(attr.size());
- for (APInt value : attr.getAsValueRange<IntegerAttr>())
- result.push_back(value.getSExtValue());
- return result;
-}
-
namespace {
/// A simple pattern rewriter that implements no special logic.
class SimpleRewriter : public PatternRewriter {
@@ -1205,7 +1195,7 @@ transform::TileReductionUsingForeachThreadOp::applyToOne(
DiagnosedSilenceableFailure
transform::TileOp::apply(TransformResults &transformResults,
TransformState &state) {
- SmallVector<int64_t> tileSizes = extractFromI64ArrayAttr(getStaticSizes());
+ ArrayRef<int64_t> tileSizes = getStaticSizes();
ArrayRef<Operation *> targets = state.getPayloadOps(getTarget());
SmallVector<ArrayRef<Operation *>> dynamicSizeProducers;
@@ -1270,7 +1260,7 @@ transform::TileOp::apply(TransformResults &transformResults,
});
}
- tilingOptions.setInterchange(extractI64Array(getInterchange()));
+ tilingOptions.setInterchange(getInterchange());
SimpleRewriter rewriter(linalgOp.getContext());
FailureOr<scf::SCFTilingResult> maybeTilingResult = tileUsingSCFForOp(
rewriter, cast<TilingInterface>(linalgOp.getOperation()),
@@ -1298,7 +1288,7 @@ transform::TileOp::apply(TransformResults &transformResults,
SmallVector<OpFoldResult> transform::TileOp::getMixedSizes() {
ValueRange dynamic = getDynamicSizes();
- SmallVector<int64_t> tileSizes = extractFromI64ArrayAttr(getStaticSizes());
+ ArrayRef<int64_t> tileSizes = getStaticSizes();
SmallVector<OpFoldResult> results;
results.reserve(tileSizes.size());
unsigned dynamicPos = 0;
@@ -1313,22 +1303,51 @@ SmallVector<OpFoldResult> transform::TileOp::getMixedSizes() {
return results;
}
+// We want to parse `DenseI64ArrayAttr` using the short form without the
+// `array` prefix to be consistent in the IR with `parseDynamicIndexList`.
+ParseResult parseOptionalInterchange(OpAsmParser &parser,
+ OperationState &result) {
+ if (succeeded(parser.parseOptionalLBrace())) {
+ if (failed(parser.parseKeyword("interchange")))
+ return parser.emitError(parser.getNameLoc()) << "expect `interchange`";
+ if (failed(parser.parseEqual()))
+ return parser.emitError(parser.getNameLoc()) << "expect `=`";
+ result.addAttribute("interchange",
+ DenseI64ArrayAttr::parse(parser, Type{}));
+ if (failed(parser.parseRBrace()))
+ return parser.emitError(parser.getNameLoc()) << "expect `}`";
+ }
+ return success();
+}
+
+void printOptionalInterchange(OpAsmPrinter &p,
+ ArrayRef<int64_t> interchangeVals) {
+ if (!interchangeVals.empty()) {
+ p << " {interchange = [";
+ llvm::interleaveComma(interchangeVals, p,
+ [&](int64_t integer) { p << integer; });
+ p << "]}";
+ }
+}
+
ParseResult transform::TileOp::parse(OpAsmParser &parser,
OperationState &result) {
OpAsmParser::UnresolvedOperand target;
SmallVector<OpAsmParser::UnresolvedOperand> dynamicSizes;
- ArrayAttr staticSizes;
+ DenseI64ArrayAttr staticSizes;
auto pdlOperationType = pdl::OperationType::get(parser.getContext());
if (parser.parseOperand(target) ||
parser.resolveOperand(target, pdlOperationType, result.operands) ||
parseDynamicIndexList(parser, dynamicSizes, staticSizes) ||
- parser.resolveOperands(dynamicSizes, pdlOperationType, result.operands) ||
- parser.parseOptionalAttrDict(result.attributes))
+ parser.resolveOperands(dynamicSizes, pdlOperationType, result.operands))
return ParseResult::failure();
+ // Parse optional interchange.
+ if (failed(parseOptionalInterchange(parser, result)))
+ return ParseResult::failure();
result.addAttribute(getStaticSizesAttrName(result.name), staticSizes);
size_t numExpectedLoops =
- staticSizes.size() - llvm::count(extractFromI64ArrayAttr(staticSizes), 0);
+ staticSizes.size() - llvm::count(staticSizes.asArrayRef(), 0);
result.addTypes(SmallVector<Type>(numExpectedLoops + 1, pdlOperationType));
return success();
}
@@ -1336,7 +1355,7 @@ ParseResult transform::TileOp::parse(OpAsmParser &parser,
void TileOp::print(OpAsmPrinter &p) {
p << ' ' << getTarget();
printDynamicIndexList(p, getOperation(), getDynamicSizes(), getStaticSizes());
- p.printOptionalAttrDict((*this)->getAttrs(), {getStaticSizesAttrName()});
+ printOptionalInterchange(p, getInterchange());
}
void transform::TileOp::getEffects(
@@ -1379,13 +1398,13 @@ void transform::TileToForeachThreadOp::build(
// bugs ensue.
MLIRContext *ctx = builder.getContext();
auto operationType = pdl::OperationType::get(ctx);
- auto staticTileSizesAttr = builder.getI64ArrayAttr(staticTileSizes);
+ auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes);
build(builder, result,
/*resultTypes=*/TypeRange{operationType, operationType},
/*target=*/target,
/*num_threads=*/ValueRange{},
/*tile_sizes=*/dynamicTileSizes,
- /*static_num_threads=*/builder.getI64ArrayAttr({}),
+ /*static_num_threads=*/builder.getDenseI64ArrayAttr({}),
/*static_tile_sizes=*/staticTileSizesAttr,
/*mapping=*/mapping);
}
@@ -1414,14 +1433,14 @@ void transform::TileToForeachThreadOp::build(
// bugs ensue.
MLIRContext *ctx = builder.getContext();
auto operationType = pdl::OperationType::get(ctx);
- auto staticNumThreadsAttr = builder.getI64ArrayAttr(staticNumThreads);
+ auto staticNumThreadsAttr = builder.getDenseI64ArrayAttr(staticNumThreads);
build(builder, result,
/*resultTypes=*/TypeRange{operationType, operationType},
/*target=*/target,
/*num_threads=*/dynamicNumThreads,
/*tile_sizes=*/ValueRange{},
/*static_num_threads=*/staticNumThreadsAttr,
- /*static_tile_sizes=*/builder.getI64ArrayAttr({}),
+ /*static_tile_sizes=*/builder.getDenseI64ArrayAttr({}),
/*mapping=*/mapping);
}
@@ -1547,11 +1566,13 @@ void transform::TileToForeachThreadOp::getEffects(
}
SmallVector<OpFoldResult> TileToForeachThreadOp::getMixedNumThreads() {
- return getMixedValues(getStaticNumThreads(), getNumThreads());
+ Builder b(getContext());
+ return getMixedValues(getStaticNumThreads(), getNumThreads(), b);
}
SmallVector<OpFoldResult> TileToForeachThreadOp::getMixedTileSizes() {
- return getMixedValues(getStaticTileSizes(), getTileSizes());
+ Builder b(getContext());
+ return getMixedValues(getStaticTileSizes(), getTileSizes(), b);
}
LogicalResult TileToForeachThreadOp::verify() {
@@ -1567,7 +1588,7 @@ LogicalResult TileToForeachThreadOp::verify() {
DiagnosedSilenceableFailure
transform::TileToScfForOp::apply(TransformResults &transformResults,
TransformState &state) {
- SmallVector<int64_t> tileSizes = extractFromI64ArrayAttr(getStaticSizes());
+ ArrayRef<int64_t> tileSizes = getStaticSizes();
ArrayRef<Operation *> targets = state.getPayloadOps(getTarget());
SmallVector<ArrayRef<Operation *>> dynamicSizeProducers;
@@ -1632,7 +1653,7 @@ transform::TileToScfForOp::apply(TransformResults &transformResults,
});
}
- tilingOptions.setInterchange(extractI64Array(getInterchange()));
+ tilingOptions.setInterchange(getInterchange());
SimpleRewriter rewriter(tilingInterfaceOp.getContext());
FailureOr<scf::SCFTilingResult> tilingResult =
tileUsingSCFForOp(rewriter, tilingInterfaceOp, tilingOptions);
@@ -1655,7 +1676,7 @@ transform::TileToScfForOp::apply(TransformResults &transformResults,
SmallVector<OpFoldResult> transform::TileToScfForOp::getMixedSizes() {
ValueRange dynamic = getDynamicSizes();
- SmallVector<int64_t> tileSizes = extractFromI64ArrayAttr(getStaticSizes());
+ ArrayRef<int64_t> tileSizes = getStaticSizes();
SmallVector<OpFoldResult> results;
results.reserve(tileSizes.size());
unsigned dynamicPos = 0;
@@ -1674,18 +1695,20 @@ ParseResult transform::TileToScfForOp::parse(OpAsmParser &parser,
OperationState &result) {
OpAsmParser::UnresolvedOperand target;
SmallVector<OpAsmParser::UnresolvedOperand> dynamicSizes;
- ArrayAttr staticSizes;
+ DenseI64ArrayAttr staticSizes;
auto pdlOperationType = pdl::OperationType::get(parser.getContext());
if (parser.parseOperand(target) ||
parser.resolveOperand(target, pdlOperationType, result.operands) ||
parseDynamicIndexList(parser, dynamicSizes, staticSizes) ||
- parser.resolveOperands(dynamicSizes, pdlOperationType, result.operands) ||
- parser.parseOptionalAttrDict(result.attributes))
+ parser.resolveOperands(dynamicSizes, pdlOperationType, result.operands))
return ParseResult::failure();
+ // Parse optional interchange.
+ if (failed(parseOptionalInterchange(parser, result)))
+ return ParseResult::failure();
result.addAttribute(getStaticSizesAttrName(result.name), staticSizes);
size_t numExpectedLoops =
- staticSizes.size() - llvm::count(extractFromI64ArrayAttr(staticSizes), 0);
+ staticSizes.size() - llvm::count(staticSizes.asArrayRef(), 0);
result.addTypes(SmallVector<Type>(numExpectedLoops + 1, pdlOperationType));
return success();
}
@@ -1693,7 +1716,7 @@ ParseResult transform::TileToScfForOp::parse(OpAsmParser &parser,
void TileToScfForOp::print(OpAsmPrinter &p) {
p << ' ' << getTarget();
printDynamicIndexList(p, getOperation(), getDynamicSizes(), getStaticSizes());
- p.printOptionalAttrDict((*this)->getAttrs(), {getStaticSizesAttrName()});
+ printOptionalInterchange(p, getInterchange());
}
void transform::TileToScfForOp::getEffects(
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 2b84860c5e735..91e0ae41f3914 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -348,7 +348,7 @@ PadOpTransformationPattern::matchAndRewrite(tensor::PadOp padOp,
SmallVector<AffineExpr, 4> outputExprs;
for (unsigned i = 0; i < resultShapedType.getRank(); ++i) {
outputExprs.push_back(getAffineDimExpr(i, rewriter.getContext()) +
- padOp.getStaticLow()[i].cast<IntegerAttr>().getInt());
+ padOp.getStaticLow()[i]);
}
SmallVector<AffineMap, 2> transferMaps = {
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 503c8aed2709d..4ba4050bb45e6 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1776,8 +1776,9 @@ void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
ShapedType::kDynamic);
build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
- dynamicStrides, b.getI64ArrayAttr(staticOffsets),
- b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides));
+ dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
+ b.getDenseI64ArrayAttr(staticSizes),
+ b.getDenseI64ArrayAttr(staticStrides));
result.addAttributes(attrs);
}
@@ -1823,8 +1824,8 @@ LogicalResult ReinterpretCastOp::verify() {
<< srcType << " and result memref type " << resultType;
// Match sizes in result memref type and in static_sizes attribute.
- for (auto &en : llvm::enumerate(llvm::zip(
- resultType.getShape(), extractFromI64ArrayAttr(getStaticSizes())))) {
+ for (auto &en :
+ llvm::enumerate(llvm::zip(resultType.getShape(), getStaticSizes()))) {
int64_t resultSize = std::get<0>(en.value());
int64_t expectedSize = std::get<1>(en.value());
if (!ShapedType::isDynamic(resultSize) &&
@@ -1844,7 +1845,7 @@ LogicalResult ReinterpretCastOp::verify() {
<< resultType;
// Match offset in result memref type and in static_offsets attribute.
- int64_t expectedOffset = extractFromI64ArrayAttr(getStaticOffsets()).front();
+ int64_t expectedOffset = getStaticOffsets().front();
if (!ShapedType::isDynamic(resultOffset) &&
!ShapedType::isDynamic(expectedOffset) &&
resultOffset != expectedOffset)
@@ -1852,8 +1853,8 @@ LogicalResult ReinterpretCastOp::verify() {
<< resultOffset << " instead of " << expectedOffset;
// Match strides in result memref type and in static_strides attribute.
- for (auto &en : llvm::enumerate(llvm::zip(
- resultStrides, extractFromI64ArrayAttr(getStaticStrides())))) {
+ for (auto &en :
+ llvm::enumerate(llvm::zip(resultStrides, getStaticStrides()))) {
int64_t resultStride = std::get<0>(en.value());
int64_t expectedStride = std::get<1>(en.value());
if (!ShapedType::isDynamic(resultStride) &&
@@ -2665,8 +2666,9 @@ void SubViewOp::build(OpBuilder &b, OperationState &result,
.cast<MemRefType>();
}
build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
- dynamicStrides, b.getI64ArrayAttr(staticOffsets),
- b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides));
+ dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
+ b.getDenseI64ArrayAttr(staticSizes),
+ b.getDenseI64ArrayAttr(staticStrides));
result.addAttributes(attrs);
}
@@ -2831,9 +2833,7 @@ LogicalResult SubViewOp::verify() {
// Verify result type against inferred type.
auto expectedType = SubViewOp::inferResultType(
- baseType, extractFromI64ArrayAttr(getStaticOffsets()),
- extractFromI64ArrayAttr(getStaticSizes()),
- extractFromI64ArrayAttr(getStaticStrides()));
+ baseType, getStaticOffsets(), getStaticSizes(), getStaticStrides());
auto result = isRankReducedMemRefType(expectedType.cast<MemRefType>(),
subViewType, getMixedSizes());
diff --git a/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp b/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp
index faa00e2c97811..fae68a0a349e8 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp
@@ -45,9 +45,8 @@ static void replaceUsesAndPropagateType(Operation *oldOp, Value val,
builder.setInsertionPoint(subviewUse);
Type newType = memref::SubViewOp::inferRankReducedResultType(
subviewUse.getType().getShape(), val.getType().cast<MemRefType>(),
- extractFromI64ArrayAttr(subviewUse.getStaticOffsets()),
- extractFromI64ArrayAttr(subviewUse.getStaticSizes()),
- extractFromI64ArrayAttr(subviewUse.getStaticStrides()));
+ subviewUse.getStaticOffsets(), subviewUse.getStaticSizes(),
+ subviewUse.getStaticStrides());
Value newSubview = builder.create<memref::SubViewOp>(
subviewUse->getLoc(), newType.cast<MemRefType>(), val,
subviewUse.getMixedOffsets(), subviewUse.getMixedSizes(),
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 23af46c6d7912..f279876d19541 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -337,8 +337,7 @@ struct TensorCastExtractSlice : public OpRewritePattern<CastOp> {
SmallVector<OpFoldResult, 4> sizes = extractOperand.getMixedSizes();
auto dimMask = computeRankReductionMask(
- extractFromI64ArrayAttr(extractOperand.getStaticSizes()),
- extractOperand.getType().getShape());
+ extractOperand.getStaticSizes(), extractOperand.getType().getShape());
size_t dimIndex = 0;
for (size_t i = 0, e = sizes.size(); i < e; i++) {
if (dimMask && dimMask->count(i))
@@ -1713,8 +1712,9 @@ void ExtractSliceOp::build(OpBuilder &b, OperationState &result,
.cast<RankedTensorType>();
}
build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
- dynamicStrides, b.getI64ArrayAttr(staticOffsets),
- b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides));
+ dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
+ b.getDenseI64ArrayAttr(staticSizes),
+ b.getDenseI64ArrayAttr(staticStrides));
result.addAttributes(attrs);
}
@@ -1949,13 +1949,13 @@ class ConstantOpExtractSliceFolder final
return failure();
// Check if there are any dynamic parts, which are not supported.
- auto offsets = extractFromI64ArrayAttr(op.getStaticOffsets());
+ auto offsets = op.getStaticOffsets();
if (llvm::is_contained(offsets, ShapedType::kDynamic))
return failure();
- auto sizes = extractFromI64ArrayAttr(op.getStaticSizes());
+ auto sizes = op.getStaticSizes();
if (llvm::is_contained(sizes, ShapedType::kDynamic))
return failure();
- auto strides = extractFromI64ArrayAttr(op.getStaticStrides());
+ auto strides = op.getStaticStrides();
if (llvm::is_contained(strides, ShapedType::kDynamic))
return failure();
@@ -2124,8 +2124,9 @@ void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
ShapedType::kDynamic);
build(b, result, dest.getType(), source, dest, dynamicOffsets, dynamicSizes,
- dynamicStrides, b.getI64ArrayAttr(staticOffsets),
- b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides));
+ dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
+ b.getDenseI64ArrayAttr(staticSizes),
+ b.getDenseI64ArrayAttr(staticStrides));
result.addAttributes(attrs);
}
@@ -2153,17 +2154,14 @@ void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
/// Rank-reducing type verification for both InsertSliceOp and
/// ParallelInsertSliceOp.
-static SliceVerificationResult
-verifyInsertSliceOp(ShapedType srcType, ShapedType dstType,
- ArrayAttr staticOffsets, ArrayAttr staticSizes,
- ArrayAttr staticStrides,
- ShapedType *expectedType = nullptr) {
+static SliceVerificationResult verifyInsertSliceOp(
+ ShapedType srcType, ShapedType dstType, ArrayRef<int64_t> staticOffsets,
+ ArrayRef<int64_t> staticSizes, ArrayRef<int64_t> staticStrides,
+ ShapedType *expectedType = nullptr) {
// insert_slice is the inverse of extract_slice, use the same type
// inference.
RankedTensorType expected = ExtractSliceOp::inferResultType(
- dstType, extractFromI64ArrayAttr(staticOffsets),
- extractFromI64ArrayAttr(staticSizes),
- extractFromI64ArrayAttr(staticStrides));
+ dstType, staticOffsets, staticSizes, staticStrides);
if (expectedType)
*expectedType = expected;
return isRankReducedType(expected, srcType);
@@ -2482,9 +2480,8 @@ ParseResult parseInferType(OpAsmParser &parser,
LogicalResult PadOp::verify() {
auto sourceType = getSource().getType().cast<RankedTensorType>();
auto resultType = getResult().getType().cast<RankedTensorType>();
- auto expectedType = PadOp::inferResultType(
- sourceType, extractFromI64ArrayAttr(getStaticLow()),
- extractFromI64ArrayAttr(getStaticHigh()));
+ auto expectedType =
+ PadOp::inferResultType(sourceType, getStaticLow(), getStaticHigh());
for (int i = 0, e = sourceType.getRank(); i < e; ++i) {
if (resultType.getDimSize(i) == expectedType.getDimSize(i))
continue;
@@ -2556,8 +2553,9 @@ void PadOp::build(OpBuilder &b, OperationState &result, Value source,
ArrayRef<NamedAttribute> attrs) {
auto sourceType = source.getType().cast<RankedTensorType>();
auto resultType = inferResultType(sourceType, staticLow, staticHigh);
- build(b, result, resultType, source, low, high, b.getI64ArrayAttr(staticLow),
- b.getI64ArrayAttr(staticHigh), nofold ? b.getUnitAttr() : UnitAttr());
+ build(b, result, resultType, source, low, high,
+ b.getDenseI64ArrayAttr(staticLow), b.getDenseI64ArrayAttr(staticHigh),
+ nofold ? b.getUnitAttr() : UnitAttr());
result.addAttributes(attrs);
}
@@ -2591,7 +2589,7 @@ void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
}
assert(resultType.isa<RankedTensorType>());
build(b, result, resultType, source, dynamicLow, dynamicHigh,
- b.getI64ArrayAttr(staticLow), b.getI64ArrayAttr(staticHigh),
+ b.getDenseI64ArrayAttr(staticLow), b.getDenseI64ArrayAttr(staticHigh),
nofold ? b.getUnitAttr() : UnitAttr());
result.addAttributes(attrs);
}
@@ -2658,8 +2656,7 @@ struct FoldSourceTensorCast : public OpRewritePattern<PadOp> {
auto newResultType = PadOp::inferResultType(
castOp.getSource().getType().cast<RankedTensorType>(),
- extractFromI64ArrayAttr(padTensorOp.getStaticLow()),
- extractFromI64ArrayAttr(padTensorOp.getStaticHigh()),
+ padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(),
padTensorOp.getResultType().getShape());
if (newResultType == padTensorOp.getResultType()) {
@@ -2940,8 +2937,9 @@ void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
ShapedType::kDynamic);
build(b, result, {}, source, dest, dynamicOffsets, dynamicSizes,
- dynamicStrides, b.getI64ArrayAttr(staticOffsets),
- b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides));
+ dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
+ b.getDenseI64ArrayAttr(staticSizes),
+ b.getDenseI64ArrayAttr(staticStrides));
result.addAttributes(attrs);
}
@@ -3086,12 +3084,12 @@ template <typename OpTy>
static SmallVector<OpFoldResult> getMixedTilesImpl(OpTy op) {
static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
"applies to only pack or unpack operations");
+ Builder builder(op);
SmallVector<OpFoldResult> mixedInnerTiles;
unsigned dynamicValIndex = 0;
- for (Attribute attr : op.getStaticInnerTiles()) {
- auto tileAttr = attr.cast<IntegerAttr>();
- if (!ShapedType::isDynamic(tileAttr.getInt()))
- mixedInnerTiles.push_back(tileAttr);
+ for (int64_t staticTile : op.getStaticInnerTiles()) {
+ if (!ShapedType::isDynamic(staticTile))
+ mixedInnerTiles.push_back(builder.getI64IntegerAttr(staticTile));
else
mixedInnerTiles.push_back(op.getInnerTiles()[dynamicValIndex++]);
}
diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
index 5694cfeb5130f..432e75618917c 100644
--- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
+++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
@@ -137,4 +137,41 @@ SmallVector<Value> getAsValues(OpBuilder &b, Location loc,
return getValueOrCreateConstantIndexOp(b, loc, value);
}));
}
+
+/// Return a vector of OpFoldResults with the same size a staticValues, but all
+/// elements for which ShapedType::isDynamic is true, will be replaced by
+/// dynamicValues.
+SmallVector<OpFoldResult> getMixedValues(ArrayRef<int64_t> staticValues,
+ ValueRange dynamicValues, Builder &b) {
+ SmallVector<OpFoldResult> res;
+ res.reserve(staticValues.size());
+ unsigned numDynamic = 0;
+ unsigned count = static_cast<unsigned>(staticValues.size());
+ for (unsigned idx = 0; idx < count; ++idx) {
+ int64_t value = staticValues[idx];
+ res.push_back(ShapedType::isDynamic(value)
+ ? OpFoldResult{dynamicValues[numDynamic++]}
+ : OpFoldResult{b.getI64IntegerAttr(staticValues[idx])});
+ }
+ return res;
+}
+
+/// Decompose a vector of mixed static or dynamic values into the corresponding
+/// pair of arrays. This is the inverse function of `getMixedValues`.
+std::pair<ArrayAttr, SmallVector<Value>>
+decomposeMixedValues(Builder &b,
+ const SmallVectorImpl<OpFoldResult> &mixedValues) {
+ SmallVector<int64_t> staticValues;
+ SmallVector<Value> dynamicValues;
+ for (const auto &it : mixedValues) {
+ if (it.is<Attribute>()) {
+ staticValues.push_back(it.get<Attribute>().cast<IntegerAttr>().getInt());
+ } else {
+ staticValues.push_back(ShapedType::kDynamic);
+ dynamicValues.push_back(it.get<Value>());
+ }
+ }
+ return {b.getI64ArrayAttr(staticValues), dynamicValues};
+}
+
} // namespace mlir
diff --git a/mlir/lib/Interfaces/ViewLikeInterface.cpp b/mlir/lib/Interfaces/ViewLikeInterface.cpp
index 775d26a6d1590..9a39a217bf442 100644
--- a/mlir/lib/Interfaces/ViewLikeInterface.cpp
+++ b/mlir/lib/Interfaces/ViewLikeInterface.cpp
@@ -20,15 +20,15 @@ using namespace mlir;
LogicalResult mlir::verifyListOfOperandsOrIntegers(Operation *op,
StringRef name,
unsigned numElements,
- ArrayAttr attr,
+ ArrayRef<int64_t> staticVals,
ValueRange values) {
- /// Check static and dynamic offsets/sizes/strides does not overflow type.
- if (attr.size() != numElements)
+ // Check static and dynamic offsets/sizes/strides does not overflow type.
+ if (staticVals.size() != numElements)
return op->emitError("expected ")
<< numElements << " " << name << " values";
unsigned expectedNumDynamicEntries =
- llvm::count_if(attr.getValue(), [&](Attribute attr) {
- return ShapedType::isDynamic(attr.cast<IntegerAttr>().getInt());
+ llvm::count_if(staticVals, [&](int64_t staticVal) {
+ return ShapedType::isDynamic(staticVal);
});
if (values.size() != expectedNumDynamicEntries)
return op->emitError("expected ")
@@ -70,19 +70,19 @@ mlir::detail::verifyOffsetSizeAndStrideOp(OffsetSizeAndStrideOpInterface op) {
}
void mlir::printDynamicIndexList(OpAsmPrinter &printer, Operation *op,
- OperandRange values, ArrayAttr integers) {
+ OperandRange values,
+ ArrayRef<int64_t> integers) {
printer << '[';
if (integers.empty()) {
printer << "]";
return;
}
unsigned idx = 0;
- llvm::interleaveComma(integers, printer, [&](Attribute a) {
- int64_t val = a.cast<IntegerAttr>().getInt();
- if (ShapedType::isDynamic(val))
+ llvm::interleaveComma(integers, printer, [&](int64_t integer) {
+ if (ShapedType::isDynamic(integer))
printer << values[idx++];
else
- printer << val;
+ printer << integer;
});
printer << ']';
}
@@ -90,28 +90,28 @@ void mlir::printDynamicIndexList(OpAsmPrinter &printer, Operation *op,
ParseResult mlir::parseDynamicIndexList(
OpAsmParser &parser,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
- ArrayAttr &integers) {
+ DenseI64ArrayAttr &integers) {
if (failed(parser.parseLSquare()))
return failure();
// 0-D.
if (succeeded(parser.parseOptionalRSquare())) {
- integers = parser.getBuilder().getArrayAttr({});
+ integers = parser.getBuilder().getDenseI64ArrayAttr({});
return success();
}
- SmallVector<int64_t, 4> attrVals;
+ SmallVector<int64_t, 4> integerVals;
while (true) {
OpAsmParser::UnresolvedOperand operand;
auto res = parser.parseOptionalOperand(operand);
if (res.has_value() && succeeded(res.value())) {
values.push_back(operand);
- attrVals.push_back(ShapedType::kDynamic);
+ integerVals.push_back(ShapedType::kDynamic);
} else {
- IntegerAttr attr;
- if (failed(parser.parseAttribute<IntegerAttr>(attr)))
+ int64_t integer;
+ if (failed(parser.parseInteger(integer)))
return parser.emitError(parser.getNameLoc())
<< "expected SSA value or integer";
- attrVals.push_back(attr.getInt());
+ integerVals.push_back(integer);
}
if (succeeded(parser.parseOptionalComma()))
@@ -120,7 +120,7 @@ ParseResult mlir::parseDynamicIndexList(
return failure();
break;
}
- integers = parser.getBuilder().getI64ArrayAttr(attrVals);
+ integers = parser.getBuilder().getDenseI64ArrayAttr(integerVals);
return success();
}
@@ -144,34 +144,3 @@ bool mlir::detail::sameOffsetsSizesAndStrides(
return false;
return true;
}
-
-SmallVector<OpFoldResult, 4> mlir::getMixedValues(ArrayAttr staticValues,
- ValueRange dynamicValues) {
- SmallVector<OpFoldResult, 4> res;
- res.reserve(staticValues.size());
- unsigned numDynamic = 0;
- unsigned count = static_cast<unsigned>(staticValues.size());
- for (unsigned idx = 0; idx < count; ++idx) {
- APInt value = staticValues[idx].cast<IntegerAttr>().getValue();
- res.push_back(ShapedType::isDynamic(value.getSExtValue())
- ? OpFoldResult{dynamicValues[numDynamic++]}
- : OpFoldResult{staticValues[idx]});
- }
- return res;
-}
-
-std::pair<ArrayAttr, SmallVector<Value>>
-mlir::decomposeMixedValues(Builder &b,
- const SmallVectorImpl<OpFoldResult> &mixedValues) {
- SmallVector<int64_t> staticValues;
- SmallVector<Value> dynamicValues;
- for (const auto &it : mixedValues) {
- if (it.is<Attribute>()) {
- staticValues.push_back(it.get<Attribute>().cast<IntegerAttr>().getInt());
- } else {
- staticValues.push_back(ShapedType::kDynamic);
- dynamicValues.push_back(it.get<Value>());
- }
- }
- return {b.getI64ArrayAttr(staticValues), dynamicValues};
-}
diff --git a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py
index 527a8656f7e33..5fd5cfe1073ad 100644
--- a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py
+++ b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py
@@ -49,6 +49,15 @@ def _get_int_array_attr(
return ArrayAttr.get([_get_int64_attr(v) for v in values])
+def _get_dense_int64_array_attr(
+ values: Sequence[int]) -> DenseI64ArrayAttr:
+ """Creates a dense integer array from a sequence of integers.
+ Expects the thread-local MLIR context to have been set by the context
+ manager.
+ """
+ if values is None:
+ return DenseI64ArrayAttr.get([])
+ return DenseI64ArrayAttr.get(values)
def _get_int_int_array_attr(
values: Optional[Union[ArrayAttr, Sequence[Union[ArrayAttr,
@@ -250,14 +259,11 @@ def __init__(self,
else:
for size in sizes:
if isinstance(size, int):
- static_sizes.append(IntegerAttr.get(i64_type, size))
- elif isinstance(size, IntegerAttr):
static_sizes.append(size)
else:
- static_sizes.append(
- IntegerAttr.get(i64_type, ShapedType.get_dynamic_size()))
+ static_sizes.append(ShapedType.get_dynamic_size())
dynamic_sizes.append(_get_op_result_or_value(size))
- sizes_attr = ArrayAttr.get(static_sizes)
+ sizes_attr = DenseI64ArrayAttr.get(static_sizes)
num_loops = sum(
v if v == 0 else 1 for v in self.__extract_values(sizes_attr))
@@ -266,14 +272,14 @@ def __init__(self,
_get_op_result_or_value(target),
dynamic_sizes=dynamic_sizes,
static_sizes=sizes_attr,
- interchange=_get_int_array_attr(interchange) if interchange else None,
+ interchange=_get_dense_int64_array_attr(interchange) if interchange else None,
loc=loc,
ip=ip)
- def __extract_values(self, attr: Optional[ArrayAttr]) -> List[int]:
+ def __extract_values(self, attr: Optional[DenseI64ArrayAttr]) -> List[int]:
if not attr:
return []
- return [IntegerAttr(element).value for element in attr]
+ return [element for element in attr]
class VectorizeOp:
diff --git a/mlir/test/Dialect/Linalg/transform-patterns.mlir b/mlir/test/Dialect/Linalg/transform-patterns.mlir
index 06c52f50e0fa2..482cbc786d485 100644
--- a/mlir/test/Dialect/Linalg/transform-patterns.mlir
+++ b/mlir/test/Dialect/Linalg/transform-patterns.mlir
@@ -138,7 +138,7 @@ func.func @permute_generic(%A: memref<?x?xf32, strided<[?, 1], offset: ?>>,
transform.sequence failures(propagate) {
^bb1(%arg1: !pdl.operation):
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1
- transform.structured.interchange %0 { iterator_interchange = [1, 2, 0]}
+ transform.structured.interchange %0 {iterator_interchange = [1, 2, 0]}
}
// CHECK-LABEL: func @permute_generic
@@ -191,8 +191,8 @@ func.func @matmul_perm(%A: memref<?x?xf32, strided<[?, 1], offset: ?>>,
transform.sequence failures(propagate) {
^bb0(%arg1: !pdl.operation):
%0 = transform.structured.match ops{["linalg.matmul"]} in %arg1
- %1, %loops:3 = transform.structured.tile %0 [2000, 3000, 4000] {interchange=[1, 2, 0]}
- %2, %loops_2:3 = transform.structured.tile %1 [200, 300, 400] {interchange=[1, 0, 2]}
+ %1, %loops:3 = transform.structured.tile %0 [2000, 3000, 4000] {interchange = [1, 2, 0]}
+ %2, %loops_2:3 = transform.structured.tile %1 [200, 300, 400] {interchange = [1, 0, 2]}
%3, %loops_3:3 = transform.structured.tile %2 [20, 30, 40]
}
diff --git a/mlir/test/python/dialects/transform_structured_ext.py b/mlir/test/python/dialects/transform_structured_ext.py
index f52c4b6d63b33..34c86a317920b 100644
--- a/mlir/test/python/dialects/transform_structured_ext.py
+++ b/mlir/test/python/dialects/transform_structured_ext.py
@@ -108,7 +108,6 @@ def testSplit():
# CHECK: %[[F:.+]], %[[S:.+]] = transform.structured.split %{{.*}} after 42 {dimension = 1
# CHECK: transform.structured.split %[[F]] after %[[S]] {dimension = 3
-
@run
def testTileCompact():
sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get())
@@ -120,14 +119,11 @@ def testTileCompact():
# CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile %{{.*}}[4, 8]
# CHECK: interchange = [0, 1]
-
@run
def testTileAttributes():
sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get())
- attr = ArrayAttr.get(
- [IntegerAttr.get(IntegerType.get_signless(64), x) for x in [4, 8]])
- ichange = ArrayAttr.get(
- [IntegerAttr.get(IntegerType.get_signless(64), x) for x in [0, 1]])
+ attr = DenseI64ArrayAttr.get([4, 8])
+ ichange = DenseI64ArrayAttr.get([0, 1])
with InsertionPoint(sequence.body):
structured.TileOp(sequence.bodyTarget, sizes=attr, interchange=ichange)
transform.YieldOp()
@@ -136,7 +132,6 @@ def testTileAttributes():
# CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile %{{.*}}[4, 8]
# CHECK: interchange = [0, 1]
-
@run
def testTileZero():
sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get())
@@ -149,7 +144,6 @@ def testTileZero():
# CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile %{{.*}}[4, 0, 2, 0]
# CHECK: interchange = [0, 1, 2, 3]
-
@run
def testTileDynamic():
with_pdl = transform.WithPDLPatternsOp(pdl.OperationType.get())
More information about the Mlir-commits
mailing list