[Mlir-commits] [mlir] [mlir]Fix compose subview (PR #80551)
lonely eagle
llvmlistbot at llvm.org
Mon Feb 5 17:20:16 PST 2024
================
@@ -52,66 +52,81 @@ struct ComposeSubViewOpPattern : public OpRewritePattern<memref::SubViewOp> {
// Offsets, sizes and strides OpFoldResult for the combined 'SubViewOp'.
SmallVector<OpFoldResult> offsets, sizes, strides;
-
- // Because we only support input strides of 1, the output stride is also
- // always 1.
- if (llvm::all_of(strides, [](OpFoldResult &valueOrAttr) {
- Attribute attr = llvm::dyn_cast_if_present<Attribute>(valueOrAttr);
- return attr && cast<IntegerAttr>(attr).getInt() == 1;
- })) {
- strides = SmallVector<OpFoldResult>(sourceOp.getMixedStrides().size(),
- rewriter.getI64IntegerAttr(1));
- } else {
- return failure();
+ auto opStrides = op.getMixedStrides();
+ auto sourceStrides = sourceOp.getMixedStrides();
+
+ // The output stride in each dimension is equal to the product of the
+ // dimensions corresponding to source and op.
+ for (auto [opStride, sourceStride] : llvm::zip(opStrides, sourceStrides)) {
+ Attribute opStrideAttr = dyn_cast_if_present<Attribute>(opStride);
+ Attribute sourceStrideAttr = dyn_cast_if_present<Attribute>(sourceStride);
+ if (!opStrideAttr || !sourceStrideAttr)
+ return failure();
+ strides.push_back(rewriter.getI64IntegerAttr(
+ cast<IntegerAttr>(opStrideAttr).getInt() *
+ cast<IntegerAttr>(sourceStrideAttr).getInt()));
}
// The rules for calculating the new offsets and sizes are:
// * Multiple subview offsets for a given dimension compose additively.
- // ("Offset by m" followed by "Offset by n" == "Offset by m + n")
+ // ("Offset by m and Stride by k" followed by "Offset by n" == "Offset by
+ // m + n * k")
// * Multiple sizes for a given dimension compose by taking the size of the
// final subview and ignoring the rest. ("Take m values" followed by "Take
// n values" == "Take n values") This size must also be the smallest one
// by definition (a subview needs to be the same size as or smaller than
// its source along each dimension; presumably subviews that are larger
// than their sources are disallowed by validation).
- for (auto it : llvm::zip(op.getMixedOffsets(), sourceOp.getMixedOffsets(),
- op.getMixedSizes())) {
- auto opOffset = std::get<0>(it);
- auto sourceOffset = std::get<1>(it);
- auto opSize = std::get<2>(it);
-
+ for (auto [opOffset, sourceOffset, sourceStride, opSize] :
+ llvm::zip(op.getMixedOffsets(), sourceOp.getMixedOffsets(),
+ sourceOp.getMixedStrides(), op.getMixedSizes())) {
// We only support static sizes.
if (opSize.is<Value>()) {
return failure();
}
-
sizes.push_back(opSize);
Attribute opOffsetAttr = llvm::dyn_cast_if_present<Attribute>(opOffset),
sourceOffsetAttr =
- llvm::dyn_cast_if_present<Attribute>(sourceOffset);
-
+ llvm::dyn_cast_if_present<Attribute>(sourceOffset),
+ sourceStrideAttr =
+ llvm::dyn_cast_if_present<Attribute>(sourceStride);
if (opOffsetAttr && sourceOffsetAttr) {
+
// If both offsets are static we can simply calculate the combined
// offset statically.
offsets.push_back(rewriter.getI64IntegerAttr(
- cast<IntegerAttr>(opOffsetAttr).getInt() +
+ cast<IntegerAttr>(opOffsetAttr).getInt() *
+ cast<IntegerAttr>(sourceStrideAttr).getInt() +
cast<IntegerAttr>(sourceOffsetAttr).getInt()));
} else {
- // When either offset is dynamic, we must emit an additional affine
- // transformation to add the two offsets together dynamically.
- AffineExpr expr = rewriter.getAffineConstantExpr(0);
+ AffineExpr expr0 = rewriter.getAffineConstantExpr(0);
+ AffineExpr expr1 = rewriter.getAffineConstantExpr(0);
SmallVector<Value> affineApplyOperands;
- for (auto valueOrAttr : {opOffset, sourceOffset}) {
- if (auto attr = llvm::dyn_cast_if_present<Attribute>(valueOrAttr)) {
- expr = expr + cast<IntegerAttr>(attr).getInt();
+ SmallVector<OpFoldResult> opOffsets{sourceOffset, opOffset};
+ for (auto [idx, offset] : llvm::enumerate(opOffsets)) {
+ if (auto attr = llvm::dyn_cast_if_present<Attribute>(offset)) {
+ if (idx == 0) {
+ expr0 = expr0 + cast<IntegerAttr>(attr).getInt();
+ } else if (idx == 1) {
+ expr1 = expr1 + cast<IntegerAttr>(attr).getInt();
+ expr1 = expr1 * cast<IntegerAttr>(sourceStrideAttr).getInt();
+ expr0 = expr0 + expr1;
+ }
----------------
linuxlonelyeagle wrote:
The purpose of creating a loop that iterates twice is to iterate over soruceOffset and opOffset.Since it is possible that sourceOffset and opOffset are attributes or values, the first level of branching is created, and the purpose of the second level of branching is mainly to differentiate between processing sourceOffset and opOffset.For sourceOffset(idx=0),we just need to add it to expr0.For opOffset(idx=1),I also need to make it multiply by sourceStride.Formula for calculating the total offset in a dimension: offset = sourceOffset + opOffset * sourceStride.
https://github.com/llvm/llvm-project/pull/80551
More information about the Mlir-commits
mailing list