[Mlir-commits] [mlir] [mlir]Fix compose subview (PR #80551)
Oleksandr Alex Zinenko
llvmlistbot at llvm.org
Tue Feb 6 11:58:50 PST 2024
================
@@ -51,67 +50,83 @@ struct ComposeSubViewOpPattern : public OpRewritePattern<memref::SubViewOp> {
}
// Offsets, sizes and strides OpFoldResult for the combined 'SubViewOp'.
- SmallVector<OpFoldResult> offsets, sizes, strides;
-
- // Because we only support input strides of 1, the output stride is also
- // always 1.
- if (llvm::all_of(strides, [](OpFoldResult &valueOrAttr) {
- Attribute attr = llvm::dyn_cast_if_present<Attribute>(valueOrAttr);
- return attr && cast<IntegerAttr>(attr).getInt() == 1;
- })) {
- strides = SmallVector<OpFoldResult>(sourceOp.getMixedStrides().size(),
- rewriter.getI64IntegerAttr(1));
- } else {
- return failure();
+ SmallVector<OpFoldResult> offsets, sizes, strides,
+ opStrides = op.getMixedStrides(),
+ sourceStrides = sourceOp.getMixedStrides();
+
+ // The output stride in each dimension is equal to the product of the
+ // dimensions corresponding to source and op.
+ for (auto &&[opStride, sourceStride] :
+ llvm::zip(opStrides, sourceStrides)) {
+ Attribute opStrideAttr = dyn_cast_if_present<Attribute>(opStride);
+ Attribute sourceStrideAttr = dyn_cast_if_present<Attribute>(sourceStride);
+ if (!opStrideAttr || !sourceStrideAttr)
+ return failure();
+ strides.push_back(rewriter.getI64IntegerAttr(
+ cast<IntegerAttr>(opStrideAttr).getInt() *
+ cast<IntegerAttr>(sourceStrideAttr).getInt()));
}
// The rules for calculating the new offsets and sizes are:
// * Multiple subview offsets for a given dimension compose additively.
- // ("Offset by m" followed by "Offset by n" == "Offset by m + n")
+ // ("Offset by m and Stride by k" followed by "Offset by n" == "Offset by
+ // m + n * k")
// * Multiple sizes for a given dimension compose by taking the size of the
// final subview and ignoring the rest. ("Take m values" followed by "Take
// n values" == "Take n values") This size must also be the smallest one
// by definition (a subview needs to be the same size as or smaller than
// its source along each dimension; presumably subviews that are larger
// than their sources are disallowed by validation).
- for (auto it : llvm::zip(op.getMixedOffsets(), sourceOp.getMixedOffsets(),
- op.getMixedSizes())) {
- auto opOffset = std::get<0>(it);
- auto sourceOffset = std::get<1>(it);
- auto opSize = std::get<2>(it);
-
+ for (auto &&[opOffset, sourceOffset, sourceStride, opSize] :
+ llvm::zip(op.getMixedOffsets(), sourceOp.getMixedOffsets(),
+ sourceOp.getMixedStrides(), op.getMixedSizes())) {
// We only support static sizes.
if (opSize.is<Value>()) {
return failure();
}
-
sizes.push_back(opSize);
Attribute opOffsetAttr = llvm::dyn_cast_if_present<Attribute>(opOffset),
sourceOffsetAttr =
- llvm::dyn_cast_if_present<Attribute>(sourceOffset);
-
+ llvm::dyn_cast_if_present<Attribute>(sourceOffset),
+ sourceStrideAttr =
+ llvm::dyn_cast_if_present<Attribute>(sourceStride);
if (opOffsetAttr && sourceOffsetAttr) {
+
// If both offsets are static we can simply calculate the combined
// offset statically.
offsets.push_back(rewriter.getI64IntegerAttr(
- cast<IntegerAttr>(opOffsetAttr).getInt() +
+ cast<IntegerAttr>(opOffsetAttr).getInt() *
+ cast<IntegerAttr>(sourceStrideAttr).getInt() +
cast<IntegerAttr>(sourceOffsetAttr).getInt()));
} else {
- // When either offset is dynamic, we must emit an additional affine
- // transformation to add the two offsets together dynamically.
- AffineExpr expr = rewriter.getAffineConstantExpr(0);
+ AffineExpr expr0 = rewriter.getAffineConstantExpr(0);
+ AffineExpr expr1 = rewriter.getAffineConstantExpr(0);
SmallVector<Value> affineApplyOperands;
- for (auto valueOrAttr : {opOffset, sourceOffset}) {
- if (auto attr = llvm::dyn_cast_if_present<Attribute>(valueOrAttr)) {
- expr = expr + cast<IntegerAttr>(attr).getInt();
- } else {
- expr =
- expr + rewriter.getAffineSymbolExpr(affineApplyOperands.size());
- affineApplyOperands.push_back(valueOrAttr.get<Value>());
- }
+
+ // Make 'expr0' add 'sourceOffset'.
+ if (auto attr = llvm::dyn_cast_if_present<Attribute>(sourceOffset)) {
+ expr0 = expr0 + cast<IntegerAttr>(attr).getInt();
+ } else {
+ expr0 =
+ expr0 + rewriter.getAffineSymbolExpr(affineApplyOperands.size());
+ affineApplyOperands.push_back(sourceOffset.get<Value>());
+ }
----------------
ftynse wrote:
Okay, so now we don't need `expr0 = expr0 + X` , we can just do `expr0 = X` in both branches. A better name than `expr0` is also beneficial ;)
https://github.com/llvm/llvm-project/pull/80551
More information about the Mlir-commits
mailing list